marianna13's picture
add HF support
3af3aa0
# coding=utf-8
# Copyright 2024 Google AI, LAION team. team. All rights reserved.
#
# This code is based on open_clip framework. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to the original MaMMUT model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MaMMUT model."""
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from .configuration_mammut import MammutTextConfig, MammutVisionConfig, MammutConfig
from transformers.models.clip.modeling_clip import (
CLIPAttention,
CLIPMLP,
CLIPEncoderLayer,
CLIPTextModel,
CLIPVisionModel,
CLIPVisionModelOutput,
CLIPVisionTransformer,
CLIPTextModelOutput,
CLIPOutput,
CLIPModel,
CLIPPreTrainedModel,
CLIPVisionEmbeddings,
CLIPEncoder,
eager_attention_forward
) # noqa: E501
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
from transformers.generation import GenerateDecoderOnlyOutput
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from transformers import AutoModel
import logging
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers import (
BeamSearchScorer,
LogitsProcessorList,
TopPLogitsWarper,
TopKLogitsWarper,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MaxLengthCriteria,
StoppingCriteriaList
)
log = logging.getLogger(__name__)
class MammutCrossAttnLayer(nn.Module):
def __init__(self, config: MammutTextConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = MammutAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.layer_norm1_kv = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
print0_hidden_states: bool = False,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
if k_x is not None and v_x is not None:
k_x = self.layer_norm1_kv(k_x)
v_x = self.layer_norm1_kv(v_x)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
keys=k_x,
values=v_x,
print0_hidden_states=print0_hidden_states,
)
hidden_states = hidden_states.permute(1, 0, 2) # (seq_length, batch_size, embed_dim)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class MammutAttention(CLIPAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Union[MammutTextConfig, MammutVisionConfig]):
super().__init__(config)
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
# self.scale = 1
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.training = False # Set to True by default, can be changed during training or evaluation
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
keys: Optional[torch.Tensor] = None,
values: Optional[torch.Tensor] = None,
print0_hidden_states: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, seq_length, embed_dim = hidden_states.shape
if keys is None and values is None:
keys = hidden_states
values = hidden_states
#TODO: CLIP attention interface
# keys = self.k_proj(keys)
# values = self.v_proj(values)
# if print0_hidden_states:
# # print("head_dim:", self.head_dim)
# print("query shape:", queries.shape)
# print("key shape:", keys.shape)
# print("value shape:", values.shape)
# queries = queries.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# keys = keys.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# values = values.view(batch_size, seq_length, -1, self.head_dim).transpose(1, 2)
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
# in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
# if self.config._attn_implementation == "flash_attention_2":
# self.is_causal = causal_attention_mask is not None
# else:
# if attention_mask is not None and causal_attention_mask is not None:
# attention_mask = attention_mask + causal_attention_mask
# elif causal_attention_mask is not None:
# attention_mask = causal_attention_mask
# attention_interface: Callable = eager_attention_forward
# if self.config._attn_implementation != "eager":
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = F.multi_head_attention_forward(
query=hidden_states.permute(1, 0, 2), # (seq_length, batch_size, embed_dim)
key=keys.permute(1, 0, 2) if keys is not None else hidden_states.permute(1, 0, 2),
value=values.permute(1, 0, 2) if values is not None else hidden_states.permute(1, 0, 2),
embed_dim_to_check=embed_dim,
num_heads=self.num_heads,
in_proj_weight=torch.cat(
[self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0
),
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias], dim=0
) if self.q_proj.bias is not None else None,
bias_k=None,
bias_v=None,
add_zero_attn=False,
attn_mask=attention_mask,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
is_causal=self.is_causal,
dropout_p=0.0 if not self.training else self.dropout,
out_proj_weight=self.out_proj.weight,
out_proj_bias=self.out_proj.bias,
training=self.training, # Use the training flag to control dropout
)
# attn_output, attn_weights = attention_interface(
# self,
# queries, # (seq_length, batch_size, embed_dim)
# keys,
# values,
# attention_mask,
# is_causal=self.is_causal,
# scaling=self.scale,
# dropout=0.0 if not self.training else self.dropout,
# output_attentions=output_attentions,
# )
# attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
# attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class MammutEncoderLayer(CLIPEncoderLayer):
def __init__(self, config: MammutTextConfig, has_mlp: bool = True):
super().__init__(config)
self.embed_dim = config.hidden_size
self.self_attn = MammutAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = CLIPMLP(config) if has_mlp else None
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
print_hidden_states: bool = False,
) -> Tuple[torch.FloatTensor]:
"""
Forward pass for the encoder layer.
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
causal_attention_mask (`torch.FloatTensor`, *optional*): causal attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=None,
output_attentions=output_attentions,
print0_hidden_states=print_hidden_states,
)
hidden_states = hidden_states.permute(1, 0, 2) # (seq_length, batch_size, embed_dim)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states) if self.mlp is not None else hidden_states
hidden_states = residual + hidden_states
return hidden_states
class MammutMultimodalEncoder(nn.Module):
does_full_decoding: torch.jit.Final[bool]
def __init__(
self,
config: MammutConfig,
):
super().__init__()
self.config = config
self.n_cross_attn, _ = divmod(config.num_hidden_layers, config.cross_attn_ratio)
self.cross_step, _ = divmod(config.num_hidden_layers, self.n_cross_attn)
self.does_full_decoding = config.does_full_decoding
self.output_tokens = config.output_tokens
self.batch_first = config.batch_first
self.context_length = config.max_position_embeddings
self.layers = nn.ModuleList([])
self.cross_attn = nn.ModuleList([])
num_cross_attn = 0
for l_idx in range(config.num_hidden_layers):
_, r = divmod(l_idx, self.cross_step)
has_cross_attn = r == 0
layer = MammutEncoderLayer(config)
self.layers.append(layer)
if has_cross_attn:
num_cross_attn += 1
cross_attn_layer = MammutCrossAttnLayer(config)
self.cross_attn.append(cross_attn_layer)
def forward(
self,
text_embeds: torch.Tensor,
img_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Union[BaseModelOutput, Tuple[torch.Tensor]]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = text_embeds
seq_len = hidden_states.shape[1] if self.batch_first else hidden_states.shape[0]
if causal_attention_mask is None:
causal_attention_mask = self.build_causal_mask()
else:
causal_attention_mask = causal_attention_mask.to(dtype=hidden_states.dtype)
if attention_mask is None:
attention_mask = causal_attention_mask
else:
attention_mask = attention_mask + causal_attention_mask
if img_embeds is not None:
img_embeds = img_embeds.to(dtype=hidden_states.dtype)
k_x = img_embeds
v_x = img_embeds
else:
k_x = None
v_x = None
if img_embeds is not None:
attention_mask = attention_mask[:seq_len, :seq_len]
for i, layer in enumerate(self.layers):
cross_attn_idx, r = divmod(i, self.cross_step)
has_cross_attn = r == 0 and img_embeds is not None
if i == 0:
print_hidden_states = True
else:
print_hidden_states = False
hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask if img_embeds is not None else None,
causal_attention_mask=None,
output_attentions=output_attentions,
print_hidden_states=print_hidden_states,
)
if has_cross_attn:
cross_attn = self.cross_attn[cross_attn_idx]
hidden_states = cross_attn(
hidden_states=hidden_states,
k_x=k_x,
v_x=v_x,
print0_hidden_states=i== 0,
# attention_mask=attention_mask,
# causal_attention_mask=causal_attention_mask,
)
if output_hidden_states:
encoder_states = tuple(encoder_states)
if self.does_full_decoding:
encoder_states = encoder_states[:self.n_cross_attn + 1]
else:
encoder_states = encoder_states[:self.config.text_config.num_hidden_layers]
else:
encoder_states = None
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
def build_causal_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
def build_attn_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@dataclass
class MammutPoolingOutput(BaseModelOutputWithPooling):
"""
Base class for outputs of the Mammut model.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
output_ids: Optional[torch.Tensor] = None
pooler_output: Optional[torch.FloatTensor] = None
class MammutMultimodalEmbeddings(nn.Module):
def __init__(self, config: MammutTextConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, config.hidden_size
)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
max_position_embedding = self.position_embedding.weight.shape[0]
if seq_length > max_position_embedding:
raise ValueError(
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'):
if pool_type == 'first':
pooled, tokens = x[:, 0], x[:, 1:]
elif pool_type == 'last':
pooled, tokens = x[:, -1], x[:, :-1]
elif pool_type == 'argmax':
# take features from the eot embedding (eot_token is the highest number in each sequence)
assert text is not None
pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x
else:
pooled = tokens = x
return pooled, tokens
class MammutMultimodalTransformer(nn.Module):
def __init__(self, config: MammutTextConfig, output_tokens=True):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.encoder = MammutMultimodalEncoder(config)
self.text_projection = nn.Linear(
config.hidden_size, config.vocab_size, bias=False
) if config.hidden_size is not None else None
self.final_layer_norm = nn.LayerNorm(
embed_dim, eps=config.layer_norm_eps
)
# self.init_weights()
self.does_full_decoding = config.does_full_decoding
self.context_length = config.context_length
self.vocab_size = config.vocab_size
width = config.hidden_size
self.batch_first = config.batch_first
self.has_mlp = config.has_mlp
self.cross_attn_ratio = config.cross_attn_ratio
self.cross_step = config.cross_attn_ratio
self.n_cross_attn = config.num_hidden_layers // config.cross_attn_ratio
vocab_size = config.vocab_size
self.output_tokens = output_tokens
if self.does_full_decoding:
self.num_pos = self.context_length
self.embeddings = MammutMultimodalEmbeddings(config)
else:
self.num_pos = None
self.embeddings = None
def init_weights(self):
self.final_layer_norm.weight.data.fill_(1.0)
self.final_layer_norm.bias.data.zero_()
log.info("MammutMultimodalTransformer weights initialized.")
def forward(
self,
img_embs: torch.Tensor,
text_embs: Optional[torch.Tensor] = None,
output_tokens: Optional[bool] = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[CLIPVisionModelOutput, CLIPTextModelOutput]:
if text_embs is not None:
if self.embeddings is not None:
# print("text_embs shape:", text_embs.shape)
text_embs = self.embeddings(
input_ids=text_embs,
position_ids=position_ids,
# inputs_embeds=img_embs if img_embs is not None else None,
)
if self.does_full_decoding:
text_embs = text_embs[:, :self.context_length, :]
text_embs = self.encoder(
text_embeds=text_embs,
img_embeds=img_embs,
attention_mask=None,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
text_embs = text_embs.last_hidden_state
if self.does_full_decoding:
text_embs = text_embs[:, :self.context_length, :]
else:
text_embs = text_embs[:, 0, :]
if self.text_projection is not None:
output_ids = self.text_projection(text_embs)
else:
output_ids = text_embs
if output_tokens:
return MammutPoolingOutput(
last_hidden_state=text_embs, # Last hidden state is the text embeddings
hidden_states=None, # No hidden states in this implementation
attentions=None, # No attentions in this implementation
output_ids=output_ids, # Placeholder for output tokens
pooler_output=text_embs, # Pooler output is the text embeddings
)
return MammutPoolingOutput(
last_hidden_state=text_embs, # Last hidden state is the text embeddings
pooler_output=text_embs,
hidden_states=None, # No hidden states in this implementation
attentions=None, # No attentions in this implementation
)
def build_causal_mask(self, seq_len: Optional[int] = None, device: Optional[torch.device] = None) -> torch.Tensor:
if seq_len is None:
seq_len = self.context_length if self.does_full_decoding else self.config.context_length
if device is None:
device = torch.device("cpu")
mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).view(1, 1, seq_len, seq_len)
return mask
def build_attn_mask(self):
# lazily create causal attention mask, with full attention between the tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
class MammutMultimodalModel(CLIPTextModel):
"""
Mammut multimodal model with text and vision encoders.
"""
config_class = MammutTextConfig
base_model_prefix = "mammut_multimodal"
def __init__(self, config: MammutTextConfig):
super().__init__(config)
self.config = config.text_config
self.text_model = MammutMultimodalTransformer(config.text_config)
self.text_embed_dim = config.hidden_size
self.vision_embed_dim = config.vision_config.hidden_size
self.projection_dim = config.projection_dim
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_embs: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_tokens: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[MammutPoolingOutput, CLIPTextModelOutput]:
return self.text_model(
img_embs=image_embs,
text_embs=input_ids,
output_tokens=output_tokens,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
position_ids=position_ids,
)
class MammutVisionTransformer(CLIPVisionTransformer):
"""
Mammut Vision Transformer model.
Inherits from CLIPVisionTransformer and initializes the vision model.
"""
config_class = MammutVisionConfig
base_model_prefix = "mammut_vision"
def __init__(self, config: MammutVisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.pool_type = config.pool_type
def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.pool_type == 'avg':
pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]
elif self.pool_type == 'tok':
pooled, tokens = x[:, 0], x[:, 1:]
elif self.pool_type == "avg_all":
pooled, tokens = x.mean(dim=1), x
else:
pooled = tokens = x
return pooled, tokens
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> BaseModelOutputWithPooling:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs.last_hidden_state
pooled_output = last_hidden_state[:, 0, :]
if self.config.final_ln_after_pool:
pooled, _ = self._global_pool(last_hidden_state)
pooled_output = self.post_layernorm(pooled)
else:
pooled_output = self.post_layernorm(pooled_output)
pooled, _ = self._global_pool(pooled_output)
pooled_output = pooled
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class MammutVisionModel(CLIPVisionModel):
"""
Mammut Vision Model.
Inherits from CLIPVisionModel and initializes the vision model.
"""
config_class = MammutVisionConfig
base_model_prefix = "mammut_vision"
def __init__(self, config: MammutVisionConfig):
super().__init__(config)
self.config = config
self.vision_model = MammutVisionTransformer(config)
self.post_init()
@dataclass
class MammutContrastiveOutput(CLIPOutput):
"""
Output class for Mammut model in contrastive learning mode.
Contains contrastive output:
- loss: Loss value if return_loss is True.
- logits_per_text: Logits for text inputs.
- logits_per_image: Logits for image inputs.
- text_embeds: Text embeddings.
- image_embeds: Image embeddings.
"""
loss: Optional[torch.FloatTensor] = None
logits_per_text: Optional[torch.FloatTensor] = None
logits_per_image: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
@dataclass
class MammutCaptioningOutput(ModelOutput):
"""
Output class for Mammut captioning part.
Contains:
- last_hidden_state: Last hidden state of the text model.
- pooler_output: Pooler output of the text model.
- hidden_states: Hidden states from the text model.
- attentions: Attention weights from the text model.
- output_ids: Output tokens from the text model.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
output_ids: Optional[torch.Tensor] = None
@dataclass
class MammutOutput(ModelOutput):
"""
Output class for Mammut model.
Contains contrastive output:
- loss: Loss value if return_loss is True.
- logits_per_text: Logits for text inputs.
- logits_per_image: Logits for image inputs.
- text_embeds: Text embeddings.
- image_embeds: Image embeddings.
Captioning output:
- text_model_output: Output from the text model.
- output_ids: Output tokens from the text model.
"""
loss: Optional[torch.FloatTensor] = None
logits_per_text: Optional[torch.FloatTensor] = None
logits_per_image: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
text_model_output: Optional[MammutCaptioningOutput] = None
output_ids: Optional[torch.Tensor] = None
# @dataclass
# class MammutGenerationOutput(GenerateDecoderOnlyOutput)
def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
"""
This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
"""
square_tensor = torch.pow(tensor, 2)
sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
normed_tensor = torch.pow(sum_tensor, 0.5)
return normed_tensor
class MammutModel(CLIPPreTrainedModel):
"""
Mammut model with text and vision encoders.
"""
config_class = MammutConfig
base_model_prefix = "mammut"
def __init__(self, config: MammutConfig):
super().__init__(config)
self.config = config
self.text_model = MammutMultimodalTransformer(config.text_config, output_tokens=config.output_tokens)
vision_model = MammutVisionModel._from_config(config.vision_config)
self.vision_model = vision_model.vision_model
self.text_embed_dim = config.text_config.hidden_size
self.vision_embed_dim = config.vision_config.hidden_size
self.projection_dim = config.projection_dim
self.text_projection = self.text_model.text_projection
self.visual_projection = nn.Linear(
self.vision_embed_dim, self.projection_dim, bias=False
) if self.projection_dim is not None else None
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
self.map_viz2txt_kv = nn.Parameter(torch.randn(
self.config.vision_config.width, self.config.text_config.width
))
self.eos_token_id = self.config.text_config.eos_token_id
self.bos_token_id = self.config.text_config.bos_token_id
self.pad_token_id = self.config.text_config.pad_token_id
self.does_full_decoding = config.text_config.does_full_decoding
self.context_length = config.text_config.context_length
self.vocab_size = config.text_config.vocab_size
self.batch_first = config.text_config.batch_first
# Initialize weights and apply final processing
self.post_init()
def get_text_features(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
img_embs: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Get text features from the Mammut model.
"""
text_model_output = self.text_model(
img_embs=img_embs,
text_embs=input_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
text_embeds = text_model_output.last_hidden_state
text_embeds = self.text_model.final_layer_norm(text_embeds)
text_embeds = text_embeds.mean(1)
text_embeds = F.normalize(text_embeds, dim=-1)
return text_embeds
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
normalize: bool = True,
) -> torch.FloatTensor:
"""
Get image features from the Mammut model.
"""
vision_outputs: CLIPVisionModelOutput = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
image_embeds = vision_outputs.pooler_output
if self.visual_projection is not None:
image_embeds = self.visual_projection(image_embeds)
image_embeds = F.normalize(image_embeds, dim=-1) if normalize else image_embeds
return image_embeds
def _contrastive_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
output_tokens: Optional[bool] = None,
contrastive: Optional[bool] = False,
) -> MammutContrastiveOutput:
"""
Forward pass for the Mammut model in contrastive learning mode.
- **Two-pass learning:** to unify contrastive and next-token
prediction, we need to unify unconditional representation learning and token-conditioned next-token prediction objective.
- **First pass: contrastive task.** For the first pass, text features should not see image features (dual-encoder contrastive learner) but attend to all tokens at once to produce sequence-level representation. Cross-attention and causal masking is disabled.
- **Second pass: captioning task.** Using cross attention and causal masking learn caption generation task.
Return:
MammutContrastiveOutput: Contains contrastive output with logits, embeddings, and optional loss.
"""
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs: CLIPVisionModelOutput = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
)
# text_model is MammutMultimodalTransformer, which handles text embeddings
text_outputs: MammutPoolingOutput = self.text_model(
img_embs=None, # No image embeddings in contrastive forward pass for text model
text_embs=input_ids,
output_tokens=output_tokens,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
position_ids=position_ids,
)
image_embeds = vision_outputs.pooler_output
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs.pooler_output
pooled, tokens = text_global_pool(text_embeds, text=input_ids)
text_embeds = self.text_model.final_layer_norm(text_embeds)
text_embeds = text_embeds.mean(1)
tokens = self.text_projection(pooled)
# Normalize the embeddings
image_embeds = image_embeds / _get_vector_norm(image_embeds)
text_embeds = text_embeds / _get_vector_norm(text_embeds)
# cosine similarity as logits
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
logits_per_text = logits_per_text * self.logit_scale.exp().to(text_embeds.device)
logits_per_image = logits_per_text.t()
loss = None
return MammutContrastiveOutput(
loss=loss,
logits_per_text=logits_per_text,
logits_per_image=logits_per_image,
text_embeds=text_embeds,
image_embeds=image_embeds,
)
def _captioning_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
output_tokens: Optional[bool] = None,
) -> MammutCaptioningOutput:
"""
Forward pass for the Mammut model in captioning mode.
Return:
MammutCaptioningOutput: Contains captioning output with last hidden state, pooler output, hidden states, attentions, and output tokens.
"""
if pixel_values is None:
raise ValueError("Pixel values must be provided for captioning.")
if input_ids is None:
input_ids = torch.ones(
(pixel_values.shape[0], self.context_length), dtype=torch.long, device=pixel_values.device
) * self.bos_token_id
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if image_embeds is None:
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs.last_hidden_state
image_embeds = image_embeds @ self.map_viz2txt_kv
text_model_output = self.text_model(
img_embs=image_embeds, # Use image embeddings for captioning
text_embs=input_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
text_embeds = text_model_output.last_hidden_state
text_embeds = self.text_model.final_layer_norm(text_embeds)
logits = self.text_projection(text_embeds)
if output_tokens:
return MammutCaptioningOutput(
last_hidden_state=text_embeds,
pooler_output=image_embeds, # Placeholder for pooler output
output_ids=logits, # Output tokens from the text model
)
return MammutCaptioningOutput(
last_hidden_state=text_embeds,
pooler_output=image_embeds, # Placeholder for pooler output
output_ids=None, # No output tokens in this case
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
output_tokens: Optional[bool] = False,
contrastive_only: Optional[bool] = False,
captioning_only: Optional[bool] = False,
) -> MammutOutput:
"""
Forward pass for the Mammut model.
- **Two-pass learning:** to unify contrastive and next-token prediction, we need to unify unconditional representation learning and token-conditioned next-token prediction objective.
- **First pass: contrastive task.** For the first pass, text features should not see image features (dual-encoder contrastive learner) but attend to all tokens at once to produce sequence-level representation. Cross-attention and causal masking is disabled.
- **Second pass: captioning task.** Using cross attention and causal masking learn caption generation task.
"""
# first pass: contrastive task
# second pass: captioning task
if pixel_values is None and input_ids is None:
raise ValueError("Pixel values or input IDs must be provided for captioning.")
if output_tokens is None:
output_tokens = self.config.output_tokens
if output_tokens and not self.config.output_tokens:
raise ValueError("Output tokens are not enabled in the configuration.")
if output_tokens and pixel_values is None:
raise ValueError("Pixel values must be provided if output tokens are enabled.")
if output_tokens and input_ids is None:
# Only captioning
captioning_only = True
if input_ids is not None and pixel_values is not None:
contrastive_output = self._contrastive_forward(
input_ids=input_ids,
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
)
else:
contrastive_output = MammutContrastiveOutput(
loss=None,
logits_per_text=None,
logits_per_image=None,
text_embeds=None,
image_embeds=None,
)
if contrastive_only:
# If only contrastive output is needed, return it directly
return MammutOutput(
loss=contrastive_output.loss,
logits_per_text=contrastive_output.logits_per_text,
logits_per_image=contrastive_output.logits_per_image,
text_embeds=contrastive_output.text_embeds,
image_embeds=contrastive_output.image_embeds,
)
if captioning_only:
# If only captioning output is needed, return it directly
text_model_output = self._captioning_forward(
input_ids=input_ids,
pixel_values=pixel_values, # No pixel values for captioning only
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
output_tokens=output_tokens,
)
return MammutOutput(
loss=None, # No loss in captioning only mode
logits_per_text=None, # No logits in captioning only mode
logits_per_image=None, # No logits in captioning only mode
text_embeds=text_model_output.last_hidden_state, # Use last hidden state as text embeddings
image_embeds=None, # No image embeddings in captioning only mode
text_model_output=text_model_output, # Output from the text model
output_ids=text_model_output.output_ids, # Output tokens from the text model
)
# If both contrastive and captioning outputs are needed, return both
text_model_output = self._captioning_forward(
input_ids=input_ids,
pixel_values=pixel_values, # No pixel values for captioning only
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
output_tokens=output_tokens,
)
return MammutOutput(
loss=contrastive_output.loss,
logits_per_text=contrastive_output.logits_per_text,
logits_per_image=contrastive_output.logits_per_image,
text_embeds=contrastive_output.text_embeds,
image_embeds=contrastive_output.image_embeds,
text_model_output=text_model_output, # Output from the text model
output_ids=text_model_output.output_ids, # Output tokens from the text model
)
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
max_new_tokens: int = 20,
do_sample: bool = False,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
top_p: float = 0,
top_k: int = 0,
min_seq_len: int = 1,
stopping_criteria= None,
) -> GenerateDecoderOnlyOutput:
"""
Generate captions using the Mammut model.
Args:
input_ids (torch.LongTensor, optional): Input token IDs for the text model.
pixel_values (torch.FloatTensor, optional): Pixel values for the vision model.
attention_mask (torch.Tensor, optional): Attention mask for the text model.
position_ids (torch.LongTensor, optional): Position IDs for the text model.
max_new_tokens (int): Maximum length of the generated sequence.
do_sample (bool): Whether to sample from the distribution or take argmax.
temperature (float): Temperature for sampling.
repetition_penalty (float): Penalty for repetition in sampling.
top_p (float): Top-p sampling parameter.
top_k (int): Top-k sampling parameter.
min_seq_len (int): Minimum sequence length for generation.
stopping_criteria: Stopping criteria for generation.
Returns:
GenerateDecoderOnlyOutput: Contains the generated sequences and logits.
"""
# This method should implement the generation logic for the Mammut model.
if input_ids is None and pixel_values is None:
raise ValueError("Input IDs or pixel values must be provided for generation.")
if input_ids is None:
input_ids = torch.ones(
(pixel_values.shape[0], 1), dtype=torch.long, device=pixel_values.device
) * self.bos_token_id
if pixel_values is None:
raise ValueError("Pixel values must be provided for generation.")
self.eval()
device = pixel_values.device if pixel_values is not None else input_ids.device
if input_ids is None:
input_ids = torch.ones(
(pixel_values.shape[0], 1), dtype=torch.long, device=device
) * self.bos_token_id
eos_token_id = self.eos_token_id if self.eos_token_id is not None else self.text_model.config.eos_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)
if do_sample:
if top_k > 0:
logit_warper = LogitsProcessorList(
[
TopKLogitsWarper(top_k),
]
)
if top_p > 0:
logit_warper = LogitsProcessorList(
[
TopPLogitsWarper(top_p),
]
)
if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_new_tokens)]
stopping_criteria = StoppingCriteriaList(
stopping_criteria
)
out = input_ids
vision_outputs = self.vision_model(
pixel_values=pixel_values
)
image_embeds = vision_outputs.last_hidden_state
with torch.no_grad():
while True:
x = out[:, -max_new_tokens:]
# Get text features
captioning_output = self._captioning_forward(
input_ids=x,
pixel_values=pixel_values,
image_embeds=image_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=False,
output_hidden_states=False,
interpolate_pos_encoding=False,
output_tokens=True, # We want the output tokens
)
output_ids = captioning_output.output_ids
# Get logits for the next token
logits = output_ids[:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == self.pad_token_id)
logits = logits[~mask, :]
filtered_logits = logit_processor(x[~mask, :], logits)
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
# Sample or take the argmax of the logits
cur_len = out.shape[1]
if cur_len >= max_new_tokens:
next_token = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
elif do_sample:
probs = F.softmax(filtered_logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(filtered_logits, dim=-1, keepdim=True)
if mask.all():
break
# Check if we have reached the end of the sequence or max length
if (out.shape[1] >= max_new_tokens) or (next_token == eos_token_id).all():
break
# Append the next token to the output sequence
out = torch.cat([out, next_token], dim=1)
output_ids = out.long() if out.dtype != torch.long else out
# If we reach the end of the sequence or max length, break the loop
return GenerateDecoderOnlyOutput(
logits=logits,
sequences=output_ids, # Output tokens from the text model
)
AutoModel.register(MammutConfig, MammutModel)