|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""MaMMUT configuration.""" |
|
|
|
|
|
from transformers import (CLIPConfig, CLIPTextConfig, CLIPVisionConfig, PretrainedConfig, AutoConfig) |
|
from typing import Callable, List, Optional, Sequence, Tuple, Union |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
class MultimodalConfig(PretrainedConfig): |
|
|
|
model_type = "mammut_text_model" |
|
|
|
def __init__( |
|
self, |
|
mlp_ratio: int = 4, |
|
dim_head: int = 64, |
|
heads: int = 8, |
|
n_queries: int = 256, |
|
attn_pooler_heads: int = 8, |
|
cross_attn_ratio: int = 1, |
|
does_full_decoding: bool = False, |
|
output_tokens: bool = False, |
|
has_mlp: bool = True, |
|
context_length: int = 77, |
|
vocab_size: int = 49408, |
|
hidden_size: int = 1024, |
|
layers: int = 12, |
|
batch_first: bool = True, |
|
**kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] |
|
): |
|
super().__init__() |
|
self.mlp_ratio = mlp_ratio |
|
self.dim_head = dim_head |
|
self.heads = heads |
|
self.n_queries = n_queries |
|
self.attn_pooler_heads = attn_pooler_heads |
|
self.cross_attn_ratio = cross_attn_ratio |
|
self.does_full_decoding = does_full_decoding |
|
self.output_tokens = output_tokens |
|
self.has_mlp = has_mlp |
|
self.context_length = context_length |
|
self.vocab_size = vocab_size |
|
self.width = hidden_size |
|
self.layers = layers |
|
self.batch_first = batch_first |
|
for key, value in kwargs.items(): |
|
setattr(self, key, value) |
|
|
|
|
|
|
|
class MammutTextConfig(MultimodalConfig,CLIPTextConfig): |
|
model_type = "mammut_text_model" |
|
base_config_key = "text_config" |
|
|
|
def __init__( |
|
self, |
|
mlp_ratio: int = 4, |
|
num_attention_heads: int = 8, |
|
n_queries: int = 256, |
|
attn_pooler_heads: int = 8, |
|
cross_attn_ratio: int = 1, |
|
does_full_decoding: bool = False, |
|
output_tokens: bool = False, |
|
has_mlp: bool = True, |
|
max_position_embeddings: int = 77, |
|
vocab_size: int = 49408, |
|
num_hidden_layers: int = 12, |
|
hidden_size: int = 1024, |
|
attention_dropout: float = 0.0, |
|
hidden_act: str = "gelu", |
|
layer_norm_eps: float = 1e-5, |
|
intermediate_size: Optional[int] = None, |
|
initializer_factor: float = 0.02, |
|
logit_scale_init_value: float = 2.6592, |
|
**kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] |
|
): |
|
super().__init__( |
|
mlp_ratio=mlp_ratio, |
|
num_attention_heads=num_attention_heads, |
|
n_queries=n_queries, |
|
attn_pooler_heads=attn_pooler_heads, |
|
cross_attn_ratio=cross_attn_ratio, |
|
does_full_decoding=does_full_decoding, |
|
output_tokens=output_tokens, |
|
has_mlp=has_mlp, |
|
vocab_size=vocab_size, |
|
hidden_size=hidden_size, |
|
num_hidden_layers=num_hidden_layers, |
|
attention_dropout=attention_dropout, |
|
logit_scale_init_value=logit_scale_init_value, |
|
max_position_embeddings=max_position_embeddings, |
|
layer_norm_eps=layer_norm_eps, |
|
intermediate_size=intermediate_size, |
|
initializer_factor=initializer_factor, |
|
hidden_act=hidden_act, |
|
**kwargs |
|
) |
|
|
|
|
|
self.logit_scale_init_value = logit_scale_init_value |
|
self.does_full_decoding = does_full_decoding |
|
self.output_tokens = output_tokens |
|
self.architectures = ["MammutTextModel"] |
|
self.hidden_size = hidden_size |
|
self.num_attention_heads = num_attention_heads |
|
|
|
class MammutVisionConfig(CLIPVisionConfig): |
|
model_type = "mammut_vision_model" |
|
base_config_key = "vision_config" |
|
|
|
def __init__( |
|
self, |
|
mlp_ratio: int = 4, |
|
dim_head: int = 64, |
|
num_attention_heads: int = 8, |
|
n_queries: int = 256, |
|
attn_pooler_heads: int = 8, |
|
cross_attn_ratio: int = 1, |
|
does_full_decoding: bool = False, |
|
output_tokens: bool = False, |
|
has_mlp: bool = True, |
|
image_size: int = 224, |
|
patch_size: int = 16, |
|
width: int = 1024, |
|
layers: int = 12, |
|
**kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] |
|
): |
|
super().__init__( |
|
mlp_ratio=mlp_ratio, |
|
dim_head=dim_head, |
|
num_attention_heads=num_attention_heads, |
|
n_queries=n_queries, |
|
attn_pooler_heads=attn_pooler_heads, |
|
cross_attn_ratio=cross_attn_ratio, |
|
does_full_decoding=does_full_decoding, |
|
output_tokens=output_tokens, |
|
has_mlp=has_mlp, |
|
image_size=image_size, |
|
patch_size=patch_size, |
|
width=width, |
|
layers=layers, |
|
**kwargs |
|
) |
|
|
|
self.num_attention_heads = num_attention_heads |
|
|
|
class MammutConfig(CLIPConfig): |
|
model_type = "mammut" |
|
|
|
def __init__( |
|
self, |
|
mlp_ratio: int = 4, |
|
dim_head: int = 64, |
|
num_attention_heads: int = 8, |
|
n_queries: int = 256, |
|
attn_pooler_heads: int = 8, |
|
cross_attn_ratio: int = 1, |
|
does_full_decoding: bool = False, |
|
output_tokens: bool = False, |
|
has_mlp: bool = True, |
|
text_config: Optional[MammutTextConfig] = None, |
|
vision_config: Optional[MammutVisionConfig] = None, |
|
projection_dim: int = 768, |
|
logit_scale_init_value: float = 2.6592, |
|
**kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]] |
|
): |
|
kwargs["architectures"] = ["MammutModel"] |
|
super().__init__( |
|
mlp_ratio=mlp_ratio, |
|
dim_head=dim_head, |
|
num_attention_heads=num_attention_heads, |
|
n_queries=n_queries, |
|
attn_pooler_heads=attn_pooler_heads, |
|
cross_attn_ratio=cross_attn_ratio, |
|
does_full_decoding=does_full_decoding, |
|
output_tokens=output_tokens, |
|
has_mlp=has_mlp, |
|
**kwargs |
|
) |
|
self.text_config = MammutTextConfig(**text_config) if text_config is not None else MammutTextConfig() |
|
self.vision_config = MammutVisionConfig(**vision_config) if vision_config is not None else MammutVisionConfig() |
|
self.text_config.architectures = ["MammutTextModel"] |
|
self.vision_config.architectures = ["MammutVisionModel"] |
|
self.projection_dim = projection_dim |
|
self.hidden_size = self.text_config.hidden_size |
|
self.logit_scale_init_value = logit_scale_init_value |
|
self.architectures = ["MammutModel"] |
|
|
|
self.does_full_decoding = does_full_decoding |
|
self.output_tokens = output_tokens |
|
|
|
def _post_init(self): |
|
if self.logit_scale_init_value is not None: |
|
setattr(self.text_config, "logit_scale_init_value", self.logit_scale_init_value) |
|
|
|
super()._post_init() |
|
|
|
|
|
AutoConfig.register("mammut", MammutConfig) |