Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig | |
class CLIPVisionPerPatchModel(CLIPVisionModelWithProjection): | |
""" | |
Like CLIPVisionModelWithProjection but returns | |
per-patch embeddings instead of pooled CLS tokens. | |
""" | |
def __init__(self, config: CLIPVisionConfig): | |
super().__init__(config) | |
# everything else (self.vision_model, self.visual_projection) | |
# is set up for you by the parent class | |
def forward(self, pixel_values, **kwargs): | |
# 1) run the ViT backbone β last_hidden_state [B, n_patches, hidden_size] | |
outputs = self.vision_model(pixel_values, return_dict=True, **kwargs) | |
hidden_states = outputs.last_hidden_state | |
# 2) project every patch token β [B, n_patches, projection_dim] | |
patch_embeds = self.visual_projection(hidden_states) | |
# 3) Postprocessing embeds | |
patch_embeds = torch.nn.functional.normalize(patch_embeds, dim=-1) | |
patch_embeds = patch_embeds.squeeze() # (Patches, proj_dim) | |
return patch_embeds |