|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch MaMMUT model.""" |
|
|
|
|
|
from typing import Callable, List, Optional, Tuple, Union |
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from .configuration_mammut import MammutTextConfig, MammutVisionConfig, MammutConfig |
|
from transformers.models.clip.modeling_clip import ( |
|
CLIPAttention, |
|
CLIPMLP, |
|
CLIPEncoderLayer, |
|
CLIPTextModel, |
|
CLIPVisionModel, |
|
CLIPVisionModelOutput, |
|
CLIPVisionTransformer, |
|
CLIPTextModelOutput, |
|
CLIPOutput, |
|
CLIPModel, |
|
CLIPPreTrainedModel, |
|
CLIPVisionEmbeddings, |
|
CLIPEncoder, |
|
eager_attention_forward |
|
) |
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput |
|
from transformers.generation import GenerateDecoderOnlyOutput |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union |
|
from transformers import AutoModel |
|
import logging |
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
|
from transformers import ( |
|
BeamSearchScorer, |
|
LogitsProcessorList, |
|
TopPLogitsWarper, |
|
TopKLogitsWarper, |
|
RepetitionPenaltyLogitsProcessor, |
|
MinLengthLogitsProcessor, |
|
MaxLengthCriteria, |
|
StoppingCriteriaList |
|
) |
|
|
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
class MammutCrossAttnLayer(nn.Module): |
|
def __init__(self, config: MammutTextConfig): |
|
super().__init__() |
|
self.embed_dim = config.hidden_size |
|
self.self_attn = MammutAttention(config) |
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
self.mlp = CLIPMLP(config) |
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
self.layer_norm1_kv = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
k_x: Optional[torch.Tensor] = None, |
|
v_x: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
causal_attention_mask: Optional[torch.Tensor] = None, |
|
print0_hidden_states: bool = False, |
|
) -> torch.Tensor: |
|
residual = hidden_states |
|
hidden_states = self.layer_norm1(hidden_states) |
|
|
|
if k_x is not None and v_x is not None: |
|
k_x = self.layer_norm1_kv(k_x) |
|
v_x = self.layer_norm1_kv(v_x) |
|
hidden_states, attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
causal_attention_mask=causal_attention_mask, |
|
keys=k_x, |
|
values=v_x, |
|
print0_hidden_states=print0_hidden_states, |
|
) |
|
|
|
hidden_states = hidden_states.permute(1, 0, 2) |
|
|
|
|
|
hidden_states = residual + hidden_states |
|
residual = hidden_states |
|
hidden_states = self.layer_norm2(hidden_states) |
|
hidden_states = self.mlp(hidden_states) |
|
hidden_states = residual + hidden_states |
|
return hidden_states |
|
|
|
|
|
class LayerScale(nn.Module): |
|
def __init__(self, dim, init_values=1e-5, inplace=False): |
|
super().__init__() |
|
self.inplace = inplace |
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
|
class MammutAttention(CLIPAttention): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__(self, config: Union[MammutTextConfig, MammutVisionConfig]): |
|
super().__init__(config) |
|
self.config = config |
|
self.embed_dim = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.embed_dim // self.num_heads |
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
raise ValueError( |
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
|
f" {self.num_heads})." |
|
) |
|
self.scale = self.head_dim**-0.5 |
|
|
|
self.dropout = config.attention_dropout |
|
self.is_causal = False |
|
|
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
self.training = False |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
causal_attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
keys: Optional[torch.Tensor] = None, |
|
values: Optional[torch.Tensor] = None, |
|
print0_hidden_states: bool = False, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
batch_size, seq_length, embed_dim = hidden_states.shape |
|
|
|
if keys is None and values is None: |
|
keys = hidden_states |
|
values = hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_output, attn_weights = F.multi_head_attention_forward( |
|
query=hidden_states.permute(1, 0, 2), |
|
key=keys.permute(1, 0, 2) if keys is not None else hidden_states.permute(1, 0, 2), |
|
value=values.permute(1, 0, 2) if values is not None else hidden_states.permute(1, 0, 2), |
|
embed_dim_to_check=embed_dim, |
|
num_heads=self.num_heads, |
|
in_proj_weight=torch.cat( |
|
[self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0 |
|
), |
|
in_proj_bias=torch.cat( |
|
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias], dim=0 |
|
) if self.q_proj.bias is not None else None, |
|
bias_k=None, |
|
bias_v=None, |
|
add_zero_attn=False, |
|
attn_mask=attention_mask, |
|
q_proj_weight=self.q_proj.weight, |
|
k_proj_weight=self.k_proj.weight, |
|
v_proj_weight=self.v_proj.weight, |
|
is_causal=self.is_causal, |
|
dropout_p=0.0 if not self.training else self.dropout, |
|
out_proj_weight=self.out_proj.weight, |
|
out_proj_bias=self.out_proj.bias, |
|
training=self.training, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
return attn_output, attn_weights |
|
|
|
class MammutEncoderLayer(CLIPEncoderLayer): |
|
def __init__(self, config: MammutTextConfig, has_mlp: bool = True): |
|
super().__init__(config) |
|
self.embed_dim = config.hidden_size |
|
self.self_attn = MammutAttention(config) |
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
self.mlp = CLIPMLP(config) if has_mlp else None |
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
|
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
causal_attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
print_hidden_states: bool = False, |
|
) -> Tuple[torch.FloatTensor]: |
|
""" |
|
Forward pass for the encoder layer. |
|
Args: |
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
causal_attention_mask (`torch.FloatTensor`, *optional*): causal attention mask of size |
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
output_attentions (`bool`, *optional*): |
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
returned tensors for more detail. |
|
""" |
|
|
|
residual = hidden_states |
|
hidden_states = self.layer_norm1(hidden_states) |
|
|
|
|
|
hidden_states, attn_weights = self.self_attn( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
causal_attention_mask=None, |
|
output_attentions=output_attentions, |
|
print0_hidden_states=print_hidden_states, |
|
) |
|
|
|
hidden_states = hidden_states.permute(1, 0, 2) |
|
|
|
|
|
hidden_states = residual + hidden_states |
|
|
|
residual = hidden_states |
|
hidden_states = self.layer_norm2(hidden_states) |
|
|
|
hidden_states = self.mlp(hidden_states) if self.mlp is not None else hidden_states |
|
hidden_states = residual + hidden_states |
|
return hidden_states |
|
|
|
|
|
class MammutMultimodalEncoder(nn.Module): |
|
does_full_decoding: torch.jit.Final[bool] |
|
|
|
def __init__( |
|
self, |
|
config: MammutConfig, |
|
): |
|
|
|
super().__init__() |
|
|
|
self.config = config |
|
|
|
self.n_cross_attn, _ = divmod(config.num_hidden_layers, config.cross_attn_ratio) |
|
self.cross_step, _ = divmod(config.num_hidden_layers, self.n_cross_attn) |
|
self.does_full_decoding = config.does_full_decoding |
|
self.output_tokens = config.output_tokens |
|
self.batch_first = config.batch_first |
|
self.context_length = config.max_position_embeddings |
|
self.layers = nn.ModuleList([]) |
|
self.cross_attn = nn.ModuleList([]) |
|
num_cross_attn = 0 |
|
for l_idx in range(config.num_hidden_layers): |
|
_, r = divmod(l_idx, self.cross_step) |
|
has_cross_attn = r == 0 |
|
layer = MammutEncoderLayer(config) |
|
self.layers.append(layer) |
|
if has_cross_attn: |
|
num_cross_attn += 1 |
|
cross_attn_layer = MammutCrossAttnLayer(config) |
|
self.cross_attn.append(cross_attn_layer) |
|
|
|
|
|
def forward( |
|
self, |
|
text_embeds: torch.Tensor, |
|
img_embeds: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
causal_attention_mask: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
) -> Union[BaseModelOutput, Tuple[torch.Tensor]]: |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
encoder_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
hidden_states = text_embeds |
|
|
|
seq_len = hidden_states.shape[1] if self.batch_first else hidden_states.shape[0] |
|
|
|
if causal_attention_mask is None: |
|
causal_attention_mask = self.build_causal_mask() |
|
else: |
|
causal_attention_mask = causal_attention_mask.to(dtype=hidden_states.dtype) |
|
|
|
if attention_mask is None: |
|
attention_mask = causal_attention_mask |
|
else: |
|
attention_mask = attention_mask + causal_attention_mask |
|
|
|
|
|
if img_embeds is not None: |
|
img_embeds = img_embeds.to(dtype=hidden_states.dtype) |
|
k_x = img_embeds |
|
v_x = img_embeds |
|
else: |
|
k_x = None |
|
v_x = None |
|
|
|
if img_embeds is not None: |
|
attention_mask = attention_mask[:seq_len, :seq_len] |
|
|
|
for i, layer in enumerate(self.layers): |
|
|
|
|
|
cross_attn_idx, r = divmod(i, self.cross_step) |
|
|
|
has_cross_attn = r == 0 and img_embeds is not None |
|
if i == 0: |
|
print_hidden_states = True |
|
else: |
|
print_hidden_states = False |
|
|
|
|
|
hidden_states = layer( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask if img_embeds is not None else None, |
|
causal_attention_mask=None, |
|
output_attentions=output_attentions, |
|
print_hidden_states=print_hidden_states, |
|
) |
|
|
|
if has_cross_attn: |
|
cross_attn = self.cross_attn[cross_attn_idx] |
|
|
|
|
|
hidden_states = cross_attn( |
|
hidden_states=hidden_states, |
|
k_x=k_x, |
|
v_x=v_x, |
|
print0_hidden_states=i== 0, |
|
|
|
|
|
) |
|
|
|
|
|
if output_hidden_states: |
|
encoder_states = tuple(encoder_states) |
|
if self.does_full_decoding: |
|
encoder_states = encoder_states[:self.n_cross_attn + 1] |
|
else: |
|
encoder_states = encoder_states[:self.config.text_config.num_hidden_layers] |
|
else: |
|
encoder_states = None |
|
|
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=encoder_states, |
|
attentions=all_attentions, |
|
) |
|
|
|
def build_causal_mask(self): |
|
|
|
|
|
mask = torch.empty(self.context_length, self.context_length) |
|
mask.fill_(float("-inf")) |
|
mask.triu_(1) |
|
return mask |
|
|
|
|
|
def build_attn_mask(self): |
|
|
|
|
|
mask = torch.empty(self.context_length, self.context_length) |
|
mask.fill_(float("-inf")) |
|
mask.triu_(1) |
|
return mask |
|
|
|
|
|
@dataclass |
|
class MammutPoolingOutput(BaseModelOutputWithPooling): |
|
""" |
|
Base class for outputs of the Mammut model. |
|
""" |
|
|
|
last_hidden_state: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
output_ids: Optional[torch.Tensor] = None |
|
pooler_output: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class MammutMultimodalEmbeddings(nn.Module): |
|
def __init__(self, config: MammutTextConfig): |
|
super().__init__() |
|
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
|
self.position_embedding = nn.Embedding( |
|
config.max_position_embeddings, config.hidden_size |
|
) |
|
self.register_buffer( |
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False |
|
) |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
) -> torch.Tensor: |
|
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] |
|
max_position_embedding = self.position_embedding.weight.shape[0] |
|
|
|
if seq_length > max_position_embedding: |
|
raise ValueError( |
|
f"Sequence length must be less than max_position_embeddings (got `sequence length`: " |
|
f"{seq_length} and max_position_embeddings: {max_position_embedding}" |
|
) |
|
|
|
if position_ids is None: |
|
position_ids = self.position_ids[:, :seq_length] |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.token_embedding(input_ids) |
|
|
|
position_embeddings = self.position_embedding(position_ids) |
|
embeddings = inputs_embeds + position_embeddings |
|
|
|
return embeddings |
|
|
|
|
|
def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): |
|
if pool_type == 'first': |
|
pooled, tokens = x[:, 0], x[:, 1:] |
|
elif pool_type == 'last': |
|
pooled, tokens = x[:, -1], x[:, :-1] |
|
elif pool_type == 'argmax': |
|
|
|
assert text is not None |
|
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x |
|
else: |
|
pooled = tokens = x |
|
|
|
return pooled, tokens |
|
|
|
|
|
class MammutMultimodalTransformer(nn.Module): |
|
def __init__(self, config: MammutTextConfig, output_tokens=True): |
|
super().__init__() |
|
self.config = config |
|
embed_dim = config.hidden_size |
|
self.encoder = MammutMultimodalEncoder(config) |
|
self.text_projection = nn.Linear( |
|
config.hidden_size, config.vocab_size, bias=False |
|
) if config.hidden_size is not None else None |
|
self.final_layer_norm = nn.LayerNorm( |
|
embed_dim, eps=config.layer_norm_eps |
|
) |
|
|
|
|
|
self.does_full_decoding = config.does_full_decoding |
|
self.context_length = config.context_length |
|
self.vocab_size = config.vocab_size |
|
width = config.hidden_size |
|
self.batch_first = config.batch_first |
|
self.has_mlp = config.has_mlp |
|
self.cross_attn_ratio = config.cross_attn_ratio |
|
self.cross_step = config.cross_attn_ratio |
|
self.n_cross_attn = config.num_hidden_layers // config.cross_attn_ratio |
|
vocab_size = config.vocab_size |
|
self.output_tokens = output_tokens |
|
|
|
if self.does_full_decoding: |
|
self.num_pos = self.context_length |
|
self.embeddings = MammutMultimodalEmbeddings(config) |
|
else: |
|
self.num_pos = None |
|
self.embeddings = None |
|
|
|
def init_weights(self): |
|
|
|
self.final_layer_norm.weight.data.fill_(1.0) |
|
self.final_layer_norm.bias.data.zero_() |
|
log.info("MammutMultimodalTransformer weights initialized.") |
|
|
|
def forward( |
|
self, |
|
img_embs: torch.Tensor, |
|
text_embs: Optional[torch.Tensor] = None, |
|
output_tokens: Optional[bool] = False, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
) -> Union[CLIPVisionModelOutput, CLIPTextModelOutput]: |
|
|
|
|
|
if text_embs is not None: |
|
if self.embeddings is not None: |
|
|
|
text_embs = self.embeddings( |
|
input_ids=text_embs, |
|
position_ids=position_ids, |
|
|
|
) |
|
|
|
|
|
if self.does_full_decoding: |
|
text_embs = text_embs[:, :self.context_length, :] |
|
|
|
|
|
text_embs = self.encoder( |
|
text_embeds=text_embs, |
|
img_embeds=img_embs, |
|
attention_mask=None, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
text_embs = text_embs.last_hidden_state |
|
|
|
if self.does_full_decoding: |
|
text_embs = text_embs[:, :self.context_length, :] |
|
else: |
|
text_embs = text_embs[:, 0, :] |
|
|
|
|
|
if self.text_projection is not None: |
|
output_ids = self.text_projection(text_embs) |
|
else: |
|
output_ids = text_embs |
|
|
|
if output_tokens: |
|
return MammutPoolingOutput( |
|
last_hidden_state=text_embs, |
|
hidden_states=None, |
|
attentions=None, |
|
output_ids=output_ids, |
|
pooler_output=text_embs, |
|
) |
|
|
|
return MammutPoolingOutput( |
|
last_hidden_state=text_embs, |
|
pooler_output=text_embs, |
|
hidden_states=None, |
|
attentions=None, |
|
) |
|
|
|
|
|
def build_causal_mask(self, seq_len: Optional[int] = None, device: Optional[torch.device] = None) -> torch.Tensor: |
|
if seq_len is None: |
|
seq_len = self.context_length if self.does_full_decoding else self.config.context_length |
|
if device is None: |
|
device = torch.device("cpu") |
|
mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).view(1, 1, seq_len, seq_len) |
|
return mask |
|
|
|
def build_attn_mask(self): |
|
|
|
|
|
mask = torch.empty(self.context_length, self.context_length) |
|
mask.fill_(float("-inf")) |
|
mask.triu_(1) |
|
return mask |
|
|
|
class MammutMultimodalModel(CLIPTextModel): |
|
""" |
|
Mammut multimodal model with text and vision encoders. |
|
""" |
|
|
|
config_class = MammutTextConfig |
|
base_model_prefix = "mammut_multimodal" |
|
|
|
def __init__(self, config: MammutTextConfig): |
|
super().__init__(config) |
|
self.config = config.text_config |
|
self.text_model = MammutMultimodalTransformer(config.text_config) |
|
self.text_embed_dim = config.hidden_size |
|
self.vision_embed_dim = config.vision_config.hidden_size |
|
self.projection_dim = config.projection_dim |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
image_embs: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_tokens: Optional[bool] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
) -> Union[MammutPoolingOutput, CLIPTextModelOutput]: |
|
|
|
return self.text_model( |
|
img_embs=image_embs, |
|
text_embs=input_ids, |
|
output_tokens=output_tokens, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
position_ids=position_ids, |
|
) |
|
|
|
|
|
class MammutVisionTransformer(CLIPVisionTransformer): |
|
""" |
|
Mammut Vision Transformer model. |
|
Inherits from CLIPVisionTransformer and initializes the vision model. |
|
""" |
|
|
|
config_class = MammutVisionConfig |
|
base_model_prefix = "mammut_vision" |
|
|
|
def __init__(self, config: MammutVisionConfig): |
|
super().__init__(config) |
|
self.config = config |
|
embed_dim = config.hidden_size |
|
|
|
self.embeddings = CLIPVisionEmbeddings(config) |
|
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
|
self.encoder = CLIPEncoder(config) |
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
|
self.pool_type = config.pool_type |
|
|
|
|
|
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if self.pool_type == 'avg': |
|
pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] |
|
elif self.pool_type == 'tok': |
|
pooled, tokens = x[:, 0], x[:, 1:] |
|
elif self.pool_type == "avg_all": |
|
pooled, tokens = x.mean(dim=1), x |
|
else: |
|
pooled = tokens = x |
|
|
|
return pooled, tokens |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
interpolate_pos_encoding: Optional[bool] = False, |
|
) -> BaseModelOutputWithPooling: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
if pixel_values is None: |
|
raise ValueError("You have to specify pixel_values") |
|
|
|
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) |
|
hidden_states = self.pre_layrnorm(hidden_states) |
|
|
|
encoder_outputs: BaseModelOutput = self.encoder( |
|
inputs_embeds=hidden_states, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
last_hidden_state = encoder_outputs.last_hidden_state |
|
pooled_output = last_hidden_state[:, 0, :] |
|
if self.config.final_ln_after_pool: |
|
pooled, _ = self._global_pool(last_hidden_state) |
|
pooled_output = self.post_layernorm(pooled) |
|
else: |
|
pooled_output = self.post_layernorm(pooled_output) |
|
pooled, _ = self._global_pool(pooled_output) |
|
pooled_output = pooled |
|
|
|
return BaseModelOutputWithPooling( |
|
last_hidden_state=last_hidden_state, |
|
pooler_output=pooled_output, |
|
hidden_states=encoder_outputs.hidden_states, |
|
attentions=encoder_outputs.attentions, |
|
) |
|
|
|
class MammutVisionModel(CLIPVisionModel): |
|
""" |
|
Mammut Vision Model. |
|
Inherits from CLIPVisionModel and initializes the vision model. |
|
""" |
|
|
|
config_class = MammutVisionConfig |
|
base_model_prefix = "mammut_vision" |
|
|
|
def __init__(self, config: MammutVisionConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.vision_model = MammutVisionTransformer(config) |
|
self.post_init() |
|
|
|
|
|
@dataclass |
|
class MammutContrastiveOutput(CLIPOutput): |
|
""" |
|
Output class for Mammut model in contrastive learning mode. |
|
Contains contrastive output: |
|
- loss: Loss value if return_loss is True. |
|
- logits_per_text: Logits for text inputs. |
|
- logits_per_image: Logits for image inputs. |
|
- text_embeds: Text embeddings. |
|
- image_embeds: Image embeddings. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits_per_text: Optional[torch.FloatTensor] = None |
|
logits_per_image: Optional[torch.FloatTensor] = None |
|
text_embeds: Optional[torch.FloatTensor] = None |
|
image_embeds: Optional[torch.FloatTensor] = None |
|
|
|
@dataclass |
|
class MammutCaptioningOutput(ModelOutput): |
|
""" |
|
Output class for Mammut captioning part. |
|
Contains: |
|
- last_hidden_state: Last hidden state of the text model. |
|
- pooler_output: Pooler output of the text model. |
|
- hidden_states: Hidden states from the text model. |
|
- attentions: Attention weights from the text model. |
|
- output_ids: Output tokens from the text model. |
|
""" |
|
|
|
last_hidden_state: torch.FloatTensor = None |
|
pooler_output: Optional[torch.FloatTensor] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
output_ids: Optional[torch.Tensor] = None |
|
|
|
@dataclass |
|
class MammutOutput(ModelOutput): |
|
""" |
|
Output class for Mammut model. |
|
Contains contrastive output: |
|
- loss: Loss value if return_loss is True. |
|
- logits_per_text: Logits for text inputs. |
|
- logits_per_image: Logits for image inputs. |
|
- text_embeds: Text embeddings. |
|
- image_embeds: Image embeddings. |
|
|
|
Captioning output: |
|
- text_model_output: Output from the text model. |
|
- output_ids: Output tokens from the text model. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
logits_per_text: Optional[torch.FloatTensor] = None |
|
logits_per_image: Optional[torch.FloatTensor] = None |
|
text_embeds: Optional[torch.FloatTensor] = None |
|
image_embeds: Optional[torch.FloatTensor] = None |
|
text_model_output: Optional[MammutCaptioningOutput] = None |
|
output_ids: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
|
|
|
def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor: |
|
""" |
|
This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make |
|
model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566 |
|
""" |
|
square_tensor = torch.pow(tensor, 2) |
|
sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True) |
|
normed_tensor = torch.pow(sum_tensor, 0.5) |
|
return normed_tensor |
|
|
|
class MammutModel(CLIPPreTrainedModel): |
|
""" |
|
Mammut model with text and vision encoders. |
|
""" |
|
|
|
config_class = MammutConfig |
|
base_model_prefix = "mammut" |
|
|
|
def __init__(self, config: MammutConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.text_model = MammutMultimodalTransformer(config.text_config, output_tokens=config.output_tokens) |
|
vision_model = MammutVisionModel._from_config(config.vision_config) |
|
self.vision_model = vision_model.vision_model |
|
self.text_embed_dim = config.text_config.hidden_size |
|
self.vision_embed_dim = config.vision_config.hidden_size |
|
self.projection_dim = config.projection_dim |
|
self.text_projection = self.text_model.text_projection |
|
self.visual_projection = nn.Linear( |
|
self.vision_embed_dim, self.projection_dim, bias=False |
|
) if self.projection_dim is not None else None |
|
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) |
|
|
|
|
|
self.map_viz2txt_kv = nn.Parameter(torch.randn( |
|
self.config.vision_config.width, self.config.text_config.width |
|
)) |
|
|
|
self.eos_token_id = self.config.text_config.eos_token_id |
|
self.bos_token_id = self.config.text_config.bos_token_id |
|
self.pad_token_id = self.config.text_config.pad_token_id |
|
self.does_full_decoding = config.text_config.does_full_decoding |
|
self.context_length = config.text_config.context_length |
|
self.vocab_size = config.text_config.vocab_size |
|
self.batch_first = config.text_config.batch_first |
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_text_features( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
img_embs: Optional[torch.FloatTensor] = None, |
|
) -> torch.FloatTensor: |
|
""" |
|
Get text features from the Mammut model. |
|
""" |
|
|
|
text_model_output = self.text_model( |
|
img_embs=img_embs, |
|
text_embs=input_ids, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
text_embeds = text_model_output.last_hidden_state |
|
text_embeds = self.text_model.final_layer_norm(text_embeds) |
|
text_embeds = text_embeds.mean(1) |
|
text_embeds = F.normalize(text_embeds, dim=-1) |
|
return text_embeds |
|
|
|
def get_image_features( |
|
self, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
normalize: bool = True, |
|
) -> torch.FloatTensor: |
|
""" |
|
Get image features from the Mammut model. |
|
""" |
|
|
|
vision_outputs: CLIPVisionModelOutput = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
|
|
image_embeds = vision_outputs.pooler_output |
|
if self.visual_projection is not None: |
|
image_embeds = self.visual_projection(image_embeds) |
|
|
|
image_embeds = F.normalize(image_embeds, dim=-1) if normalize else image_embeds |
|
return image_embeds |
|
|
|
def _contrastive_forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
return_loss: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
interpolate_pos_encoding: bool = False, |
|
output_tokens: Optional[bool] = None, |
|
contrastive: Optional[bool] = False, |
|
) -> MammutContrastiveOutput: |
|
""" |
|
Forward pass for the Mammut model in contrastive learning mode. |
|
- **Two-pass learning:** to unify contrastive and next-token |
|
prediction, we need to unify unconditional representation learning and token-conditioned next-token prediction objective. |
|
- **First pass: contrastive task.** For the first pass, text features should not see image features (dual-encoder contrastive learner) but attend to all tokens at once to produce sequence-level representation. Cross-attention and causal masking is disabled. |
|
- **Second pass: captioning task.** Using cross attention and causal masking learn caption generation task. |
|
|
|
Return: |
|
MammutContrastiveOutput: Contains contrastive output with logits, embeddings, and optional loss. |
|
""" |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
vision_outputs: CLIPVisionModelOutput = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
) |
|
|
|
|
|
|
|
text_outputs: MammutPoolingOutput = self.text_model( |
|
img_embs=None, |
|
text_embs=input_ids, |
|
output_tokens=output_tokens, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
position_ids=position_ids, |
|
) |
|
|
|
image_embeds = vision_outputs.pooler_output |
|
image_embeds = self.visual_projection(image_embeds) |
|
|
|
text_embeds = text_outputs.pooler_output |
|
|
|
pooled, tokens = text_global_pool(text_embeds, text=input_ids) |
|
|
|
text_embeds = self.text_model.final_layer_norm(text_embeds) |
|
text_embeds = text_embeds.mean(1) |
|
tokens = self.text_projection(pooled) |
|
|
|
|
|
image_embeds = image_embeds / _get_vector_norm(image_embeds) |
|
text_embeds = text_embeds / _get_vector_norm(text_embeds) |
|
|
|
|
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) |
|
logits_per_text = logits_per_text * self.logit_scale.exp().to(text_embeds.device) |
|
|
|
logits_per_image = logits_per_text.t() |
|
|
|
loss = None |
|
return MammutContrastiveOutput( |
|
loss=loss, |
|
logits_per_text=logits_per_text, |
|
logits_per_image=logits_per_image, |
|
text_embeds=text_embeds, |
|
image_embeds=image_embeds, |
|
) |
|
|
|
|
|
def _captioning_forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
image_embeds: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
return_loss: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
interpolate_pos_encoding: bool = False, |
|
output_tokens: Optional[bool] = None, |
|
) -> MammutCaptioningOutput: |
|
""" |
|
Forward pass for the Mammut model in captioning mode. |
|
|
|
Return: |
|
MammutCaptioningOutput: Contains captioning output with last hidden state, pooler output, hidden states, attentions, and output tokens. |
|
""" |
|
|
|
if pixel_values is None: |
|
raise ValueError("Pixel values must be provided for captioning.") |
|
|
|
if input_ids is None: |
|
input_ids = torch.ones( |
|
(pixel_values.shape[0], self.context_length), dtype=torch.long, device=pixel_values.device |
|
) * self.bos_token_id |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
|
|
if image_embeds is None: |
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
) |
|
image_embeds = vision_outputs.last_hidden_state |
|
|
|
|
|
image_embeds = image_embeds @ self.map_viz2txt_kv |
|
|
|
text_model_output = self.text_model( |
|
img_embs=image_embeds, |
|
text_embs=input_ids, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
text_embeds = text_model_output.last_hidden_state |
|
|
|
text_embeds = self.text_model.final_layer_norm(text_embeds) |
|
logits = self.text_projection(text_embeds) |
|
|
|
if output_tokens: |
|
|
|
return MammutCaptioningOutput( |
|
last_hidden_state=text_embeds, |
|
pooler_output=image_embeds, |
|
output_ids=logits, |
|
) |
|
|
|
return MammutCaptioningOutput( |
|
last_hidden_state=text_embeds, |
|
pooler_output=image_embeds, |
|
output_ids=None, |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
return_loss: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
interpolate_pos_encoding: bool = False, |
|
output_tokens: Optional[bool] = False, |
|
contrastive_only: Optional[bool] = False, |
|
captioning_only: Optional[bool] = False, |
|
) -> MammutOutput: |
|
|
|
""" |
|
Forward pass for the Mammut model. |
|
- **Two-pass learning:** to unify contrastive and next-token prediction, we need to unify unconditional representation learning and token-conditioned next-token prediction objective. |
|
- **First pass: contrastive task.** For the first pass, text features should not see image features (dual-encoder contrastive learner) but attend to all tokens at once to produce sequence-level representation. Cross-attention and causal masking is disabled. |
|
- **Second pass: captioning task.** Using cross attention and causal masking learn caption generation task. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if pixel_values is None and input_ids is None: |
|
raise ValueError("Pixel values or input IDs must be provided for captioning.") |
|
if output_tokens is None: |
|
output_tokens = self.config.output_tokens |
|
if output_tokens and not self.config.output_tokens: |
|
raise ValueError("Output tokens are not enabled in the configuration.") |
|
if output_tokens and pixel_values is None: |
|
raise ValueError("Pixel values must be provided if output tokens are enabled.") |
|
if output_tokens and input_ids is None: |
|
|
|
captioning_only = True |
|
|
|
if input_ids is not None and pixel_values is not None: |
|
|
|
contrastive_output = self._contrastive_forward( |
|
input_ids=input_ids, |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
) |
|
else: |
|
contrastive_output = MammutContrastiveOutput( |
|
loss=None, |
|
logits_per_text=None, |
|
logits_per_image=None, |
|
text_embeds=None, |
|
image_embeds=None, |
|
) |
|
|
|
if contrastive_only: |
|
|
|
return MammutOutput( |
|
loss=contrastive_output.loss, |
|
logits_per_text=contrastive_output.logits_per_text, |
|
logits_per_image=contrastive_output.logits_per_image, |
|
text_embeds=contrastive_output.text_embeds, |
|
image_embeds=contrastive_output.image_embeds, |
|
) |
|
|
|
if captioning_only: |
|
|
|
text_model_output = self._captioning_forward( |
|
input_ids=input_ids, |
|
pixel_values=pixel_values, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
output_tokens=output_tokens, |
|
) |
|
return MammutOutput( |
|
loss=None, |
|
logits_per_text=None, |
|
logits_per_image=None, |
|
text_embeds=text_model_output.last_hidden_state, |
|
image_embeds=None, |
|
text_model_output=text_model_output, |
|
output_ids=text_model_output.output_ids, |
|
) |
|
|
|
|
|
text_model_output = self._captioning_forward( |
|
input_ids=input_ids, |
|
pixel_values=pixel_values, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
output_tokens=output_tokens, |
|
) |
|
return MammutOutput( |
|
loss=contrastive_output.loss, |
|
logits_per_text=contrastive_output.logits_per_text, |
|
logits_per_image=contrastive_output.logits_per_image, |
|
text_embeds=contrastive_output.text_embeds, |
|
image_embeds=contrastive_output.image_embeds, |
|
text_model_output=text_model_output, |
|
output_ids=text_model_output.output_ids, |
|
) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
max_new_tokens: int = 20, |
|
do_sample: bool = False, |
|
temperature: float = 1.0, |
|
repetition_penalty: float = 1.0, |
|
top_p: float = 0, |
|
top_k: int = 0, |
|
min_seq_len: int = 1, |
|
stopping_criteria= None, |
|
) -> GenerateDecoderOnlyOutput: |
|
""" |
|
Generate captions using the Mammut model. |
|
|
|
Args: |
|
input_ids (torch.LongTensor, optional): Input token IDs for the text model. |
|
pixel_values (torch.FloatTensor, optional): Pixel values for the vision model. |
|
attention_mask (torch.Tensor, optional): Attention mask for the text model. |
|
position_ids (torch.LongTensor, optional): Position IDs for the text model. |
|
max_new_tokens (int): Maximum length of the generated sequence. |
|
do_sample (bool): Whether to sample from the distribution or take argmax. |
|
temperature (float): Temperature for sampling. |
|
repetition_penalty (float): Penalty for repetition in sampling. |
|
top_p (float): Top-p sampling parameter. |
|
top_k (int): Top-k sampling parameter. |
|
min_seq_len (int): Minimum sequence length for generation. |
|
stopping_criteria: Stopping criteria for generation. |
|
Returns: |
|
GenerateDecoderOnlyOutput: Contains the generated sequences and logits. |
|
""" |
|
|
|
|
|
if input_ids is None and pixel_values is None: |
|
raise ValueError("Input IDs or pixel values must be provided for generation.") |
|
if input_ids is None: |
|
input_ids = torch.ones( |
|
(pixel_values.shape[0], 1), dtype=torch.long, device=pixel_values.device |
|
) * self.bos_token_id |
|
if pixel_values is None: |
|
raise ValueError("Pixel values must be provided for generation.") |
|
|
|
self.eval() |
|
device = pixel_values.device if pixel_values is not None else input_ids.device |
|
if input_ids is None: |
|
input_ids = torch.ones( |
|
(pixel_values.shape[0], 1), dtype=torch.long, device=device |
|
) * self.bos_token_id |
|
|
|
eos_token_id = self.eos_token_id if self.eos_token_id is not None else self.text_model.config.eos_token_id |
|
|
|
logit_processor = LogitsProcessorList( |
|
[ |
|
MinLengthLogitsProcessor(min_seq_len, eos_token_id), |
|
RepetitionPenaltyLogitsProcessor(repetition_penalty), |
|
] |
|
) |
|
|
|
if do_sample: |
|
if top_k > 0: |
|
logit_warper = LogitsProcessorList( |
|
[ |
|
TopKLogitsWarper(top_k), |
|
] |
|
) |
|
if top_p > 0: |
|
logit_warper = LogitsProcessorList( |
|
[ |
|
TopPLogitsWarper(top_p), |
|
] |
|
) |
|
if stopping_criteria is None: |
|
stopping_criteria = [MaxLengthCriteria(max_new_tokens)] |
|
|
|
stopping_criteria = StoppingCriteriaList( |
|
stopping_criteria |
|
) |
|
|
|
out = input_ids |
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values |
|
) |
|
image_embeds = vision_outputs.last_hidden_state |
|
with torch.no_grad(): |
|
while True: |
|
|
|
x = out[:, -max_new_tokens:] |
|
|
|
captioning_output = self._captioning_forward( |
|
input_ids=x, |
|
pixel_values=pixel_values, |
|
image_embeds=image_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
interpolate_pos_encoding=False, |
|
output_tokens=True, |
|
) |
|
|
|
|
|
output_ids = captioning_output.output_ids |
|
|
|
|
|
logits = output_ids[:, -1] |
|
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == self.pad_token_id) |
|
|
|
|
|
logits = logits[~mask, :] |
|
|
|
filtered_logits = logit_processor(x[~mask, :], logits) |
|
filtered_logits = logit_warper(x[~mask, :], filtered_logits) |
|
|
|
|
|
|
|
cur_len = out.shape[1] |
|
|
|
if cur_len >= max_new_tokens: |
|
next_token = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id |
|
elif do_sample: |
|
probs = F.softmax(filtered_logits / temperature, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
else: |
|
next_token = torch.argmax(filtered_logits, dim=-1, keepdim=True) |
|
|
|
if mask.all(): |
|
break |
|
|
|
|
|
if (out.shape[1] >= max_new_tokens) or (next_token == eos_token_id).all(): |
|
break |
|
|
|
|
|
|
|
out = torch.cat([out, next_token], dim=1) |
|
|
|
|
|
output_ids = out.long() if out.dtype != torch.long else out |
|
|
|
|
|
return GenerateDecoderOnlyOutput( |
|
logits=logits, |
|
sequences=output_ids, |
|
) |
|
|
|
AutoModel.register(MammutConfig, MammutModel) |
|
|