Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,077 Bytes
dd3c1c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig
class CLIPVisionPerPatchModel(CLIPVisionModelWithProjection):
"""
Like CLIPVisionModelWithProjection but returns
per-patch embeddings instead of pooled CLS tokens.
"""
def __init__(self, config: CLIPVisionConfig):
super().__init__(config)
# everything else (self.vision_model, self.visual_projection)
# is set up for you by the parent class
def forward(self, pixel_values, **kwargs):
# 1) run the ViT backbone → last_hidden_state [B, n_patches, hidden_size]
outputs = self.vision_model(pixel_values, return_dict=True, **kwargs)
hidden_states = outputs.last_hidden_state
# 2) project every patch token → [B, n_patches, projection_dim]
patch_embeds = self.visual_projection(hidden_states)
# 3) Postprocessing embeds
patch_embeds = torch.nn.functional.normalize(patch_embeds, dim=-1)
patch_embeds = patch_embeds.squeeze() # (Patches, proj_dim)
return patch_embeds |