fix issues with erf and xavier init
Browse files- modeling_siglip.py +15 -9
modeling_siglip.py
CHANGED
|
@@ -95,7 +95,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
|
|
| 95 |
|
| 96 |
# Use inverse cdf transform for normal distribution to get truncated
|
| 97 |
# standard normal
|
| 98 |
-
tensor.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# Transform to proper mean, std
|
| 101 |
tensor.mul_(std * math.sqrt(2.0))
|
|
@@ -670,6 +675,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
|
| 670 |
|
| 671 |
def _init_weights(self, module):
|
| 672 |
"""Initialize the weights"""
|
|
|
|
| 673 |
if isinstance(module, SiglipVisionEmbeddings):
|
| 674 |
width = (
|
| 675 |
self.config.vision_config.hidden_size
|
|
@@ -680,22 +686,22 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
|
| 680 |
elif isinstance(module, nn.Embedding):
|
| 681 |
default_flax_embed_init(module.weight)
|
| 682 |
elif isinstance(module, SiglipAttention):
|
| 683 |
-
nn.init.
|
| 684 |
-
nn.init.
|
| 685 |
-
nn.init.
|
| 686 |
-
nn.init.
|
| 687 |
nn.init.zeros_(module.q_proj.bias)
|
| 688 |
nn.init.zeros_(module.k_proj.bias)
|
| 689 |
nn.init.zeros_(module.v_proj.bias)
|
| 690 |
nn.init.zeros_(module.out_proj.bias)
|
| 691 |
elif isinstance(module, SiglipMLP):
|
| 692 |
-
nn.init.
|
| 693 |
-
nn.init.
|
| 694 |
nn.init.normal_(module.fc1.bias, std=1e-6)
|
| 695 |
nn.init.normal_(module.fc2.bias, std=1e-6)
|
| 696 |
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
|
| 697 |
-
nn.init.
|
| 698 |
-
nn.init.
|
| 699 |
nn.init.zeros_(module.attention.in_proj_bias.data)
|
| 700 |
elif isinstance(module, SiglipModel):
|
| 701 |
logit_scale_init = torch.log(torch.tensor(1.0))
|
|
|
|
| 95 |
|
| 96 |
# Use inverse cdf transform for normal distribution to get truncated
|
| 97 |
# standard normal
|
| 98 |
+
if tensor.dtype == torch.bfloat16:
|
| 99 |
+
tensor = tensor.to(torch.float32)
|
| 100 |
+
tensor.erfinv_()
|
| 101 |
+
tensor = tensor.to(torch.bfloat16)
|
| 102 |
+
else:
|
| 103 |
+
tensor.erfinv_()
|
| 104 |
|
| 105 |
# Transform to proper mean, std
|
| 106 |
tensor.mul_(std * math.sqrt(2.0))
|
|
|
|
| 675 |
|
| 676 |
def _init_weights(self, module):
|
| 677 |
"""Initialize the weights"""
|
| 678 |
+
|
| 679 |
if isinstance(module, SiglipVisionEmbeddings):
|
| 680 |
width = (
|
| 681 |
self.config.vision_config.hidden_size
|
|
|
|
| 686 |
elif isinstance(module, nn.Embedding):
|
| 687 |
default_flax_embed_init(module.weight)
|
| 688 |
elif isinstance(module, SiglipAttention):
|
| 689 |
+
nn.init.normal_(module.q_proj.weight)
|
| 690 |
+
nn.init.normal_(module.k_proj.weight)
|
| 691 |
+
nn.init.normal_(module.v_proj.weight)
|
| 692 |
+
nn.init.normal_(module.out_proj.weight)
|
| 693 |
nn.init.zeros_(module.q_proj.bias)
|
| 694 |
nn.init.zeros_(module.k_proj.bias)
|
| 695 |
nn.init.zeros_(module.v_proj.bias)
|
| 696 |
nn.init.zeros_(module.out_proj.bias)
|
| 697 |
elif isinstance(module, SiglipMLP):
|
| 698 |
+
nn.init.normal_(module.fc1.weight)
|
| 699 |
+
nn.init.normal_(module.fc2.weight)
|
| 700 |
nn.init.normal_(module.fc1.bias, std=1e-6)
|
| 701 |
nn.init.normal_(module.fc2.bias, std=1e-6)
|
| 702 |
elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
|
| 703 |
+
nn.init.normal_(module.probe.data)
|
| 704 |
+
nn.init.normal_(module.attention.in_proj_weight.data)
|
| 705 |
nn.init.zeros_(module.attention.in_proj_bias.data)
|
| 706 |
elif isinstance(module, SiglipModel):
|
| 707 |
logit_scale_init = torch.log(torch.tensor(1.0))
|