|
|
|
from __future__ import annotations |
|
|
|
from typing import List, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import comfy.ldm.common_dit |
|
|
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm |
|
from comfy.ldm.modules.attention import optimized_attention_masked |
|
from comfy.ldm.flux.layers import EmbedND |
|
|
|
|
|
def modulate(x, scale): |
|
return x * (1 + scale.unsqueeze(1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class JointAttention(nn.Module): |
|
"""Multi-head attention module.""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
n_heads: int, |
|
n_kv_heads: Optional[int], |
|
qk_norm: bool, |
|
operation_settings={}, |
|
): |
|
""" |
|
Initialize the Attention module. |
|
|
|
Args: |
|
dim (int): Number of input dimensions. |
|
n_heads (int): Number of heads. |
|
n_kv_heads (Optional[int]): Number of kv heads, if using GQA. |
|
|
|
""" |
|
super().__init__() |
|
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads |
|
self.n_local_heads = n_heads |
|
self.n_local_kv_heads = self.n_kv_heads |
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads |
|
self.head_dim = dim // n_heads |
|
|
|
self.qkv = operation_settings.get("operations").Linear( |
|
dim, |
|
(n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim, |
|
bias=False, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
) |
|
self.out = operation_settings.get("operations").Linear( |
|
n_heads * self.head_dim, |
|
dim, |
|
bias=False, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
) |
|
|
|
if qk_norm: |
|
self.q_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) |
|
self.k_norm = RMSNorm(self.head_dim, elementwise_affine=True, **operation_settings) |
|
else: |
|
self.q_norm = self.k_norm = nn.Identity() |
|
|
|
@staticmethod |
|
def apply_rotary_emb( |
|
x_in: torch.Tensor, |
|
freqs_cis: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
Apply rotary embeddings to input tensors using the given frequency |
|
tensor. |
|
|
|
This function applies rotary embeddings to the given query 'xq' and |
|
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The |
|
input tensors are reshaped as complex numbers, and the frequency tensor |
|
is reshaped for broadcasting compatibility. The resulting tensors |
|
contain rotary embeddings and are returned as real tensors. |
|
|
|
Args: |
|
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. |
|
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex |
|
exponentials. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor |
|
and key tensor with rotary embeddings. |
|
""" |
|
|
|
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) |
|
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] |
|
return t_out.reshape(*x_in.shape) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
x_mask: torch.Tensor, |
|
freqs_cis: torch.Tensor, |
|
) -> torch.Tensor: |
|
""" |
|
|
|
Args: |
|
x: |
|
x_mask: |
|
freqs_cis: |
|
|
|
Returns: |
|
|
|
""" |
|
bsz, seqlen, _ = x.shape |
|
|
|
xq, xk, xv = torch.split( |
|
self.qkv(x), |
|
[ |
|
self.n_local_heads * self.head_dim, |
|
self.n_local_kv_heads * self.head_dim, |
|
self.n_local_kv_heads * self.head_dim, |
|
], |
|
dim=-1, |
|
) |
|
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
|
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
|
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) |
|
|
|
xq = self.q_norm(xq) |
|
xk = self.k_norm(xk) |
|
|
|
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) |
|
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) |
|
|
|
n_rep = self.n_local_heads // self.n_local_kv_heads |
|
if n_rep >= 1: |
|
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) |
|
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) |
|
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True) |
|
|
|
return self.out(output) |
|
|
|
|
|
class FeedForward(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
hidden_dim: int, |
|
multiple_of: int, |
|
ffn_dim_multiplier: Optional[float], |
|
operation_settings={}, |
|
): |
|
""" |
|
Initialize the FeedForward module. |
|
|
|
Args: |
|
dim (int): Input dimension. |
|
hidden_dim (int): Hidden dimension of the feedforward layer. |
|
multiple_of (int): Value to ensure hidden dimension is a multiple |
|
of this value. |
|
ffn_dim_multiplier (float, optional): Custom multiplier for hidden |
|
dimension. Defaults to None. |
|
|
|
""" |
|
super().__init__() |
|
|
|
if ffn_dim_multiplier is not None: |
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim) |
|
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
|
|
|
self.w1 = operation_settings.get("operations").Linear( |
|
dim, |
|
hidden_dim, |
|
bias=False, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
) |
|
self.w2 = operation_settings.get("operations").Linear( |
|
hidden_dim, |
|
dim, |
|
bias=False, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
) |
|
self.w3 = operation_settings.get("operations").Linear( |
|
dim, |
|
hidden_dim, |
|
bias=False, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
) |
|
|
|
|
|
def _forward_silu_gating(self, x1, x3): |
|
return F.silu(x1) * x3 |
|
|
|
def forward(self, x): |
|
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) |
|
|
|
|
|
class JointTransformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
layer_id: int, |
|
dim: int, |
|
n_heads: int, |
|
n_kv_heads: int, |
|
multiple_of: int, |
|
ffn_dim_multiplier: float, |
|
norm_eps: float, |
|
qk_norm: bool, |
|
modulation=True, |
|
operation_settings={}, |
|
) -> None: |
|
""" |
|
Initialize a TransformerBlock. |
|
|
|
Args: |
|
layer_id (int): Identifier for the layer. |
|
dim (int): Embedding dimension of the input features. |
|
n_heads (int): Number of attention heads. |
|
n_kv_heads (Optional[int]): Number of attention heads in key and |
|
value features (if using GQA), or set to None for the same as |
|
query. |
|
multiple_of (int): |
|
ffn_dim_multiplier (float): |
|
norm_eps (float): |
|
|
|
""" |
|
super().__init__() |
|
self.dim = dim |
|
self.head_dim = dim // n_heads |
|
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) |
|
self.feed_forward = FeedForward( |
|
dim=dim, |
|
hidden_dim=4 * dim, |
|
multiple_of=multiple_of, |
|
ffn_dim_multiplier=ffn_dim_multiplier, |
|
operation_settings=operation_settings, |
|
) |
|
self.layer_id = layer_id |
|
self.attention_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) |
|
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) |
|
|
|
self.attention_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) |
|
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) |
|
|
|
self.modulation = modulation |
|
if modulation: |
|
self.adaLN_modulation = nn.Sequential( |
|
nn.SiLU(), |
|
operation_settings.get("operations").Linear( |
|
min(dim, 1024), |
|
4 * dim, |
|
bias=True, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
), |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
x_mask: torch.Tensor, |
|
freqs_cis: torch.Tensor, |
|
adaln_input: Optional[torch.Tensor]=None, |
|
): |
|
""" |
|
Perform a forward pass through the TransformerBlock. |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor. |
|
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. |
|
|
|
Returns: |
|
torch.Tensor: Output tensor after applying attention and |
|
feedforward layers. |
|
|
|
""" |
|
if self.modulation: |
|
assert adaln_input is not None |
|
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) |
|
|
|
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( |
|
self.attention( |
|
modulate(self.attention_norm1(x), scale_msa), |
|
x_mask, |
|
freqs_cis, |
|
) |
|
) |
|
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( |
|
self.feed_forward( |
|
modulate(self.ffn_norm1(x), scale_mlp), |
|
) |
|
) |
|
else: |
|
assert adaln_input is None |
|
x = x + self.attention_norm2( |
|
self.attention( |
|
self.attention_norm1(x), |
|
x_mask, |
|
freqs_cis, |
|
) |
|
) |
|
x = x + self.ffn_norm2( |
|
self.feed_forward( |
|
self.ffn_norm1(x), |
|
) |
|
) |
|
return x |
|
|
|
|
|
class FinalLayer(nn.Module): |
|
""" |
|
The final layer of NextDiT. |
|
""" |
|
|
|
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): |
|
super().__init__() |
|
self.norm_final = operation_settings.get("operations").LayerNorm( |
|
hidden_size, |
|
elementwise_affine=False, |
|
eps=1e-6, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
) |
|
self.linear = operation_settings.get("operations").Linear( |
|
hidden_size, |
|
patch_size * patch_size * out_channels, |
|
bias=True, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
) |
|
|
|
self.adaLN_modulation = nn.Sequential( |
|
nn.SiLU(), |
|
operation_settings.get("operations").Linear( |
|
min(hidden_size, 1024), |
|
hidden_size, |
|
bias=True, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
), |
|
) |
|
|
|
def forward(self, x, c): |
|
scale = self.adaLN_modulation(c) |
|
x = modulate(self.norm_final(x), scale) |
|
x = self.linear(x) |
|
return x |
|
|
|
|
|
class NextDiT(nn.Module): |
|
""" |
|
Diffusion model with a Transformer backbone. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
patch_size: int = 2, |
|
in_channels: int = 4, |
|
dim: int = 4096, |
|
n_layers: int = 32, |
|
n_refiner_layers: int = 2, |
|
n_heads: int = 32, |
|
n_kv_heads: Optional[int] = None, |
|
multiple_of: int = 256, |
|
ffn_dim_multiplier: Optional[float] = None, |
|
norm_eps: float = 1e-5, |
|
qk_norm: bool = False, |
|
cap_feat_dim: int = 5120, |
|
axes_dims: List[int] = (16, 56, 56), |
|
axes_lens: List[int] = (1, 512, 512), |
|
image_model=None, |
|
device=None, |
|
dtype=None, |
|
operations=None, |
|
) -> None: |
|
super().__init__() |
|
self.dtype = dtype |
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype} |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels |
|
self.patch_size = patch_size |
|
|
|
self.x_embedder = operation_settings.get("operations").Linear( |
|
in_features=patch_size * patch_size * in_channels, |
|
out_features=dim, |
|
bias=True, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
) |
|
|
|
self.noise_refiner = nn.ModuleList( |
|
[ |
|
JointTransformerBlock( |
|
layer_id, |
|
dim, |
|
n_heads, |
|
n_kv_heads, |
|
multiple_of, |
|
ffn_dim_multiplier, |
|
norm_eps, |
|
qk_norm, |
|
modulation=True, |
|
operation_settings=operation_settings, |
|
) |
|
for layer_id in range(n_refiner_layers) |
|
] |
|
) |
|
self.context_refiner = nn.ModuleList( |
|
[ |
|
JointTransformerBlock( |
|
layer_id, |
|
dim, |
|
n_heads, |
|
n_kv_heads, |
|
multiple_of, |
|
ffn_dim_multiplier, |
|
norm_eps, |
|
qk_norm, |
|
modulation=False, |
|
operation_settings=operation_settings, |
|
) |
|
for layer_id in range(n_refiner_layers) |
|
] |
|
) |
|
|
|
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) |
|
self.cap_embedder = nn.Sequential( |
|
RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, **operation_settings), |
|
operation_settings.get("operations").Linear( |
|
cap_feat_dim, |
|
dim, |
|
bias=True, |
|
device=operation_settings.get("device"), |
|
dtype=operation_settings.get("dtype"), |
|
), |
|
) |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
JointTransformerBlock( |
|
layer_id, |
|
dim, |
|
n_heads, |
|
n_kv_heads, |
|
multiple_of, |
|
ffn_dim_multiplier, |
|
norm_eps, |
|
qk_norm, |
|
operation_settings=operation_settings, |
|
) |
|
for layer_id in range(n_layers) |
|
] |
|
) |
|
self.norm_final = RMSNorm(dim, eps=norm_eps, elementwise_affine=True, **operation_settings) |
|
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) |
|
|
|
assert (dim // n_heads) == sum(axes_dims) |
|
self.axes_dims = axes_dims |
|
self.axes_lens = axes_lens |
|
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) |
|
self.dim = dim |
|
self.n_heads = n_heads |
|
|
|
def unpatchify( |
|
self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False |
|
) -> List[torch.Tensor]: |
|
""" |
|
x: (N, T, patch_size**2 * C) |
|
imgs: (N, H, W, C) |
|
""" |
|
pH = pW = self.patch_size |
|
imgs = [] |
|
for i in range(x.size(0)): |
|
H, W = img_size[i] |
|
begin = cap_size[i] |
|
end = begin + (H // pH) * (W // pW) |
|
imgs.append( |
|
x[i][begin:end] |
|
.view(H // pH, W // pW, pH, pW, self.out_channels) |
|
.permute(4, 0, 2, 1, 3) |
|
.flatten(3, 4) |
|
.flatten(1, 2) |
|
) |
|
|
|
if return_tensor: |
|
imgs = torch.stack(imgs, dim=0) |
|
return imgs |
|
|
|
def patchify_and_embed( |
|
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens |
|
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: |
|
bsz = len(x) |
|
pH = pW = self.patch_size |
|
device = x[0].device |
|
dtype = x[0].dtype |
|
|
|
if cap_mask is not None: |
|
l_effective_cap_len = cap_mask.sum(dim=1).tolist() |
|
else: |
|
l_effective_cap_len = [num_tokens] * bsz |
|
|
|
if cap_mask is not None and not torch.is_floating_point(cap_mask): |
|
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max |
|
|
|
img_sizes = [(img.size(1), img.size(2)) for img in x] |
|
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] |
|
|
|
max_seq_len = max( |
|
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) |
|
) |
|
max_cap_len = max(l_effective_cap_len) |
|
max_img_len = max(l_effective_img_len) |
|
|
|
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) |
|
|
|
for i in range(bsz): |
|
cap_len = l_effective_cap_len[i] |
|
img_len = l_effective_img_len[i] |
|
H, W = img_sizes[i] |
|
H_tokens, W_tokens = H // pH, W // pW |
|
assert H_tokens * W_tokens == img_len |
|
|
|
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) |
|
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len |
|
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() |
|
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() |
|
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids |
|
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids |
|
|
|
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) |
|
|
|
|
|
cap_freqs_cis_shape = list(freqs_cis.shape) |
|
|
|
cap_freqs_cis_shape[1] = cap_feats.shape[1] |
|
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) |
|
|
|
img_freqs_cis_shape = list(freqs_cis.shape) |
|
img_freqs_cis_shape[1] = max_img_len |
|
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) |
|
|
|
for i in range(bsz): |
|
cap_len = l_effective_cap_len[i] |
|
img_len = l_effective_img_len[i] |
|
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] |
|
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] |
|
|
|
|
|
for layer in self.context_refiner: |
|
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) |
|
|
|
|
|
flat_x = [] |
|
for i in range(bsz): |
|
img = x[i] |
|
C, H, W = img.size() |
|
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) |
|
flat_x.append(img) |
|
x = flat_x |
|
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) |
|
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device) |
|
for i in range(bsz): |
|
padded_img_embed[i, :l_effective_img_len[i]] = x[i] |
|
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max |
|
|
|
padded_img_embed = self.x_embedder(padded_img_embed) |
|
padded_img_mask = padded_img_mask.unsqueeze(1) |
|
for layer in self.noise_refiner: |
|
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) |
|
|
|
if cap_mask is not None: |
|
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) |
|
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len] |
|
else: |
|
mask = None |
|
|
|
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype) |
|
for i in range(bsz): |
|
cap_len = l_effective_cap_len[i] |
|
img_len = l_effective_img_len[i] |
|
|
|
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] |
|
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] |
|
|
|
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis |
|
|
|
|
|
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): |
|
t = 1.0 - timesteps |
|
cap_feats = context |
|
cap_mask = attention_mask |
|
bs, c, h, w = x.shape |
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) |
|
""" |
|
Forward pass of NextDiT. |
|
t: (N,) tensor of diffusion timesteps |
|
y: (N,) tensor of text tokens/features |
|
""" |
|
|
|
t = self.t_embedder(t, dtype=x.dtype) |
|
adaln_input = t |
|
|
|
cap_feats = self.cap_embedder(cap_feats) |
|
|
|
x_is_tensor = isinstance(x, torch.Tensor) |
|
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) |
|
freqs_cis = freqs_cis.to(x.device) |
|
|
|
for layer in self.layers: |
|
x = layer(x, mask, freqs_cis, adaln_input) |
|
|
|
x = self.final_layer(x, adaln_input) |
|
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] |
|
|
|
return -x |
|
|
|
|