Commit
·
545fbb4
1
Parent(s):
a4ce5f5
working version
Browse files- modeling_siglip.py +15 -20
modeling_siglip.py
CHANGED
|
@@ -292,32 +292,25 @@ class SiglipVisionEmbeddings(nn.Module):
|
|
| 292 |
batch_size = pixel_values.size(0)
|
| 293 |
|
| 294 |
patch_embeds = self.patch_embedding(pixel_values)
|
| 295 |
-
|
| 296 |
|
| 297 |
-
patches_to_select = patch_attention_mask.view(batch_size, -1)
|
| 298 |
-
max_num_patches = patches_to_select.sum(dim=-1).max()
|
| 299 |
-
embeddings = torch.zeros((batch_size, max_num_patches, patch_embeds.size(2)), device=patch_embeds.device, dtype=patch_embeds.dtype)
|
| 300 |
-
for b_idx, (p_embeds, p_to_select) in enumerate(zip(patch_embeds, patches_to_select)):
|
| 301 |
-
sub_p_embds = p_embeds[p_to_select]
|
| 302 |
-
embeddings[b_idx][:len(sub_p_embds)] = sub_p_embds
|
| 303 |
-
|
| 304 |
-
boundaries = torch.arange(1/self.num_patches_per_side, 1., 1/self.num_patches_per_side)
|
| 305 |
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
| 306 |
max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
|
| 307 |
-
|
|
|
|
| 308 |
|
| 309 |
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
| 310 |
-
nb_patches_h = p_attn_mask[0].sum()
|
| 311 |
-
nb_patches_w = p_attn_mask[
|
| 312 |
|
| 313 |
-
fractional_coords_h = torch.arange(0, 1, 1/nb_patches_h)
|
| 314 |
-
fractional_coords_w = torch.arange(0, 1, 1/nb_patches_w)
|
| 315 |
|
| 316 |
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
| 317 |
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
| 318 |
|
| 319 |
pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
|
| 320 |
-
position_ids[batch_idx][
|
| 321 |
|
| 322 |
position_ids = position_ids.to(self.position_embedding.weight.device)
|
| 323 |
|
|
@@ -1099,11 +1092,11 @@ class SiglipVisionTransformer(nn.Module):
|
|
| 1099 |
)
|
| 1100 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1101 |
|
|
|
|
| 1102 |
if pixel_attention_mask is None:
|
| 1103 |
-
#
|
| 1104 |
-
|
| 1105 |
|
| 1106 |
-
batch_size = pixel_attention_mask.size(0) # assuming `pixel_attention_mask` is of size bs x h x w
|
| 1107 |
subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
|
| 1108 |
patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
|
| 1109 |
|
|
@@ -1112,9 +1105,11 @@ class SiglipVisionTransformer(nn.Module):
|
|
| 1112 |
patch_attention_mask=patch_attention_mask
|
| 1113 |
)
|
| 1114 |
|
|
|
|
|
|
|
| 1115 |
encoder_outputs = self.encoder(
|
| 1116 |
inputs_embeds=hidden_states,
|
| 1117 |
-
attention_mask=patch_attention_mask
|
| 1118 |
output_attentions=output_attentions,
|
| 1119 |
output_hidden_states=output_hidden_states,
|
| 1120 |
return_dict=return_dict,
|
|
@@ -1125,7 +1120,7 @@ class SiglipVisionTransformer(nn.Module):
|
|
| 1125 |
|
| 1126 |
pooled_output = self.head(
|
| 1127 |
hidden_state=last_hidden_state,
|
| 1128 |
-
attention_mask=patch_attention_mask
|
| 1129 |
)
|
| 1130 |
|
| 1131 |
if not return_dict:
|
|
|
|
| 292 |
batch_size = pixel_values.size(0)
|
| 293 |
|
| 294 |
patch_embeds = self.patch_embedding(pixel_values)
|
| 295 |
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
| 298 |
max_nb_patches_h, max_nb_patches_w = max_im_h//self.patch_size, max_im_w//self.patch_size
|
| 299 |
+
boundaries = torch.arange(1/self.num_patches_per_side, 1., 1/self.num_patches_per_side)
|
| 300 |
+
position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w,), fill_value=0)
|
| 301 |
|
| 302 |
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
| 303 |
+
nb_patches_h = p_attn_mask[:, 0].sum()
|
| 304 |
+
nb_patches_w = p_attn_mask[0].sum()
|
| 305 |
|
| 306 |
+
fractional_coords_h = torch.arange(0, 1-1e-6, 1/nb_patches_h)
|
| 307 |
+
fractional_coords_w = torch.arange(0, 1-1e-6, 1/nb_patches_w)
|
| 308 |
|
| 309 |
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
| 310 |
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
| 311 |
|
| 312 |
pos_ids = (self.num_patches_per_side * bucket_coords_w[:, None] + bucket_coords_h[None, :]).flatten()
|
| 313 |
+
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
| 314 |
|
| 315 |
position_ids = position_ids.to(self.position_embedding.weight.device)
|
| 316 |
|
|
|
|
| 1092 |
)
|
| 1093 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1094 |
|
| 1095 |
+
batch_size = pixel_values.size(0)
|
| 1096 |
if pixel_attention_mask is None:
|
| 1097 |
+
# assuming `pixel_attention_mask` is of size bs x h x w
|
| 1098 |
+
pixel_attention_mask = torch.ones(size=(batch_size, pixel_values.size(2), pixel_values.size(3)), dtype=torch.bool, device=pixel_values.device)
|
| 1099 |
|
|
|
|
| 1100 |
subgrids = pixel_attention_mask.unfold(dimension=1, size=self.config.patch_size, step=self.config.patch_size).unfold(dimension=2, size=self.config.patch_size, step=self.config.patch_size)
|
| 1101 |
patch_attention_mask = (subgrids.sum(dim=(-1, -2)) > 0).bool()
|
| 1102 |
|
|
|
|
| 1105 |
patch_attention_mask=patch_attention_mask
|
| 1106 |
)
|
| 1107 |
|
| 1108 |
+
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
| 1109 |
+
|
| 1110 |
encoder_outputs = self.encoder(
|
| 1111 |
inputs_embeds=hidden_states,
|
| 1112 |
+
attention_mask=_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) if not self.config._flash_attn_2_enabled else patch_attention_mask,
|
| 1113 |
output_attentions=output_attentions,
|
| 1114 |
output_hidden_states=output_hidden_states,
|
| 1115 |
return_dict=return_dict,
|
|
|
|
| 1120 |
|
| 1121 |
pooled_output = self.head(
|
| 1122 |
hidden_state=last_hidden_state,
|
| 1123 |
+
attention_mask=patch_attention_mask,
|
| 1124 |
)
|
| 1125 |
|
| 1126 |
if not return_dict:
|