marianna13 commited on
Commit
3af3aa0
·
0 Parent(s):

add HF support

Browse files
Files changed (5) hide show
  1. __init__.py +2 -0
  2. config.json +173 -0
  3. configuration_mammut.py +221 -0
  4. modeling_mammut.py +1338 -0
  5. pytorch_model.bin +3 -0
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_mammut import MammutConfig
2
+ from .modeling_mammut import MammutModel
config.json ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "architectures": [
4
+ "MammutModel"
5
+ ],
6
+ "initializer_factor": 1.0,
7
+ "logit_scale_init_value": 4.6052,
8
+ "model_type": "mammut",
9
+ "projection_dim": 768,
10
+ "text_config": {
11
+ "_name_or_path": "",
12
+ "cross_attn_ratio": 2,
13
+ "does_full_decoding": true,
14
+ "add_cross_attention": false,
15
+ "architectures": null,
16
+ "attention_dropout": 0.0,
17
+ "bad_words_ids": null,
18
+ "begin_suppress_tokens": null,
19
+ "bos_token_id": 49406,
20
+ "chunk_size_feed_forward": 0,
21
+ "cross_attention_hidden_size": null,
22
+ "decoder_start_token_id": null,
23
+ "diversity_penalty": 0.0,
24
+ "do_sample": false,
25
+ "early_stopping": false,
26
+ "encoder_no_repeat_ngram_size": 0,
27
+ "eos_token_id": 49407,
28
+ "exponential_decay_length_penalty": null,
29
+ "finetuning_task": null,
30
+ "forced_bos_token_id": null,
31
+ "forced_eos_token_id": null,
32
+ "hidden_act": "gelu",
33
+ "hidden_size": 768,
34
+ "id2label": {
35
+ "0": "LABEL_0",
36
+ "1": "LABEL_1"
37
+ },
38
+ "initializer_factor": 1.0,
39
+ "initializer_range": 0.02,
40
+ "intermediate_size": 3072,
41
+ "is_decoder": false,
42
+ "is_encoder_decoder": false,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1
46
+ },
47
+ "layer_norm_eps": 1e-05,
48
+ "length_penalty": 1.0,
49
+ "max_length": 20,
50
+ "max_position_embeddings": 77,
51
+ "min_length": 0,
52
+ "model_type": "clip_text_model",
53
+ "no_repeat_ngram_size": 0,
54
+ "num_attention_heads": 12,
55
+ "num_beam_groups": 1,
56
+ "num_beams": 1,
57
+ "num_hidden_layers": 12,
58
+ "num_return_sequences": 1,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "pad_token_id": 49408,
63
+ "prefix": null,
64
+ "problem_type": null,
65
+ "projection_dim": 768,
66
+ "pruned_heads": {},
67
+ "remove_invalid_values": false,
68
+ "repetition_penalty": 1.0,
69
+ "return_dict": true,
70
+ "return_dict_in_generate": false,
71
+ "sep_token_id": null,
72
+ "suppress_tokens": null,
73
+ "task_specific_params": null,
74
+ "temperature": 1.0,
75
+ "tf_legacy_loss": false,
76
+ "tie_encoder_decoder": false,
77
+ "tie_word_embeddings": true,
78
+ "tokenizer_class": null,
79
+ "top_k": 50,
80
+ "top_p": 1.0,
81
+ "torch_dtype": null,
82
+ "torchscript": false,
83
+ "transformers_version": "4.29.1",
84
+ "typical_p": 1.0,
85
+ "use_bfloat16": false,
86
+ "vocab_size": 49408
87
+ },
88
+ "torch_dtype": "float32",
89
+ "transformers_version": null,
90
+ "vision_config": {
91
+ "_name_or_path": "",
92
+ "add_cross_attention": false,
93
+ "architectures": null,
94
+ "attention_dropout": 0.0,
95
+ "bad_words_ids": null,
96
+ "begin_suppress_tokens": null,
97
+ "bos_token_id": null,
98
+ "chunk_size_feed_forward": 0,
99
+ "cross_attention_hidden_size": null,
100
+ "decoder_start_token_id": null,
101
+ "diversity_penalty": 0.0,
102
+ "do_sample": false,
103
+ "early_stopping": false,
104
+ "encoder_no_repeat_ngram_size": 0,
105
+ "eos_token_id": null,
106
+ "exponential_decay_length_penalty": null,
107
+ "finetuning_task": null,
108
+ "forced_bos_token_id": null,
109
+ "forced_eos_token_id": null,
110
+ "hidden_act": "gelu",
111
+ "hidden_size": 1024,
112
+ "id2label": {
113
+ "0": "LABEL_0",
114
+ "1": "LABEL_1"
115
+ },
116
+ "image_size": 224,
117
+ "initializer_factor": 1.0,
118
+ "initializer_range": 0.02,
119
+ "intermediate_size": 4096,
120
+ "is_decoder": true,
121
+ "is_encoder_decoder": false,
122
+ "label2id": {
123
+ "LABEL_0": 0,
124
+ "LABEL_1": 1
125
+ },
126
+ "layer_norm_eps": 1e-05,
127
+ "length_penalty": 1.0,
128
+ "max_length": 20,
129
+ "min_length": 0,
130
+ "model_type": "clip_vision_model",
131
+ "no_repeat_ngram_size": 0,
132
+ "num_attention_heads": 16,
133
+ "num_beam_groups": 1,
134
+ "num_beams": 1,
135
+ "num_channels": 3,
136
+ "num_hidden_layers": 24,
137
+ "num_return_sequences": 1,
138
+ "output_attentions": false,
139
+ "output_hidden_states": false,
140
+ "output_scores": false,
141
+ "pad_token_id": null,
142
+ "patch_size": 14,
143
+ "prefix": null,
144
+ "problem_type": null,
145
+ "projection_dim": 768,
146
+ "pruned_heads": {},
147
+ "remove_invalid_values": false,
148
+ "repetition_penalty": 1.0,
149
+ "return_dict": true,
150
+ "return_dict_in_generate": false,
151
+ "sep_token_id": null,
152
+ "suppress_tokens": null,
153
+ "task_specific_params": null,
154
+ "temperature": 1.0,
155
+ "tf_legacy_loss": false,
156
+ "tie_encoder_decoder": false,
157
+ "tie_word_embeddings": true,
158
+ "tokenizer_class": null,
159
+ "top_k": 50,
160
+ "top_p": 1.0,
161
+ "torch_dtype": null,
162
+ "torchscript": false,
163
+ "transformers_version": "4.29.1",
164
+ "typical_p": 1.0,
165
+ "use_bfloat16": false,
166
+ "pool_type": "avg_all",
167
+ "final_ln_after_pool": true
168
+ },
169
+ "auto_map": {
170
+ "AutoConfig": "configuration_mammut.MammutConfig",
171
+ "AutoModel": "modeling_mammut.MammutModel"
172
+ }
173
+ }
configuration_mammut.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Google AI, LAION team. team. All rights reserved.
3
+ #
4
+ # This code is based on open_clip framework. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to the original MaMMUT model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """MaMMUT configuration."""
20
+
21
+
22
+ from transformers import (CLIPConfig, CLIPTextConfig, CLIPVisionConfig, PretrainedConfig, AutoConfig)
23
+ from typing import Callable, List, Optional, Sequence, Tuple, Union
24
+ from transformers.utils import logging
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+
30
+
31
+ class MultimodalConfig(PretrainedConfig):
32
+
33
+ model_type = "mammut_text_model"
34
+
35
+ def __init__(
36
+ self,
37
+ mlp_ratio: int = 4,
38
+ dim_head: int = 64,
39
+ heads: int = 8,
40
+ n_queries: int = 256,
41
+ attn_pooler_heads: int = 8,
42
+ cross_attn_ratio: int = 1,
43
+ does_full_decoding: bool = False,
44
+ output_tokens: bool = False,
45
+ has_mlp: bool = True,
46
+ context_length: int = 77,
47
+ vocab_size: int = 49408,
48
+ hidden_size: int = 1024,
49
+ layers: int = 12,
50
+ batch_first: bool = True,
51
+ **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]]
52
+ ):
53
+ super().__init__()
54
+ self.mlp_ratio = mlp_ratio
55
+ self.dim_head = dim_head
56
+ self.heads = heads
57
+ self.n_queries = n_queries
58
+ self.attn_pooler_heads = attn_pooler_heads
59
+ self.cross_attn_ratio = cross_attn_ratio
60
+ self.does_full_decoding = does_full_decoding
61
+ self.output_tokens = output_tokens
62
+ self.has_mlp = has_mlp
63
+ self.context_length = context_length
64
+ self.vocab_size = vocab_size
65
+ self.width = hidden_size
66
+ self.layers = layers
67
+ self.batch_first = batch_first
68
+ for key, value in kwargs.items():
69
+ setattr(self, key, value)
70
+
71
+
72
+
73
+ class MammutTextConfig(MultimodalConfig,CLIPTextConfig):
74
+ model_type = "mammut_text_model"
75
+ base_config_key = "text_config"
76
+
77
+ def __init__(
78
+ self,
79
+ mlp_ratio: int = 4,
80
+ num_attention_heads: int = 8,
81
+ n_queries: int = 256,
82
+ attn_pooler_heads: int = 8,
83
+ cross_attn_ratio: int = 1,
84
+ does_full_decoding: bool = False,
85
+ output_tokens: bool = False,
86
+ has_mlp: bool = True,
87
+ max_position_embeddings: int = 77,
88
+ vocab_size: int = 49408,
89
+ num_hidden_layers: int = 12,
90
+ hidden_size: int = 1024,
91
+ attention_dropout: float = 0.0,
92
+ hidden_act: str = "gelu",
93
+ layer_norm_eps: float = 1e-5,
94
+ intermediate_size: Optional[int] = None,
95
+ initializer_factor: float = 0.02,
96
+ logit_scale_init_value: float = 2.6592,
97
+ **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]]
98
+ ):
99
+ super().__init__(
100
+ mlp_ratio=mlp_ratio,
101
+ num_attention_heads=num_attention_heads,
102
+ n_queries=n_queries,
103
+ attn_pooler_heads=attn_pooler_heads,
104
+ cross_attn_ratio=cross_attn_ratio,
105
+ does_full_decoding=does_full_decoding,
106
+ output_tokens=output_tokens,
107
+ has_mlp=has_mlp,
108
+ vocab_size=vocab_size,
109
+ hidden_size=hidden_size,
110
+ num_hidden_layers=num_hidden_layers,
111
+ attention_dropout=attention_dropout,
112
+ logit_scale_init_value=logit_scale_init_value,
113
+ max_position_embeddings=max_position_embeddings,
114
+ layer_norm_eps=layer_norm_eps,
115
+ intermediate_size=intermediate_size,
116
+ initializer_factor=initializer_factor,
117
+ hidden_act=hidden_act,
118
+ **kwargs
119
+ )
120
+
121
+
122
+ self.logit_scale_init_value = logit_scale_init_value
123
+ self.does_full_decoding = does_full_decoding
124
+ self.output_tokens = output_tokens
125
+ self.architectures = ["MammutTextModel"]
126
+ self.hidden_size = hidden_size
127
+ self.num_attention_heads = num_attention_heads
128
+
129
+ class MammutVisionConfig(CLIPVisionConfig):
130
+ model_type = "mammut_vision_model"
131
+ base_config_key = "vision_config"
132
+
133
+ def __init__(
134
+ self,
135
+ mlp_ratio: int = 4,
136
+ dim_head: int = 64,
137
+ num_attention_heads: int = 8,
138
+ n_queries: int = 256,
139
+ attn_pooler_heads: int = 8,
140
+ cross_attn_ratio: int = 1,
141
+ does_full_decoding: bool = False,
142
+ output_tokens: bool = False,
143
+ has_mlp: bool = True,
144
+ image_size: int = 224,
145
+ patch_size: int = 16,
146
+ width: int = 1024,
147
+ layers: int = 12,
148
+ **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]]
149
+ ):
150
+ super().__init__(
151
+ mlp_ratio=mlp_ratio,
152
+ dim_head=dim_head,
153
+ num_attention_heads=num_attention_heads,
154
+ n_queries=n_queries,
155
+ attn_pooler_heads=attn_pooler_heads,
156
+ cross_attn_ratio=cross_attn_ratio,
157
+ does_full_decoding=does_full_decoding,
158
+ output_tokens=output_tokens,
159
+ has_mlp=has_mlp,
160
+ image_size=image_size,
161
+ patch_size=patch_size,
162
+ width=width,
163
+ layers=layers,
164
+ **kwargs
165
+ )
166
+
167
+ self.num_attention_heads = num_attention_heads
168
+
169
+ class MammutConfig(CLIPConfig):
170
+ model_type = "mammut"
171
+
172
+ def __init__(
173
+ self,
174
+ mlp_ratio: int = 4,
175
+ dim_head: int = 64,
176
+ num_attention_heads: int = 8,
177
+ n_queries: int = 256,
178
+ attn_pooler_heads: int = 8,
179
+ cross_attn_ratio: int = 1,
180
+ does_full_decoding: bool = False,
181
+ output_tokens: bool = False,
182
+ has_mlp: bool = True,
183
+ text_config: Optional[MammutTextConfig] = None,
184
+ vision_config: Optional[MammutVisionConfig] = None,
185
+ projection_dim: int = 768,
186
+ logit_scale_init_value: float = 2.6592,
187
+ **kwargs: Union[int, float, str, bool, List[int], List[float], List[str], List[bool], Callable, Sequence[Union[int, float, str, bool]]]
188
+ ):
189
+ kwargs["architectures"] = ["MammutModel"]
190
+ super().__init__(
191
+ mlp_ratio=mlp_ratio,
192
+ dim_head=dim_head,
193
+ num_attention_heads=num_attention_heads,
194
+ n_queries=n_queries,
195
+ attn_pooler_heads=attn_pooler_heads,
196
+ cross_attn_ratio=cross_attn_ratio,
197
+ does_full_decoding=does_full_decoding,
198
+ output_tokens=output_tokens,
199
+ has_mlp=has_mlp,
200
+ **kwargs
201
+ )
202
+ self.text_config = MammutTextConfig(**text_config) if text_config is not None else MammutTextConfig()
203
+ self.vision_config = MammutVisionConfig(**vision_config) if vision_config is not None else MammutVisionConfig()
204
+ self.text_config.architectures = ["MammutTextModel"]
205
+ self.vision_config.architectures = ["MammutVisionModel"]
206
+ self.projection_dim = projection_dim
207
+ self.hidden_size = self.text_config.hidden_size
208
+ self.logit_scale_init_value = logit_scale_init_value
209
+ self.architectures = ["MammutModel"]
210
+
211
+ self.does_full_decoding = does_full_decoding
212
+ self.output_tokens = output_tokens
213
+
214
+ def _post_init(self):
215
+ if self.logit_scale_init_value is not None:
216
+ setattr(self.text_config, "logit_scale_init_value", self.logit_scale_init_value)
217
+
218
+ super()._post_init()
219
+
220
+
221
+ AutoConfig.register("mammut", MammutConfig)
modeling_mammut.py ADDED
@@ -0,0 +1,1338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Google AI, LAION team. team. All rights reserved.
3
+ #
4
+ # This code is based on open_clip framework. It has been modified from its
5
+ # original forms to accommodate minor architectural differences compared
6
+ # to the original MaMMUT model.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """PyTorch MaMMUT model."""
20
+
21
+
22
+ from typing import Callable, List, Optional, Tuple, Union
23
+ import torch
24
+ from torch import nn
25
+ from torch.nn import functional as F
26
+ from .configuration_mammut import MammutTextConfig, MammutVisionConfig, MammutConfig
27
+ from transformers.models.clip.modeling_clip import (
28
+ CLIPAttention,
29
+ CLIPMLP,
30
+ CLIPEncoderLayer,
31
+ CLIPTextModel,
32
+ CLIPVisionModel,
33
+ CLIPVisionModelOutput,
34
+ CLIPVisionTransformer,
35
+ CLIPTextModelOutput,
36
+ CLIPOutput,
37
+ CLIPModel,
38
+ CLIPPreTrainedModel,
39
+ CLIPVisionEmbeddings,
40
+ CLIPEncoder,
41
+ eager_attention_forward
42
+ ) # noqa: E501
43
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
44
+ from transformers.generation import GenerateDecoderOnlyOutput
45
+ from dataclasses import dataclass
46
+ from typing import Optional, Tuple, Union
47
+ from transformers import AutoModel
48
+ import logging
49
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
50
+ from transformers import (
51
+ BeamSearchScorer,
52
+ LogitsProcessorList,
53
+ TopPLogitsWarper,
54
+ TopKLogitsWarper,
55
+ RepetitionPenaltyLogitsProcessor,
56
+ MinLengthLogitsProcessor,
57
+ MaxLengthCriteria,
58
+ StoppingCriteriaList
59
+ )
60
+
61
+
62
+
63
+ log = logging.getLogger(__name__)
64
+
65
+
66
+ class MammutCrossAttnLayer(nn.Module):
67
+ def __init__(self, config: MammutTextConfig):
68
+ super().__init__()
69
+ self.embed_dim = config.hidden_size
70
+ self.self_attn = MammutAttention(config)
71
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
72
+ self.mlp = CLIPMLP(config)
73
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
74
+ self.layer_norm1_kv = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
75
+
76
+ def forward(
77
+ self,
78
+ hidden_states: torch.Tensor,
79
+ k_x: Optional[torch.Tensor] = None,
80
+ v_x: Optional[torch.Tensor] = None,
81
+ attention_mask: Optional[torch.Tensor] = None,
82
+ causal_attention_mask: Optional[torch.Tensor] = None,
83
+ print0_hidden_states: bool = False,
84
+ ) -> torch.Tensor:
85
+ residual = hidden_states
86
+ hidden_states = self.layer_norm1(hidden_states)
87
+
88
+ if k_x is not None and v_x is not None:
89
+ k_x = self.layer_norm1_kv(k_x)
90
+ v_x = self.layer_norm1_kv(v_x)
91
+ hidden_states, attn_weights = self.self_attn(
92
+ hidden_states=hidden_states,
93
+ attention_mask=attention_mask,
94
+ causal_attention_mask=causal_attention_mask,
95
+ keys=k_x,
96
+ values=v_x,
97
+ print0_hidden_states=print0_hidden_states,
98
+ )
99
+
100
+ hidden_states = hidden_states.permute(1, 0, 2) # (seq_length, batch_size, embed_dim)
101
+
102
+
103
+ hidden_states = residual + hidden_states
104
+ residual = hidden_states
105
+ hidden_states = self.layer_norm2(hidden_states)
106
+ hidden_states = self.mlp(hidden_states)
107
+ hidden_states = residual + hidden_states
108
+ return hidden_states
109
+
110
+
111
+ class LayerScale(nn.Module):
112
+ def __init__(self, dim, init_values=1e-5, inplace=False):
113
+ super().__init__()
114
+ self.inplace = inplace
115
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
116
+
117
+ def forward(self, x):
118
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
119
+
120
+
121
+ class MammutAttention(CLIPAttention):
122
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
123
+
124
+ def __init__(self, config: Union[MammutTextConfig, MammutVisionConfig]):
125
+ super().__init__(config)
126
+ self.config = config
127
+ self.embed_dim = config.hidden_size
128
+ self.num_heads = config.num_attention_heads
129
+ self.head_dim = self.embed_dim // self.num_heads
130
+ if self.head_dim * self.num_heads != self.embed_dim:
131
+ raise ValueError(
132
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
133
+ f" {self.num_heads})."
134
+ )
135
+ self.scale = self.head_dim**-0.5
136
+ # self.scale = 1
137
+ self.dropout = config.attention_dropout
138
+ self.is_causal = False
139
+
140
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
141
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
142
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
143
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
144
+
145
+ self.training = False # Set to True by default, can be changed during training or evaluation
146
+
147
+ def forward(
148
+ self,
149
+ hidden_states: torch.Tensor,
150
+ attention_mask: Optional[torch.Tensor] = None,
151
+ causal_attention_mask: Optional[torch.Tensor] = None,
152
+ output_attentions: Optional[bool] = False,
153
+ keys: Optional[torch.Tensor] = None,
154
+ values: Optional[torch.Tensor] = None,
155
+ print0_hidden_states: bool = False,
156
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
157
+
158
+ """Input shape: Batch x Time x Channel"""
159
+
160
+ batch_size, seq_length, embed_dim = hidden_states.shape
161
+
162
+ if keys is None and values is None:
163
+ keys = hidden_states
164
+ values = hidden_states
165
+
166
+ #TODO: CLIP attention interface
167
+ # keys = self.k_proj(keys)
168
+ # values = self.v_proj(values)
169
+
170
+ # if print0_hidden_states:
171
+ # # print("head_dim:", self.head_dim)
172
+ # print("query shape:", queries.shape)
173
+ # print("key shape:", keys.shape)
174
+ # print("value shape:", values.shape)
175
+
176
+ # queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
177
+ # keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
178
+ # values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
179
+
180
+
181
+ # CLIP text model uses both `causal_attention_mask` and `attention_mask`
182
+ # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
183
+ # if self.config._attn_implementation == "flash_attention_2":
184
+ # self.is_causal = causal_attention_mask is not None
185
+ # else:
186
+ # if attention_mask is not None and causal_attention_mask is not None:
187
+ # attention_mask = attention_mask + causal_attention_mask
188
+ # elif causal_attention_mask is not None:
189
+ # attention_mask = causal_attention_mask
190
+ # attention_interface: Callable = eager_attention_forward
191
+
192
+ # if self.config._attn_implementation != "eager":
193
+
194
+ # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
195
+
196
+
197
+ attn_output, attn_weights = F.multi_head_attention_forward(
198
+ query=hidden_states.permute(1, 0, 2), # (seq_length, batch_size, embed_dim)
199
+ key=keys.permute(1, 0, 2) if keys is not None else hidden_states.permute(1, 0, 2),
200
+ value=values.permute(1, 0, 2) if values is not None else hidden_states.permute(1, 0, 2),
201
+ embed_dim_to_check=embed_dim,
202
+ num_heads=self.num_heads,
203
+ in_proj_weight=torch.cat(
204
+ [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0
205
+ ),
206
+ in_proj_bias=torch.cat(
207
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias], dim=0
208
+ ) if self.q_proj.bias is not None else None,
209
+ bias_k=None,
210
+ bias_v=None,
211
+ add_zero_attn=False,
212
+ attn_mask=attention_mask,
213
+ q_proj_weight=self.q_proj.weight,
214
+ k_proj_weight=self.k_proj.weight,
215
+ v_proj_weight=self.v_proj.weight,
216
+ is_causal=self.is_causal,
217
+ dropout_p=0.0 if not self.training else self.dropout,
218
+ out_proj_weight=self.out_proj.weight,
219
+ out_proj_bias=self.out_proj.bias,
220
+ training=self.training, # Use the training flag to control dropout
221
+ )
222
+
223
+
224
+ # attn_output, attn_weights = attention_interface(
225
+ # self,
226
+ # queries, # (seq_length, batch_size, embed_dim)
227
+ # keys,
228
+ # values,
229
+ # attention_mask,
230
+ # is_causal=self.is_causal,
231
+ # scaling=self.scale,
232
+ # dropout=0.0 if not self.training else self.dropout,
233
+ # output_attentions=output_attentions,
234
+ # )
235
+
236
+ # attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
237
+ # attn_output = self.out_proj(attn_output)
238
+
239
+ if not output_attentions:
240
+ attn_weights = None
241
+ return attn_output, attn_weights
242
+
243
+ class MammutEncoderLayer(CLIPEncoderLayer):
244
+ def __init__(self, config: MammutTextConfig, has_mlp: bool = True):
245
+ super().__init__(config)
246
+ self.embed_dim = config.hidden_size
247
+ self.self_attn = MammutAttention(config)
248
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
249
+ self.mlp = CLIPMLP(config) if has_mlp else None
250
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
251
+
252
+
253
+ def forward(
254
+ self,
255
+ hidden_states: torch.Tensor,
256
+ attention_mask: Optional[torch.Tensor] = None,
257
+ causal_attention_mask: Optional[torch.Tensor] = None,
258
+ output_attentions: Optional[bool] = False,
259
+ print_hidden_states: bool = False,
260
+ ) -> Tuple[torch.FloatTensor]:
261
+ """
262
+ Forward pass for the encoder layer.
263
+ Args:
264
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
265
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
266
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
267
+ causal_attention_mask (`torch.FloatTensor`, *optional*): causal attention mask of size
268
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
269
+ output_attentions (`bool`, *optional*):
270
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
271
+ returned tensors for more detail.
272
+ """
273
+
274
+ residual = hidden_states
275
+ hidden_states = self.layer_norm1(hidden_states)
276
+
277
+
278
+ hidden_states, attn_weights = self.self_attn(
279
+ hidden_states=hidden_states,
280
+ attention_mask=attention_mask,
281
+ causal_attention_mask=None,
282
+ output_attentions=output_attentions,
283
+ print0_hidden_states=print_hidden_states,
284
+ )
285
+
286
+ hidden_states = hidden_states.permute(1, 0, 2) # (seq_length, batch_size, embed_dim)
287
+
288
+
289
+ hidden_states = residual + hidden_states
290
+
291
+ residual = hidden_states
292
+ hidden_states = self.layer_norm2(hidden_states)
293
+
294
+ hidden_states = self.mlp(hidden_states) if self.mlp is not None else hidden_states
295
+ hidden_states = residual + hidden_states
296
+ return hidden_states
297
+
298
+
299
+ class MammutMultimodalEncoder(nn.Module):
300
+ does_full_decoding: torch.jit.Final[bool]
301
+
302
+ def __init__(
303
+ self,
304
+ config: MammutConfig,
305
+ ):
306
+
307
+ super().__init__()
308
+
309
+ self.config = config
310
+
311
+ self.n_cross_attn, _ = divmod(config.num_hidden_layers, config.cross_attn_ratio)
312
+ self.cross_step, _ = divmod(config.num_hidden_layers, self.n_cross_attn)
313
+ self.does_full_decoding = config.does_full_decoding
314
+ self.output_tokens = config.output_tokens
315
+ self.batch_first = config.batch_first
316
+ self.context_length = config.max_position_embeddings
317
+ self.layers = nn.ModuleList([])
318
+ self.cross_attn = nn.ModuleList([])
319
+ num_cross_attn = 0
320
+ for l_idx in range(config.num_hidden_layers):
321
+ _, r = divmod(l_idx, self.cross_step)
322
+ has_cross_attn = r == 0
323
+ layer = MammutEncoderLayer(config)
324
+ self.layers.append(layer)
325
+ if has_cross_attn:
326
+ num_cross_attn += 1
327
+ cross_attn_layer = MammutCrossAttnLayer(config)
328
+ self.cross_attn.append(cross_attn_layer)
329
+
330
+
331
+ def forward(
332
+ self,
333
+ text_embeds: torch.Tensor,
334
+ img_embeds: Optional[torch.Tensor] = None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ causal_attention_mask: Optional[torch.Tensor] = None,
337
+ output_attentions: Optional[bool] = None,
338
+ output_hidden_states: Optional[bool] = None,
339
+ ) -> Union[BaseModelOutput, Tuple[torch.Tensor]]:
340
+
341
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
342
+ output_hidden_states = (
343
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
344
+ )
345
+
346
+ encoder_states = () if output_hidden_states else None
347
+ all_attentions = () if output_attentions else None
348
+ hidden_states = text_embeds
349
+
350
+ seq_len = hidden_states.shape[1] if self.batch_first else hidden_states.shape[0]
351
+
352
+ if causal_attention_mask is None:
353
+ causal_attention_mask = self.build_causal_mask()
354
+ else:
355
+ causal_attention_mask = causal_attention_mask.to(dtype=hidden_states.dtype)
356
+
357
+ if attention_mask is None:
358
+ attention_mask = causal_attention_mask
359
+ else:
360
+ attention_mask = attention_mask + causal_attention_mask
361
+
362
+
363
+ if img_embeds is not None:
364
+ img_embeds = img_embeds.to(dtype=hidden_states.dtype)
365
+ k_x = img_embeds
366
+ v_x = img_embeds
367
+ else:
368
+ k_x = None
369
+ v_x = None
370
+
371
+ if img_embeds is not None:
372
+ attention_mask = attention_mask[:seq_len, :seq_len]
373
+
374
+ for i, layer in enumerate(self.layers):
375
+
376
+
377
+ cross_attn_idx, r = divmod(i, self.cross_step)
378
+
379
+ has_cross_attn = r == 0 and img_embeds is not None
380
+ if i == 0:
381
+ print_hidden_states = True
382
+ else:
383
+ print_hidden_states = False
384
+
385
+
386
+ hidden_states = layer(
387
+ hidden_states=hidden_states,
388
+ attention_mask=attention_mask if img_embeds is not None else None,
389
+ causal_attention_mask=None,
390
+ output_attentions=output_attentions,
391
+ print_hidden_states=print_hidden_states,
392
+ )
393
+
394
+ if has_cross_attn:
395
+ cross_attn = self.cross_attn[cross_attn_idx]
396
+
397
+
398
+ hidden_states = cross_attn(
399
+ hidden_states=hidden_states,
400
+ k_x=k_x,
401
+ v_x=v_x,
402
+ print0_hidden_states=i== 0,
403
+ # attention_mask=attention_mask,
404
+ # causal_attention_mask=causal_attention_mask,
405
+ )
406
+
407
+
408
+ if output_hidden_states:
409
+ encoder_states = tuple(encoder_states)
410
+ if self.does_full_decoding:
411
+ encoder_states = encoder_states[:self.n_cross_attn + 1]
412
+ else:
413
+ encoder_states = encoder_states[:self.config.text_config.num_hidden_layers]
414
+ else:
415
+ encoder_states = None
416
+
417
+ return BaseModelOutput(
418
+ last_hidden_state=hidden_states,
419
+ hidden_states=encoder_states,
420
+ attentions=all_attentions,
421
+ )
422
+
423
+ def build_causal_mask(self):
424
+ # lazily create causal attention mask, with full attention between the tokens
425
+ # pytorch uses additive attention mask; fill with -inf
426
+ mask = torch.empty(self.context_length, self.context_length)
427
+ mask.fill_(float("-inf"))
428
+ mask.triu_(1) # zero out the lower diagonal
429
+ return mask
430
+
431
+
432
+ def build_attn_mask(self):
433
+ # lazily create causal attention mask, with full attention between the tokens
434
+ # pytorch uses additive attention mask; fill with -inf
435
+ mask = torch.empty(self.context_length, self.context_length)
436
+ mask.fill_(float("-inf"))
437
+ mask.triu_(1) # zero out the lower diagonal
438
+ return mask
439
+
440
+
441
+ @dataclass
442
+ class MammutPoolingOutput(BaseModelOutputWithPooling):
443
+ """
444
+ Base class for outputs of the Mammut model.
445
+ """
446
+
447
+ last_hidden_state: torch.FloatTensor = None
448
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
449
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
450
+ output_ids: Optional[torch.Tensor] = None
451
+ pooler_output: Optional[torch.FloatTensor] = None
452
+
453
+
454
+ class MammutMultimodalEmbeddings(nn.Module):
455
+ def __init__(self, config: MammutTextConfig):
456
+ super().__init__()
457
+ self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
458
+ self.position_embedding = nn.Embedding(
459
+ config.max_position_embeddings, config.hidden_size
460
+ )
461
+ self.register_buffer(
462
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
463
+ )
464
+
465
+
466
+ def forward(
467
+ self,
468
+ input_ids: Optional[torch.LongTensor] = None,
469
+ position_ids: Optional[torch.LongTensor] = None,
470
+ inputs_embeds: Optional[torch.FloatTensor] = None,
471
+ ) -> torch.Tensor:
472
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
473
+ max_position_embedding = self.position_embedding.weight.shape[0]
474
+
475
+ if seq_length > max_position_embedding:
476
+ raise ValueError(
477
+ f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
478
+ f"{seq_length} and max_position_embeddings: {max_position_embedding}"
479
+ )
480
+
481
+ if position_ids is None:
482
+ position_ids = self.position_ids[:, :seq_length]
483
+
484
+ if inputs_embeds is None:
485
+ inputs_embeds = self.token_embedding(input_ids)
486
+
487
+ position_embeddings = self.position_embedding(position_ids)
488
+ embeddings = inputs_embeds + position_embeddings
489
+
490
+ return embeddings
491
+
492
+
493
+ def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'):
494
+ if pool_type == 'first':
495
+ pooled, tokens = x[:, 0], x[:, 1:]
496
+ elif pool_type == 'last':
497
+ pooled, tokens = x[:, -1], x[:, :-1]
498
+ elif pool_type == 'argmax':
499
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
500
+ assert text is not None
501
+ pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
502
+ else:
503
+ pooled = tokens = x
504
+
505
+ return pooled, tokens
506
+
507
+
508
+ class MammutMultimodalTransformer(nn.Module):
509
+ def __init__(self, config: MammutTextConfig, output_tokens=True):
510
+ super().__init__()
511
+ self.config = config
512
+ embed_dim = config.hidden_size
513
+ self.encoder = MammutMultimodalEncoder(config)
514
+ self.text_projection = nn.Linear(
515
+ config.hidden_size, config.vocab_size, bias=False
516
+ ) if config.hidden_size is not None else None
517
+ self.final_layer_norm = nn.LayerNorm(
518
+ embed_dim, eps=config.layer_norm_eps
519
+ )
520
+
521
+ # self.init_weights()
522
+ self.does_full_decoding = config.does_full_decoding
523
+ self.context_length = config.context_length
524
+ self.vocab_size = config.vocab_size
525
+ width = config.hidden_size
526
+ self.batch_first = config.batch_first
527
+ self.has_mlp = config.has_mlp
528
+ self.cross_attn_ratio = config.cross_attn_ratio
529
+ self.cross_step = config.cross_attn_ratio
530
+ self.n_cross_attn = config.num_hidden_layers // config.cross_attn_ratio
531
+ vocab_size = config.vocab_size
532
+ self.output_tokens = output_tokens
533
+
534
+ if self.does_full_decoding:
535
+ self.num_pos = self.context_length
536
+ self.embeddings = MammutMultimodalEmbeddings(config)
537
+ else:
538
+ self.num_pos = None
539
+ self.embeddings = None
540
+
541
+ def init_weights(self):
542
+
543
+ self.final_layer_norm.weight.data.fill_(1.0)
544
+ self.final_layer_norm.bias.data.zero_()
545
+ log.info("MammutMultimodalTransformer weights initialized.")
546
+
547
+ def forward(
548
+ self,
549
+ img_embs: torch.Tensor,
550
+ text_embs: Optional[torch.Tensor] = None,
551
+ output_tokens: Optional[bool] = False,
552
+ output_attentions: Optional[bool] = None,
553
+ output_hidden_states: Optional[bool] = None,
554
+ position_ids: Optional[torch.LongTensor] = None,
555
+ ) -> Union[CLIPVisionModelOutput, CLIPTextModelOutput]:
556
+
557
+
558
+ if text_embs is not None:
559
+ if self.embeddings is not None:
560
+ # print("text_embs shape:", text_embs.shape)
561
+ text_embs = self.embeddings(
562
+ input_ids=text_embs,
563
+ position_ids=position_ids,
564
+ # inputs_embeds=img_embs if img_embs is not None else None,
565
+ )
566
+
567
+
568
+ if self.does_full_decoding:
569
+ text_embs = text_embs[:, :self.context_length, :]
570
+
571
+
572
+ text_embs = self.encoder(
573
+ text_embeds=text_embs,
574
+ img_embeds=img_embs,
575
+ attention_mask=None,
576
+ output_attentions=output_attentions,
577
+ output_hidden_states=output_hidden_states,
578
+ )
579
+
580
+ text_embs = text_embs.last_hidden_state
581
+
582
+ if self.does_full_decoding:
583
+ text_embs = text_embs[:, :self.context_length, :]
584
+ else:
585
+ text_embs = text_embs[:, 0, :]
586
+
587
+
588
+ if self.text_projection is not None:
589
+ output_ids = self.text_projection(text_embs)
590
+ else:
591
+ output_ids = text_embs
592
+
593
+ if output_tokens:
594
+ return MammutPoolingOutput(
595
+ last_hidden_state=text_embs, # Last hidden state is the text embeddings
596
+ hidden_states=None, # No hidden states in this implementation
597
+ attentions=None, # No attentions in this implementation
598
+ output_ids=output_ids, # Placeholder for output tokens
599
+ pooler_output=text_embs, # Pooler output is the text embeddings
600
+ )
601
+
602
+ return MammutPoolingOutput(
603
+ last_hidden_state=text_embs, # Last hidden state is the text embeddings
604
+ pooler_output=text_embs,
605
+ hidden_states=None, # No hidden states in this implementation
606
+ attentions=None, # No attentions in this implementation
607
+ )
608
+
609
+
610
+ def build_causal_mask(self, seq_len: Optional[int] = None, device: Optional[torch.device] = None) -> torch.Tensor:
611
+ if seq_len is None:
612
+ seq_len = self.context_length if self.does_full_decoding else self.config.context_length
613
+ if device is None:
614
+ device = torch.device("cpu")
615
+ mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).view(1, 1, seq_len, seq_len)
616
+ return mask
617
+
618
+ def build_attn_mask(self):
619
+ # lazily create causal attention mask, with full attention between the tokens
620
+ # pytorch uses additive attention mask; fill with -inf
621
+ mask = torch.empty(self.context_length, self.context_length)
622
+ mask.fill_(float("-inf"))
623
+ mask.triu_(1) # zero out the lower diagonal
624
+ return mask
625
+
626
+ class MammutMultimodalModel(CLIPTextModel):
627
+ """
628
+ Mammut multimodal model with text and vision encoders.
629
+ """
630
+
631
+ config_class = MammutTextConfig
632
+ base_model_prefix = "mammut_multimodal"
633
+
634
+ def __init__(self, config: MammutTextConfig):
635
+ super().__init__(config)
636
+ self.config = config.text_config
637
+ self.text_model = MammutMultimodalTransformer(config.text_config)
638
+ self.text_embed_dim = config.hidden_size
639
+ self.vision_embed_dim = config.vision_config.hidden_size
640
+ self.projection_dim = config.projection_dim
641
+
642
+ # Initialize weights and apply final processing
643
+ self.post_init()
644
+
645
+
646
+ def forward(
647
+ self,
648
+ input_ids: Optional[torch.Tensor] = None,
649
+ attention_mask: Optional[torch.Tensor] = None,
650
+ image_embs: Optional[torch.Tensor] = None,
651
+ output_attentions: Optional[bool] = None,
652
+ output_hidden_states: Optional[bool] = None,
653
+ output_tokens: Optional[bool] = None,
654
+ position_ids: Optional[torch.LongTensor] = None,
655
+ ) -> Union[MammutPoolingOutput, CLIPTextModelOutput]:
656
+
657
+ return self.text_model(
658
+ img_embs=image_embs,
659
+ text_embs=input_ids,
660
+ output_tokens=output_tokens,
661
+ output_attentions=output_attentions,
662
+ output_hidden_states=output_hidden_states,
663
+ position_ids=position_ids,
664
+ )
665
+
666
+
667
+ class MammutVisionTransformer(CLIPVisionTransformer):
668
+ """
669
+ Mammut Vision Transformer model.
670
+ Inherits from CLIPVisionTransformer and initializes the vision model.
671
+ """
672
+
673
+ config_class = MammutVisionConfig
674
+ base_model_prefix = "mammut_vision"
675
+
676
+ def __init__(self, config: MammutVisionConfig):
677
+ super().__init__(config)
678
+ self.config = config
679
+ embed_dim = config.hidden_size
680
+
681
+ self.embeddings = CLIPVisionEmbeddings(config)
682
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
683
+ self.encoder = CLIPEncoder(config)
684
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
685
+ self.pool_type = config.pool_type
686
+
687
+
688
+ def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
689
+ if self.pool_type == 'avg':
690
+ pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
691
+ elif self.pool_type == 'tok':
692
+ pooled, tokens = x[:, 0], x[:, 1:]
693
+ elif self.pool_type == "avg_all":
694
+ pooled, tokens = x.mean(dim=1), x
695
+ else:
696
+ pooled = tokens = x
697
+
698
+ return pooled, tokens
699
+
700
+
701
+
702
+ def forward(
703
+ self,
704
+ pixel_values: Optional[torch.FloatTensor] = None,
705
+ output_attentions: Optional[bool] = None,
706
+ output_hidden_states: Optional[bool] = None,
707
+ interpolate_pos_encoding: Optional[bool] = False,
708
+ ) -> BaseModelOutputWithPooling:
709
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
710
+ output_hidden_states = (
711
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
712
+ )
713
+
714
+ if pixel_values is None:
715
+ raise ValueError("You have to specify pixel_values")
716
+
717
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
718
+ hidden_states = self.pre_layrnorm(hidden_states)
719
+
720
+ encoder_outputs: BaseModelOutput = self.encoder(
721
+ inputs_embeds=hidden_states,
722
+ output_attentions=output_attentions,
723
+ output_hidden_states=output_hidden_states,
724
+ )
725
+
726
+ last_hidden_state = encoder_outputs.last_hidden_state
727
+ pooled_output = last_hidden_state[:, 0, :]
728
+ if self.config.final_ln_after_pool:
729
+ pooled, _ = self._global_pool(last_hidden_state)
730
+ pooled_output = self.post_layernorm(pooled)
731
+ else:
732
+ pooled_output = self.post_layernorm(pooled_output)
733
+ pooled, _ = self._global_pool(pooled_output)
734
+ pooled_output = pooled
735
+
736
+ return BaseModelOutputWithPooling(
737
+ last_hidden_state=last_hidden_state,
738
+ pooler_output=pooled_output,
739
+ hidden_states=encoder_outputs.hidden_states,
740
+ attentions=encoder_outputs.attentions,
741
+ )
742
+
743
+ class MammutVisionModel(CLIPVisionModel):
744
+ """
745
+ Mammut Vision Model.
746
+ Inherits from CLIPVisionModel and initializes the vision model.
747
+ """
748
+
749
+ config_class = MammutVisionConfig
750
+ base_model_prefix = "mammut_vision"
751
+
752
+ def __init__(self, config: MammutVisionConfig):
753
+ super().__init__(config)
754
+ self.config = config
755
+ self.vision_model = MammutVisionTransformer(config)
756
+ self.post_init()
757
+
758
+
759
+ @dataclass
760
+ class MammutContrastiveOutput(CLIPOutput):
761
+ """
762
+ Output class for Mammut model in contrastive learning mode.
763
+ Contains contrastive output:
764
+ - loss: Loss value if return_loss is True.
765
+ - logits_per_text: Logits for text inputs.
766
+ - logits_per_image: Logits for image inputs.
767
+ - text_embeds: Text embeddings.
768
+ - image_embeds: Image embeddings.
769
+ """
770
+
771
+ loss: Optional[torch.FloatTensor] = None
772
+ logits_per_text: Optional[torch.FloatTensor] = None
773
+ logits_per_image: Optional[torch.FloatTensor] = None
774
+ text_embeds: Optional[torch.FloatTensor] = None
775
+ image_embeds: Optional[torch.FloatTensor] = None
776
+
777
+ @dataclass
778
+ class MammutCaptioningOutput(ModelOutput):
779
+ """
780
+ Output class for Mammut captioning part.
781
+ Contains:
782
+ - last_hidden_state: Last hidden state of the text model.
783
+ - pooler_output: Pooler output of the text model.
784
+ - hidden_states: Hidden states from the text model.
785
+ - attentions: Attention weights from the text model.
786
+ - output_ids: Output tokens from the text model.
787
+ """
788
+
789
+ last_hidden_state: torch.FloatTensor = None
790
+ pooler_output: Optional[torch.FloatTensor] = None
791
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
792
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
793
+ output_ids: Optional[torch.Tensor] = None
794
+
795
+ @dataclass
796
+ class MammutOutput(ModelOutput):
797
+ """
798
+ Output class for Mammut model.
799
+ Contains contrastive output:
800
+ - loss: Loss value if return_loss is True.
801
+ - logits_per_text: Logits for text inputs.
802
+ - logits_per_image: Logits for image inputs.
803
+ - text_embeds: Text embeddings.
804
+ - image_embeds: Image embeddings.
805
+
806
+ Captioning output:
807
+ - text_model_output: Output from the text model.
808
+ - output_ids: Output tokens from the text model.
809
+ """
810
+
811
+ loss: Optional[torch.FloatTensor] = None
812
+ logits_per_text: Optional[torch.FloatTensor] = None
813
+ logits_per_image: Optional[torch.FloatTensor] = None
814
+ text_embeds: Optional[torch.FloatTensor] = None
815
+ image_embeds: Optional[torch.FloatTensor] = None
816
+ text_model_output: Optional[MammutCaptioningOutput] = None
817
+ output_ids: Optional[torch.Tensor] = None
818
+
819
+ # @dataclass
820
+ # class MammutGenerationOutput(GenerateDecoderOnlyOutput)
821
+
822
+
823
+ def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
824
+ """
825
+ This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
826
+ model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
827
+ """
828
+ square_tensor = torch.pow(tensor, 2)
829
+ sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
830
+ normed_tensor = torch.pow(sum_tensor, 0.5)
831
+ return normed_tensor
832
+
833
+ class MammutModel(CLIPPreTrainedModel):
834
+ """
835
+ Mammut model with text and vision encoders.
836
+ """
837
+
838
+ config_class = MammutConfig
839
+ base_model_prefix = "mammut"
840
+
841
+ def __init__(self, config: MammutConfig):
842
+ super().__init__(config)
843
+ self.config = config
844
+ self.text_model = MammutMultimodalTransformer(config.text_config, output_tokens=config.output_tokens)
845
+ vision_model = MammutVisionModel._from_config(config.vision_config)
846
+ self.vision_model = vision_model.vision_model
847
+ self.text_embed_dim = config.text_config.hidden_size
848
+ self.vision_embed_dim = config.vision_config.hidden_size
849
+ self.projection_dim = config.projection_dim
850
+ self.text_projection = self.text_model.text_projection
851
+ self.visual_projection = nn.Linear(
852
+ self.vision_embed_dim, self.projection_dim, bias=False
853
+ ) if self.projection_dim is not None else None
854
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
855
+
856
+
857
+ self.map_viz2txt_kv = nn.Parameter(torch.randn(
858
+ self.config.vision_config.width, self.config.text_config.width
859
+ ))
860
+
861
+ self.eos_token_id = self.config.text_config.eos_token_id
862
+ self.bos_token_id = self.config.text_config.bos_token_id
863
+ self.pad_token_id = self.config.text_config.pad_token_id
864
+ self.does_full_decoding = config.text_config.does_full_decoding
865
+ self.context_length = config.text_config.context_length
866
+ self.vocab_size = config.text_config.vocab_size
867
+ self.batch_first = config.text_config.batch_first
868
+
869
+
870
+ # Initialize weights and apply final processing
871
+ self.post_init()
872
+
873
+
874
+ def get_text_features(
875
+ self,
876
+ input_ids: Optional[torch.LongTensor] = None,
877
+ attention_mask: Optional[torch.Tensor] = None,
878
+ position_ids: Optional[torch.LongTensor] = None,
879
+ output_attentions: Optional[bool] = None,
880
+ output_hidden_states: Optional[bool] = None,
881
+ img_embs: Optional[torch.FloatTensor] = None,
882
+ ) -> torch.FloatTensor:
883
+ """
884
+ Get text features from the Mammut model.
885
+ """
886
+
887
+ text_model_output = self.text_model(
888
+ img_embs=img_embs,
889
+ text_embs=input_ids,
890
+ position_ids=position_ids,
891
+ output_attentions=output_attentions,
892
+ output_hidden_states=output_hidden_states,
893
+ )
894
+
895
+ text_embeds = text_model_output.last_hidden_state
896
+ text_embeds = self.text_model.final_layer_norm(text_embeds)
897
+ text_embeds = text_embeds.mean(1)
898
+ text_embeds = F.normalize(text_embeds, dim=-1)
899
+ return text_embeds
900
+
901
+ def get_image_features(
902
+ self,
903
+ pixel_values: Optional[torch.FloatTensor] = None,
904
+ output_attentions: Optional[bool] = None,
905
+ output_hidden_states: Optional[bool] = None,
906
+ normalize: bool = True,
907
+ ) -> torch.FloatTensor:
908
+ """
909
+ Get image features from the Mammut model.
910
+ """
911
+
912
+ vision_outputs: CLIPVisionModelOutput = self.vision_model(
913
+ pixel_values=pixel_values,
914
+ output_attentions=output_attentions,
915
+ output_hidden_states=output_hidden_states,
916
+ )
917
+
918
+
919
+ image_embeds = vision_outputs.pooler_output
920
+ if self.visual_projection is not None:
921
+ image_embeds = self.visual_projection(image_embeds)
922
+
923
+ image_embeds = F.normalize(image_embeds, dim=-1) if normalize else image_embeds
924
+ return image_embeds
925
+
926
+ def _contrastive_forward(
927
+ self,
928
+ input_ids: Optional[torch.LongTensor] = None,
929
+ pixel_values: Optional[torch.FloatTensor] = None,
930
+ attention_mask: Optional[torch.Tensor] = None,
931
+ position_ids: Optional[torch.LongTensor] = None,
932
+ return_loss: Optional[bool] = None,
933
+ output_attentions: Optional[bool] = None,
934
+ output_hidden_states: Optional[bool] = None,
935
+ interpolate_pos_encoding: bool = False,
936
+ output_tokens: Optional[bool] = None,
937
+ contrastive: Optional[bool] = False,
938
+ ) -> MammutContrastiveOutput:
939
+ """
940
+ Forward pass for the Mammut model in contrastive learning mode.
941
+ - **Two-pass learning:** to unify contrastive and next-token
942
+ prediction, we need to unify unconditional representation learning and token-conditioned next-token prediction objective.
943
+ - **First pass: contrastive task.** For the first pass, text features should not see image features (dual-encoder contrastive learner) but attend to all tokens at once to produce sequence-level representation. Cross-attention and causal masking is disabled.
944
+ - **Second pass: captioning task.** Using cross attention and causal masking learn caption generation task.
945
+
946
+ Return:
947
+ MammutContrastiveOutput: Contains contrastive output with logits, embeddings, and optional loss.
948
+ """
949
+
950
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
951
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
952
+ output_hidden_states = (
953
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
954
+ )
955
+
956
+ vision_outputs: CLIPVisionModelOutput = self.vision_model(
957
+ pixel_values=pixel_values,
958
+ output_attentions=output_attentions,
959
+ output_hidden_states=output_hidden_states,
960
+ interpolate_pos_encoding=interpolate_pos_encoding,
961
+ )
962
+
963
+ # text_model is MammutMultimodalTransformer, which handles text embeddings
964
+
965
+ text_outputs: MammutPoolingOutput = self.text_model(
966
+ img_embs=None, # No image embeddings in contrastive forward pass for text model
967
+ text_embs=input_ids,
968
+ output_tokens=output_tokens,
969
+ output_attentions=output_attentions,
970
+ output_hidden_states=output_hidden_states,
971
+ position_ids=position_ids,
972
+ )
973
+
974
+ image_embeds = vision_outputs.pooler_output
975
+ image_embeds = self.visual_projection(image_embeds)
976
+
977
+ text_embeds = text_outputs.pooler_output
978
+
979
+ pooled, tokens = text_global_pool(text_embeds, text=input_ids)
980
+
981
+ text_embeds = self.text_model.final_layer_norm(text_embeds)
982
+ text_embeds = text_embeds.mean(1)
983
+ tokens = self.text_projection(pooled)
984
+
985
+ # Normalize the embeddings
986
+ image_embeds = image_embeds / _get_vector_norm(image_embeds)
987
+ text_embeds = text_embeds / _get_vector_norm(text_embeds)
988
+
989
+ # cosine similarity as logits
990
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
991
+ logits_per_text = logits_per_text * self.logit_scale.exp().to(text_embeds.device)
992
+
993
+ logits_per_image = logits_per_text.t()
994
+
995
+ loss = None
996
+ return MammutContrastiveOutput(
997
+ loss=loss,
998
+ logits_per_text=logits_per_text,
999
+ logits_per_image=logits_per_image,
1000
+ text_embeds=text_embeds,
1001
+ image_embeds=image_embeds,
1002
+ )
1003
+
1004
+
1005
+ def _captioning_forward(
1006
+ self,
1007
+ input_ids: Optional[torch.LongTensor] = None,
1008
+ pixel_values: Optional[torch.FloatTensor] = None,
1009
+ image_embeds: Optional[torch.FloatTensor] = None,
1010
+ attention_mask: Optional[torch.Tensor] = None,
1011
+ position_ids: Optional[torch.LongTensor] = None,
1012
+ return_loss: Optional[bool] = None,
1013
+ output_attentions: Optional[bool] = None,
1014
+ output_hidden_states: Optional[bool] = None,
1015
+ interpolate_pos_encoding: bool = False,
1016
+ output_tokens: Optional[bool] = None,
1017
+ ) -> MammutCaptioningOutput:
1018
+ """
1019
+ Forward pass for the Mammut model in captioning mode.
1020
+
1021
+ Return:
1022
+ MammutCaptioningOutput: Contains captioning output with last hidden state, pooler output, hidden states, attentions, and output tokens.
1023
+ """
1024
+
1025
+ if pixel_values is None:
1026
+ raise ValueError("Pixel values must be provided for captioning.")
1027
+
1028
+ if input_ids is None:
1029
+ input_ids = torch.ones(
1030
+ (pixel_values.shape[0], self.context_length), dtype=torch.long, device=pixel_values.device
1031
+ ) * self.bos_token_id
1032
+
1033
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1034
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1035
+ output_hidden_states = (
1036
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1037
+ )
1038
+
1039
+ if image_embeds is None:
1040
+
1041
+ vision_outputs = self.vision_model(
1042
+ pixel_values=pixel_values,
1043
+ output_attentions=output_attentions,
1044
+ output_hidden_states=output_hidden_states,
1045
+ interpolate_pos_encoding=interpolate_pos_encoding,
1046
+ )
1047
+ image_embeds = vision_outputs.last_hidden_state
1048
+
1049
+
1050
+ image_embeds = image_embeds @ self.map_viz2txt_kv
1051
+
1052
+ text_model_output = self.text_model(
1053
+ img_embs=image_embeds, # Use image embeddings for captioning
1054
+ text_embs=input_ids,
1055
+ position_ids=position_ids,
1056
+ output_attentions=output_attentions,
1057
+ output_hidden_states=output_hidden_states,
1058
+ )
1059
+
1060
+ text_embeds = text_model_output.last_hidden_state
1061
+
1062
+ text_embeds = self.text_model.final_layer_norm(text_embeds)
1063
+ logits = self.text_projection(text_embeds)
1064
+
1065
+ if output_tokens:
1066
+
1067
+ return MammutCaptioningOutput(
1068
+ last_hidden_state=text_embeds,
1069
+ pooler_output=image_embeds, # Placeholder for pooler output
1070
+ output_ids=logits, # Output tokens from the text model
1071
+ )
1072
+
1073
+ return MammutCaptioningOutput(
1074
+ last_hidden_state=text_embeds,
1075
+ pooler_output=image_embeds, # Placeholder for pooler output
1076
+ output_ids=None, # No output tokens in this case
1077
+ )
1078
+
1079
+ def forward(
1080
+ self,
1081
+ input_ids: Optional[torch.LongTensor] = None,
1082
+ pixel_values: Optional[torch.FloatTensor] = None,
1083
+ attention_mask: Optional[torch.Tensor] = None,
1084
+ position_ids: Optional[torch.LongTensor] = None,
1085
+ return_loss: Optional[bool] = None,
1086
+ output_attentions: Optional[bool] = None,
1087
+ output_hidden_states: Optional[bool] = None,
1088
+ interpolate_pos_encoding: bool = False,
1089
+ output_tokens: Optional[bool] = False,
1090
+ contrastive_only: Optional[bool] = False,
1091
+ captioning_only: Optional[bool] = False,
1092
+ ) -> MammutOutput:
1093
+
1094
+ """
1095
+ Forward pass for the Mammut model.
1096
+ - **Two-pass learning:** to unify contrastive and next-token prediction, we need to unify unconditional representation learning and token-conditioned next-token prediction objective.
1097
+ - **First pass: contrastive task.** For the first pass, text features should not see image features (dual-encoder contrastive learner) but attend to all tokens at once to produce sequence-level representation. Cross-attention and causal masking is disabled.
1098
+ - **Second pass: captioning task.** Using cross attention and causal masking learn caption generation task.
1099
+ """
1100
+
1101
+ # first pass: contrastive task
1102
+
1103
+
1104
+ # second pass: captioning task
1105
+ if pixel_values is None and input_ids is None:
1106
+ raise ValueError("Pixel values or input IDs must be provided for captioning.")
1107
+ if output_tokens is None:
1108
+ output_tokens = self.config.output_tokens
1109
+ if output_tokens and not self.config.output_tokens:
1110
+ raise ValueError("Output tokens are not enabled in the configuration.")
1111
+ if output_tokens and pixel_values is None:
1112
+ raise ValueError("Pixel values must be provided if output tokens are enabled.")
1113
+ if output_tokens and input_ids is None:
1114
+ # Only captioning
1115
+ captioning_only = True
1116
+
1117
+ if input_ids is not None and pixel_values is not None:
1118
+
1119
+ contrastive_output = self._contrastive_forward(
1120
+ input_ids=input_ids,
1121
+ pixel_values=pixel_values,
1122
+ output_attentions=output_attentions,
1123
+ output_hidden_states=output_hidden_states,
1124
+ interpolate_pos_encoding=interpolate_pos_encoding,
1125
+ )
1126
+ else:
1127
+ contrastive_output = MammutContrastiveOutput(
1128
+ loss=None,
1129
+ logits_per_text=None,
1130
+ logits_per_image=None,
1131
+ text_embeds=None,
1132
+ image_embeds=None,
1133
+ )
1134
+
1135
+ if contrastive_only:
1136
+ # If only contrastive output is needed, return it directly
1137
+ return MammutOutput(
1138
+ loss=contrastive_output.loss,
1139
+ logits_per_text=contrastive_output.logits_per_text,
1140
+ logits_per_image=contrastive_output.logits_per_image,
1141
+ text_embeds=contrastive_output.text_embeds,
1142
+ image_embeds=contrastive_output.image_embeds,
1143
+ )
1144
+
1145
+ if captioning_only:
1146
+ # If only captioning output is needed, return it directly
1147
+ text_model_output = self._captioning_forward(
1148
+ input_ids=input_ids,
1149
+ pixel_values=pixel_values, # No pixel values for captioning only
1150
+ attention_mask=attention_mask,
1151
+ position_ids=position_ids,
1152
+ output_attentions=output_attentions,
1153
+ output_hidden_states=output_hidden_states,
1154
+ interpolate_pos_encoding=interpolate_pos_encoding,
1155
+ output_tokens=output_tokens,
1156
+ )
1157
+ return MammutOutput(
1158
+ loss=None, # No loss in captioning only mode
1159
+ logits_per_text=None, # No logits in captioning only mode
1160
+ logits_per_image=None, # No logits in captioning only mode
1161
+ text_embeds=text_model_output.last_hidden_state, # Use last hidden state as text embeddings
1162
+ image_embeds=None, # No image embeddings in captioning only mode
1163
+ text_model_output=text_model_output, # Output from the text model
1164
+ output_ids=text_model_output.output_ids, # Output tokens from the text model
1165
+ )
1166
+
1167
+ # If both contrastive and captioning outputs are needed, return both
1168
+ text_model_output = self._captioning_forward(
1169
+ input_ids=input_ids,
1170
+ pixel_values=pixel_values, # No pixel values for captioning only
1171
+ attention_mask=attention_mask,
1172
+ position_ids=position_ids,
1173
+ output_attentions=output_attentions,
1174
+ output_hidden_states=output_hidden_states,
1175
+ interpolate_pos_encoding=interpolate_pos_encoding,
1176
+ output_tokens=output_tokens,
1177
+ )
1178
+ return MammutOutput(
1179
+ loss=contrastive_output.loss,
1180
+ logits_per_text=contrastive_output.logits_per_text,
1181
+ logits_per_image=contrastive_output.logits_per_image,
1182
+ text_embeds=contrastive_output.text_embeds,
1183
+ image_embeds=contrastive_output.image_embeds,
1184
+ text_model_output=text_model_output, # Output from the text model
1185
+ output_ids=text_model_output.output_ids, # Output tokens from the text model
1186
+ )
1187
+
1188
+ @torch.no_grad()
1189
+ def generate(
1190
+ self,
1191
+ input_ids: Optional[torch.LongTensor] = None,
1192
+ pixel_values: Optional[torch.FloatTensor] = None,
1193
+ attention_mask: Optional[torch.Tensor] = None,
1194
+ position_ids: Optional[torch.LongTensor] = None,
1195
+ max_new_tokens: int = 20,
1196
+ do_sample: bool = False,
1197
+ temperature: float = 1.0,
1198
+ repetition_penalty: float = 1.0,
1199
+ top_p: float = 0,
1200
+ top_k: int = 0,
1201
+ min_seq_len: int = 1,
1202
+ stopping_criteria= None,
1203
+ ) -> GenerateDecoderOnlyOutput:
1204
+ """
1205
+ Generate captions using the Mammut model.
1206
+
1207
+ Args:
1208
+ input_ids (torch.LongTensor, optional): Input token IDs for the text model.
1209
+ pixel_values (torch.FloatTensor, optional): Pixel values for the vision model.
1210
+ attention_mask (torch.Tensor, optional): Attention mask for the text model.
1211
+ position_ids (torch.LongTensor, optional): Position IDs for the text model.
1212
+ max_new_tokens (int): Maximum length of the generated sequence.
1213
+ do_sample (bool): Whether to sample from the distribution or take argmax.
1214
+ temperature (float): Temperature for sampling.
1215
+ repetition_penalty (float): Penalty for repetition in sampling.
1216
+ top_p (float): Top-p sampling parameter.
1217
+ top_k (int): Top-k sampling parameter.
1218
+ min_seq_len (int): Minimum sequence length for generation.
1219
+ stopping_criteria: Stopping criteria for generation.
1220
+ Returns:
1221
+ GenerateDecoderOnlyOutput: Contains the generated sequences and logits.
1222
+ """
1223
+ # This method should implement the generation logic for the Mammut model.
1224
+
1225
+ if input_ids is None and pixel_values is None:
1226
+ raise ValueError("Input IDs or pixel values must be provided for generation.")
1227
+ if input_ids is None:
1228
+ input_ids = torch.ones(
1229
+ (pixel_values.shape[0], 1), dtype=torch.long, device=pixel_values.device
1230
+ ) * self.bos_token_id
1231
+ if pixel_values is None:
1232
+ raise ValueError("Pixel values must be provided for generation.")
1233
+
1234
+ self.eval()
1235
+ device = pixel_values.device if pixel_values is not None else input_ids.device
1236
+ if input_ids is None:
1237
+ input_ids = torch.ones(
1238
+ (pixel_values.shape[0], 1), dtype=torch.long, device=device
1239
+ ) * self.bos_token_id
1240
+
1241
+ eos_token_id = self.eos_token_id if self.eos_token_id is not None else self.text_model.config.eos_token_id
1242
+
1243
+ logit_processor = LogitsProcessorList(
1244
+ [
1245
+ MinLengthLogitsProcessor(min_seq_len, eos_token_id),
1246
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
1247
+ ]
1248
+ )
1249
+
1250
+ if do_sample:
1251
+ if top_k > 0:
1252
+ logit_warper = LogitsProcessorList(
1253
+ [
1254
+ TopKLogitsWarper(top_k),
1255
+ ]
1256
+ )
1257
+ if top_p > 0:
1258
+ logit_warper = LogitsProcessorList(
1259
+ [
1260
+ TopPLogitsWarper(top_p),
1261
+ ]
1262
+ )
1263
+ if stopping_criteria is None:
1264
+ stopping_criteria = [MaxLengthCriteria(max_new_tokens)]
1265
+
1266
+ stopping_criteria = StoppingCriteriaList(
1267
+ stopping_criteria
1268
+ )
1269
+
1270
+ out = input_ids
1271
+
1272
+ vision_outputs = self.vision_model(
1273
+ pixel_values=pixel_values
1274
+ )
1275
+ image_embeds = vision_outputs.last_hidden_state
1276
+ with torch.no_grad():
1277
+ while True:
1278
+
1279
+ x = out[:, -max_new_tokens:]
1280
+ # Get text features
1281
+ captioning_output = self._captioning_forward(
1282
+ input_ids=x,
1283
+ pixel_values=pixel_values,
1284
+ image_embeds=image_embeds,
1285
+ attention_mask=attention_mask,
1286
+ position_ids=position_ids,
1287
+ output_attentions=False,
1288
+ output_hidden_states=False,
1289
+ interpolate_pos_encoding=False,
1290
+ output_tokens=True, # We want the output tokens
1291
+ )
1292
+
1293
+
1294
+ output_ids = captioning_output.output_ids
1295
+
1296
+ # Get logits for the next token
1297
+ logits = output_ids[:, -1]
1298
+ mask = (out[:, -1] == eos_token_id) | (out[:, -1] == self.pad_token_id)
1299
+
1300
+
1301
+ logits = logits[~mask, :]
1302
+
1303
+ filtered_logits = logit_processor(x[~mask, :], logits)
1304
+ filtered_logits = logit_warper(x[~mask, :], filtered_logits)
1305
+
1306
+
1307
+ # Sample or take the argmax of the logits
1308
+ cur_len = out.shape[1]
1309
+
1310
+ if cur_len >= max_new_tokens:
1311
+ next_token = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
1312
+ elif do_sample:
1313
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
1314
+ next_token = torch.multinomial(probs, num_samples=1)
1315
+ else:
1316
+ next_token = torch.argmax(filtered_logits, dim=-1, keepdim=True)
1317
+
1318
+ if mask.all():
1319
+ break
1320
+
1321
+ # Check if we have reached the end of the sequence or max length
1322
+ if (out.shape[1] >= max_new_tokens) or (next_token == eos_token_id).all():
1323
+ break
1324
+
1325
+
1326
+ # Append the next token to the output sequence
1327
+ out = torch.cat([out, next_token], dim=1)
1328
+
1329
+
1330
+ output_ids = out.long() if out.dtype != torch.long else out
1331
+
1332
+ # If we reach the end of the sequence or max length, break the loop
1333
+ return GenerateDecoderOnlyOutput(
1334
+ logits=logits,
1335
+ sequences=output_ids, # Output tokens from the text model
1336
+ )
1337
+
1338
+ AutoModel.register(MammutConfig, MammutModel)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f2d284ded5f643a6976af9ab4fa9940e60fe825553b2101e63550fb2d5d6c88
3
+ size 2033381111