Commit
·
a4ce5f5
1
Parent(s):
83cef5d
save changes to modeling_siglip
Browse files- modeling_siglip.py +62 -12
modeling_siglip.py
CHANGED
|
@@ -39,7 +39,7 @@ from transformers.utils import (
|
|
| 39 |
logging,
|
| 40 |
replace_return_docstrings,
|
| 41 |
)
|
| 42 |
-
from
|
| 43 |
|
| 44 |
|
| 45 |
logger = logging.get_logger(__name__)
|
|
@@ -283,16 +283,45 @@ class SiglipVisionEmbeddings(nn.Module):
|
|
| 283 |
padding="valid",
|
| 284 |
)
|
| 285 |
|
| 286 |
-
self.
|
|
|
|
| 287 |
self.num_positions = self.num_patches
|
| 288 |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
| 289 |
-
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
| 290 |
|
| 291 |
-
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 292 |
-
|
| 293 |
-
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
| 294 |
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
return embeddings
|
| 297 |
|
| 298 |
|
|
@@ -675,7 +704,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
|
|
| 675 |
|
| 676 |
def _init_weights(self, module):
|
| 677 |
"""Initialize the weights"""
|
| 678 |
-
|
| 679 |
if isinstance(module, SiglipVisionEmbeddings):
|
| 680 |
width = (
|
| 681 |
self.config.vision_config.hidden_size
|
|
@@ -1055,6 +1084,7 @@ class SiglipVisionTransformer(nn.Module):
|
|
| 1055 |
def forward(
|
| 1056 |
self,
|
| 1057 |
pixel_values,
|
|
|
|
| 1058 |
output_attentions: Optional[bool] = None,
|
| 1059 |
output_hidden_states: Optional[bool] = None,
|
| 1060 |
return_dict: Optional[bool] = None,
|
|
@@ -1069,10 +1099,22 @@ class SiglipVisionTransformer(nn.Module):
|
|
| 1069 |
)
|
| 1070 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1071 |
|
| 1072 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1073 |
|
| 1074 |
encoder_outputs = self.encoder(
|
| 1075 |
inputs_embeds=hidden_states,
|
|
|
|
| 1076 |
output_attentions=output_attentions,
|
| 1077 |
output_hidden_states=output_hidden_states,
|
| 1078 |
return_dict=return_dict,
|
|
@@ -1081,7 +1123,10 @@ class SiglipVisionTransformer(nn.Module):
|
|
| 1081 |
last_hidden_state = encoder_outputs[0]
|
| 1082 |
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 1083 |
|
| 1084 |
-
pooled_output = self.head(
|
|
|
|
|
|
|
|
|
|
| 1085 |
|
| 1086 |
if not return_dict:
|
| 1087 |
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
@@ -1105,11 +1150,16 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
|
|
| 1105 |
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1106 |
self.mlp = SiglipMLP(config)
|
| 1107 |
|
| 1108 |
-
def forward(self, hidden_state):
|
| 1109 |
batch_size = hidden_state.shape[0]
|
| 1110 |
probe = self.probe.repeat(batch_size, 1, 1)
|
| 1111 |
|
| 1112 |
-
hidden_state = self.attention(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1113 |
|
| 1114 |
residual = hidden_state
|
| 1115 |
hidden_state = self.layernorm(hidden_state)
|
|
|
|
| 39 |
logging,
|
| 40 |
replace_return_docstrings,
|
| 41 |
)
|
| 42 |
+
from configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
|
| 43 |
|
| 44 |
|
| 45 |
logger = logging.get_logger(__name__)
|
|
|
|
| 283 |
padding="valid",
|
| 284 |
)
|
| 285 |
|
| 286 |
+
self.num_patches_per_side = self.image_size // self.patch_size
|
| 287 |
+
self.num_patches = self.num_patches_per_side ** 2
|
| 288 |
self.num_positions = self.num_patches
|
| 289 |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
|
|
| 290 |
|
| 291 |
+
def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor:
|
| 292 |
+
batch_size = pixel_values.size(0)
|
|
|
|
| 293 |
|
| 294 |
+
patch_embeds = self.patch_embedding(pixel_values)
|
| 295 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 296 |
+
|
| 297 |
+
patches_to_select = patch_attention_mask.view(batch_size, -1)
|
| 298 |
+
max_num_patches = patches_to_select.sum(dim=-1).max()
|
| 299 |
+
embeddings = torch.zeros((batch_size, max_num_patches, patch_embeds.size(2)), device=patch_embeds.device, dtype=patch_embeds.dtype)
|
| 300 |
+
for b_idx, (p_embeds, p_to_select) in enumerate(zip(patch_embeds, patches_to_select)):
|
| 301 |
+
sub_p_embds = p_embeds[p_to_select]
|
| 302 |
+
embeddings[b_idx][:len(sub_p_embds)] = sub_p_embds
|
| 303 |
+
|
| 304 |
+
boundaries = torch.arange(1/self.num_patches_per_side, 1., 1/self.num_patches_per_side)
|
| 305 |
+
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
| 306 |
+
max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
|
| 307 |
+
position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
|
| 308 |
+
|
| 309 |
+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
| 310 |
+
nb_patches_h = p_attn_mask[0].sum()
|
| 311 |
+
nb_patches_w = p_attn_mask[:, 0].sum()
|
| 312 |
+
|
| 313 |
+
fractional_coords_h = torch.arange(0, 1, 1/nb_patches_h)
|
| 314 |
+
fractional_coords_w = torch.arange(0, 1, 1/nb_patches_w)
|
| 315 |
+
|
| 316 |
+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
| 317 |
+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
| 318 |
+
|
| 319 |
+
pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
|
| 320 |
+
position_ids[batch_idx][:len(pos_ids)] = pos_ids
|
| 321 |
+
|
| 322 |
+
position_ids = position_ids.to(self.position_embedding.weight.device)
|
| 323 |
+
|
| 324 |
+
embeddings = embeddings + self.position_embedding(position_ids)
|
| 325 |
return embeddings
|
| 326 |
|
| 327 |
|
|
|
|
| 704 |
|
| 705 |
def _init_weights(self, module):
|
| 706 |
"""Initialize the weights"""
|
| 707 |
+
|
| 708 |
if isinstance(module, SiglipVisionEmbeddings):
|
| 709 |
width = (
|
| 710 |
self.config.vision_config.hidden_size
|
|
|
|
| 1084 |
def forward(
|
| 1085 |
self,
|
| 1086 |
pixel_values,
|
| 1087 |
+
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
| 1088 |
output_attentions: Optional[bool] = None,
|
| 1089 |
output_hidden_states: Optional[bool] = None,
|
| 1090 |
return_dict: Optional[bool] = None,
|
|
|
|
| 1099 |
)
|
| 1100 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1101 |
|
| 1102 |
+
if pixel_attention_mask is None:
|
| 1103 |
+
#TODO
|
| 1104 |
+
pass
|
| 1105 |
+
|
| 1106 |
+
batch_size = pixel_attention_mask.size(0) # assuming `pixel_attention_mask` is of size bs x h x w
|
| 1107 |
+
subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
|
| 1108 |
+
patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
|
| 1109 |
+
|
| 1110 |
+
hidden_states = self.embeddings(
|
| 1111 |
+
pixel_values=pixel_values,
|
| 1112 |
+
patch_attention_mask=patch_attention_mask
|
| 1113 |
+
)
|
| 1114 |
|
| 1115 |
encoder_outputs = self.encoder(
|
| 1116 |
inputs_embeds=hidden_states,
|
| 1117 |
+
attention_mask=patch_attention_mask.view(batch_size, -1),
|
| 1118 |
output_attentions=output_attentions,
|
| 1119 |
output_hidden_states=output_hidden_states,
|
| 1120 |
return_dict=return_dict,
|
|
|
|
| 1123 |
last_hidden_state = encoder_outputs[0]
|
| 1124 |
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 1125 |
|
| 1126 |
+
pooled_output = self.head(
|
| 1127 |
+
hidden_state=last_hidden_state,
|
| 1128 |
+
attention_mask=patch_attention_mask.view(batch_size, -1)
|
| 1129 |
+
)
|
| 1130 |
|
| 1131 |
if not return_dict:
|
| 1132 |
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
|
|
| 1150 |
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1151 |
self.mlp = SiglipMLP(config)
|
| 1152 |
|
| 1153 |
+
def forward(self, hidden_state, attention_mask):
|
| 1154 |
batch_size = hidden_state.shape[0]
|
| 1155 |
probe = self.probe.repeat(batch_size, 1, 1)
|
| 1156 |
|
| 1157 |
+
hidden_state = self.attention(
|
| 1158 |
+
query=probe,
|
| 1159 |
+
key=hidden_state,
|
| 1160 |
+
value=hidden_state,
|
| 1161 |
+
key_padding_mask=~attention_mask
|
| 1162 |
+
)[0]
|
| 1163 |
|
| 1164 |
residual = hidden_state
|
| 1165 |
hidden_state = self.layernorm(hidden_state)
|