import logging import os import warnings import torch from torch import Tensor from torch import nn import torch.nn.functional as F from typing import Union, Tuple, Dict, Optional from einops import rearrange XFORMERS_AVAILABLE = False class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = True, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, norm_layer: nn.Module = nn.LayerNorm, qk_norm: bool = False, fused_attn: bool = True, # use F.scaled_dot_product_attention or not rope=None, ) -> None: super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.fused_attn = fused_attn self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) self.rope = rope def forward(self, x: torch.Tensor, pos=None, attn_mask=None, past_key_values=None, use_cache=False ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple]]: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) pos_k = pos if use_cache: k = k.unsqueeze(2) v = v.unsqueeze(2) if past_key_values is not None: past_k, past_v = past_key_values k = torch.cat([past_k, k], dim=2) v = torch.cat([past_v, v], dim=2) new_kv = (k, v) a, b, c, d, e = k.shape k = k.reshape(a, b, c*d, e) v = v.reshape(a, b, c*d, e) if pos_k is not None: #print(pos_k.shape) pos_k = pos_k.repeat(1, c, 1) #print(pos_k.shape) q, k = self.q_norm(q), self.k_norm(k) if self.rope is not None: q = self.rope(q, pos) k = self.rope(k, pos_k) if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.attn_drop.p if self.training else 0.0, ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) # Mask if attn_mask is not None: assert attn_mask.shape[-2:] == (N, N), f"Expected mask shape [..., {N}, {N}], got {attn_mask.shape}" attn = attn + attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) if use_cache: return x, new_kv return x class MemEffAttention(Attention): def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor: assert pos is None if not XFORMERS_AVAILABLE: if attn_bias is not None: raise AssertionError("xFormers is required for using nested tensors") return super().forward(x) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = unbind(qkv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x