| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from . import eva_vit | |
| from .transformer import text_transformer | |
| class CLIP(nn.Module): | |
| def __init__( | |
| self, | |
| vision_model: str = 'eva_base_p16', | |
| ): | |
| super().__init__() | |
| self.visual = eva_vit.__dict__[vision_model]() | |
| self.text = text_transformer() | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
| def encode_image(self, image, normalize: bool = False): | |
| features = self.visual(image) | |
| return F.normalize(features, dim=-1) if normalize else features | |
| def encode_text(self, text, normalize: bool = False): | |
| features = self.text(text) | |
| return F.normalize(features, dim=-1) if normalize else features | |
| def forward(self, image, text): | |
| image_features = self.encode_image(image, normalize=True) | |
| text_features = self.encode_text(text, normalize=True) | |
| return image_features, text_features, self.logit_scale.exp() |