from ast import Module from cProfile import label from functools import partial from black import Mode from matplotlib.pyplot import grid import torch import torch.nn as nn import torch.nn.functional as F from transformers.activations import PytorchGELUTanh from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.utils import is_flash_attn_2_available, logging from transformers.integrations import use_kernel_forward_from_hub from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_attn_mask_utils import AttentionMaskConverter if is_flash_attn_2_available(): from flash_attn import flash_attn_varlen_func else: flash_attn_varlen_func = None from collections.abc import Callable from transformers.activations import ACT2FN from transformers.processing_utils import Unpack from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin import math from copy import deepcopy from typing import Union, Tuple, Sequence, Optional, List from einops import rearrange from .configuration_llava_uhd_v3 import LlavaUHDV3Config, LlavaUHDV3VisionConfig, LlavaUHDV3TextConfig logger = logging.get_logger(__name__) ##### MoonViT part ##### def multihead_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_cu_seqlens: Optional[torch.Tensor] = None, k_cu_seqlens: Optional[torch.Tensor] = None, ): """Multi-head attention using flash attention 2. Args: q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), or (tot_seqlens, num_heads, head_dim) if packing. q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q. The first element should be 0 and the last element should be q.shape[0]. k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k. The first element should be 0 and the last element should be k.shape[0]. Returns: output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing, where dim = num_heads * head_dim """ # Unified format legal check assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims" assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]" assert ( k_cu_seqlens[-1] == k.shape[0] == v.shape[0] ), "k_cu_seqlens must sum to k.shape[0]" assert q.dtype in [ torch.bfloat16, torch.float16, ], f"unsupported dtype {q.dtype} for multihead attn" max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item() max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item() attn_out = flash_attn_varlen_func( q, k, v, q_cu_seqlens, k_cu_seqlens, max_seqlen_q, max_seqlen_k, causal=False, ) attn_out = attn_out.flatten(start_dim=-2) return attn_out def sdpa_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_cu_seqlens: Optional[torch.Tensor] = None, k_cu_seqlens: Optional[torch.Tensor] = None, ) -> torch.Tensor: """SDPA attention. Args: q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim), or (tot_seqlens, num_heads, head_dim) if packing. """ seq_length = q.shape[0] attention_mask = torch.zeros( [1, seq_length, seq_length], device=q.device, dtype=torch.bool ) for i in range(1, len(q_cu_seqlens)): attention_mask[ ..., q_cu_seqlens[i - 1] : q_cu_seqlens[i], q_cu_seqlens[i - 1] : q_cu_seqlens[i], ] = True q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) return attn_output def eager_attention( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_cu_seqlens: Optional[torch.Tensor] = None, k_cu_seqlens: Optional[torch.Tensor] = None, ) -> torch.Tensor: seq_length = q.shape[0] attention_mask = torch.zeros( [1, seq_length, seq_length], device=q.device, dtype=torch.bool ) for i in range(1, len(q_cu_seqlens)): attention_mask[ ..., q_cu_seqlens[i - 1] : q_cu_seqlens[i], q_cu_seqlens[i - 1] : q_cu_seqlens[i], ] = True q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1]) attn_weight += attention_mask attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype) attn_output = attn_weight @ v attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) return attn_output VL_VISION_ATTENTION_FUNCTIONS = { "flash_attention_2": multihead_attention, "sdpa": sdpa_attention, "eager": eager_attention, } def _apply_rope_input_validation(x, freqs_cis): assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape) assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape) assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape) assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype def apply_rope( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Args: (The leading dimensions of all inputs should be the same) xq: query, tensor of shape (..., num_heads, head_dim) xk: key, tensor of shape (..., num_heads, head_dim) freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid. Returns: xq_out, xk_out: tensors of shape (..., num_heads, head_dim) """ _apply_rope_input_validation(xq, freqs_cis) _apply_rope_input_validation(xk, freqs_cis) freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2 # ..., num_heads, head_dim/2 xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2)) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim return xq_out.type_as(xq), xk_out.type_as(xk) class Learnable2DInterpPosEmb(nn.Module): def __init__( self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic" ) -> None: super().__init__() self.height = height self.width = width self.interpolation_mode = interpolation_mode self.weight = nn.Parameter(torch.empty(height, width, dim)) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.weight) def forward(self, x, grid_hws) -> torch.Tensor: pos_embs = [] for shape in grid_hws.tolist(): shape = [int(i) for i in shape] if shape == self.weight.shape[:-1]: pos_embs.append(self.weight.flatten(end_dim=1)) else: pos_embs.append( F.interpolate( self.weight.permute((2, 0, 1)).unsqueeze(0), size=shape, mode=self.interpolation_mode, ) .squeeze(0) .permute((1, 2, 0)) .flatten(end_dim=1) ) out = x + torch.cat(pos_embs) return out class MoonVisionPatchEmbed(nn.Module): def __init__( self, out_dim: int, in_dim: int = 3, patch_size: Union[int, Tuple[int, int]] = (14, 14), pos_emb_height: int = 14, pos_emb_width: int = 14, ): super().__init__() assert isinstance( patch_size, (int, Sequence) ), f"Invalid patch_size type: {type(patch_size)}" if isinstance(patch_size, int): patch_size = (patch_size, patch_size) assert ( len(patch_size) == 2 ), f"Expected patch_size to be a tuple of 2, got {patch_size}" self.patch_size = patch_size self.proj = nn.Conv2d( in_dim, out_dim, kernel_size=patch_size, stride=patch_size ) self.pos_emb = Learnable2DInterpPosEmb( height=pos_emb_height, width=pos_emb_width, dim=out_dim ) def forward(self, x, grid_hws) -> torch.Tensor: """ Args: x (L, Channels): input tensor grid_hws (N, 2): grid height and width Returns: (L, Cout) tensor """ x = self.proj(x).view(x.size(0), -1) # apply positional embedding x = self.pos_emb(x, grid_hws) return x class Rope2DPosEmb(nn.Module): """2D rotary position embedding with multi-resolution support. This class is intended to be used in the following way: 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis. 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration. 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation. The rope is shared across all attention layers and all heads. Refs: - RoFormer: https://arxiv.org/abs/2104.09864 - VisionLLaMA: https://arxiv.org/abs/2403.00522 - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py Args: dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed) max_height (int): the maximum height of the 2D grid max_width (int): the maximum width of the 2D grid theta_base (float): the base of the theta device (str): the device to store the precomputed cis """ def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000): super().__init__() self.dim = dim assert self.dim % 4 == 0, "dim must be divisible by 4" self.max_height = max_height self.max_width = max_width self.theta_base = theta_base self.freqs_cis = None def extra_repr(self): return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}" def _precompute_freqs_cis(self, down_scale_rate, device: torch.device) -> torch.Tensor: """Calculate the cis(freqs) for each position in the 2D grid. Return: complex tensor of shape (max_height, max_width, dim//2) and value: height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim)) weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4)) note: `cis` is a mathematical notation defined by cis x = cos x + i sin x, """ max_height = self.max_height // down_scale_rate max_width = self.max_width // down_scale_rate N = max_height * max_width flat_pos = torch.arange(0, N).float().to(device) x_pos = flat_pos % max_width y_pos = flat_pos // max_width dim_range = ( torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device) ) # C/4 freqs = 1.0 / (self.theta_base ** (dim_range / self.dim)) x_freqs = torch.outer(x_pos, freqs).float() # N, C/4 y_freqs = torch.outer(y_pos, freqs).float() # N, C/4 x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4 y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4 # N, C/4, 2 freqs_cis = torch.cat( [x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1 ) # max_height, max_width, C/2 freqs_cis = freqs_cis.reshape(max_height, max_width, -1) return freqs_cis def get_freqs_cis(self, grid_hws: torch.Tensor, down_scale_rate=1, init_freqs=False) -> torch.Tensor: """ Args: grid_hws (torch.Tensor): grid height and width Returns: freqs_cis: tensor of shape (sum(t * height * width), dim//2) """ max_height = self.max_height // down_scale_rate max_width = self.max_width // down_scale_rate if self.freqs_cis is None or init_freqs: self.freqs_cis = self._precompute_freqs_cis(down_scale_rate, grid_hws.device) shapes = grid_hws.tolist() assert all( 1 <= h <= max_height and 1 <= w <= max_width for h, w in shapes ), ( shapes, max_height, max_width, ) freqs_cis = torch.cat( [self.freqs_cis[:int(h), :int(w)].reshape(-1, self.dim // 2) for h, w in shapes], dim=0, ) return freqs_cis class MLP2(nn.Module): """ Args: dims: [in_dim, hidden_dim, out_dim] bias: whether to use bias in linear layer. """ def __init__(self, dims: list[int], activation, bias=True): super().__init__() assert len(dims) == 3 self.fc0 = nn.Linear(dims[0], dims[1], bias=bias) self.fc1 = nn.Linear(dims[1], dims[2], bias=bias) self.activation = activation for m in [self.fc0, self.fc1]: nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features)) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc0(x) x = self.activation(x) return self.fc1(x) class PatchMergingLayer(nn.Module): def __init__(self, embed_dim, enable_merging=True, merging_method="avg_pooling", norm_layer=nn.LayerNorm): """ :param embed_dim: Transformer token 的嵌入维度 :param enable_merging: 是否启用 token 合并功能 :param merging_method: 选择 'mlp' 或 'avg_pooling' 作为合并方式 """ super().__init__() self.enable_merging = enable_merging self.merging_method = merging_method self.zero_init_fc = nn.Linear(embed_dim, embed_dim, bias=False) if self.merging_method == 'avg_pooling': pass elif self.merging_method == 'm_pooling': self.attn_layer = nn.Sequential( nn.Linear(embed_dim * 2, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim) ) self.num_head = 16 def forward(self, x, cu_seqlens, spatial_shapes): if not self.enable_merging: return x, cu_seqlens cu_seqlens_out = cu_seqlens.clone() # (N+1, ) feature_x = x x_i_list = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i-1].item() end_idx = cu_seqlens[i].item() x_i = x[start_idx:end_idx, :] h, w = spatial_shapes[i-1] x_i = x_i.view(int(h), int(w), -1) # (h, w, embed_dim) if self.merging_method == 'avg_pooling': x_i = rearrange(x_i, 'h w c -> c h w') x_i = F.avg_pool2d(x_i, kernel_size=2, stride=2) x_i = rearrange(x_i, 'c h w -> (h w) c') elif self.merging_method == 'm_pooling': x_i = rearrange(x_i, '(h p1) (w p2) c -> (h w) (p1 p2) c', p1=2, p2=2) pooled_x_i = x_i.mean(-2, keepdim=True).expand(-1, 4, -1) fused_x_i = torch.cat([x_i, pooled_x_i], dim=-1) attn_logits = self.attn_layer(fused_x_i) # multi-head attn attn_logits = rearrange(attn_logits, 'n s (m d) -> n m s d', m=self.num_head) attn_weights = F.softmax(attn_logits, dim=-2) attn_weights = rearrange(attn_weights, 'n m s d -> n s (m d)') # multi-head attn x_i = (x_i * attn_weights).sum(-2) x_i_list.append(x_i) cu_seqlens_out[i] = cu_seqlens_out[i-1] + x_i.shape[0] x = torch.cat(x_i_list, dim=0) # (L, embed_dim) return x, cu_seqlens_out, spatial_shapes//2, feature_x class MoonVitEncoderLayer(nn.Module): def __init__( self, layer_idx: int, num_heads: int, hidden_dim: int, mlp_dim: int, *, attn_implementation: str = "eager", activation=F.gelu, attn_bias: bool = False, enable_merging: bool = False, merging_method: str = "avg_pooling", merger_layer_index: List[int] = None, ): super().__init__() self.num_heads = num_heads self.hidden_dim = hidden_dim self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads self.attn_implementation = attn_implementation self.norm0 = nn.LayerNorm(hidden_dim) self.norm1 = nn.LayerNorm(hidden_dim) self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation) self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias) self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias) if merger_layer_index is not None and layer_idx in merger_layer_index: self.merger = PatchMergingLayer( embed_dim=hidden_dim, enable_merging=enable_merging, merging_method=merging_method, ) else: self.merger = None def attention_qkvpacked( self, x: torch.Tensor, cu_seqlens: torch.Tensor, rope_freqs_cis: Optional[torch.Tensor] = None, ): """ Args: x (torch.Tensor): (batch_size, seqlen, hidden_dim) cu_seqlens (torch.Tensor): """ xqkv = self.wqkv(x) qkv_shape = xqkv.size()[:-1] + ( 3, self.num_heads, self.hidden_size_per_attention_head, ) # xqkv: (batch_size, seqlen, 3, nheads, headdim) xqkv = xqkv.view(*qkv_shape) xq, xk, xv = torch.unbind(xqkv, dim=-3) xq, xk = apply_rope(xq, xk, rope_freqs_cis) attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation] attn_out = attn_func( xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens ) attn_out = self.wo(attn_out) return attn_out def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rope_freqs_cis: Union[torch.Tensor, None] = None, spatial_shapes: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set Returns: output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input """ residual = hidden_states hidden_states = self.norm0(hidden_states) attn_out = self.attention_qkvpacked( hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis ) hidden_states = residual + attn_out residual = hidden_states hidden_states = self.mlp(self.norm1(hidden_states)) hidden_states = residual + hidden_states if self.merger is not None: hidden_states, cu_seqlens, spatial_shapes, feature_x = self.merger( hidden_states, cu_seqlens, spatial_shapes ) outputs = (hidden_states, cu_seqlens, spatial_shapes, feature_x)# return the feature_x for later use else: outputs = (hidden_states, cu_seqlens) return outputs class MoonVitEncoder(nn.Module): def __init__( self, hidden_dim: int, num_layers: int, block_cfg: dict, ) -> None: super().__init__() self.blocks = nn.ModuleList( [MoonVitEncoderLayer(layer_idx=i, **block_cfg) for i in range(num_layers)] ) self.final_layernorm = nn.LayerNorm(hidden_dim) self.rope_2d = Rope2DPosEmb( block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512 ) def forward( self, hidden_states: torch.Tensor, grid_hws: torch.Tensor ) -> torch.Tensor: rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws) lengths = torch.cat( ( torch.zeros(1, device=hidden_states.device, dtype=grid_hws.dtype), grid_hws[:, 0] * grid_hws[:, 1], ) ) cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32) down_scale_rate = 1 feature_x_list = [] for _, block in enumerate(self.blocks): layer_outputs = block( hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis, spatial_shapes=grid_hws ) if len(layer_outputs) > 2: down_scale_rate *= 2 hidden_states, cu_seqlens, grid_hws, feature_x = layer_outputs rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws, down_scale_rate=down_scale_rate) feature_x_list.append(feature_x) else: hidden_states, cu_seqlens = layer_outputs hidden_states = self.final_layernorm(hidden_states) return hidden_states, grid_hws ##### Qwen2 part ##### class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Qwen2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) sliding_window = None if ( self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers ): sliding_window = self.config.sliding_window attention_interface: Callable = eager_attention_forward if self.config.attn_implementation != "eager": if self.config.attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config.attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=sliding_window, # main diff with Llama **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Qwen2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Qwen2DecoderLayer(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class Qwen2RotaryEmbedding(nn.Module): def __init__(self, config, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) ##### LlavaUHDV3 part ##### class Qwen2vlPatchMerger(nn.Module): def __init__( self, embed_dim, image_embed_dim=1024, compression_factor=(2,2), norm_layer=partial(nn.LayerNorm, eps=1e-6) ): super().__init__() self.embed_dim = embed_dim self.image_embed_dim = image_embed_dim self.hidden_size = image_embed_dim * (compression_factor[0]*compression_factor[1]) self.nl = norm_layer(image_embed_dim) self.mlp = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, embed_dim), ) self.compression_factor = compression_factor def forward(self, x, tgt_size=(24,24), attn_mask=None): # x = x.to(torch.bfloat16) # dtype = x.dtype height, width = tgt_size if height * width != x.shape[1]: x = x[:, :int(height * width)] x = self.nl(x) x = x.permute(0, 2, 1).unflatten(-1, (int(height), int(width))) # b, dim, h, w batch_size, dim, height, width = x.shape # 计算输出空间的高度和宽度 # h_compressed = (height + self.compression_factor[0] - 1) // self.compression_factor[0] # w_compressed = (width + self.compression_factor[1] - 1) // self.compression_factor[1] unfolded = x.unfold(2, self.compression_factor[0], self.compression_factor[0]).unfold(3, self.compression_factor[1], self.compression_factor[1]) unfolded = unfolded.contiguous().view(batch_size, dim, -1, self.compression_factor[0] * self.compression_factor[1]) unfolded = unfolded.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, dim*self.compression_factor[0] * self.compression_factor[1]) compressed_x = self.mlp(unfolded) return compressed_x class LlavaUHDV3PretrainedModel(PreTrainedModel): config: LlavaUHDV3Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Qwen2DecoderLayer", "MoonViTEncoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _can_compile_fullgraph = True _supports_attention_backend = True def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) class LlavaUHDV3VisionTransformerPretrainedModel(LlavaUHDV3PretrainedModel): config: LlavaUHDV3VisionConfig _no_split_modules = ["MoonViTEncoderLayer"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) config = deepcopy(config) self.patch_size = config.patch_size self.patch_embed = MoonVisionPatchEmbed( out_dim=config.hidden_size, patch_size=config.patch_size, pos_emb_height=config.init_pos_emb_height, pos_emb_width=config.init_pos_emb_width, ) if hasattr(config, "merger_layer_index"): merger_layer_index = config.merger_layer_index merging_method = config.merging_method if merger_layer_index is not None: enable_merging = True merging_method = merging_method if merging_method is not None else "avg_pooling" else: enable_merging = False merging_method = None self.encoder = MoonVitEncoder( hidden_dim=config.hidden_size, num_layers=config.num_hidden_layers, block_cfg={ "num_heads": config.num_attention_heads, "hidden_dim": config.hidden_size, "mlp_dim": config.intermediate_size, "activation": PytorchGELUTanh(), "attn_bias": True, "attn_implementation": self.config.attn_implementation, "enable_merging": enable_merging, "merging_method": merging_method, "merger_layer_index": merger_layer_index, }, ) def forward( self, pixel_values: torch.Tensor, grid_hws: torch.Tensor ) -> torch.Tensor: """ Args: pixel_values (torch.Tensor): The input pixel values. grid_hws (torch.Tensor): The grid height and width. Returns: torch.Tensor: The output tokens. """ pixel_values = pixel_values.to(torch.bfloat16) hidden_states = self.patch_embed(pixel_values, grid_hws) image_features, grid_hws = self.encoder(hidden_states, grid_hws) output_features = [] offset = 0 for grid_hw in grid_hws: h, w = grid_hw num_tokens = int(h * w) output_features.append(image_features[offset: offset+num_tokens].unsqueeze(0)) offset += num_tokens assert offset == image_features.shape[0], \ f"Used {offset} tokens, but image_features has {image_features.shape[0]} tokens!" return output_features class LlavaUHDV3TextModel(LlavaUHDV3PretrainedModel): config: LlavaUHDV3TextConfig _no_split_modules = ["Qwen2DecoderLayer"] def __init__(self, config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def get_input_embeddings(self): return self.embed_tokens def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> BaseModelOutputWithPast: if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds position_embeddings = self.rotary_emb(hidden_states, position_ids) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( partial(decoder_layer.__call__, **kwargs), hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, ): if self.config.attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config.attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, past_key_values=past_key_values, ) if ( self.config.attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, config, past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. config (`Qwen2Config`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( causal_mask.device ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask class LlavaUHDV3Model(LlavaUHDV3PretrainedModel): config_class = LlavaUHDV3Config def __init__(self, config): super().__init__(config) config.model_type = "llava_uhd_v3" config.rope_scaling = None self.visual = LlavaUHDV3VisionTransformerPretrainedModel._from_config(config.vision_config) self.language_model = LlavaUHDV3TextModel._from_config(config.text_config) self.projector = Qwen2vlPatchMerger( embed_dim=config.text_config.hidden_size, image_embed_dim=config.vision_config.hidden_size, compression_factor=(2, 2), ) self.rope_deltas = None # Initialize model layers here self.post_init() def get_image_features(self, pixel_values, grid_hws): down_smaple_ratio = 1 merger_layer_index = getattr(self.config.vision_config, "merger_layer_index", None) if merger_layer_index is not None: down_smaple_ratio = down_smaple_ratio * len(merger_layer_index)**2 image_features = self.visual(pixel_values, grid_hws) projected_image_feaures = [] for image_feature, grid_hw in zip(image_features, grid_hws): grid_hw = (grid_hw[0]//down_smaple_ratio, grid_hw[1]//down_smaple_ratio) projected_image_feature = self.projector(image_feature, tgt_size=grid_hw)[0] projected_image_feaures.append(projected_image_feature) return projected_image_feaures def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, pixel_values, grid_hws ): if pixel_values is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels image_features = self.get_image_features(pixel_values, grid_hws) _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, -100) input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] new_input_embeds = [] new_labels = [] cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == -200).sum() if num_images == 0: cur_image_features = image_features[cur_image_idx] cur_input_embeds_1 = self.language_model.embed_tokens(cur_input_ids) cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0][0:0]], dim=0) new_input_embeds.append(cur_input_embeds) new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = [-1] + torch.where(cur_input_ids == -200)[0].tolist() + [cur_input_ids.shape[0]] cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) split_sizes = [x.shape[0] for x in cur_labels_noim] cur_input_embeds = self.language_model.embed_tokens(torch.cat(cur_input_ids_noim)) cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) cur_new_input_embeds = [] cur_new_labels = [] for i in range(num_images + 1): cur_new_input_embeds.append(cur_input_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) if i < num_images: try: cur_image_features = image_features[cur_image_idx] except IndexError: cur_image_features = image_features[cur_image_idx - 1] cur_image_idx += 1 cur_new_input_embeds.append(cur_image_features) cur_new_labels.append(torch.full((cur_image_features.shape[0],), -100, device=cur_labels.device, dtype=cur_labels.dtype)) cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", 4096) new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] new_labels = [x[:tokenizer_model_max_length] for x in new_labels] max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full((batch_size, max_len), -100, dtype=new_labels[0].dtype, device=new_labels[0].device) attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): cur_len = cur_new_embed.shape[0] new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels def forward( self, input_ids = None, position_ids = None, attention_mask = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, pixel_values = None, grid_hws = None, return_dict = None, **kwargs, ): if inputs_embeds is None: input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, pixel_values, grid_hws ) output = self.language_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) if labels is not None: return output[0], labels return output class LlavaUHDV3ForCausalLM(LlavaUHDV3PretrainedModel, GenerationMixin): config_class = LlavaUHDV3Config _checkpoint_conversion_mapping = { "^visual": "model.visual", r"^model(?!\.(language_model|visual|projector))": "model.language_model", } # _tied_weights_keys = ["lm_head.weight", "model.language_model.embed_tokens.weight"] def __init__(self, config): super().__init__(config) self.model = LlavaUHDV3Model(config) self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.post_init() @property def language_model(self): return self.model.language_model @property def visual(self): return self.model.visual def get_input_embeddings(self): return self.language_model.embed_tokens def get_output_embeddings(self): return self.lm_head def forward(self, input_ids, labels=None, attention_mask=None, pixel_values=None, grid_hws=None, **kwargs): if labels is not None: outputs, labels = self.model(input_ids, labels=labels, attention_mask=attention_mask, pixel_values=pixel_values, grid_hws=grid_hws, **kwargs) else: outputs = self.model(input_ids, labels=labels, attention_mask=attention_mask, pixel_values=pixel_values, grid_hws=grid_hws, **kwargs) hidden_states = outputs.last_hidden_state slice_indices = slice(0, None) logits = self.lm_head(hidden_states[:,slice_indices,:]) loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @torch.no_grad() def generate( self, input_ids: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, grid_hws: Optional[torch.Tensor] = None, **kwargs, ): position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") if pixel_values is not None: input_ids, position_ids, attention_mask, _, inputs_embeds, _ = self.model.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, None, None, pixel_values, grid_hws ) else: inputs_embeds = self.model.language_model.embed_tokens(input_ids) return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): pixel_values = kwargs.pop("pixel_values", None) grid_hws = kwargs.pop("grid_hws", None) inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) if pixel_values is not None: inputs["pixel_values"] = pixel_values if grid_hws is not None: inputs["grid_hws"] = grid_hws return inputs __all__ = ["LlavaUHDV3ForCausalLM", "LlavaUHDV3Model", "LlavaUHDV3PretrainedModel", "LlavaUHDV3TextModel"] # At the end of this model file # ModelClass = LlavaUHDV3ForCausalLM