marianna13's picture
add HF support
3af3aa0
# coding=utf-8
# Copyright 2024 Google AI, LAION team. team. All rights reserved.
#
# This code is based on open_clip framework. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to the original MaMMUT model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MaMMUT configuration."""
from transformers import (CLIPConfig, CLIPTextConfig, CLIPVisionConfig, PretrainedConfig, AutoConfig)
from typing import Callable, List, Optional, Sequence, Tuple, Union
from transformers.utils import logging
logger = logging.get_logger(__name__)
class MultimodalConfig(PretrainedConfig):
model_type = "mammut_text_model"
def __init__(
self,
mlp_ratio: int = 4,
dim_head: int = 64,
heads: int = 8,
n_queries: int = 256,
attn_pooler_heads: int = 8,
cross_attn_ratio: int = 1,
does_full_decoding: bool = False,
output_tokens: bool = False,
has_mlp: bool = True,
context_length: int = 77,
vocab_size: int = 49408,
hidden_size: int = 1024,
layers: int = 12,
batch_first: bool = True,
**kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]]
):
super().__init__()
self.mlp_ratio = mlp_ratio
self.dim_head = dim_head
self.heads = heads
self.n_queries = n_queries
self.attn_pooler_heads = attn_pooler_heads
self.cross_attn_ratio = cross_attn_ratio
self.does_full_decoding = does_full_decoding
self.output_tokens = output_tokens
self.has_mlp = has_mlp
self.context_length = context_length
self.vocab_size = vocab_size
self.width = hidden_size
self.layers = layers
self.batch_first = batch_first
for key, value in kwargs.items():
setattr(self, key, value)
class MammutTextConfig(MultimodalConfig,CLIPTextConfig):
model_type = "mammut_text_model"
base_config_key = "text_config"
def __init__(
self,
mlp_ratio: int = 4,
num_attention_heads: int = 8,
n_queries: int = 256,
attn_pooler_heads: int = 8,
cross_attn_ratio: int = 1,
does_full_decoding: bool = False,
output_tokens: bool = False,
has_mlp: bool = True,
max_position_embeddings: int = 77,
vocab_size: int = 49408,
num_hidden_layers: int = 12,
hidden_size: int = 1024,
attention_dropout: float = 0.0,
hidden_act: str = "gelu",
layer_norm_eps: float = 1e-5,
intermediate_size: Optional[int] = None,
initializer_factor: float = 0.02,
logit_scale_init_value: float = 2.6592,
**kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]]
):
super().__init__(
mlp_ratio=mlp_ratio,
num_attention_heads=num_attention_heads,
n_queries=n_queries,
attn_pooler_heads=attn_pooler_heads,
cross_attn_ratio=cross_attn_ratio,
does_full_decoding=does_full_decoding,
output_tokens=output_tokens,
has_mlp=has_mlp,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
attention_dropout=attention_dropout,
logit_scale_init_value=logit_scale_init_value,
max_position_embeddings=max_position_embeddings,
layer_norm_eps=layer_norm_eps,
intermediate_size=intermediate_size,
initializer_factor=initializer_factor,
hidden_act=hidden_act,
**kwargs
)
self.logit_scale_init_value = logit_scale_init_value
self.does_full_decoding = does_full_decoding
self.output_tokens = output_tokens
self.architectures = ["MammutTextModel"]
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
class MammutVisionConfig(CLIPVisionConfig):
model_type = "mammut_vision_model"
base_config_key = "vision_config"
def __init__(
self,
mlp_ratio: int = 4,
dim_head: int = 64,
num_attention_heads: int = 8,
n_queries: int = 256,
attn_pooler_heads: int = 8,
cross_attn_ratio: int = 1,
does_full_decoding: bool = False,
output_tokens: bool = False,
has_mlp: bool = True,
image_size: int = 224,
patch_size: int = 16,
width: int = 1024,
layers: int = 12,
**kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]]
):
super().__init__(
mlp_ratio=mlp_ratio,
dim_head=dim_head,
num_attention_heads=num_attention_heads,
n_queries=n_queries,
attn_pooler_heads=attn_pooler_heads,
cross_attn_ratio=cross_attn_ratio,
does_full_decoding=does_full_decoding,
output_tokens=output_tokens,
has_mlp=has_mlp,
image_size=image_size,
patch_size=patch_size,
width=width,
layers=layers,
**kwargs
)
self.num_attention_heads = num_attention_heads
class MammutConfig(CLIPConfig):
model_type = "mammut"
def __init__(
self,
mlp_ratio: int = 4,
dim_head: int = 64,
num_attention_heads: int = 8,
n_queries: int = 256,
attn_pooler_heads: int = 8,
cross_attn_ratio: int = 1,
does_full_decoding: bool = False,
output_tokens: bool = False,
has_mlp: bool = True,
text_config: Optional[MammutTextConfig] = None,
vision_config: Optional[MammutVisionConfig] = None,
projection_dim: int = 768,
logit_scale_init_value: float = 2.6592,
**kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]]
):
kwargs["architectures"] = ["MammutModel"]
super().__init__(
mlp_ratio=mlp_ratio,
dim_head=dim_head,
num_attention_heads=num_attention_heads,
n_queries=n_queries,
attn_pooler_heads=attn_pooler_heads,
cross_attn_ratio=cross_attn_ratio,
does_full_decoding=does_full_decoding,
output_tokens=output_tokens,
has_mlp=has_mlp,
**kwargs
)
self.text_config = MammutTextConfig(**text_config) if text_config is not None else MammutTextConfig()
self.vision_config = MammutVisionConfig(**vision_config) if vision_config is not None else MammutVisionConfig()
self.text_config.architectures = ["MammutTextModel"]
self.vision_config.architectures = ["MammutVisionModel"]
self.projection_dim = projection_dim
self.hidden_size = self.text_config.hidden_size
self.logit_scale_init_value = logit_scale_init_value
self.architectures = ["MammutModel"]
self.does_full_decoding = does_full_decoding
self.output_tokens = output_tokens
def _post_init(self):
if self.logit_scale_init_value is not None:
setattr(self.text_config, "logit_scale_init_value", self.logit_scale_init_value)
super()._post_init()
AutoConfig.register("mammut", MammutConfig)