Upload 14 files
Browse files- DCMoE.py +561 -0
- README (1).md +216 -0
- config.json +4 -0
- deepspeed_utils.py +124 -0
- model-00001-of-00003.safetensors +2 -2
- modeling.py +1182 -0
- special_tokens_map.json +78 -1
- tokenizer_config.json +102 -2
- utils.py +491 -0
- video_preprocessor_config (1).json +43 -0
- vocab.json +0 -0
DCMoE.py
ADDED
|
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
import deepspeed
|
| 8 |
+
from deepspeed import comm as dist
|
| 9 |
+
from deepspeed.utils import groups, log_dist
|
| 10 |
+
from deepspeed.utils.timer import SynchronizedWallClockTimer
|
| 11 |
+
from deepspeed.moe.sharded_moe import FIRST_ALLTOALL_TIMER, MOE_TIMER, SECOND_ALLTOALL_TIMER, _AllToAll, einsum, gumbel_rsample
|
| 12 |
+
from transformers.activations import ACT2FN
|
| 13 |
+
|
| 14 |
+
def compress_matrix(A: torch.Tensor, mask: torch.Tensor, force_dim: int = None, allow_larger_dim=None) -> torch.Tensor:
|
| 15 |
+
if A.shape[:2] != mask.shape:
|
| 16 |
+
raise ValueError("First two dimensions of A and mask must match.")
|
| 17 |
+
if mask.ndim != 2:
|
| 18 |
+
raise ValueError("mask must be a 2D tensor.")
|
| 19 |
+
if not ((mask == 0) | (mask == 1)).all():
|
| 20 |
+
raise ValueError(
|
| 21 |
+
f"mask must only contain 0s and 1s. dtype: {mask.dtype}. "
|
| 22 |
+
f"Invalid elements found at indices: {((mask != 0) & (mask != 1)).nonzero().tolist()} " # Get indices of elements not 0 AND not 1
|
| 23 |
+
f"with corresponding values: {mask[((mask != 0) & (mask != 1))].tolist()}. " # Get the values at those indices
|
| 24 |
+
f"\nOriginal mask (showing up to first 20 elements if large):\n{mask.flatten()[:20]}{'...' if mask.numel() > 20 else ''}"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
S, E = mask.shape
|
| 28 |
+
trailing_dims_shape = A.shape[2:]
|
| 29 |
+
num_trailing_dims = len(trailing_dims_shape)
|
| 30 |
+
device = A.device
|
| 31 |
+
|
| 32 |
+
ones_per_column = mask.sum(dim=0)
|
| 33 |
+
X = ones_per_column.max().item() if force_dim is None else force_dim
|
| 34 |
+
|
| 35 |
+
if X == 0:
|
| 36 |
+
return torch.empty((0, E, *trailing_dims_shape), dtype=A.dtype, device=device)
|
| 37 |
+
|
| 38 |
+
sorted_row_indices_2d = torch.argsort(mask.float(), dim=0, descending=True)
|
| 39 |
+
view_shape_for_indices = (S, E, *((1,) * num_trailing_dims))
|
| 40 |
+
expanded_indices = sorted_row_indices_2d.view(view_shape_for_indices).expand_as(A)
|
| 41 |
+
|
| 42 |
+
A_gathered = torch.gather(A, 0, expanded_indices)
|
| 43 |
+
|
| 44 |
+
if X <= A_gathered.shape[0]:
|
| 45 |
+
B_candidate = A_gathered[:X, ...]
|
| 46 |
+
elif allow_larger_dim or allow_larger_dim is None:
|
| 47 |
+
if allow_larger_dim is None:
|
| 48 |
+
print(f"[Warning compress_matrix] Target dimension X ({X}) is larger than "
|
| 49 |
+
f"A's original row count S ({S}). Padding B_candidate with zeros.")
|
| 50 |
+
B_candidate = A_gathered
|
| 51 |
+
zeros_shape = [X - A_gathered.shape[0]] + list(B_candidate.shape[1:])
|
| 52 |
+
B_candidate = torch.cat((B_candidate, torch.zeros(zeros_shape, dtype=B_candidate.dtype, device=B_candidate.device)), dim=0) # Shape (X_target_dim, E, ...)
|
| 53 |
+
else:
|
| 54 |
+
raise AssertionError(
|
| 55 |
+
f"Target dimension X ({X}) is larger than A's original row count S ({S}) "
|
| 56 |
+
f"and allow_larger_dim is False. Padding is disallowed."
|
| 57 |
+
)
|
| 58 |
+
row_indices_for_B = torch.arange(X, device=device).unsqueeze(1)
|
| 59 |
+
b_mask_2d = row_indices_for_B < ones_per_column.unsqueeze(0)
|
| 60 |
+
view_shape_for_b_mask = (X, E, *((1,) * num_trailing_dims))
|
| 61 |
+
B = B_candidate * b_mask_2d.view(view_shape_for_b_mask).to(A.dtype)
|
| 62 |
+
|
| 63 |
+
return B
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def decompress_matrix(B: torch.Tensor, mask: torch.Tensor, allow_larger_dim=None) -> torch.Tensor:
|
| 67 |
+
if B.shape[1] != mask.shape[1]:
|
| 68 |
+
raise ValueError("B's second dimension and mask's second dimension (E) must match.")
|
| 69 |
+
if mask.ndim != 2:
|
| 70 |
+
raise ValueError("mask must be a 2D tensor.")
|
| 71 |
+
if not ((mask == 0) | (mask == 1)).all():
|
| 72 |
+
raise ValueError("mask must only contain 0s and 1s.")
|
| 73 |
+
|
| 74 |
+
S, E = mask.shape
|
| 75 |
+
X = B.shape[0]
|
| 76 |
+
trailing_dims_shape = B.shape[2:]
|
| 77 |
+
num_trailing_dims = len(trailing_dims_shape)
|
| 78 |
+
device = B.device
|
| 79 |
+
|
| 80 |
+
if X == 0: return torch.zeros((S, E, *trailing_dims_shape), dtype=B.dtype, device=device)
|
| 81 |
+
if X <= S: pass
|
| 82 |
+
elif allow_larger_dim or allow_larger_dim is None:
|
| 83 |
+
if allow_larger_dim is None:
|
| 84 |
+
print(f"[Warning decompress_matrix] Input B.shape[0] ({X}) is larger than "
|
| 85 |
+
f"target A's row count S ({S}). Truncating B to its first {S} rows.")
|
| 86 |
+
B = B[:S, ...]
|
| 87 |
+
X = S
|
| 88 |
+
else:
|
| 89 |
+
raise AssertionError(
|
| 90 |
+
f"Input B.shape[0] ({X}) is larger than target A's row count S ({S}) "
|
| 91 |
+
f"and allow_larger_dim is False. Truncation is disallowed."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
sorted_row_indices_2d = torch.argsort(mask.float(), dim=0, descending=True)
|
| 95 |
+
target_A_row_indices_2d = sorted_row_indices_2d[:X, :]
|
| 96 |
+
A_reconstructed = torch.zeros((S, E, *trailing_dims_shape), dtype=B.dtype, device=device)
|
| 97 |
+
view_shape_for_target_indices = (X, E, *((1,) * num_trailing_dims))
|
| 98 |
+
expanded_target_indices = target_A_row_indices_2d.view(view_shape_for_target_indices).expand_as(B)
|
| 99 |
+
A_reconstructed.scatter_(dim=0, index=expanded_target_indices, src=B)
|
| 100 |
+
|
| 101 |
+
return A_reconstructed
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class AudioSharedExpertMLP(nn.Module):
|
| 106 |
+
"""
|
| 107 |
+
Shared expert MLP for UniMoE-Audio model.
|
| 108 |
+
Handles common audio feature transformations across all tokens.
|
| 109 |
+
"""
|
| 110 |
+
def __init__(self, config):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.hidden_size = config.hidden_size
|
| 113 |
+
self.intermediate_size = config.shared_intermediate_size
|
| 114 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 115 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 116 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 117 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 118 |
+
|
| 119 |
+
def forward(self, hidden_state):
|
| 120 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class AudioDynamicExpertMLP(nn.Module):
|
| 124 |
+
"""
|
| 125 |
+
Dynamic expert MLP for UniMoE-Audio model.
|
| 126 |
+
Specialized for adaptive audio feature processing based on content.
|
| 127 |
+
"""
|
| 128 |
+
def __init__(self, config):
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.hidden_size = config.hidden_size
|
| 131 |
+
self.intermediate_size = config.dynamic_intermediate_size
|
| 132 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 133 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 134 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 135 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 136 |
+
|
| 137 |
+
def forward(self, hidden_state):
|
| 138 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class AudioNullExpertMLP(nn.Module):
|
| 142 |
+
"""
|
| 143 |
+
Null expert MLP for UniMoE-Audio model.
|
| 144 |
+
Returns zero output for tokens that don't require expert processing.
|
| 145 |
+
"""
|
| 146 |
+
def __init__(self, config):
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
def forward(self, hidden_state):
|
| 150 |
+
return torch.zeros_like(hidden_state, dtype=hidden_state.dtype, device=hidden_state.device)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def audio_sparse_expert_mixer(scores, top_k, jitter_eps, training):
|
| 154 |
+
"""
|
| 155 |
+
Sparse expert mixing function for UniMoE-Audio.
|
| 156 |
+
Implements adaptive expert selection with noise injection for training.
|
| 157 |
+
"""
|
| 158 |
+
masked_scores = scores
|
| 159 |
+
multiplier_list = []
|
| 160 |
+
selected_experts_list = []
|
| 161 |
+
|
| 162 |
+
for _ in range(top_k):
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
|
| 165 |
+
factor = scores.abs().clamp(min=mask_logits_threshold.abs())
|
| 166 |
+
mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
|
| 167 |
+
|
| 168 |
+
masked_gates = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
|
| 169 |
+
|
| 170 |
+
selected_experts = max_ind
|
| 171 |
+
|
| 172 |
+
masked_gates = torch.softmax(masked_gates, dim=-1)
|
| 173 |
+
multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
|
| 174 |
+
|
| 175 |
+
multiplier = multiplier_o
|
| 176 |
+
|
| 177 |
+
masked_scores = torch.scatter(
|
| 178 |
+
masked_scores,
|
| 179 |
+
-1,
|
| 180 |
+
selected_experts,
|
| 181 |
+
float("-inf"),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
multiplier_list.append(multiplier)
|
| 185 |
+
selected_experts_list.append(selected_experts)
|
| 186 |
+
|
| 187 |
+
multiplier = torch.concat(multiplier_list, dim=-1)
|
| 188 |
+
selected_experts = torch.concat(selected_experts_list, dim=-1)
|
| 189 |
+
return (
|
| 190 |
+
multiplier,
|
| 191 |
+
selected_experts,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def audio_dynamic_expert_selection(logits, top_p):
|
| 196 |
+
"""
|
| 197 |
+
Dynamic expert selection for UniMoE-Audio based on cumulative probability threshold.
|
| 198 |
+
Adapts the number of experts based on audio content complexity.
|
| 199 |
+
"""
|
| 200 |
+
dynamic_scores = torch.softmax(logits, dim=-1)
|
| 201 |
+
dynamic_scores_sorted, _ = torch.sort(dynamic_scores, dim=-1, descending=True)
|
| 202 |
+
dynamic_scores_cumsum = dynamic_scores_sorted.cumsum(dim=-1)
|
| 203 |
+
dynamic_top_k = (~(dynamic_scores_cumsum >= top_p)).sum(dim=-1)
|
| 204 |
+
dynamic_top_k = dynamic_top_k + 1
|
| 205 |
+
return dynamic_top_k
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _audio_expert_capacity(num_tokens, num_experts, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
|
| 209 |
+
"""Calculate expert capacity for UniMoE-Audio based on token distribution and capacity factor."""
|
| 210 |
+
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
|
| 211 |
+
if capacity < min_capacity:
|
| 212 |
+
capacity = min_capacity.to(torch.int64)
|
| 213 |
+
return capacity
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def calculate_audio_global_routing_weight(
|
| 217 |
+
expert_mask: torch.Tensor,
|
| 218 |
+
full_router_logits: torch.Tensor,
|
| 219 |
+
mlp_dynamic_expert_num: int,
|
| 220 |
+
routing_weights: torch.Tensor,
|
| 221 |
+
):
|
| 222 |
+
"""
|
| 223 |
+
Calculate global routing weights for UniMoE-Audio combining dynamic and fixed expert weights.
|
| 224 |
+
Optimized for audio generation tasks.
|
| 225 |
+
"""
|
| 226 |
+
global_weight = torch.softmax(full_router_logits.masked_fill(expert_mask == 0, float("-inf")), dim=-1)
|
| 227 |
+
global_dynamic_weight = global_weight[:, :mlp_dynamic_expert_num]
|
| 228 |
+
global_fixed_weight = global_weight[:, mlp_dynamic_expert_num:]
|
| 229 |
+
global_dynamic_weight = routing_weights * global_dynamic_weight.sum(-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1])
|
| 230 |
+
global_weight = torch.cat((global_dynamic_weight, global_fixed_weight), dim=-1)
|
| 231 |
+
return global_weight
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class UniMoEAudioSparseMoeBlock(nn.Module):
|
| 235 |
+
"""
|
| 236 |
+
UniMoE-Audio Sparse Mixture of Experts block with dynamic routing and expert selection.
|
| 237 |
+
Optimized for audio generation tasks with efficient sparse operations and capacity management.
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
def __init__(self, config):
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.hidden_dim = config.hidden_size
|
| 243 |
+
self.mlp_dynamic_expert_num = config.mlp_dynamic_expert_num + config.mlp_dynamic_null_expert_num
|
| 244 |
+
self.mlp_dynamic_real_expert_num = config.mlp_dynamic_expert_num
|
| 245 |
+
self.mlp_dynamic_null_expert_num = config.mlp_dynamic_null_expert_num
|
| 246 |
+
self.mlp_dynamic_top_p = config.mlp_dynamic_top_p
|
| 247 |
+
self.mlp_dynamic_top_k = config.mlp_dynamic_top_k
|
| 248 |
+
self.mlp_fixed_expert_num = config.mlp_fixed_expert_num
|
| 249 |
+
self.num_experts = self.mlp_dynamic_expert_num + self.mlp_fixed_expert_num
|
| 250 |
+
|
| 251 |
+
if self.mlp_dynamic_top_p == 0:
|
| 252 |
+
print(f"mlp_dynamic_top_p is 0, will use mlp_dynamic_top_k={self.mlp_dynamic_top_k} instead !!!")
|
| 253 |
+
|
| 254 |
+
self.ignore_differentiable_router = config.ignore_differentiable_router
|
| 255 |
+
if self.ignore_differentiable_router:
|
| 256 |
+
print("ignore_differentiable_router is True, will not use router_logits !!!")
|
| 257 |
+
|
| 258 |
+
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
| 259 |
+
self.fixed_real_moe = nn.ModuleList([AudioSharedExpertMLP(config) for _ in range(self.mlp_fixed_expert_num)])
|
| 260 |
+
self.dynamic_real_moe = UniMoEAudioMoE(config, AudioDynamicExpertMLP(config), self.mlp_dynamic_real_expert_num, config.ep_size)
|
| 261 |
+
|
| 262 |
+
self.router_jitter_noise = config.router_jitter_noise
|
| 263 |
+
self.input_jitter_noise = config.input_jitter_noise
|
| 264 |
+
|
| 265 |
+
self.min_capacity = config.min_capacity
|
| 266 |
+
self.capacity_factor = config.capacity_factor
|
| 267 |
+
self.token_drop = config.token_drop
|
| 268 |
+
self.drop_policy = config.drop_policy
|
| 269 |
+
|
| 270 |
+
self.avg_hidden_states_last = config.avg_hidden_states_last
|
| 271 |
+
self.drop_token_num_print = config.drop_token_num_print
|
| 272 |
+
self.fp32_gate = config.fp32_gate
|
| 273 |
+
|
| 274 |
+
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, aux_balance_weight: torch.Tensor=None):
|
| 275 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 276 |
+
original_hidden_states = hidden_states
|
| 277 |
+
|
| 278 |
+
if self.training and self.fp32_gate:
|
| 279 |
+
hidden_states = hidden_states.float()
|
| 280 |
+
|
| 281 |
+
if self.training and self.input_jitter_noise > 0:
|
| 282 |
+
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise)
|
| 283 |
+
|
| 284 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
| 285 |
+
|
| 286 |
+
if self.training and self.fp32_gate:
|
| 287 |
+
full_router_logits = torch.nn.functional.linear(hidden_states, weight=self.gate.weight.float(), bias=None)
|
| 288 |
+
else:
|
| 289 |
+
full_router_logits = self.gate(hidden_states)
|
| 290 |
+
dynamic_router_logits = full_router_logits[:, : self.mlp_dynamic_expert_num]
|
| 291 |
+
|
| 292 |
+
if self.mlp_dynamic_top_p != 0:
|
| 293 |
+
dynamic_top_k = audio_dynamic_expert_selection(dynamic_router_logits, self.mlp_dynamic_top_p)
|
| 294 |
+
else:
|
| 295 |
+
dynamic_top_k = torch.full((dynamic_router_logits.shape[0],), self.mlp_dynamic_top_k, dtype=torch.int, device=dynamic_router_logits.device)
|
| 296 |
+
|
| 297 |
+
expert_mask = torch.zeros((batch_size * sequence_length, self.num_experts), dtype=torch.int, device=hidden_states.device)
|
| 298 |
+
|
| 299 |
+
routing_weights = torch.zeros((batch_size * sequence_length, self.mlp_dynamic_expert_num), dtype=hidden_states.dtype, device=hidden_states.device)
|
| 300 |
+
for top_k in range(1, self.mlp_dynamic_expert_num + 1):
|
| 301 |
+
group_idx = torch.nonzero(dynamic_top_k == top_k, as_tuple=True)[0]
|
| 302 |
+
if len(group_idx) == 0:
|
| 303 |
+
continue
|
| 304 |
+
|
| 305 |
+
dynamic_group_logits = dynamic_router_logits[group_idx]
|
| 306 |
+
group_routing_weights, group_selected_experts = audio_sparse_expert_mixer(
|
| 307 |
+
dynamic_group_logits,
|
| 308 |
+
top_k=top_k,
|
| 309 |
+
jitter_eps=self.router_jitter_noise,
|
| 310 |
+
training=self.training and not self.ignore_differentiable_router,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
group_expert_mask = torch.nn.functional.one_hot(group_selected_experts, num_classes=self.num_experts)
|
| 314 |
+
group_expert_mask = group_expert_mask.sum(dim=1)
|
| 315 |
+
|
| 316 |
+
group_weight = torch.zeros((len(group_idx), self.mlp_dynamic_expert_num), dtype=hidden_states.dtype, device=hidden_states.device)
|
| 317 |
+
group_weight.scatter_(dim=-1, index=group_selected_experts, src=group_routing_weights)
|
| 318 |
+
routing_weights.index_add_(0, group_idx, group_weight)
|
| 319 |
+
|
| 320 |
+
expert_mask.index_add_(0, group_idx, group_expert_mask.to(expert_mask.dtype))
|
| 321 |
+
|
| 322 |
+
routing_weights = routing_weights / (routing_weights.sum(dim=-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1]) + 1e-6)
|
| 323 |
+
|
| 324 |
+
if attention_mask is not None:
|
| 325 |
+
attention_mask = attention_mask.to(expert_mask.dtype).view(-1).unsqueeze(-1).expand(-1, self.num_experts)
|
| 326 |
+
expert_mask = expert_mask * attention_mask
|
| 327 |
+
|
| 328 |
+
if self.mlp_dynamic_expert_num < self.num_experts:
|
| 329 |
+
expert_mask[:, self.mlp_dynamic_expert_num :] = 1
|
| 330 |
+
|
| 331 |
+
aux_loss = audio_load_balancing_loss_func(
|
| 332 |
+
expert_mask=expert_mask,
|
| 333 |
+
mlp_dynamic_expert_num=self.mlp_dynamic_expert_num,
|
| 334 |
+
global_weight=None,
|
| 335 |
+
full_router_logits=full_router_logits,
|
| 336 |
+
routing_weights=routing_weights,
|
| 337 |
+
aux_balance_weight=aux_balance_weight,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if self.token_drop:
|
| 341 |
+
expert_mask_dtype = expert_mask.dtype
|
| 342 |
+
capacity = _audio_expert_capacity(batch_size * sequence_length, self.mlp_dynamic_expert_num, torch.tensor(self.capacity_factor), torch.tensor(self.min_capacity))
|
| 343 |
+
if self.drop_policy == "probs":
|
| 344 |
+
if capacity > dynamic_router_logits.shape[0]:
|
| 345 |
+
print(f"[warning] token capacity({capacity}) > token num({dynamic_router_logits.shape[0]}), setting capacity=token num")
|
| 346 |
+
capacity = dynamic_router_logits.shape[0]
|
| 347 |
+
dynamic_expert_mask = expert_mask[:, : self.mlp_dynamic_expert_num].bool()
|
| 348 |
+
token_drop_router_logits = torch.masked_fill(dynamic_router_logits, ~dynamic_expert_mask, torch.finfo(dynamic_router_logits.dtype).min)
|
| 349 |
+
capacity_probs, capacity_indices = torch.topk(token_drop_router_logits, k=capacity, dim=0, sorted=False)
|
| 350 |
+
capacity_mask = torch.zeros_like(expert_mask).scatter(0, capacity_indices, 1)
|
| 351 |
+
capacity_mask[:, self.mlp_dynamic_expert_num :] = 1
|
| 352 |
+
expert_mask = torch.logical_and(expert_mask, capacity_mask)
|
| 353 |
+
|
| 354 |
+
ori_token_num = dynamic_expert_mask.sum().item()
|
| 355 |
+
cur_token_num = expert_mask[:, : self.mlp_dynamic_expert_num].sum().item()
|
| 356 |
+
if self.drop_token_num_print and ("RANK" not in os.environ or int(os.environ["RANK"]) == 0):
|
| 357 |
+
print(f"drop {ori_token_num - cur_token_num} tokens from total {ori_token_num} tokens")
|
| 358 |
+
|
| 359 |
+
elif self.drop_policy == "position":
|
| 360 |
+
locations = torch.cumsum(expert_mask, dim=0) - 1
|
| 361 |
+
expert_mask *= torch.lt(locations, capacity)
|
| 362 |
+
else:
|
| 363 |
+
raise ValueError(f"Invalid drop_policy: {self.drop_policy}")
|
| 364 |
+
expert_mask = expert_mask.to(expert_mask_dtype)
|
| 365 |
+
|
| 366 |
+
routing_weights = routing_weights.masked_fill(~(expert_mask[:, : self.mlp_dynamic_expert_num].bool()), 0.0)
|
| 367 |
+
routing_weights = routing_weights / (routing_weights.sum(dim=-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1]) + 1e-6)
|
| 368 |
+
|
| 369 |
+
if self.mlp_dynamic_expert_num < self.num_experts:
|
| 370 |
+
global_weight = calculate_audio_global_routing_weight(expert_mask, full_router_logits, self.mlp_dynamic_expert_num, routing_weights)
|
| 371 |
+
else:
|
| 372 |
+
global_weight = routing_weights
|
| 373 |
+
|
| 374 |
+
hidden_states = original_hidden_states.view(-1, hidden_dim)
|
| 375 |
+
|
| 376 |
+
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
|
| 377 |
+
global_weight = global_weight.to(hidden_states.dtype)
|
| 378 |
+
|
| 379 |
+
current_hidden_states = self.dynamic_real_moe(hidden_states, expert_mask=expert_mask[:, : self.mlp_dynamic_real_expert_num], router_weight=global_weight[:, : self.mlp_dynamic_real_expert_num])
|
| 380 |
+
final_hidden_states = final_hidden_states + current_hidden_states
|
| 381 |
+
|
| 382 |
+
for expert_idx in range(self.mlp_fixed_expert_num):
|
| 383 |
+
expert_layer = self.fixed_real_moe[expert_idx]
|
| 384 |
+
|
| 385 |
+
current_state = hidden_states
|
| 386 |
+
current_global_weight = global_weight[:, self.mlp_dynamic_expert_num + expert_idx].unsqueeze(-1)
|
| 387 |
+
current_hidden_states = expert_layer(current_state) * current_global_weight
|
| 388 |
+
|
| 389 |
+
final_hidden_states = final_hidden_states + current_hidden_states
|
| 390 |
+
|
| 391 |
+
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
| 392 |
+
|
| 393 |
+
if not self.training and self.avg_hidden_states_last:
|
| 394 |
+
dist.all_reduce(final_hidden_states, op=dist.ReduceOp.AVG, group=self.dynamic_real_moe.deepspeed_moe.ep_group)
|
| 395 |
+
|
| 396 |
+
return final_hidden_states, full_router_logits, dynamic_top_k, expert_mask, global_weight, aux_loss
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def audio_load_balancing_loss_func(
|
| 400 |
+
expert_mask: torch.Tensor,
|
| 401 |
+
mlp_dynamic_expert_num: int,
|
| 402 |
+
global_weight: Optional[torch.Tensor] = None,
|
| 403 |
+
full_router_logits: Optional[torch.Tensor] = None,
|
| 404 |
+
routing_weights: Optional[torch.Tensor] = None,
|
| 405 |
+
aux_balance_weight: Optional[torch.Tensor] = None,
|
| 406 |
+
) -> float:
|
| 407 |
+
"""Calculate load balancing loss for UniMoE-Audio expert routing to encourage balanced usage."""
|
| 408 |
+
min_dtype = torch.finfo(full_router_logits.dtype).min
|
| 409 |
+
global_weight = full_router_logits.masked_fill(expert_mask == 0, min_dtype)
|
| 410 |
+
global_weight = global_weight[:, :mlp_dynamic_expert_num]
|
| 411 |
+
global_weight = torch.softmax(global_weight, dim=-1)
|
| 412 |
+
expert_mask = expert_mask[:, :mlp_dynamic_expert_num]
|
| 413 |
+
|
| 414 |
+
num_experts = expert_mask.shape[-1]
|
| 415 |
+
if aux_balance_weight is None:
|
| 416 |
+
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
| 417 |
+
router_prob_per_expert = torch.mean(global_weight, dim=0)
|
| 418 |
+
else:
|
| 419 |
+
batch_size, sequence_length = aux_balance_weight.shape
|
| 420 |
+
num_hidden_layers = global_weight.shape[0] // (batch_size * sequence_length)
|
| 421 |
+
expert_attention_mask = aux_balance_weight[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(global_weight.device)
|
| 422 |
+
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
|
| 423 |
+
router_prob_per_expert = torch.sum(global_weight * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
|
| 424 |
+
|
| 425 |
+
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
|
| 426 |
+
|
| 427 |
+
return overall_loss * num_experts
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class AudioExperts(deepspeed.moe.experts.Experts):
|
| 431 |
+
"""Custom Audio experts class extending DeepSpeed MoE experts with additional functionality."""
|
| 432 |
+
|
| 433 |
+
def __init__(self, expert, num_local_experts=1, expert_group_name=None):
|
| 434 |
+
super(deepspeed.moe.experts.Experts, self).__init__()
|
| 435 |
+
|
| 436 |
+
self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
|
| 437 |
+
self.num_local_experts = num_local_experts
|
| 438 |
+
|
| 439 |
+
for expert in self.deepspeed_experts:
|
| 440 |
+
for name, param in expert.named_parameters():
|
| 441 |
+
param.allreduce = False
|
| 442 |
+
param.group_name = expert_group_name
|
| 443 |
+
|
| 444 |
+
def forward(self, inputs):
|
| 445 |
+
chunks = inputs.chunk(self.num_local_experts, dim=1)
|
| 446 |
+
expert_outputs = []
|
| 447 |
+
for chunk, expert in zip(chunks, self.deepspeed_experts):
|
| 448 |
+
out = expert(chunk)
|
| 449 |
+
if type(out) is tuple:
|
| 450 |
+
out = out[0]
|
| 451 |
+
expert_outputs += [out]
|
| 452 |
+
|
| 453 |
+
expert_output = torch.cat(expert_outputs, dim=1)
|
| 454 |
+
return expert_output
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class AudioMOELayer(deepspeed.moe.sharded_moe.MOELayer):
|
| 458 |
+
"""Custom Audio MoE layer extending DeepSpeed MOELayer with matrix compression optimization."""
|
| 459 |
+
|
| 460 |
+
def __init__(
|
| 461 |
+
self,
|
| 462 |
+
experts: nn.Module,
|
| 463 |
+
ep_group_name,
|
| 464 |
+
ep_size,
|
| 465 |
+
num_local_experts: int,
|
| 466 |
+
use_tutel: bool = False,
|
| 467 |
+
) -> None:
|
| 468 |
+
super(deepspeed.moe.sharded_moe.MOELayer, self).__init__()
|
| 469 |
+
|
| 470 |
+
self.experts = experts
|
| 471 |
+
self.ep_group = None
|
| 472 |
+
self.ep_size = ep_size
|
| 473 |
+
self.ep_group_name = ep_group_name
|
| 474 |
+
self.num_local_experts = num_local_experts
|
| 475 |
+
self.time_falltoall = 0.0
|
| 476 |
+
self.time_salltoall = 0.0
|
| 477 |
+
self.time_moe = 0.0
|
| 478 |
+
self.timers = SynchronizedWallClockTimer()
|
| 479 |
+
self.wall_clock_breakdown = False
|
| 480 |
+
|
| 481 |
+
def _set_ep_group(self, ep_group):
|
| 482 |
+
self.ep_group = ep_group
|
| 483 |
+
|
| 484 |
+
def forward(self, hidden_states: Tensor, expert_mask: Tensor, router_weight: Tensor) -> Tensor:
|
| 485 |
+
router_weight = router_weight * expert_mask
|
| 486 |
+
|
| 487 |
+
if self.wall_clock_breakdown:
|
| 488 |
+
self.timers(MOE_TIMER).start()
|
| 489 |
+
|
| 490 |
+
d_model = hidden_states.shape[-1]
|
| 491 |
+
seq_len = hidden_states.shape[0]
|
| 492 |
+
expert_num = expert_mask.shape[-1]
|
| 493 |
+
capacity = expert_mask.sum(dim=0).max()
|
| 494 |
+
if self.ep_group is not None:
|
| 495 |
+
dist.all_reduce(capacity, op=dist.ReduceOp.MAX, group=self.ep_group)
|
| 496 |
+
|
| 497 |
+
compres_hidden_states = hidden_states.unsqueeze(1).expand(seq_len, expert_num, d_model)
|
| 498 |
+
compres_hidden_states = compress_matrix(compres_hidden_states, expert_mask, force_dim=capacity, allow_larger_dim=True) # [C, expert_num, d_model]
|
| 499 |
+
compres_expert_mask = compress_matrix(expert_mask, expert_mask, force_dim=capacity, allow_larger_dim=True)
|
| 500 |
+
dispatched_input = einsum("ce,cem->ecm", compres_expert_mask, compres_hidden_states)
|
| 501 |
+
|
| 502 |
+
if self.wall_clock_breakdown:
|
| 503 |
+
self.timers(FIRST_ALLTOALL_TIMER).start()
|
| 504 |
+
|
| 505 |
+
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
|
| 506 |
+
|
| 507 |
+
if self.wall_clock_breakdown:
|
| 508 |
+
self.timers(FIRST_ALLTOALL_TIMER).stop()
|
| 509 |
+
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
|
| 510 |
+
|
| 511 |
+
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
| 512 |
+
|
| 513 |
+
expert_output = self.experts(dispatched_input)
|
| 514 |
+
|
| 515 |
+
if self.wall_clock_breakdown:
|
| 516 |
+
self.timers(SECOND_ALLTOALL_TIMER).start()
|
| 517 |
+
|
| 518 |
+
expert_output = _AllToAll.apply(self.ep_group, expert_output)
|
| 519 |
+
|
| 520 |
+
if self.wall_clock_breakdown:
|
| 521 |
+
self.timers(SECOND_ALLTOALL_TIMER).stop()
|
| 522 |
+
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
|
| 523 |
+
|
| 524 |
+
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
| 525 |
+
expert_output = decompress_matrix(expert_output.transpose(0, 1), expert_mask, allow_larger_dim=True)
|
| 526 |
+
combined_output = einsum("se,sem->sm", router_weight, expert_output)
|
| 527 |
+
if self.wall_clock_breakdown:
|
| 528 |
+
self.timers(MOE_TIMER).stop()
|
| 529 |
+
self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False)
|
| 530 |
+
|
| 531 |
+
return combined_output
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class UniMoEAudioMoE(deepspeed.moe.layer.MoE):
|
| 535 |
+
"""Custom Audio MoE class extending DeepSpeed MoE with configuration and parallelism setup."""
|
| 536 |
+
|
| 537 |
+
def __init__(self, config, expert, num_experts, ep_size, moe_name_prefix="ep_size"):
|
| 538 |
+
super(deepspeed.moe.layer.MoE, self).__init__()
|
| 539 |
+
self.enable_expert_tensor_parallelism = config.enable_expert_tensor_parallelism
|
| 540 |
+
self.ep_size = ep_size
|
| 541 |
+
self.num_experts = num_experts
|
| 542 |
+
self.expert_group_name = f"{moe_name_prefix}_{self.ep_size}"
|
| 543 |
+
self.num_local_experts = self.num_experts // self.ep_size
|
| 544 |
+
log_dist(f"Creating MoE layer with num_experts: {self.num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}", [0])
|
| 545 |
+
experts = AudioExperts(expert, self.num_local_experts, self.expert_group_name)
|
| 546 |
+
self.deepspeed_moe = AudioMOELayer(experts, self.expert_group_name, self.ep_size, self.num_local_experts)
|
| 547 |
+
|
| 548 |
+
def set_deepspeed_parallelism(self, use_data_before_expert_parallel_=False):
|
| 549 |
+
self._create_process_groups(use_data_before_expert_parallel_=use_data_before_expert_parallel_)
|
| 550 |
+
|
| 551 |
+
def _create_process_groups(self, use_data_before_expert_parallel_=False):
|
| 552 |
+
if self.expert_group_name not in groups._get_expert_parallel_group_dict():
|
| 553 |
+
print(f"No existing process group found, creating a new group named: {self.expert_group_name}")
|
| 554 |
+
if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):
|
| 555 |
+
groups._create_expert_and_data_parallel(self.ep_size, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
|
| 556 |
+
else:
|
| 557 |
+
groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
|
| 558 |
+
self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))
|
| 559 |
+
|
| 560 |
+
def forward(self, *input_args, **input_kwargs):
|
| 561 |
+
return self.deepspeed_moe(*input_args, **input_kwargs)
|
README (1).md
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
- zh
|
| 6 |
+
base_model:
|
| 7 |
+
- Qwen/Qwen2-0.5B
|
| 8 |
+
pipeline_tag: feature-extraction
|
| 9 |
+
library_name: sentence-transformers
|
| 10 |
+
tags:
|
| 11 |
+
- MoE
|
| 12 |
+
- Unified Generation
|
| 13 |
+
- Speech and Music
|
| 14 |
+
- Multi-modal
|
| 15 |
+
datasets:
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
<h1 align="center">UniMoE-Audio</h1>
|
| 19 |
+
|
| 20 |
+
**UniMoE-Audio** is a unified framework that seamlessly combines speech and music generation. Powered by a novel dynamic-capacity Mixture-of-Experts design, it adapts intelligently to input complexity, enabling high-fidelity voice and expressive music within a single model.
|
| 21 |
+
|
| 22 |
+
## Key Innovations
|
| 23 |
+
|
| 24 |
+
#### **Top-P Dynamic Routing Strategy**
|
| 25 |
+
We introduce a **Top-P routing strategy** that overcomes the limitations of conventional static Top-K routing:
|
| 26 |
+
|
| 27 |
+
- **Dynamic Expert Allocation**: Instead of assigning a fixed number of experts to every token, our approach dynamically determines the number of experts based on token complexity
|
| 28 |
+
- **Resource Efficiency**: Simple tokens don't consume unnecessary resources, while complex tokens receive sufficient processing power
|
| 29 |
+
- **Performance Optimization**: Results in improved overall efficiency and performance
|
| 30 |
+
|
| 31 |
+
#### **Three-Stage Training Curriculum**
|
| 32 |
+
We employ a comprehensive training approach to enable effective joint learning from imbalanced data:
|
| 33 |
+
|
| 34 |
+
1. **Independent Specialist Training** - Initial expert specialization
|
| 35 |
+
2. **Integration with Warm-up** - Gradual system integration
|
| 36 |
+
3. **Synergistic Joint Training** - Collaborative optimization
|
| 37 |
+
|
| 38 |
+
## Model Information
|
| 39 |
+
- **Base Model**: Qwen2.5-VL with MoE extensions
|
| 40 |
+
- **Audio Codec**: DAC (Descript Audio Codec) with 12 channels
|
| 41 |
+
- **Expert Configuration**: 8 dynamic experts + 2 shared experts
|
| 42 |
+
- **Audio Sampling Rate**: 16kHz
|
| 43 |
+
- Usage:
|
| 44 |
+
- Text-to-Speech (TTS)
|
| 45 |
+
- Speech-to-Text (STT)
|
| 46 |
+
- Music Generation
|
| 47 |
+
- GPU Requirements:
|
| 48 |
+
- Memory: 16GB+
|
| 49 |
+
- CUDA-enabled GPU
|
| 50 |
+
|
| 51 |
+
## Open-source Plan
|
| 52 |
+
- [☑️] Model Checkpoint
|
| 53 |
+
- [☑️] [UniMoE-Audio-preview](https://huggingface.co/foggyforest/UniMoE-Audio-preview)
|
| 54 |
+
- [☑️] Inference Code: [HITsz-TMG/UniMoE-Audio](https://github.com/HITsz-TMG/UMOE-Scaling-Unified-Multimodal-LLMs/tree/master/UniMoE-Audio)
|
| 55 |
+
- [☑️] Technical Report: [UniMoE-Audio: Unified Speech and Music Generation with Dynamic-Capacity MoE]()
|
| 56 |
+
|
| 57 |
+
## Evaluation
|
| 58 |
+
### Speech Synthesis
|
| 59 |
+

|
| 60 |
+
### Text to Music Generation
|
| 61 |
+

|
| 62 |
+
### Video-Text to Music Generation
|
| 63 |
+

|
| 64 |
+
|
| 65 |
+
## Requirements
|
| 66 |
+
We recommend using conda to install the environment.
|
| 67 |
+
```bash
|
| 68 |
+
conda env create -f configs/enviroment.yml # add -n for your name
|
| 69 |
+
conda activate unimoe-audio # default name
|
| 70 |
+
```
|
| 71 |
+
then install the torch packages
|
| 72 |
+
```bash
|
| 73 |
+
# Use the official index
|
| 74 |
+
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121
|
| 75 |
+
|
| 76 |
+
# Use Tsinghua mirror source
|
| 77 |
+
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 -i https://pypi.tuna.tsinghua.edu.cn/simple/ --extra-index-url https://download.pytorch.org/whl/cu121
|
| 78 |
+
|
| 79 |
+
# Use Alibaba Cloud mirror source
|
| 80 |
+
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 -i https://mirrors.aliyun.com/pypi/simple/ --extra-index-url https://download.pytorch.org/whl/cu121
|
| 81 |
+
```
|
| 82 |
+
A `dac model` is also required to be downloaded in '/path/to/UniMoE-Audio/utils/dac_model'.
|
| 83 |
+
It will be automatically downloaded when running the first time.
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
## Usage
|
| 87 |
+
Please move to the `utils` folder to your working directory.
|
| 88 |
+
Then you can use the model like this:
|
| 89 |
+
|
| 90 |
+
```python
|
| 91 |
+
from modeling import UniMoEAudio
|
| 92 |
+
|
| 93 |
+
MODEL_NAME= "HIT-TMG/UniMoE-Audio-Preview"
|
| 94 |
+
|
| 95 |
+
# Load model
|
| 96 |
+
unimoe_audio = UniMoEAudio.from_pretrained(
|
| 97 |
+
MODEL_NAME,
|
| 98 |
+
cache_dir='./cache',
|
| 99 |
+
torch_dtype='bfloat16',
|
| 100 |
+
device_id=0
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### TTS Example:
|
| 106 |
+
```python
|
| 107 |
+
# TTS/Voice Cloning
|
| 108 |
+
target_text = "Target Text"
|
| 109 |
+
prompt_audio = "/path/to/your/prompt_audio.wav"
|
| 110 |
+
prompt_text = "Prompt Text"
|
| 111 |
+
|
| 112 |
+
# Encode prompt audio
|
| 113 |
+
prompt_codec = unimoe_audio.dac.encode(prompt_audio)
|
| 114 |
+
|
| 115 |
+
prompt_codec_input_ids = unimoe_audio._preprocess_codec(
|
| 116 |
+
codec=prompt_codec,
|
| 117 |
+
codec_delay_pattern=unimoe_audio.model.config.codec_delay_pattern,
|
| 118 |
+
codec_channels=unimoe_audio.model.num_channels,
|
| 119 |
+
codec_bos_value=unimoe_audio.model.config.codec_bos_value,
|
| 120 |
+
codec_eos_value=unimoe_audio.model.config.codec_eos_value,
|
| 121 |
+
codec_pad_value=unimoe_audio.model.config.codec_pad_value
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Construct prompt text
|
| 125 |
+
text_input, _, _ = unimoe_audio._prepare_prompt(task="speech", caption=target_text, prompt_text=prompt_text, prompt_codec_input_ids=prompt_codec_input_ids)
|
| 126 |
+
|
| 127 |
+
# Tokenize input text
|
| 128 |
+
source_input = unimoe_audio.tokenizer(text_input, add_special_tokens=False, return_tensors="pt", padding=True)
|
| 129 |
+
prompt_codec_input_ids = prompt_codec_input_ids.unsqueeze(0).expand(len(text_input), -1, -1).reshape(-1, prompt_codec_input_ids.shape[1])
|
| 130 |
+
|
| 131 |
+
#Speech Generation
|
| 132 |
+
unimoe_audio._generate_core(
|
| 133 |
+
source_input,
|
| 134 |
+
prompt_codec_input_ids,
|
| 135 |
+
save_name = "speech",
|
| 136 |
+
output_dir = "./",
|
| 137 |
+
cfg_scale = 1.0,
|
| 138 |
+
temperature = 1.0,
|
| 139 |
+
top_p = 1.0,
|
| 140 |
+
cfg_filter_top_k = 45,
|
| 141 |
+
eos_prob_mul_factor = 1.0,
|
| 142 |
+
do_sample = True,
|
| 143 |
+
debug_guidance_step = -1,
|
| 144 |
+
use_cache = True
|
| 145 |
+
)
|
| 146 |
+
```
|
| 147 |
+
### T2M Example:
|
| 148 |
+
```python
|
| 149 |
+
caption = "music deccription"
|
| 150 |
+
|
| 151 |
+
# Construct prompt text
|
| 152 |
+
text_input, _, _ = unimoe_audio._prepare_prompt(task="music", caption=caption)
|
| 153 |
+
|
| 154 |
+
# Tokenize input text
|
| 155 |
+
source_input = unimoe_audio.tokenizer(text_input, add_special_tokens=False, return_tensors="pt", padding=True)
|
| 156 |
+
|
| 157 |
+
#music generation with prompt text
|
| 158 |
+
unimoe_audio._generate_core(
|
| 159 |
+
source_input,
|
| 160 |
+
None,
|
| 161 |
+
save_name = "music",
|
| 162 |
+
output_dir = "./",
|
| 163 |
+
cfg_scale = 10.0,
|
| 164 |
+
temperature = 1.0,
|
| 165 |
+
top_p = 1.0,
|
| 166 |
+
cfg_filter_top_k = 45,
|
| 167 |
+
eos_prob_mul_factor = 0.6,
|
| 168 |
+
do_sample = True,
|
| 169 |
+
debug_guidance_step = -1,
|
| 170 |
+
use_cache = True
|
| 171 |
+
)
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### VT2M Example:
|
| 175 |
+
```python
|
| 176 |
+
# VT2M
|
| 177 |
+
caption = "music deccription"
|
| 178 |
+
prompt_video = "/path/to/your/video.mp4"
|
| 179 |
+
|
| 180 |
+
#prepare prompt
|
| 181 |
+
text_input, video_inputs, fps_inputs = unimoe_audio._prepare_prompt(task="music", caption=caption, video=prompt_video, fps=1, sampling_fps=1, max_frames=1)
|
| 182 |
+
|
| 183 |
+
#input processor
|
| 184 |
+
source_input = unimoe_audio.processor(
|
| 185 |
+
text=text_input,
|
| 186 |
+
images=None,
|
| 187 |
+
videos=video_inputs,
|
| 188 |
+
fps=fps_inputs,
|
| 189 |
+
padding=True,
|
| 190 |
+
return_tensors="pt",
|
| 191 |
+
do_resize=False
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
#music generation with prompt video
|
| 195 |
+
unimoe_audio._generate_core(
|
| 196 |
+
source_input,
|
| 197 |
+
None,
|
| 198 |
+
save_name = "video_music",
|
| 199 |
+
output_dir = "./",
|
| 200 |
+
rebuild_codec=None,
|
| 201 |
+
cfg_scale = 10.0,
|
| 202 |
+
temperature = 1.0,
|
| 203 |
+
top_p = 1.0,
|
| 204 |
+
cfg_filter_top_k = 45,
|
| 205 |
+
eos_prob_mul_factor = 0.6,
|
| 206 |
+
do_sample = True,
|
| 207 |
+
debug_guidance_step = -1,
|
| 208 |
+
use_cache = True
|
| 209 |
+
)
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
config.json
CHANGED
|
@@ -2,6 +2,10 @@
|
|
| 2 |
"architectures": [
|
| 3 |
"UniAudioRVQQwen2_5VLMoEForConditionalGeneration"
|
| 4 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"attention_dropout": 0.0,
|
| 6 |
"bos_token_id": 151643,
|
| 7 |
"codec_bos_value": 1026,
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"UniAudioRVQQwen2_5VLMoEForConditionalGeneration"
|
| 4 |
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "modeling.UniMoEAudioConfig",
|
| 7 |
+
"AutoModelForCausalLM": "modeling.UniMoEAudio"
|
| 8 |
+
},
|
| 9 |
"attention_dropout": 0.0,
|
| 10 |
"bos_token_id": 151643,
|
| 11 |
"codec_bos_value": 1026,
|
deepspeed_utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import deepspeed
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from deepspeed import comm as dist
|
| 7 |
+
from deepspeed.moe.sharded_moe import _capacity, _one_hot_to_float, einsum, gumbel_rsample
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
# To enable Tutel MoE optimizations:
|
| 12 |
+
# python3 -m pip install --user --upgrade git+https://github.com/microsoft/[email protected]
|
| 13 |
+
from tutel import moe as tutel_moe
|
| 14 |
+
|
| 15 |
+
TUTEL_INSTALLED = True
|
| 16 |
+
except:
|
| 17 |
+
# Fail silently so we don't spam logs unnecessarily if user isn't using tutel
|
| 18 |
+
TUTEL_INSTALLED = False
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# =============================================================================
|
| 23 |
+
# DeepSpeed MoE Inference Utilities
|
| 24 |
+
# =============================================================================
|
| 25 |
+
|
| 26 |
+
def _AllToAll_forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
|
| 27 |
+
ctx.group = group
|
| 28 |
+
input = input.contiguous()
|
| 29 |
+
return input
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def gate_forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
|
| 33 |
+
d_model = input[0].shape[-1]
|
| 34 |
+
reshaped_input = input[0].reshape(-1, d_model)
|
| 35 |
+
|
| 36 |
+
if self.use_tutel:
|
| 37 |
+
self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
|
| 38 |
+
S, M = reshaped_input.size(0), reshaped_input.size(1)
|
| 39 |
+
|
| 40 |
+
if not hasattr(self, "_tutel_dispatcher"):
|
| 41 |
+
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
|
| 42 |
+
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
|
| 43 |
+
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
|
| 44 |
+
else:
|
| 45 |
+
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
|
| 46 |
+
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
|
| 47 |
+
|
| 48 |
+
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
| 49 |
+
expert_output = self.experts(dispatched_input)
|
| 50 |
+
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, dispatched_input.shape[2], -1)
|
| 51 |
+
|
| 52 |
+
if self.use_tutel:
|
| 53 |
+
combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
|
| 54 |
+
else:
|
| 55 |
+
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
|
| 56 |
+
|
| 57 |
+
a = combined_output.reshape(input[0].size()[:-1] + (-1,))
|
| 58 |
+
|
| 59 |
+
return a
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def top2gating(
|
| 63 |
+
logits: Tensor, capacity_factor: float, min_capacity: int, drop_tokens: bool = True, ep_group: Union[torch.distributed.ProcessGroup, None] = None, top2_2nd_expert_sampling: bool = True
|
| 64 |
+
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
| 65 |
+
"""Implements Top2Gating on logits."""
|
| 66 |
+
gates = F.softmax(logits, dim=1)
|
| 67 |
+
indices1_s = torch.argmax(gates, dim=1)
|
| 68 |
+
num_experts = int(gates.shape[1])
|
| 69 |
+
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
|
| 70 |
+
|
| 71 |
+
if top2_2nd_expert_sampling:
|
| 72 |
+
logits += gumbel_rsample(logits.shape, device=logits.device)
|
| 73 |
+
|
| 74 |
+
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
|
| 75 |
+
indices2_s = torch.argmax(logits_except1, dim=1)
|
| 76 |
+
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
|
| 77 |
+
|
| 78 |
+
locations1 = torch.cumsum(mask1, dim=0) - 1
|
| 79 |
+
locations2 = torch.cumsum(mask2, dim=0) - 1
|
| 80 |
+
locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
| 81 |
+
|
| 82 |
+
me = torch.mean(gates, dim=0)
|
| 83 |
+
ce = torch.mean(mask1.float(), dim=0)
|
| 84 |
+
l_aux = torch.mean(me * ce) * num_experts * num_experts
|
| 85 |
+
exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
|
| 86 |
+
|
| 87 |
+
if drop_tokens:
|
| 88 |
+
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
|
| 89 |
+
mask1 *= torch.lt(locations1, capacity)
|
| 90 |
+
mask2 *= torch.lt(locations2, capacity)
|
| 91 |
+
else:
|
| 92 |
+
new_capacity = torch.max(exp_counts)
|
| 93 |
+
capacity = new_capacity
|
| 94 |
+
|
| 95 |
+
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
| 96 |
+
locations2_s = torch.sum(locations2 * mask2, dim=1)
|
| 97 |
+
mask1_float = mask1.float()
|
| 98 |
+
mask2_float = mask2.float()
|
| 99 |
+
|
| 100 |
+
gates1_s = einsum("se,se->s", gates, mask1_float)
|
| 101 |
+
gates2_s = einsum("se,se->s", gates, mask2_float)
|
| 102 |
+
denom_s = gates1_s + gates2_s
|
| 103 |
+
|
| 104 |
+
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
|
| 105 |
+
gates1_s /= denom_s
|
| 106 |
+
gates2_s /= denom_s
|
| 107 |
+
|
| 108 |
+
gates1 = einsum("s,se->se", gates1_s, mask1_float)
|
| 109 |
+
gates2 = einsum("s,se->se", gates2_s, mask2_float)
|
| 110 |
+
locations1_sc = _one_hot_to_float(locations1_s, capacity)
|
| 111 |
+
locations2_sc = _one_hot_to_float(locations2_s, capacity)
|
| 112 |
+
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
|
| 113 |
+
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
|
| 114 |
+
combine_weights = combine1_sec + combine2_sec
|
| 115 |
+
dispatch_mask = combine_weights.bool()
|
| 116 |
+
|
| 117 |
+
return l_aux, combine_weights, dispatch_mask, exp_counts
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Apply the modifications to deepspeed
|
| 121 |
+
deepspeed.moe.sharded_moe.MOELayer.forward = gate_forward
|
| 122 |
+
deepspeed.moe.sharded_moe.top2gating = top2gating
|
| 123 |
+
deepspeed.moe.sharded_moe._AllToAll.forward = _AllToAll_forward
|
| 124 |
+
|
model-00001-of-00003.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:254260c822c07d95dcd11f897c656eda8d08e5849832d4fd4f67c074c449b2fb
|
| 3 |
+
size 4999916960
|
modeling.py
ADDED
|
@@ -0,0 +1,1182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
| 5 |
+
# and OPT implementations in this library. It has been modified from its
|
| 6 |
+
# original forms to accommodate minor architectural differences compared
|
| 7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
| 8 |
+
#
|
| 9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 10 |
+
# you may not use this file except in compliance with the License.
|
| 11 |
+
# You may obtain a copy of the License at
|
| 12 |
+
#
|
| 13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 14 |
+
#
|
| 15 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
+
# See the License for the specific language governing permissions and
|
| 19 |
+
# limitations under the License.
|
| 20 |
+
"""PyTorch Qwen2-VL model."""
|
| 21 |
+
|
| 22 |
+
from dataclasses import dataclass
|
| 23 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
from torch.nn import CrossEntropyLoss
|
| 29 |
+
|
| 30 |
+
from transformers.activations import ACT2FN
|
| 31 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 32 |
+
from transformers.generation import GenerationMixin
|
| 33 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
| 34 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 35 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 36 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
| 37 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 38 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 39 |
+
from transformers.processing_utils import Unpack
|
| 40 |
+
from transformers.utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
|
| 41 |
+
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
|
| 42 |
+
|
| 43 |
+
from transformers import AutoConfig, AutoModelForCausalLM
|
| 44 |
+
from transformers.modeling_outputs import (
|
| 45 |
+
ModelOutput,
|
| 46 |
+
)
|
| 47 |
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
| 48 |
+
Qwen2_5_VLVisionConfig,
|
| 49 |
+
Qwen2_5_VLTextConfig,
|
| 50 |
+
Qwen2_5_VLConfig,
|
| 51 |
+
)
|
| 52 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
| 53 |
+
Qwen2_5_VLAttention,
|
| 54 |
+
Qwen2RMSNorm,
|
| 55 |
+
Qwen2_5_VLRotaryEmbedding,
|
| 56 |
+
)
|
| 57 |
+
from DCMoE import UniMoEAudioSparseMoeBlock
|
| 58 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
|
| 59 |
+
|
| 60 |
+
logger = logging.get_logger(__name__)
|
| 61 |
+
|
| 62 |
+
FAST_INIT = True
|
| 63 |
+
if FAST_INIT:
|
| 64 |
+
logger.warning(f"using FAST initial for Grin Qwen2_vl !!!")
|
| 65 |
+
|
| 66 |
+
class Qwen2_5_VLMoETextConfig(Qwen2_5_VLTextConfig):
|
| 67 |
+
model_type = "qwen2_5_vl_moe_text"
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
mlp_dynamic_expert_num=4,
|
| 72 |
+
mlp_dynamic_null_expert_num=0,
|
| 73 |
+
mlp_dynamic_top_p=0.7,
|
| 74 |
+
mlp_dynamic_top_k=2,
|
| 75 |
+
mlp_fixed_expert_num=2,
|
| 76 |
+
dynamic_intermediate_size=8960,
|
| 77 |
+
shared_intermediate_size=8960,
|
| 78 |
+
ignore_differentiable_router=False,
|
| 79 |
+
enable_expert_tensor_parallelism: bool = False,
|
| 80 |
+
ep_size=1,
|
| 81 |
+
fixed_ep_size=1,
|
| 82 |
+
router_jitter_noise=0.01,
|
| 83 |
+
input_jitter_noise=0.01,
|
| 84 |
+
token_drop=False,
|
| 85 |
+
drop_policy: str = "probs",
|
| 86 |
+
min_capacity: int = 8,
|
| 87 |
+
capacity_factor: float = 1.0,
|
| 88 |
+
fp32_gate=True,
|
| 89 |
+
avg_hidden_states_last=False,
|
| 90 |
+
drop_token_num_print=True,
|
| 91 |
+
**kwargs,
|
| 92 |
+
):
|
| 93 |
+
|
| 94 |
+
super().__init__(**kwargs)
|
| 95 |
+
self.mlp_dynamic_expert_num = mlp_dynamic_expert_num
|
| 96 |
+
self.mlp_dynamic_top_p = mlp_dynamic_top_p
|
| 97 |
+
self.mlp_dynamic_top_k = mlp_dynamic_top_k
|
| 98 |
+
self.mlp_fixed_expert_num = mlp_fixed_expert_num
|
| 99 |
+
self.mlp_dynamic_null_expert_num = mlp_dynamic_null_expert_num
|
| 100 |
+
self.dynamic_intermediate_size = dynamic_intermediate_size
|
| 101 |
+
self.shared_intermediate_size = shared_intermediate_size
|
| 102 |
+
self.ignore_differentiable_router = ignore_differentiable_router
|
| 103 |
+
self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism
|
| 104 |
+
self.ep_size = ep_size
|
| 105 |
+
self.fixed_ep_size = fixed_ep_size
|
| 106 |
+
self.input_jitter_noise = input_jitter_noise
|
| 107 |
+
self.router_jitter_noise = router_jitter_noise
|
| 108 |
+
self.token_drop = token_drop
|
| 109 |
+
self.drop_policy = drop_policy
|
| 110 |
+
self.min_capacity = min_capacity
|
| 111 |
+
self.capacity_factor = capacity_factor
|
| 112 |
+
self.fp32_gate = fp32_gate
|
| 113 |
+
self.avg_hidden_states_last = avg_hidden_states_last
|
| 114 |
+
self.drop_token_num_print = drop_token_num_print
|
| 115 |
+
|
| 116 |
+
class UniMoEAudioConfig(PretrainedConfig):
|
| 117 |
+
model_type = "uni_audio_rvq_qwen2_5vl_moe"
|
| 118 |
+
sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLMoETextConfig}
|
| 119 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
text_config=None,
|
| 124 |
+
vision_config=None,
|
| 125 |
+
image_token_id=151655,
|
| 126 |
+
video_token_id=151656,
|
| 127 |
+
codec_vocab_size=1028,
|
| 128 |
+
codec_delay_pattern=[0, 8, 9, 10, 11, 12, 13, 14, 15],
|
| 129 |
+
codec_channels=9,
|
| 130 |
+
codec_eos_value=1024,
|
| 131 |
+
codec_pad_value=1025,
|
| 132 |
+
codec_bos_value=1026,
|
| 133 |
+
codec_placeholder_value=None,
|
| 134 |
+
**kwargs,
|
| 135 |
+
):
|
| 136 |
+
if isinstance(vision_config, dict):
|
| 137 |
+
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
| 138 |
+
elif vision_config is None:
|
| 139 |
+
self.vision_config = self.sub_configs["vision_config"]()
|
| 140 |
+
|
| 141 |
+
if isinstance(text_config, dict):
|
| 142 |
+
self.text_config = self.sub_configs["text_config"](**text_config)
|
| 143 |
+
elif text_config is None:
|
| 144 |
+
self.text_config = self.sub_configs["text_config"](**kwargs)
|
| 145 |
+
|
| 146 |
+
self.image_token_id = image_token_id
|
| 147 |
+
self.video_token_id = video_token_id
|
| 148 |
+
self.codec_vocab_size = codec_vocab_size
|
| 149 |
+
self.codec_delay_pattern = codec_delay_pattern
|
| 150 |
+
self.codec_channels = codec_channels
|
| 151 |
+
self.codec_eos_value = codec_eos_value
|
| 152 |
+
self.codec_pad_value = codec_pad_value
|
| 153 |
+
self.codec_bos_value = codec_bos_value
|
| 154 |
+
self.codec_placeholder_value = codec_placeholder_value
|
| 155 |
+
|
| 156 |
+
super().__init__(**kwargs)
|
| 157 |
+
|
| 158 |
+
@dataclass
|
| 159 |
+
class MoEQwen2_5VLCausalLMOutputWithPast(ModelOutput):
|
| 160 |
+
loss: Optional[torch.FloatTensor] = None
|
| 161 |
+
logits: torch.FloatTensor = None
|
| 162 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 163 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 164 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 165 |
+
rope_deltas: Optional[torch.LongTensor] = None
|
| 166 |
+
all_router_logits: Tuple = None
|
| 167 |
+
all_router_top_k: Tuple = None
|
| 168 |
+
all_router_expert_mask: Tuple = None
|
| 169 |
+
all_router_weight: Tuple = None
|
| 170 |
+
aux_balance_loss: torch.FloatTensor = None
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@dataclass
|
| 174 |
+
class BaseModelOutputWithPast(ModelOutput):
|
| 175 |
+
last_hidden_state: torch.FloatTensor = None
|
| 176 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| 177 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 178 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 179 |
+
all_router_logits: Tuple = None
|
| 180 |
+
all_router_top_k: Tuple = None
|
| 181 |
+
all_router_weight: Tuple = None
|
| 182 |
+
all_router_expert_mask: Tuple = None
|
| 183 |
+
all_aux_loss: Tuple = None
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class Qwen2_5_VLMoEDecoderLayer(GradientCheckpointingLayer):
|
| 187 |
+
def __init__(self, config: Qwen2_5_VLMoETextConfig, layer_idx: int):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.hidden_size = config.hidden_size
|
| 190 |
+
|
| 191 |
+
if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
|
| 192 |
+
logger.warning_once(
|
| 193 |
+
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
|
| 194 |
+
"unexpected results may be encountered."
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
|
| 198 |
+
self.mlp = UniMoEAudioSparseMoeBlock(config)
|
| 199 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 200 |
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 201 |
+
self.attention_type = config.layer_types[layer_idx]
|
| 202 |
+
|
| 203 |
+
def forward(
|
| 204 |
+
self,
|
| 205 |
+
hidden_states: torch.Tensor,
|
| 206 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 207 |
+
padding_token_mask: Optional[torch.Tensor] = None,
|
| 208 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 209 |
+
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
| 210 |
+
output_attentions: Optional[bool] = False,
|
| 211 |
+
output_router_logits_and_topk: Optional[bool] = False,
|
| 212 |
+
use_cache: Optional[bool] = False,
|
| 213 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 214 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 215 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 216 |
+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 217 |
+
|
| 218 |
+
residual = hidden_states
|
| 219 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 220 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 221 |
+
hidden_states=hidden_states,
|
| 222 |
+
attention_mask=attention_mask,
|
| 223 |
+
position_ids=position_ids,
|
| 224 |
+
past_key_value=past_key_value,
|
| 225 |
+
output_attentions=output_attentions,
|
| 226 |
+
use_cache=use_cache,
|
| 227 |
+
cache_position=cache_position,
|
| 228 |
+
position_embeddings=position_embeddings,
|
| 229 |
+
)
|
| 230 |
+
hidden_states = residual + hidden_states
|
| 231 |
+
residual = hidden_states
|
| 232 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 233 |
+
hidden_states, router_logits, router_top_k, router_expert_mask, router_weight, aux_loss = self.mlp(hidden_states, padding_token_mask)
|
| 234 |
+
hidden_states = residual + hidden_states
|
| 235 |
+
|
| 236 |
+
outputs = (hidden_states,)
|
| 237 |
+
|
| 238 |
+
if output_attentions:
|
| 239 |
+
outputs += (self_attn_weights,)
|
| 240 |
+
|
| 241 |
+
if output_router_logits_and_topk:
|
| 242 |
+
outputs += (router_logits,)
|
| 243 |
+
outputs += (router_top_k,)
|
| 244 |
+
outputs += (router_expert_mask,)
|
| 245 |
+
outputs += (router_weight,)
|
| 246 |
+
outputs += (aux_loss,)
|
| 247 |
+
|
| 248 |
+
return outputs
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class Qwen2_5_VLMoEPreTrainedModel(PreTrainedModel):
|
| 252 |
+
config_class = UniMoEAudioConfig
|
| 253 |
+
base_model_prefix = "model"
|
| 254 |
+
supports_gradient_checkpointing = True
|
| 255 |
+
_no_split_modules = ["Qwen2_5_VLMoEDecoderLayer", "Qwen2_5_VLVisionBlock"]
|
| 256 |
+
_skip_keys_device_placement = "past_key_values"
|
| 257 |
+
_supports_flash_attn_2 = True
|
| 258 |
+
_supports_flash_attn_3 = True
|
| 259 |
+
_supports_sdpa = True
|
| 260 |
+
_supports_cache_class = True
|
| 261 |
+
_supports_static_cache = True
|
| 262 |
+
_supports_attention_backend = True
|
| 263 |
+
|
| 264 |
+
def _init_weights(self, module):
|
| 265 |
+
std = self.config.initializer_range
|
| 266 |
+
if FAST_INIT:
|
| 267 |
+
if isinstance(module, UniMoEAudioSparseMoeBlock):
|
| 268 |
+
module.gate.weight.data.normal_(mean=0.0, std=std)
|
| 269 |
+
if module.gate.bias is not None:
|
| 270 |
+
module.gate.bias.data.zero_()
|
| 271 |
+
elif isinstance(module, nn.Embedding):
|
| 272 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 273 |
+
if module.padding_idx is not None:
|
| 274 |
+
module.weight.data[module.padding_idx].zero_()
|
| 275 |
+
else:
|
| 276 |
+
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
| 277 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 278 |
+
if module.bias is not None:
|
| 279 |
+
module.bias.data.zero_()
|
| 280 |
+
elif isinstance(module, nn.Embedding):
|
| 281 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 282 |
+
if module.padding_idx is not None:
|
| 283 |
+
module.weight.data[module.padding_idx].zero_()
|
| 284 |
+
elif isinstance(module, Qwen2RMSNorm):
|
| 285 |
+
module.weight.data.fill_(1.0)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class Qwen2_5_VLMoETextModel(Qwen2_5_VLMoEPreTrainedModel):
|
| 289 |
+
config_class = Qwen2_5_VLMoETextConfig
|
| 290 |
+
def __init__(self, config: Qwen2_5_VLMoETextConfig):
|
| 291 |
+
super().__init__(config)
|
| 292 |
+
self.padding_idx = config.pad_token_id
|
| 293 |
+
self.vocab_size = config.vocab_size
|
| 294 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 295 |
+
self.layers = nn.ModuleList(
|
| 296 |
+
[Qwen2_5_VLMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 297 |
+
)
|
| 298 |
+
self._attn_implementation = config._attn_implementation
|
| 299 |
+
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 300 |
+
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
|
| 301 |
+
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
| 302 |
+
self.gradient_checkpointing = False
|
| 303 |
+
self.post_init()
|
| 304 |
+
|
| 305 |
+
def get_input_embeddings(self):
|
| 306 |
+
return self.embed_tokens
|
| 307 |
+
|
| 308 |
+
def set_input_embeddings(self, value):
|
| 309 |
+
self.embed_tokens = value
|
| 310 |
+
|
| 311 |
+
def forward(
|
| 312 |
+
self,
|
| 313 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 314 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 315 |
+
padding_token_mask: Optional[torch.Tensor] = None,
|
| 316 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 317 |
+
past_key_values: Optional[Cache] = None,
|
| 318 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 319 |
+
use_cache: Optional[bool] = None,
|
| 320 |
+
output_attentions: Optional[bool] = None,
|
| 321 |
+
output_hidden_states: Optional[bool] = None,
|
| 322 |
+
output_router_logits_and_topk: Optional[bool] = None,
|
| 323 |
+
return_dict: Optional[bool] = None,
|
| 324 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 325 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 326 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 327 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 328 |
+
output_hidden_states = (
|
| 329 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 330 |
+
)
|
| 331 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 332 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 333 |
+
|
| 334 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 335 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 336 |
+
|
| 337 |
+
if self.gradient_checkpointing and self.training:
|
| 338 |
+
if use_cache:
|
| 339 |
+
logger.warning_once(
|
| 340 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 341 |
+
)
|
| 342 |
+
use_cache = False
|
| 343 |
+
|
| 344 |
+
if use_cache and past_key_values is None and not torch.jit.is_tracing():
|
| 345 |
+
past_key_values = DynamicCache()
|
| 346 |
+
|
| 347 |
+
if inputs_embeds is None:
|
| 348 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 349 |
+
|
| 350 |
+
if cache_position is None:
|
| 351 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 352 |
+
cache_position = torch.arange(
|
| 353 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
if position_ids is None:
|
| 357 |
+
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
| 358 |
+
elif position_ids.dim() == 2:
|
| 359 |
+
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
| 360 |
+
|
| 361 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 362 |
+
mask_kwargs = {
|
| 363 |
+
"config": self.config,
|
| 364 |
+
"input_embeds": inputs_embeds,
|
| 365 |
+
"attention_mask": attention_mask,
|
| 366 |
+
"cache_position": cache_position,
|
| 367 |
+
"past_key_values": past_key_values,
|
| 368 |
+
"position_ids": position_ids,
|
| 369 |
+
}
|
| 370 |
+
causal_mask_mapping = {
|
| 371 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 372 |
+
}
|
| 373 |
+
if self.has_sliding_layers:
|
| 374 |
+
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
| 375 |
+
|
| 376 |
+
hidden_states = inputs_embeds
|
| 377 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 378 |
+
|
| 379 |
+
all_hidden_states = () if output_hidden_states else None
|
| 380 |
+
all_self_attns = () if output_attentions else None
|
| 381 |
+
all_router_logits = () if output_router_logits_and_topk else None
|
| 382 |
+
all_router_top_k = () if output_router_logits_and_topk else None
|
| 383 |
+
all_router_expert_mask = ()
|
| 384 |
+
all_router_weight = ()
|
| 385 |
+
all_aux_loss = ()
|
| 386 |
+
next_decoder_cache = None
|
| 387 |
+
|
| 388 |
+
for decoder_layer in self.layers:
|
| 389 |
+
if output_hidden_states:
|
| 390 |
+
all_hidden_states += (hidden_states,)
|
| 391 |
+
|
| 392 |
+
layer_outputs = decoder_layer(
|
| 393 |
+
hidden_states,
|
| 394 |
+
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
| 395 |
+
padding_token_mask=padding_token_mask,
|
| 396 |
+
position_ids=position_ids,
|
| 397 |
+
past_key_value=past_key_values,
|
| 398 |
+
output_attentions=output_attentions,
|
| 399 |
+
output_router_logits_and_topk=output_router_logits_and_topk,
|
| 400 |
+
use_cache=use_cache,
|
| 401 |
+
cache_position=cache_position,
|
| 402 |
+
position_embeddings=position_embeddings,
|
| 403 |
+
**kwargs,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
hidden_states = layer_outputs[0]
|
| 407 |
+
|
| 408 |
+
if output_attentions:
|
| 409 |
+
all_self_attns += (layer_outputs[1],)
|
| 410 |
+
|
| 411 |
+
if output_router_logits_and_topk:
|
| 412 |
+
all_router_logits += (layer_outputs[-5],)
|
| 413 |
+
all_router_top_k += (layer_outputs[-4],)
|
| 414 |
+
all_router_expert_mask += (layer_outputs[-3],)
|
| 415 |
+
all_router_weight += (layer_outputs[-2],)
|
| 416 |
+
all_aux_loss += (layer_outputs[-1],)
|
| 417 |
+
|
| 418 |
+
hidden_states = self.norm(hidden_states)
|
| 419 |
+
|
| 420 |
+
if output_hidden_states:
|
| 421 |
+
all_hidden_states += (hidden_states,)
|
| 422 |
+
|
| 423 |
+
if not return_dict:
|
| 424 |
+
return tuple(
|
| 425 |
+
v for v in [
|
| 426 |
+
hidden_states,
|
| 427 |
+
past_key_values,
|
| 428 |
+
all_hidden_states,
|
| 429 |
+
all_self_attns,
|
| 430 |
+
all_router_logits,
|
| 431 |
+
all_router_top_k,
|
| 432 |
+
all_router_expert_mask,
|
| 433 |
+
all_router_weight,
|
| 434 |
+
all_aux_loss]
|
| 435 |
+
if v is not None
|
| 436 |
+
)
|
| 437 |
+
return BaseModelOutputWithPast(
|
| 438 |
+
last_hidden_state=hidden_states,
|
| 439 |
+
past_key_values=past_key_values,
|
| 440 |
+
hidden_states=all_hidden_states,
|
| 441 |
+
attentions=all_self_attns,
|
| 442 |
+
all_router_logits=all_router_logits,
|
| 443 |
+
all_router_top_k=all_router_top_k,
|
| 444 |
+
all_router_expert_mask=all_router_expert_mask,
|
| 445 |
+
all_router_weight=all_router_weight,
|
| 446 |
+
all_aux_loss=all_aux_loss,
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class UniMoEAudio(Qwen2_5_VLMoEPreTrainedModel):
|
| 451 |
+
base_model_prefix = ""
|
| 452 |
+
_no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
|
| 453 |
+
config_class = UniMoEAudioConfig
|
| 454 |
+
_checkpoint_conversion_mapping = {
|
| 455 |
+
"^visual": "visual",
|
| 456 |
+
r"^model(?!\.(language_model|visual))": "language_model",
|
| 457 |
+
}
|
| 458 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 459 |
+
|
| 460 |
+
def __init__(self, config):
|
| 461 |
+
super().__init__(config)
|
| 462 |
+
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config, attn_implementation=config._attn_implementation)
|
| 463 |
+
self.language_model = Qwen2_5_VLMoETextModel._from_config(config.text_config)
|
| 464 |
+
self.rope_deltas = None
|
| 465 |
+
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
| 466 |
+
self.num_channels = config.codec_channels
|
| 467 |
+
self.codec_vocab_size = config.codec_vocab_size
|
| 468 |
+
self.codec_embed_tokens = nn.ModuleList(
|
| 469 |
+
[nn.Embedding(self.codec_vocab_size, config.text_config.hidden_size) for embed_idx in range(self.num_channels)])
|
| 470 |
+
self.codec_placeholder_value = config.codec_placeholder_value
|
| 471 |
+
self.codec_head = nn.Linear(config.text_config.hidden_size, self.num_channels * self.codec_vocab_size, bias=False)
|
| 472 |
+
self.post_init()
|
| 473 |
+
|
| 474 |
+
@property
|
| 475 |
+
def cur_aux_weight(self):
|
| 476 |
+
if self.training_steps >= self.l_aux_weight_decay_steps:
|
| 477 |
+
return self.min_l_aux_weight
|
| 478 |
+
return self.l_aux_weight - (self.l_aux_weight - self.min_l_aux_weight) / self.l_aux_weight_decay_steps * self.training_steps
|
| 479 |
+
|
| 480 |
+
def get_input_embeddings(self):
|
| 481 |
+
return self.language_model.get_input_embeddings()
|
| 482 |
+
|
| 483 |
+
def set_input_embeddings(self, value):
|
| 484 |
+
self.language_model.set_input_embeddings(value)
|
| 485 |
+
|
| 486 |
+
def get_output_embeddings(self):
|
| 487 |
+
return self.lm_head
|
| 488 |
+
|
| 489 |
+
def set_output_embeddings(self, new_embeddings):
|
| 490 |
+
self.lm_head = new_embeddings
|
| 491 |
+
|
| 492 |
+
def set_decoder(self, decoder):
|
| 493 |
+
self.language_model = decoder
|
| 494 |
+
|
| 495 |
+
def get_decoder(self):
|
| 496 |
+
return self.language_model
|
| 497 |
+
|
| 498 |
+
def get_rope_index(
|
| 499 |
+
self,
|
| 500 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 501 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 502 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 503 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
| 504 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 505 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 506 |
+
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
| 507 |
+
image_token_id = self.config.image_token_id
|
| 508 |
+
video_token_id = self.config.video_token_id
|
| 509 |
+
vision_start_token_id = self.config.vision_start_token_id
|
| 510 |
+
mrope_position_deltas = []
|
| 511 |
+
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
| 512 |
+
total_input_ids = input_ids
|
| 513 |
+
if attention_mask is None:
|
| 514 |
+
attention_mask = torch.ones_like(total_input_ids)
|
| 515 |
+
position_ids = torch.ones(
|
| 516 |
+
3,
|
| 517 |
+
input_ids.shape[0],
|
| 518 |
+
input_ids.shape[1],
|
| 519 |
+
dtype=input_ids.dtype,
|
| 520 |
+
device=input_ids.device,
|
| 521 |
+
)
|
| 522 |
+
image_index, video_index = 0, 0
|
| 523 |
+
attention_mask = attention_mask.to(total_input_ids.device)
|
| 524 |
+
for i, input_ids in enumerate(total_input_ids):
|
| 525 |
+
input_ids = input_ids[attention_mask[i] == 1]
|
| 526 |
+
image_nums, video_nums = 0, 0
|
| 527 |
+
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
| 528 |
+
vision_tokens = input_ids[vision_start_indices + 1]
|
| 529 |
+
image_nums = (vision_tokens == image_token_id).sum()
|
| 530 |
+
video_nums = (vision_tokens == video_token_id).sum()
|
| 531 |
+
input_tokens = input_ids.tolist()
|
| 532 |
+
llm_pos_ids_list: list = []
|
| 533 |
+
st = 0
|
| 534 |
+
remain_images, remain_videos = image_nums, video_nums
|
| 535 |
+
for _ in range(image_nums + video_nums):
|
| 536 |
+
if image_token_id in input_tokens and remain_images > 0:
|
| 537 |
+
ed_image = input_tokens.index(image_token_id, st)
|
| 538 |
+
else:
|
| 539 |
+
ed_image = len(input_tokens) + 1
|
| 540 |
+
if video_token_id in input_tokens and remain_videos > 0:
|
| 541 |
+
ed_video = input_tokens.index(video_token_id, st)
|
| 542 |
+
else:
|
| 543 |
+
ed_video = len(input_tokens) + 1
|
| 544 |
+
if ed_image < ed_video:
|
| 545 |
+
t, h, w = (
|
| 546 |
+
image_grid_thw[image_index][0],
|
| 547 |
+
image_grid_thw[image_index][1],
|
| 548 |
+
image_grid_thw[image_index][2],
|
| 549 |
+
)
|
| 550 |
+
second_per_grid_t = 0
|
| 551 |
+
image_index += 1
|
| 552 |
+
remain_images -= 1
|
| 553 |
+
ed = ed_image
|
| 554 |
+
|
| 555 |
+
else:
|
| 556 |
+
t, h, w = (
|
| 557 |
+
video_grid_thw[video_index][0],
|
| 558 |
+
video_grid_thw[video_index][1],
|
| 559 |
+
video_grid_thw[video_index][2],
|
| 560 |
+
)
|
| 561 |
+
if second_per_grid_ts is not None:
|
| 562 |
+
second_per_grid_t = second_per_grid_ts[video_index]
|
| 563 |
+
else:
|
| 564 |
+
second_per_grid_t = 1.0
|
| 565 |
+
video_index += 1
|
| 566 |
+
remain_videos -= 1
|
| 567 |
+
ed = ed_video
|
| 568 |
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
| 569 |
+
t.item(),
|
| 570 |
+
h.item() // spatial_merge_size,
|
| 571 |
+
w.item() // spatial_merge_size,
|
| 572 |
+
)
|
| 573 |
+
text_len = ed - st
|
| 574 |
+
|
| 575 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 576 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 577 |
+
|
| 578 |
+
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
| 579 |
+
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
|
| 580 |
+
second_per_grid_t = torch.as_tensor(
|
| 581 |
+
second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
|
| 585 |
+
|
| 586 |
+
time_tensor_long = time_tensor.long()
|
| 587 |
+
t_index = time_tensor_long.flatten()
|
| 588 |
+
|
| 589 |
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
| 590 |
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
| 591 |
+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
| 592 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
| 593 |
+
|
| 594 |
+
if st < len(input_tokens):
|
| 595 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 596 |
+
text_len = len(input_tokens) - st
|
| 597 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 598 |
+
|
| 599 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
| 600 |
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
| 601 |
+
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
|
| 602 |
+
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
|
| 603 |
+
return position_ids, mrope_position_deltas
|
| 604 |
+
else:
|
| 605 |
+
if attention_mask is not None:
|
| 606 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 607 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 608 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
| 609 |
+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
| 610 |
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
| 611 |
+
else:
|
| 612 |
+
position_ids = (
|
| 613 |
+
torch.arange(input_ids.shape[1], device=input_ids.device)
|
| 614 |
+
.view(1, 1, -1)
|
| 615 |
+
.expand(3, input_ids.shape[0], -1)
|
| 616 |
+
)
|
| 617 |
+
mrope_position_deltas = torch.zeros(
|
| 618 |
+
[input_ids.shape[0], 1],
|
| 619 |
+
device=input_ids.device,
|
| 620 |
+
dtype=input_ids.dtype,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
return position_ids, mrope_position_deltas
|
| 624 |
+
|
| 625 |
+
def get_video_features(self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None):
|
| 626 |
+
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
| 627 |
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
| 628 |
+
split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
| 629 |
+
video_embeds = torch.split(video_embeds, split_sizes)
|
| 630 |
+
return video_embeds
|
| 631 |
+
|
| 632 |
+
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
|
| 633 |
+
pixel_values = pixel_values.type(self.visual.dtype)
|
| 634 |
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
| 635 |
+
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
| 636 |
+
image_embeds = torch.split(image_embeds, split_sizes)
|
| 637 |
+
return image_embeds
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def codec_embedding(self, codec_input_ids):
|
| 641 |
+
x = None
|
| 642 |
+
for i in range(self.num_channels):
|
| 643 |
+
channel_tokens = codec_input_ids[..., i]
|
| 644 |
+
channel_embed = self.codec_embed_tokens[i](channel_tokens)
|
| 645 |
+
x = channel_embed if x is None else x + channel_embed
|
| 646 |
+
return x
|
| 647 |
+
|
| 648 |
+
def calculate_input_embedding(self, input_ids, codec_input_ids):
|
| 649 |
+
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
| 650 |
+
if codec_input_ids is not None:
|
| 651 |
+
codec_input_embeds = self.codec_embedding(codec_input_ids)
|
| 652 |
+
|
| 653 |
+
codec_mask = (input_ids == self.codec_placeholder_value).unsqueeze(-1).expand_as(inputs_embeds)
|
| 654 |
+
inputs_embeds = inputs_embeds.masked_scatter(codec_mask, codec_input_embeds)
|
| 655 |
+
return inputs_embeds
|
| 656 |
+
|
| 657 |
+
@can_return_tuple
|
| 658 |
+
def forward(
|
| 659 |
+
self,
|
| 660 |
+
input_ids: torch.LongTensor = None,
|
| 661 |
+
codec_input_ids: torch.LongTensor = None,
|
| 662 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 663 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 664 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 665 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 666 |
+
labels: Optional[torch.LongTensor] = None,
|
| 667 |
+
codec_labels: Optional[torch.LongTensor] = None,
|
| 668 |
+
padding_token_mask: Optional[torch.Tensor] = None,
|
| 669 |
+
use_cache: Optional[bool] = None,
|
| 670 |
+
output_attentions: Optional[bool] = None,
|
| 671 |
+
output_hidden_states: Optional[bool] = None,
|
| 672 |
+
output_router_logits_and_topk: Optional[bool] = None,
|
| 673 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 674 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 675 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 676 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 677 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
| 678 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 679 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
| 680 |
+
**kwargs,
|
| 681 |
+
|
| 682 |
+
) -> Union[Tuple, MoEQwen2_5VLCausalLMOutputWithPast]:
|
| 683 |
+
return_dict = True
|
| 684 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 685 |
+
output_hidden_states = (
|
| 686 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
if inputs_embeds is None:
|
| 690 |
+
inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids)
|
| 691 |
+
|
| 692 |
+
if pixel_values is not None:
|
| 693 |
+
image_embeds = self.get_image_features(pixel_values, image_grid_thw)
|
| 694 |
+
image_embeds = torch.cat(image_embeds, dim=0)
|
| 695 |
+
|
| 696 |
+
if input_ids is None:
|
| 697 |
+
image_mask = inputs_embeds == self.get_input_embeddings()(
|
| 698 |
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 699 |
+
)
|
| 700 |
+
image_mask = image_mask.all(-1)
|
| 701 |
+
else:
|
| 702 |
+
image_mask = input_ids == self.config.image_token_id
|
| 703 |
+
|
| 704 |
+
n_image_tokens = (image_mask).sum()
|
| 705 |
+
image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 706 |
+
n_image_features = image_embeds.shape[0]
|
| 707 |
+
if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
|
| 708 |
+
raise ValueError(
|
| 709 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
| 710 |
+
)
|
| 711 |
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 712 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
| 713 |
+
|
| 714 |
+
if pixel_values_videos is not None:
|
| 715 |
+
video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
|
| 716 |
+
video_embeds = torch.cat(video_embeds, dim=0)
|
| 717 |
+
|
| 718 |
+
if input_ids is None:
|
| 719 |
+
video_mask = inputs_embeds == self.get_input_embeddings()(
|
| 720 |
+
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
|
| 721 |
+
)
|
| 722 |
+
video_mask = video_mask.all(-1)
|
| 723 |
+
else:
|
| 724 |
+
video_mask = input_ids == self.config.video_token_id
|
| 725 |
+
|
| 726 |
+
n_video_tokens = (video_mask).sum()
|
| 727 |
+
n_video_features = video_embeds.shape[0]
|
| 728 |
+
video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
| 729 |
+
if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
|
| 730 |
+
raise ValueError(
|
| 731 |
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 735 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
| 736 |
+
|
| 737 |
+
if position_ids is None:
|
| 738 |
+
attention_mask_tensor = (
|
| 739 |
+
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
| 740 |
+
)
|
| 741 |
+
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
| 742 |
+
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
| 743 |
+
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
| 744 |
+
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
| 745 |
+
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
| 746 |
+
(input_ids is not None and input_ids.shape[1] != 1)
|
| 747 |
+
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
| 748 |
+
)
|
| 749 |
+
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
| 750 |
+
(cache_position is not None and cache_position[0] == 0)
|
| 751 |
+
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
| 752 |
+
)
|
| 753 |
+
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
| 754 |
+
position_ids, rope_deltas = self.get_rope_index(
|
| 755 |
+
input_ids,
|
| 756 |
+
image_grid_thw,
|
| 757 |
+
video_grid_thw,
|
| 758 |
+
second_per_grid_ts=second_per_grid_ts,
|
| 759 |
+
attention_mask=attention_mask_tensor,
|
| 760 |
+
)
|
| 761 |
+
self.rope_deltas = rope_deltas
|
| 762 |
+
|
| 763 |
+
else:
|
| 764 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 765 |
+
delta = (
|
| 766 |
+
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
| 767 |
+
if cache_position is not None
|
| 768 |
+
else 0
|
| 769 |
+
)
|
| 770 |
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
| 771 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
| 772 |
+
if cache_position is not None:
|
| 773 |
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
| 774 |
+
position_ids = position_ids.add(delta)
|
| 775 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
| 776 |
+
|
| 777 |
+
if padding_token_mask is None:
|
| 778 |
+
padding_token_mask = attention_mask.bool()
|
| 779 |
+
|
| 780 |
+
outputs = self.language_model(
|
| 781 |
+
input_ids=None,
|
| 782 |
+
position_ids=position_ids,
|
| 783 |
+
attention_mask=attention_mask,
|
| 784 |
+
padding_token_mask=padding_token_mask,
|
| 785 |
+
past_key_values=past_key_values,
|
| 786 |
+
inputs_embeds=inputs_embeds,
|
| 787 |
+
use_cache=use_cache,
|
| 788 |
+
output_attentions=output_attentions,
|
| 789 |
+
output_hidden_states=output_hidden_states,
|
| 790 |
+
output_router_logits_and_topk=output_router_logits_and_topk,
|
| 791 |
+
return_dict=return_dict,
|
| 792 |
+
cache_position=cache_position,
|
| 793 |
+
**kwargs,
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
hidden_states = outputs[0]
|
| 797 |
+
logits = self.lm_head(hidden_states).float()
|
| 798 |
+
codec_logits = self.codec_head(hidden_states).float()
|
| 799 |
+
codec_logits = codec_logits.view((logits.shape[0], logits.shape[1], self.num_channels, self.codec_vocab_size))
|
| 800 |
+
|
| 801 |
+
loss = None
|
| 802 |
+
if labels is not None:
|
| 803 |
+
|
| 804 |
+
all_aux_loss = outputs.all_aux_loss if return_dict else outputs[-1]
|
| 805 |
+
all_aux_loss = torch.mean(torch.cat([l.unsqueeze(0) for l in all_aux_loss], dim=0))
|
| 806 |
+
aux_loss = self.cur_aux_weight * all_aux_loss
|
| 807 |
+
self.training_steps += 1
|
| 808 |
+
codec_loss = None
|
| 809 |
+
|
| 810 |
+
if codec_labels is not None:
|
| 811 |
+
for i in range(self.num_channels):
|
| 812 |
+
channel_logits = codec_logits[:, :, i].float()
|
| 813 |
+
channel_labels = codec_labels[:, :, i]
|
| 814 |
+
shift_channel_logits = channel_logits[..., :-1, :].contiguous()
|
| 815 |
+
shift_channel_labels = channel_labels[..., 1:].contiguous()
|
| 816 |
+
|
| 817 |
+
if i!= 0 and (shift_channel_labels != -100).sum() == 0:
|
| 818 |
+
continue
|
| 819 |
+
|
| 820 |
+
loss_fct = CrossEntropyLoss()
|
| 821 |
+
shift_channel_logits = shift_channel_logits.view(-1, self.codec_vocab_size)
|
| 822 |
+
shift_channel_labels = shift_channel_labels.view(-1)
|
| 823 |
+
shift_channel_labels = shift_channel_labels.to(shift_channel_logits.device)
|
| 824 |
+
channel_loss = loss_fct(shift_channel_logits, shift_channel_labels)
|
| 825 |
+
codec_loss = channel_loss if codec_loss is None else codec_loss + channel_loss
|
| 826 |
+
|
| 827 |
+
loss = codec_loss + aux_loss
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
if not return_dict:
|
| 831 |
+
output = (logits,) + outputs[1:]
|
| 832 |
+
return (loss,) + output if loss is not None else output
|
| 833 |
+
|
| 834 |
+
return MoEQwen2_5VLCausalLMOutputWithPast(
|
| 835 |
+
loss=loss,
|
| 836 |
+
logits=logits,
|
| 837 |
+
past_key_values=outputs.past_key_values,
|
| 838 |
+
hidden_states=outputs.hidden_states,
|
| 839 |
+
attentions=outputs.attentions,
|
| 840 |
+
all_router_logits=outputs.all_router_logits,
|
| 841 |
+
all_router_top_k=outputs.all_router_top_k,
|
| 842 |
+
all_router_expert_mask=outputs.all_router_expert_mask,
|
| 843 |
+
all_router_weight=outputs.all_router_weight,
|
| 844 |
+
aux_balance_loss=all_aux_loss,
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
@staticmethod
|
| 848 |
+
def _sample_next_token(
|
| 849 |
+
logits_BCxV: torch.Tensor,
|
| 850 |
+
temperature: float,
|
| 851 |
+
top_p: float,
|
| 852 |
+
top_k: int,
|
| 853 |
+
audio_eos_value: int,
|
| 854 |
+
) -> torch.Tensor:
|
| 855 |
+
if temperature == 0.0:
|
| 856 |
+
return torch.argmax(logits_BCxV, dim=-1)
|
| 857 |
+
|
| 858 |
+
logits_BCxV = logits_BCxV / temperature
|
| 859 |
+
|
| 860 |
+
if audio_eos_value is not None and audio_eos_value >= 0:
|
| 861 |
+
top_logit_indices_BC = torch.argmax(logits_BCxV, dim=-1)
|
| 862 |
+
eos_not_highest_mask_BC = top_logit_indices_BC != audio_eos_value
|
| 863 |
+
mask_eos_unless_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool)
|
| 864 |
+
mask_eos_unless_highest_BCxV[eos_not_highest_mask_BC, audio_eos_value] = True
|
| 865 |
+
logits_BCxV = logits_BCxV.masked_fill(mask_eos_unless_highest_BCxV, -torch.inf)
|
| 866 |
+
|
| 867 |
+
if top_k is not None:
|
| 868 |
+
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=top_k, dim=-1)
|
| 869 |
+
mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
|
| 870 |
+
mask = mask.scatter(dim=-1, index=top_k_indices_BCxV, value=False)
|
| 871 |
+
logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
|
| 872 |
+
|
| 873 |
+
if top_p < 1.0:
|
| 874 |
+
probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
| 875 |
+
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True)
|
| 876 |
+
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
|
| 877 |
+
|
| 878 |
+
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
|
| 879 |
+
sorted_indices_to_remove_BCxV = torch.roll(sorted_indices_to_remove_BCxV, shifts=1, dims=-1)
|
| 880 |
+
sorted_indices_to_remove_BCxV[..., 0] = torch.zeros_like(sorted_indices_to_remove_BCxV[..., 0])
|
| 881 |
+
|
| 882 |
+
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
|
| 883 |
+
indices_to_remove_BCxV = indices_to_remove_BCxV.scatter(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV)
|
| 884 |
+
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
|
| 885 |
+
|
| 886 |
+
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
|
| 887 |
+
|
| 888 |
+
sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
|
| 889 |
+
sampled_indices_C = sampled_indices_BC.squeeze(-1)
|
| 890 |
+
return sampled_indices_C
|
| 891 |
+
|
| 892 |
+
def _decoder_step(
|
| 893 |
+
self,
|
| 894 |
+
tokens_Bx1xC: torch.Tensor,
|
| 895 |
+
model_kwargs,
|
| 896 |
+
cfg_scale: float,
|
| 897 |
+
neg_input_size: int,
|
| 898 |
+
temperature: float,
|
| 899 |
+
top_p: float,
|
| 900 |
+
top_k: int,
|
| 901 |
+
do_sample=True,
|
| 902 |
+
eos_prob_mul_factor=1.0,
|
| 903 |
+
labels_Bx1xC=None,
|
| 904 |
+
use_cache=True,
|
| 905 |
+
enable_eos=True,
|
| 906 |
+
) -> torch.Tensor:
|
| 907 |
+
B = tokens_Bx1xC.shape[0]
|
| 908 |
+
audio_eos_value = self.config.codec_eos_value
|
| 909 |
+
attention_mask = model_kwargs["attention_mask"]
|
| 910 |
+
cache_position = model_kwargs["cache_position"]
|
| 911 |
+
past_key_values = model_kwargs["past_key_values"]
|
| 912 |
+
input_ids = model_kwargs["input_ids"]
|
| 913 |
+
codec_input_ids = model_kwargs["codec_input_ids"]
|
| 914 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 915 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 916 |
+
if past_key_values:
|
| 917 |
+
position_ids = position_ids[:, -tokens_Bx1xC.shape[1] :]
|
| 918 |
+
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
|
| 919 |
+
|
| 920 |
+
tokens_Bx1xC = tokens_Bx1xC.repeat_interleave(neg_input_size, dim=0)
|
| 921 |
+
codec_input_ids = torch.cat((codec_input_ids, tokens_Bx1xC), dim=1) if codec_input_ids is not None else tokens_Bx1xC.clone()
|
| 922 |
+
input_ids = torch.cat((input_ids, torch.ones(input_ids.shape[0], 1).to(input_ids) * self.codec_placeholder_value), dim=-1)
|
| 923 |
+
|
| 924 |
+
if use_cache:
|
| 925 |
+
codec_input_embeds = self.codec_embedding(tokens_Bx1xC)
|
| 926 |
+
outputs = self.language_model(
|
| 927 |
+
input_ids=None,
|
| 928 |
+
attention_mask=attention_mask,
|
| 929 |
+
position_ids=position_ids,
|
| 930 |
+
past_key_values=past_key_values,
|
| 931 |
+
inputs_embeds=codec_input_embeds,
|
| 932 |
+
use_cache=True,
|
| 933 |
+
output_attentions=False,
|
| 934 |
+
output_hidden_states=False,
|
| 935 |
+
return_dict=True,
|
| 936 |
+
cache_position=cache_position,
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
else:
|
| 940 |
+
batch_codec_input_ids = codec_input_ids.contiguous().view(-1, self.num_channels)
|
| 941 |
+
|
| 942 |
+
inputs_embeds = self.calculate_input_embedding(input_ids, batch_codec_input_ids)
|
| 943 |
+
outputs = self.language_model(
|
| 944 |
+
input_ids=None,
|
| 945 |
+
attention_mask=attention_mask,
|
| 946 |
+
position_ids=attention_mask.long().cumsum(-1) - 1,
|
| 947 |
+
past_key_values=None,
|
| 948 |
+
inputs_embeds=inputs_embeds,
|
| 949 |
+
use_cache=True,
|
| 950 |
+
output_attentions=False,
|
| 951 |
+
output_hidden_states=False,
|
| 952 |
+
return_dict=True,
|
| 953 |
+
cache_position=None,
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
last_hidden_state = outputs.last_hidden_state
|
| 957 |
+
codec_logits = self.codec_head(last_hidden_state).float()
|
| 958 |
+
codec_logits = codec_logits.view((codec_logits.shape[0], codec_logits.shape[1], self.num_channels, self.codec_vocab_size))
|
| 959 |
+
model_kwargs["past_key_values"] = outputs.past_key_values
|
| 960 |
+
attention_mask = model_kwargs["attention_mask"]
|
| 961 |
+
model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
| 962 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
| 963 |
+
model_kwargs["input_ids"] = input_ids
|
| 964 |
+
model_kwargs["codec_input_ids"] = codec_input_ids
|
| 965 |
+
|
| 966 |
+
logits_Bx1xCxV = codec_logits[: , -1:].clone()
|
| 967 |
+
logits_last_2BxCxV = logits_Bx1xCxV[:, -1]
|
| 968 |
+
logits_last_Bx2xCxV = logits_last_2BxCxV.view(B, neg_input_size, *logits_last_2BxCxV.shape[1:])
|
| 969 |
+
if cfg_scale is not None:
|
| 970 |
+
cond_logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] # Shape [B, C, V]
|
| 971 |
+
logits_BxCxV = cond_logits_BxCxV
|
| 972 |
+
for ni in range(neg_input_size - 1):
|
| 973 |
+
uncond_logits_BxCxV = logits_last_Bx2xCxV[:, ni, :, :] # Shape [B, C, V]
|
| 974 |
+
cfg_weight = cfg_scale[ni] if isinstance(cfg_scale, List) else cfg_scale
|
| 975 |
+
logits_BxCxV = logits_BxCxV + cfg_weight * (cond_logits_BxCxV - uncond_logits_BxCxV)
|
| 976 |
+
else:
|
| 977 |
+
logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] # Shape [B, C, V]
|
| 978 |
+
|
| 979 |
+
if enable_eos:
|
| 980 |
+
logits_BxCxV[:, :, audio_eos_value + 1 :] = torch.full_like(
|
| 981 |
+
logits_BxCxV[:, :, audio_eos_value + 1 :],
|
| 982 |
+
fill_value=-torch.inf,
|
| 983 |
+
)
|
| 984 |
+
logits_BxCxV[:, 1:, audio_eos_value:] = torch.full_like(
|
| 985 |
+
logits_BxCxV[:, 1:, audio_eos_value:],
|
| 986 |
+
fill_value=-torch.inf,
|
| 987 |
+
)
|
| 988 |
+
logits_BxCxV[:, 0, audio_eos_value] *= torch.tensor(eos_prob_mul_factor, device=self.device)
|
| 989 |
+
|
| 990 |
+
else:
|
| 991 |
+
logits_BxCxV[:, :, audio_eos_value:] = torch.full_like(
|
| 992 |
+
logits_BxCxV[:, :, audio_eos_value:],
|
| 993 |
+
fill_value=-torch.inf,
|
| 994 |
+
)
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
flat_logits_BCxV = logits_BxCxV.reshape(B * self.num_channels, -1)
|
| 998 |
+
if do_sample:
|
| 999 |
+
pred_BC = self._sample_next_token(
|
| 1000 |
+
flat_logits_BCxV.float(),
|
| 1001 |
+
temperature=temperature,
|
| 1002 |
+
top_p=top_p,
|
| 1003 |
+
top_k=top_k,
|
| 1004 |
+
audio_eos_value=audio_eos_value,
|
| 1005 |
+
)
|
| 1006 |
+
else:
|
| 1007 |
+
pred_BC = torch.argmax(flat_logits_BCxV, dim=1)
|
| 1008 |
+
|
| 1009 |
+
pred_BxC = pred_BC.view(B, self.num_channels)
|
| 1010 |
+
|
| 1011 |
+
return pred_BxC, model_kwargs
|
| 1012 |
+
|
| 1013 |
+
def generate(
|
| 1014 |
+
self,
|
| 1015 |
+
input_ids,
|
| 1016 |
+
attention_mask,
|
| 1017 |
+
dec_output,
|
| 1018 |
+
max_tokens,
|
| 1019 |
+
min_tokens=None,
|
| 1020 |
+
codec_input_ids: Optional[torch.Tensor] = None,
|
| 1021 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 1022 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
| 1023 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 1024 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
| 1025 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
| 1026 |
+
neg_input_size = 2,
|
| 1027 |
+
cfg_scale = 3.0,
|
| 1028 |
+
temperature: float = 1.2,
|
| 1029 |
+
top_p: float = 0.95,
|
| 1030 |
+
cfg_filter_top_k: int = 45,
|
| 1031 |
+
eos_prob_mul_factor: float = 0.8,
|
| 1032 |
+
do_sample: bool = True,
|
| 1033 |
+
debug_guidance_step: int = 0,
|
| 1034 |
+
use_cache=True,
|
| 1035 |
+
):
|
| 1036 |
+
if codec_input_ids is not None:
|
| 1037 |
+
assert use_cache
|
| 1038 |
+
batch_size = input_ids.shape[0] // neg_input_size
|
| 1039 |
+
audio_eos_value = self.config.codec_eos_value
|
| 1040 |
+
audio_pad_value = self.config.codec_pad_value
|
| 1041 |
+
delay_pattern = self.config.codec_delay_pattern
|
| 1042 |
+
max_delay_pattern = max(delay_pattern)
|
| 1043 |
+
delay_pattern_Cx = torch.tensor(delay_pattern, device=self.device, dtype=torch.long)
|
| 1044 |
+
|
| 1045 |
+
dec_step = min(dec_output.prefill_steps) - 1
|
| 1046 |
+
|
| 1047 |
+
eos_detected_Bx = torch.zeros((batch_size,), dtype=torch.bool, device=self.device)
|
| 1048 |
+
eos_countdown_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device)
|
| 1049 |
+
finished_step_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device)
|
| 1050 |
+
|
| 1051 |
+
bos_over = False
|
| 1052 |
+
model_kwargs = dict(attention_mask=attention_mask, use_cache=True)
|
| 1053 |
+
model_kwargs["past_key_values"] = DynamicCache()
|
| 1054 |
+
model_kwargs["cache_position"] = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
|
| 1055 |
+
attention_mask = model_kwargs["attention_mask"]
|
| 1056 |
+
past_key_values = model_kwargs["past_key_values"]
|
| 1057 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1058 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1059 |
+
cache_position = torch.arange(0, input_ids.shape[-1], device=input_ids.device)
|
| 1060 |
+
inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids)
|
| 1061 |
+
outputs = self.language_model(
|
| 1062 |
+
input_ids=None,
|
| 1063 |
+
attention_mask=attention_mask,
|
| 1064 |
+
position_ids=position_ids,
|
| 1065 |
+
past_key_values=past_key_values,
|
| 1066 |
+
inputs_embeds=inputs_embeds,
|
| 1067 |
+
pixel_values=pixel_values,
|
| 1068 |
+
pixel_values_videos=pixel_values_videos,
|
| 1069 |
+
image_grid_thw=image_grid_thw,
|
| 1070 |
+
video_grid_thw=video_grid_thw,
|
| 1071 |
+
second_per_grid_ts=second_per_grid_ts,
|
| 1072 |
+
use_cache=True,
|
| 1073 |
+
output_attentions=False,
|
| 1074 |
+
output_hidden_states=False,
|
| 1075 |
+
return_dict=True,
|
| 1076 |
+
cache_position=cache_position,
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
model_kwargs["input_ids"] = input_ids
|
| 1080 |
+
model_kwargs["codec_input_ids"] = None
|
| 1081 |
+
model_kwargs["labels"] = torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100
|
| 1082 |
+
labels_Bx1xC = dec_output.get_labels_at(0)
|
| 1083 |
+
if labels_Bx1xC is not None:
|
| 1084 |
+
model_kwargs["codec_labels"] = (torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100).unsqueeze(-1).expand(-1, -1, self.num_channels)
|
| 1085 |
+
assert (labels_Bx1xC != self.config.codec_bos_value).sum() == 0
|
| 1086 |
+
labels_Bx1xC = torch.full_like(labels_Bx1xC, -100)
|
| 1087 |
+
model_kwargs["codec_labels"] = torch.cat((model_kwargs["codec_labels"], labels_Bx1xC), dim=1)
|
| 1088 |
+
model_kwargs["past_key_values"] = outputs.past_key_values
|
| 1089 |
+
attention_mask = model_kwargs["attention_mask"]
|
| 1090 |
+
model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
| 1091 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
|
| 1092 |
+
|
| 1093 |
+
while dec_step < max_tokens:
|
| 1094 |
+
if (eos_countdown_Bx == 0).all():
|
| 1095 |
+
break
|
| 1096 |
+
|
| 1097 |
+
current_step_idx = dec_step + 1
|
| 1098 |
+
tokens_Bx1xC = dec_output.get_tokens_at(dec_step)
|
| 1099 |
+
labels_Bx1xC = dec_output.get_labels_at(dec_step + 1)
|
| 1100 |
+
|
| 1101 |
+
pred_BxC, model_kwargs = self._decoder_step(
|
| 1102 |
+
tokens_Bx1xC=tokens_Bx1xC,
|
| 1103 |
+
model_kwargs=model_kwargs,
|
| 1104 |
+
cfg_scale=cfg_scale,
|
| 1105 |
+
neg_input_size=neg_input_size,
|
| 1106 |
+
temperature=temperature,
|
| 1107 |
+
top_p=top_p,
|
| 1108 |
+
top_k=cfg_filter_top_k,
|
| 1109 |
+
do_sample=do_sample,
|
| 1110 |
+
eos_prob_mul_factor=eos_prob_mul_factor,
|
| 1111 |
+
labels_Bx1xC=labels_Bx1xC,
|
| 1112 |
+
use_cache=use_cache,
|
| 1113 |
+
enable_eos=(min_tokens is None or dec_step >= min_tokens),
|
| 1114 |
+
)
|
| 1115 |
+
if labels_Bx1xC is not None and (dec_step < debug_guidance_step or debug_guidance_step==-1):
|
| 1116 |
+
pred_BxC = labels_Bx1xC[:, 0]
|
| 1117 |
+
|
| 1118 |
+
active_mask_Bx = eos_countdown_Bx != 0
|
| 1119 |
+
eos_trigger_Bx = torch.zeros_like(active_mask_Bx)
|
| 1120 |
+
if active_mask_Bx.any():
|
| 1121 |
+
is_eos_token = (~eos_detected_Bx[active_mask_Bx]) & (pred_BxC[active_mask_Bx, 0] == audio_eos_value)
|
| 1122 |
+
is_max_len = current_step_idx >= max_tokens - max_delay_pattern
|
| 1123 |
+
eos_trigger_Bx[active_mask_Bx] = is_eos_token | is_max_len
|
| 1124 |
+
eos_detected_Bx |= eos_trigger_Bx
|
| 1125 |
+
start_countdown_mask_Bx = eos_trigger_Bx & (eos_countdown_Bx < 0)
|
| 1126 |
+
if start_countdown_mask_Bx.any():
|
| 1127 |
+
eos_countdown_Bx[start_countdown_mask_Bx] = max_delay_pattern
|
| 1128 |
+
finished_step_Bx[start_countdown_mask_Bx] = current_step_idx
|
| 1129 |
+
|
| 1130 |
+
padding_mask_Bx = eos_countdown_Bx > 0
|
| 1131 |
+
if padding_mask_Bx.any():
|
| 1132 |
+
pred_active_BxC = pred_BxC[padding_mask_Bx].clone()
|
| 1133 |
+
countdown_active_Bx = eos_countdown_Bx[padding_mask_Bx]
|
| 1134 |
+
step_after_eos_Bx = max_delay_pattern - countdown_active_Bx
|
| 1135 |
+
step_after_eos_Bx_ = step_after_eos_Bx.unsqueeze(1)
|
| 1136 |
+
delay_pattern_Cx_ = delay_pattern_Cx.unsqueeze(0)
|
| 1137 |
+
eos_mask_NxC = step_after_eos_Bx_ == delay_pattern_Cx_
|
| 1138 |
+
pad_mask_NxC = step_after_eos_Bx_ > delay_pattern_Cx_
|
| 1139 |
+
pred_active_BxC[eos_mask_NxC] = audio_eos_value
|
| 1140 |
+
pred_active_BxC[pad_mask_NxC] = audio_pad_value
|
| 1141 |
+
pred_BxC[padding_mask_Bx] = pred_active_BxC
|
| 1142 |
+
eos_countdown_Bx[padding_mask_Bx] -= 1
|
| 1143 |
+
|
| 1144 |
+
if not bos_over:
|
| 1145 |
+
bos_over = all(current_step_idx - prefill_step >= max_delay_pattern for prefill_step in dec_output.prefill_steps)
|
| 1146 |
+
|
| 1147 |
+
dec_output.update_one(pred_BxC, current_step_idx, not bos_over)
|
| 1148 |
+
dec_step += 1
|
| 1149 |
+
|
| 1150 |
+
final_step = dec_step + 1
|
| 1151 |
+
finished_step_Bx[finished_step_Bx == -1] = final_step - max_delay_pattern
|
| 1152 |
+
prefill_steps_tensor = torch.tensor(dec_output.prefill_steps, device=self.device)
|
| 1153 |
+
lengths_Bx = finished_step_Bx - prefill_steps_tensor
|
| 1154 |
+
lengths_Bx = torch.clamp(lengths_Bx, min=0)
|
| 1155 |
+
max_len = lengths_Bx.max().item() + max_delay_pattern
|
| 1156 |
+
|
| 1157 |
+
if max_len > 0:
|
| 1158 |
+
num_channels = self.num_channels
|
| 1159 |
+
generated_codes = torch.full(
|
| 1160 |
+
(batch_size, max_len, num_channels),
|
| 1161 |
+
fill_value=audio_pad_value,
|
| 1162 |
+
dtype=torch.long,
|
| 1163 |
+
device=self.device,
|
| 1164 |
+
)
|
| 1165 |
+
|
| 1166 |
+
for i in range(batch_size):
|
| 1167 |
+
start_step = dec_output.prefill_steps[i]
|
| 1168 |
+
actual_len = lengths_Bx[i].item() + max_delay_pattern
|
| 1169 |
+
if actual_len > 0:
|
| 1170 |
+
tokens_to_copy = dec_output.generated_tokens[i, start_step : start_step + actual_len, :]
|
| 1171 |
+
generated_codes[i, :actual_len, :] = tokens_to_copy
|
| 1172 |
+
|
| 1173 |
+
return generated_codes, lengths_Bx
|
| 1174 |
+
else:
|
| 1175 |
+
print("Warning: Nothing generated for any sequence in the batch.")
|
| 1176 |
+
return None, None
|
| 1177 |
+
|
| 1178 |
+
# AutoConfig.register("qwen2_5_vl_moe_text", Qwen2_5_VLMoETextConfig)
|
| 1179 |
+
# AutoModelForCausalLM.register(Qwen2_5_VLMoETextConfig, Qwen2_5_VLMoETextModel)
|
| 1180 |
+
|
| 1181 |
+
# AutoConfig.register("uni_audio_rvq_qwen2_5vl_moe", UniMoEAudioConfig)
|
| 1182 |
+
# AutoModelForCausalLM.register(UniMoEAudioConfig, UniMoEAudio)
|
special_tokens_map.json
CHANGED
|
@@ -12,7 +12,84 @@
|
|
| 12 |
"<|vision_end|>",
|
| 13 |
"<|vision_pad|>",
|
| 14 |
"<|image_pad|>",
|
| 15 |
-
"<|video_pad|>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
],
|
| 17 |
"eos_token": {
|
| 18 |
"content": "<|im_end|>",
|
|
|
|
| 12 |
"<|vision_end|>",
|
| 13 |
"<|vision_pad|>",
|
| 14 |
"<|image_pad|>",
|
| 15 |
+
"<|video_pad|>",
|
| 16 |
+
{
|
| 17 |
+
"content": "<|AUDIO_PLACEHOLDER|>",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"content": "<|AUDIO_START|>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"content": "<|AUDIO_END|>",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"content": "<|SPEECH_START|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"content": "<|SPEECH_END|>",
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"normalized": false,
|
| 48 |
+
"rstrip": false,
|
| 49 |
+
"single_word": false
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"content": "<|VOICE_PROMPT_START|>",
|
| 53 |
+
"lstrip": false,
|
| 54 |
+
"normalized": false,
|
| 55 |
+
"rstrip": false,
|
| 56 |
+
"single_word": false
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"content": "<|VOICE_PROMPT_END|>",
|
| 60 |
+
"lstrip": false,
|
| 61 |
+
"normalized": false,
|
| 62 |
+
"rstrip": false,
|
| 63 |
+
"single_word": false
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"content": "<|SPEECH_PROMPT_START|>",
|
| 67 |
+
"lstrip": false,
|
| 68 |
+
"normalized": false,
|
| 69 |
+
"rstrip": false,
|
| 70 |
+
"single_word": false
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"content": "<|SPEECH_PROMPT_END|>",
|
| 74 |
+
"lstrip": false,
|
| 75 |
+
"normalized": false,
|
| 76 |
+
"rstrip": false,
|
| 77 |
+
"single_word": false
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"content": "<|MUSIC_START|>",
|
| 81 |
+
"lstrip": false,
|
| 82 |
+
"normalized": false,
|
| 83 |
+
"rstrip": false,
|
| 84 |
+
"single_word": false
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"content": "<|MUSIC_END|>",
|
| 88 |
+
"lstrip": false,
|
| 89 |
+
"normalized": false,
|
| 90 |
+
"rstrip": false,
|
| 91 |
+
"single_word": false
|
| 92 |
+
}
|
| 93 |
],
|
| 94 |
"eos_token": {
|
| 95 |
"content": "<|im_end|>",
|
tokenizer_config.json
CHANGED
|
@@ -177,6 +177,94 @@
|
|
| 177 |
"rstrip": false,
|
| 178 |
"single_word": false,
|
| 179 |
"special": false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
}
|
| 181 |
},
|
| 182 |
"additional_special_tokens": [
|
|
@@ -192,15 +280,27 @@
|
|
| 192 |
"<|vision_end|>",
|
| 193 |
"<|vision_pad|>",
|
| 194 |
"<|image_pad|>",
|
| 195 |
-
"<|video_pad|>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
],
|
| 197 |
"bos_token": null,
|
| 198 |
"clean_up_tokenization_spaces": false,
|
| 199 |
"eos_token": "<|im_end|>",
|
| 200 |
"errors": "replace",
|
| 201 |
"extra_special_tokens": {},
|
| 202 |
-
"model_max_length":
|
| 203 |
"pad_token": "<|endoftext|>",
|
|
|
|
| 204 |
"processor_class": "Qwen2_5_VLProcessor",
|
| 205 |
"split_special_tokens": false,
|
| 206 |
"tokenizer_class": "Qwen2Tokenizer",
|
|
|
|
| 177 |
"rstrip": false,
|
| 178 |
"single_word": false,
|
| 179 |
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<|AUDIO_PLACEHOLDER|>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": true
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "<|AUDIO_START|>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": true
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<|AUDIO_END|>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": true
|
| 204 |
+
},
|
| 205 |
+
"151668": {
|
| 206 |
+
"content": "<|SPEECH_START|>",
|
| 207 |
+
"lstrip": false,
|
| 208 |
+
"normalized": false,
|
| 209 |
+
"rstrip": false,
|
| 210 |
+
"single_word": false,
|
| 211 |
+
"special": true
|
| 212 |
+
},
|
| 213 |
+
"151669": {
|
| 214 |
+
"content": "<|SPEECH_END|>",
|
| 215 |
+
"lstrip": false,
|
| 216 |
+
"normalized": false,
|
| 217 |
+
"rstrip": false,
|
| 218 |
+
"single_word": false,
|
| 219 |
+
"special": true
|
| 220 |
+
},
|
| 221 |
+
"151670": {
|
| 222 |
+
"content": "<|VOICE_PROMPT_START|>",
|
| 223 |
+
"lstrip": false,
|
| 224 |
+
"normalized": false,
|
| 225 |
+
"rstrip": false,
|
| 226 |
+
"single_word": false,
|
| 227 |
+
"special": true
|
| 228 |
+
},
|
| 229 |
+
"151671": {
|
| 230 |
+
"content": "<|VOICE_PROMPT_END|>",
|
| 231 |
+
"lstrip": false,
|
| 232 |
+
"normalized": false,
|
| 233 |
+
"rstrip": false,
|
| 234 |
+
"single_word": false,
|
| 235 |
+
"special": true
|
| 236 |
+
},
|
| 237 |
+
"151672": {
|
| 238 |
+
"content": "<|SPEECH_PROMPT_START|>",
|
| 239 |
+
"lstrip": false,
|
| 240 |
+
"normalized": false,
|
| 241 |
+
"rstrip": false,
|
| 242 |
+
"single_word": false,
|
| 243 |
+
"special": true
|
| 244 |
+
},
|
| 245 |
+
"151673": {
|
| 246 |
+
"content": "<|SPEECH_PROMPT_END|>",
|
| 247 |
+
"lstrip": false,
|
| 248 |
+
"normalized": false,
|
| 249 |
+
"rstrip": false,
|
| 250 |
+
"single_word": false,
|
| 251 |
+
"special": true
|
| 252 |
+
},
|
| 253 |
+
"151674": {
|
| 254 |
+
"content": "<|MUSIC_START|>",
|
| 255 |
+
"lstrip": false,
|
| 256 |
+
"normalized": false,
|
| 257 |
+
"rstrip": false,
|
| 258 |
+
"single_word": false,
|
| 259 |
+
"special": true
|
| 260 |
+
},
|
| 261 |
+
"151675": {
|
| 262 |
+
"content": "<|MUSIC_END|>",
|
| 263 |
+
"lstrip": false,
|
| 264 |
+
"normalized": false,
|
| 265 |
+
"rstrip": false,
|
| 266 |
+
"single_word": false,
|
| 267 |
+
"special": true
|
| 268 |
}
|
| 269 |
},
|
| 270 |
"additional_special_tokens": [
|
|
|
|
| 280 |
"<|vision_end|>",
|
| 281 |
"<|vision_pad|>",
|
| 282 |
"<|image_pad|>",
|
| 283 |
+
"<|video_pad|>",
|
| 284 |
+
"<|AUDIO_PLACEHOLDER|>",
|
| 285 |
+
"<|AUDIO_START|>",
|
| 286 |
+
"<|AUDIO_END|>",
|
| 287 |
+
"<|SPEECH_START|>",
|
| 288 |
+
"<|SPEECH_END|>",
|
| 289 |
+
"<|VOICE_PROMPT_START|>",
|
| 290 |
+
"<|VOICE_PROMPT_END|>",
|
| 291 |
+
"<|SPEECH_PROMPT_START|>",
|
| 292 |
+
"<|SPEECH_PROMPT_END|>",
|
| 293 |
+
"<|MUSIC_START|>",
|
| 294 |
+
"<|MUSIC_END|>"
|
| 295 |
],
|
| 296 |
"bos_token": null,
|
| 297 |
"clean_up_tokenization_spaces": false,
|
| 298 |
"eos_token": "<|im_end|>",
|
| 299 |
"errors": "replace",
|
| 300 |
"extra_special_tokens": {},
|
| 301 |
+
"model_max_length": 4096,
|
| 302 |
"pad_token": "<|endoftext|>",
|
| 303 |
+
"padding_side": "right",
|
| 304 |
"processor_class": "Qwen2_5_VLProcessor",
|
| 305 |
"split_special_tokens": false,
|
| 306 |
"tokenizer_class": "Qwen2Tokenizer",
|
utils.py
ADDED
|
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
UniMoE Audio Utilities Module
|
| 4 |
+
Author: UniMoE Audio Team
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import copy
|
| 8 |
+
import glob
|
| 9 |
+
import json
|
| 10 |
+
import math
|
| 11 |
+
import os
|
| 12 |
+
import re
|
| 13 |
+
import shutil
|
| 14 |
+
import sys
|
| 15 |
+
import time
|
| 16 |
+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, TYPE_CHECKING, Callable
|
| 17 |
+
|
| 18 |
+
import dac
|
| 19 |
+
import datasets
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import torchaudio
|
| 24 |
+
import transformers
|
| 25 |
+
from audiotools import AudioSignal
|
| 26 |
+
from safetensors import safe_open
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
from transformers import AutoProcessor, AutoTokenizer, LogitsProcessor, LogitsProcessorList
|
| 29 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
| 30 |
+
from PIL import Image
|
| 31 |
+
from torchvision import io, transforms
|
| 32 |
+
from torchvision.transforms import InterpolationMode
|
| 33 |
+
import torchvision
|
| 34 |
+
|
| 35 |
+
from qwen_vl_utils import smart_resize, process_vision_info
|
| 36 |
+
|
| 37 |
+
import deepspeed
|
| 38 |
+
from deepspeed import comm as dist
|
| 39 |
+
from deepspeed.moe.sharded_moe import _capacity, _one_hot_to_float, einsum, gumbel_rsample
|
| 40 |
+
from torch import Tensor
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
import torch_npu
|
| 44 |
+
IS_CUDA = False
|
| 45 |
+
except:
|
| 46 |
+
IS_CUDA = True
|
| 47 |
+
|
| 48 |
+
try:
|
| 49 |
+
# To enable Tutel MoE optimizations:
|
| 50 |
+
# python3 -m pip install --user --upgrade git+https://github.com/microsoft/[email protected]
|
| 51 |
+
from tutel import moe as tutel_moe
|
| 52 |
+
TUTEL_INSTALLED = True
|
| 53 |
+
except:
|
| 54 |
+
# Fail silently so we don't spam logs unnecessarily if user isn't using tutel
|
| 55 |
+
TUTEL_INSTALLED = False
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
SYSTEM_MESSAGE = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"""
|
| 60 |
+
INPUT_FORMAT = """<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"""
|
| 61 |
+
AUDIO_START = "<|AUDIO_START|>"
|
| 62 |
+
|
| 63 |
+
DEFAULT_VIDEO_PROMPT = "<|vision_start|><|video_pad|><|vision_end|>{}"
|
| 64 |
+
IMAGE_FACTOR = 28
|
| 65 |
+
MIN_PIXELS = 4 * 28 * 28
|
| 66 |
+
MAX_PIXELS = 16384 * 28 * 28
|
| 67 |
+
MAX_RATIO = 200
|
| 68 |
+
VIDEO_TOTAL_PIXELS = 16 * 28 * 28
|
| 69 |
+
VIDEO_MIN_PIXELS = 16 * 28 * 28
|
| 70 |
+
VIDEO_MAX_PIXELS = 64 * 28 * 28
|
| 71 |
+
FRAME_FACTOR = 2
|
| 72 |
+
|
| 73 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 74 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 75 |
+
|
| 76 |
+
IMG_START_TOKEN='<img>'
|
| 77 |
+
IMG_END_TOKEN='</img>'
|
| 78 |
+
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'
|
| 79 |
+
IMG_PREFIX_FORMAT = "<|IMAGE_PLACE_HOLDER|>"
|
| 80 |
+
|
| 81 |
+
# =============================================================================
|
| 82 |
+
# DAC Utilities
|
| 83 |
+
# =============================================================================
|
| 84 |
+
|
| 85 |
+
class Dac:
|
| 86 |
+
def __init__(self):
|
| 87 |
+
base_dir = os.path.dirname(__file__)
|
| 88 |
+
dac_model_dir = os.path.join(base_dir, "dac_model")
|
| 89 |
+
model_path = os.path.join(dac_model_dir, "weights_16khz.pth")
|
| 90 |
+
|
| 91 |
+
if not os.path.isfile(model_path):
|
| 92 |
+
print(f"DAC model not found at {model_path}, downloading...")
|
| 93 |
+
os.makedirs(dac_model_dir, exist_ok=True)
|
| 94 |
+
downloaded_path = dac.utils.download(model_type="16khz")
|
| 95 |
+
shutil.move(downloaded_path, model_path)
|
| 96 |
+
print(f"DAC model downloaded and saved to {model_path}")
|
| 97 |
+
|
| 98 |
+
env_path = os.environ.get("DAC_WEIGHTS")
|
| 99 |
+
candidates = []
|
| 100 |
+
if env_path:
|
| 101 |
+
candidates.append(env_path)
|
| 102 |
+
|
| 103 |
+
candidates.extend([
|
| 104 |
+
model_path,
|
| 105 |
+
os.path.join(base_dir, "weights_16khz.pth"),
|
| 106 |
+
os.path.join(os.getcwd(), "utils", "dac_model", "weights_16khz.pth"),
|
| 107 |
+
os.path.join(os.getcwd(), "dac_model", "weights_16khz.pth"),
|
| 108 |
+
])
|
| 109 |
+
|
| 110 |
+
final_model_path = next((p for p in candidates if p and os.path.isfile(p)), None)
|
| 111 |
+
if not final_model_path:
|
| 112 |
+
searched = "\n - " + "\n - ".join(candidates)
|
| 113 |
+
raise FileNotFoundError(
|
| 114 |
+
"DAC weights not found. Please place weights_16khz.pth in one of the following locations or set DAC_WEIGHTS to an absolute path:" + searched
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.model = dac.DAC.load(final_model_path)
|
| 118 |
+
self.resampler = dict()
|
| 119 |
+
if IS_CUDA:
|
| 120 |
+
self.model = self.model.to("cuda")
|
| 121 |
+
else:
|
| 122 |
+
self.model = self.model.to("npu")
|
| 123 |
+
|
| 124 |
+
def encode(self, audio_path):
|
| 125 |
+
signal = AudioSignal(audio_path)
|
| 126 |
+
if signal.audio_data.shape[1] == 2:
|
| 127 |
+
signal.audio_data = 0.5 * (signal.audio_data[:, :1, :] + signal.audio_data[:, 1:, :])
|
| 128 |
+
signal.to(self.model.device)
|
| 129 |
+
|
| 130 |
+
if signal.sample_rate != 16000:
|
| 131 |
+
if not str(signal.sample_rate) in self.resampler:
|
| 132 |
+
self.resampler[str(signal.sample_rate)] = torchaudio.transforms.Resample(signal.sample_rate, 16000)
|
| 133 |
+
if IS_CUDA:
|
| 134 |
+
self.resampler[str(signal.sample_rate)] = self.resampler[str(signal.sample_rate)].cuda()
|
| 135 |
+
else:
|
| 136 |
+
self.resampler[str(signal.sample_rate)] = self.resampler[str(signal.sample_rate)].npu()
|
| 137 |
+
|
| 138 |
+
signal.audio_data = self.resampler[str(signal.sample_rate)](signal.audio_data)
|
| 139 |
+
signal.sample_rate = 16000
|
| 140 |
+
|
| 141 |
+
x = self.model.preprocess(signal.audio_data.to(self.model.device), signal.sample_rate)
|
| 142 |
+
z, codes, latents, _, _ = self.model.encode(x)
|
| 143 |
+
|
| 144 |
+
codes = codes[0].clone().detach().transpose(0, 1)
|
| 145 |
+
assert codes.shape[1] == 12 and len(codes.shape) == 2
|
| 146 |
+
codes = codes.tolist()
|
| 147 |
+
|
| 148 |
+
return codes
|
| 149 |
+
|
| 150 |
+
def decode(self, codes, save_path, min_duration=None):
|
| 151 |
+
assert codes.shape[0] == 1 and codes.shape[1] == 12
|
| 152 |
+
z, _, _ = self.model.quantizer.from_codes(codes.to(self.model.device))
|
| 153 |
+
audio_out = self.model.decode(z)[0].detach().cpu()
|
| 154 |
+
|
| 155 |
+
sample_rate = 16000
|
| 156 |
+
duration = audio_out.size(1) / sample_rate
|
| 157 |
+
if min_duration is not None and duration < min_duration:
|
| 158 |
+
padding_duration = min_duration - duration
|
| 159 |
+
padding_samples = int(padding_duration * sample_rate)
|
| 160 |
+
padding = torch.zeros((audio_out.size(0), padding_samples), dtype=audio_out.dtype, device=audio_out.device)
|
| 161 |
+
audio_out = torch.cat((audio_out, padding), dim=1)
|
| 162 |
+
|
| 163 |
+
torchaudio.save(save_path, audio_out.detach().cpu(), sample_rate=16000, encoding="PCM_S", bits_per_sample=16)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def build_delay_indices(B: int, T: int, C: int, delay_pattern: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 167 |
+
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
|
| 168 |
+
|
| 169 |
+
t_idx_BxT = torch.broadcast_to(
|
| 170 |
+
torch.arange(T, dtype=torch.int32)[None, :],
|
| 171 |
+
[B, T],
|
| 172 |
+
)
|
| 173 |
+
t_idx_BxTx1 = t_idx_BxT[..., None]
|
| 174 |
+
t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
|
| 175 |
+
|
| 176 |
+
b_idx_BxTxC = torch.broadcast_to(
|
| 177 |
+
torch.arange(B, dtype=torch.int32).view(B, 1, 1),
|
| 178 |
+
[B, T, C],
|
| 179 |
+
)
|
| 180 |
+
c_idx_BxTxC = torch.broadcast_to(
|
| 181 |
+
torch.arange(C, dtype=torch.int32).view(1, 1, C),
|
| 182 |
+
[B, T, C],
|
| 183 |
+
)
|
| 184 |
+
t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
|
| 185 |
+
indices_BTCx3 = torch.stack(
|
| 186 |
+
[
|
| 187 |
+
b_idx_BxTxC.reshape(-1),
|
| 188 |
+
t_clamped_BxTxC.reshape(-1),
|
| 189 |
+
c_idx_BxTxC.reshape(-1),
|
| 190 |
+
],
|
| 191 |
+
dim=1,
|
| 192 |
+
).long()
|
| 193 |
+
|
| 194 |
+
return t_idx_BxTxC, indices_BTCx3
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def apply_audio_delay(audio_BxTxC: torch.Tensor, pad_value: int, bos_value: int, precomp: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
| 198 |
+
device = audio_BxTxC.device
|
| 199 |
+
t_idx_BxTxC, indices_BTCx3 = precomp
|
| 200 |
+
t_idx_BxTxC = t_idx_BxTxC.to(device)
|
| 201 |
+
indices_BTCx3 = indices_BTCx3.to(device)
|
| 202 |
+
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
|
| 203 |
+
gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
|
| 204 |
+
mask_bos = t_idx_BxTxC < 0
|
| 205 |
+
mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1]
|
| 206 |
+
|
| 207 |
+
bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
|
| 208 |
+
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
| 209 |
+
|
| 210 |
+
result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
|
| 211 |
+
|
| 212 |
+
return result_BxTxC
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def build_revert_indices(B: int, T: int, C: int, delay_pattern: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 216 |
+
device = None
|
| 217 |
+
delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
|
| 218 |
+
t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
|
| 219 |
+
t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
|
| 220 |
+
t_idx_BxTxC = torch.minimum(
|
| 221 |
+
t_idx_BT1 + delay_arr.view(1, 1, C),
|
| 222 |
+
torch.tensor(T - 1, device=device),
|
| 223 |
+
)
|
| 224 |
+
b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
|
| 225 |
+
c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
|
| 226 |
+
indices_BTCx3 = torch.stack(
|
| 227 |
+
[
|
| 228 |
+
b_idx_BxTxC.reshape(-1),
|
| 229 |
+
t_idx_BxTxC.reshape(-1),
|
| 230 |
+
c_idx_BxTxC.reshape(-1),
|
| 231 |
+
],
|
| 232 |
+
axis=1,
|
| 233 |
+
).long()
|
| 234 |
+
|
| 235 |
+
return t_idx_BxTxC, indices_BTCx3
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def revert_audio_delay(
|
| 239 |
+
audio_BxTxC: torch.Tensor,
|
| 240 |
+
pad_value: int,
|
| 241 |
+
precomp: Tuple[torch.Tensor, torch.Tensor],
|
| 242 |
+
T: int,
|
| 243 |
+
) -> torch.Tensor:
|
| 244 |
+
t_idx_BxTxC, indices_BTCx3 = precomp
|
| 245 |
+
device = audio_BxTxC.device
|
| 246 |
+
t_idx_BxTxC = t_idx_BxTxC.to(device)
|
| 247 |
+
indices_BTCx3 = indices_BTCx3.to(device)
|
| 248 |
+
gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
|
| 249 |
+
gathered_BxTxC = gathered_flat.view(audio_BxTxC.size())
|
| 250 |
+
|
| 251 |
+
pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
|
| 252 |
+
T_tensor = torch.tensor(T, device=device)
|
| 253 |
+
|
| 254 |
+
result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC)
|
| 255 |
+
|
| 256 |
+
return result_BxTxC
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def prepare_audio_prompt(model, audio_prompts: list[torch.Tensor]):
|
| 260 |
+
num_channels = model.config.codec_channels
|
| 261 |
+
audio_bos_value = model.config.codec_bos_value
|
| 262 |
+
delay_pattern = model.config.codec_delay_pattern
|
| 263 |
+
max_delay_pattern = max(delay_pattern)
|
| 264 |
+
batch_size = len(audio_prompts)
|
| 265 |
+
max_len = max(p.shape[0] if p is not None else 0 for p in audio_prompts) + max_delay_pattern + 1
|
| 266 |
+
prefill_steps = []
|
| 267 |
+
prefill = torch.full(
|
| 268 |
+
(batch_size, max_len, num_channels),
|
| 269 |
+
fill_value=-1,
|
| 270 |
+
dtype=torch.int,
|
| 271 |
+
device=model.device,
|
| 272 |
+
)
|
| 273 |
+
prefill[:, 0, :] = audio_bos_value
|
| 274 |
+
for i in range(batch_size):
|
| 275 |
+
prompt = audio_prompts[i]
|
| 276 |
+
if prompt is not None:
|
| 277 |
+
prompt = prompt.to(device=model.device, dtype=torch.int)
|
| 278 |
+
prefill[i, 1 : prompt.shape[0] + 1, :] = prompt
|
| 279 |
+
prefill_steps.append(prompt.shape[0] + 1)
|
| 280 |
+
else:
|
| 281 |
+
prefill_steps.append(1)
|
| 282 |
+
|
| 283 |
+
delay_precomp = build_delay_indices(
|
| 284 |
+
B=batch_size,
|
| 285 |
+
T=max_len,
|
| 286 |
+
C=num_channels,
|
| 287 |
+
delay_pattern=delay_pattern,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
delayed_batch = apply_audio_delay(
|
| 291 |
+
audio_BxTxC=prefill,
|
| 292 |
+
pad_value=-1,
|
| 293 |
+
bos_value=audio_bos_value,
|
| 294 |
+
precomp=delay_precomp,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
return delayed_batch, prefill_steps
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class DecoderOutput:
|
| 301 |
+
def __init__(self, prefill, prefill_steps, device: torch.device, labels_prefill=None):
|
| 302 |
+
self.generated_tokens = prefill
|
| 303 |
+
self.prefill_steps = prefill_steps
|
| 304 |
+
self.labels_prefill = labels_prefill
|
| 305 |
+
self.device = device
|
| 306 |
+
|
| 307 |
+
def get_tokens_at(self, step_from: int, step_to: int = None) -> torch.Tensor:
|
| 308 |
+
if step_to is None:
|
| 309 |
+
step_to = step_from + 1
|
| 310 |
+
return self.generated_tokens[:, step_from:step_to, :].to(self.device)
|
| 311 |
+
|
| 312 |
+
def get_labels_at(self, step_from: int, step_to: int = None) -> torch.Tensor:
|
| 313 |
+
if step_to is None:
|
| 314 |
+
step_to = step_from + 1
|
| 315 |
+
if self.labels_prefill is None:
|
| 316 |
+
return None
|
| 317 |
+
return self.labels_prefill[:, step_from:step_to, :].to(self.device)
|
| 318 |
+
|
| 319 |
+
def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
|
| 320 |
+
dec_out = dec_out.to(self.generated_tokens.dtype).to(self.generated_tokens.device)
|
| 321 |
+
if apply_mask:
|
| 322 |
+
assert step < self.generated_tokens.shape[1]
|
| 323 |
+
mask = self.generated_tokens[:, step, :] == -1
|
| 324 |
+
self.generated_tokens[:, step, :] = torch.where(mask, dec_out, self.generated_tokens[:, step, :])
|
| 325 |
+
else:
|
| 326 |
+
assert step == self.generated_tokens.shape[1]
|
| 327 |
+
self.generated_tokens = torch.cat((self.generated_tokens, dec_out[:, None, :]), dim=1)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def generate_output(model, generated_codes: torch.Tensor, lengths_Bx: torch.Tensor) -> list[np.ndarray]:
|
| 331 |
+
num_channels = model.config.codec_channels
|
| 332 |
+
batch_size = generated_codes.shape[0]
|
| 333 |
+
seq_length = generated_codes.shape[1]
|
| 334 |
+
delay_pattern = model.config.codec_delay_pattern
|
| 335 |
+
audio_pad_value = model.config.codec_pad_value
|
| 336 |
+
max_delay_pattern = max(delay_pattern)
|
| 337 |
+
revert_precomp = build_revert_indices(
|
| 338 |
+
B=batch_size,
|
| 339 |
+
T=seq_length,
|
| 340 |
+
C=num_channels,
|
| 341 |
+
delay_pattern=delay_pattern,
|
| 342 |
+
)
|
| 343 |
+
codebook = revert_audio_delay(
|
| 344 |
+
audio_BxTxC=generated_codes,
|
| 345 |
+
pad_value=audio_pad_value,
|
| 346 |
+
precomp=revert_precomp,
|
| 347 |
+
T=seq_length,
|
| 348 |
+
)[:, :-max_delay_pattern, :]
|
| 349 |
+
|
| 350 |
+
audios = []
|
| 351 |
+
for i in range(batch_size):
|
| 352 |
+
audios.append(codebook[i, : lengths_Bx[i], :].cpu())
|
| 353 |
+
|
| 354 |
+
return audios
|
| 355 |
+
|
| 356 |
+
def frame_process(images, **ele):
|
| 357 |
+
images = [torchvision.transforms.functional.pil_to_tensor(img) for img in images]
|
| 358 |
+
video = torch.stack(images, dim=0)
|
| 359 |
+
|
| 360 |
+
# copy from fetch_video
|
| 361 |
+
nframes, _, height, width = video.shape
|
| 362 |
+
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
| 363 |
+
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
|
| 364 |
+
max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
|
| 365 |
+
max_pixels_supposed = ele.get("max_pixels", max_pixels)
|
| 366 |
+
if max_pixels_supposed > max_pixels:
|
| 367 |
+
print(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
|
| 368 |
+
max_pixels = min(max_pixels_supposed, max_pixels)
|
| 369 |
+
if "resized_height" in ele and "resized_width" in ele:
|
| 370 |
+
resized_height, resized_width = smart_resize(
|
| 371 |
+
ele["resized_height"],
|
| 372 |
+
ele["resized_width"],
|
| 373 |
+
factor=IMAGE_FACTOR,
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
resized_height, resized_width = smart_resize(
|
| 377 |
+
height,
|
| 378 |
+
width,
|
| 379 |
+
factor=IMAGE_FACTOR,
|
| 380 |
+
min_pixels=min_pixels,
|
| 381 |
+
max_pixels=max_pixels,
|
| 382 |
+
)
|
| 383 |
+
video = transforms.functional.resize(
|
| 384 |
+
video,
|
| 385 |
+
[resized_height, resized_width],
|
| 386 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 387 |
+
antialias=True,
|
| 388 |
+
).float()
|
| 389 |
+
return video
|
| 390 |
+
|
| 391 |
+
def preprocess_codec(model, codec):
|
| 392 |
+
"""Preprocess codec tokens"""
|
| 393 |
+
codec_token = torch.tensor(codec, dtype=torch.long)
|
| 394 |
+
codec_token_len = codec_token.shape[0]
|
| 395 |
+
max_delay_pattern = max(model.config.codec_delay_pattern)
|
| 396 |
+
codec_input_ids = torch.zeros((codec_token_len + max_delay_pattern + 1, model.num_channels), dtype=torch.long)
|
| 397 |
+
|
| 398 |
+
for c in range(model.num_channels):
|
| 399 |
+
start = model.config.codec_delay_pattern[c] + 1
|
| 400 |
+
codec_input_ids[:start, c] = model.config.codec_bos_value
|
| 401 |
+
codec_input_ids[start : start + codec_token_len, c] = codec_token[:, c]
|
| 402 |
+
codec_input_ids[start + codec_token_len :, c] = model.config.codec_pad_value
|
| 403 |
+
if start + codec_token_len < codec_input_ids.shape[0]:
|
| 404 |
+
codec_input_ids[start + codec_token_len, c] = model.config.codec_eos_value
|
| 405 |
+
|
| 406 |
+
return codec_input_ids
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def tts_preprocess(batch_caption, prompt_codec, prompt_text, device):
|
| 410 |
+
|
| 411 |
+
text_input = []
|
| 412 |
+
codec_input_ids = []
|
| 413 |
+
for caption in batch_caption:
|
| 414 |
+
prompt_caption = "<|SPEECH_PROMPT_START|>" + prompt_text + "<|SPEECH_PROMPT_END|>"
|
| 415 |
+
prompt_caption += "<|VOICE_PROMPT_START|>" + "<|AUDIO_PLACEHOLDER|>" * prompt_codec.shape[0] + "<|VOICE_PROMPT_END|>"
|
| 416 |
+
prompt_caption_fn = lambda x: prompt_caption + "<|SPEECH_START|>" + x + "<|SPEECH_END|>"
|
| 417 |
+
|
| 418 |
+
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format(f"<|SPEECH_PROMPT_START|>{prompt_text}<|SPEECH_PROMPT_END|><|VOICE_PROMPT_START|><|VOICE_PROMPT_END|><|SPEECH_START|>{caption}<|SPEECH_END|>") + AUDIO_START)
|
| 419 |
+
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format(prompt_caption_fn("")) + AUDIO_START)
|
| 420 |
+
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format(prompt_caption_fn(caption)) + AUDIO_START)
|
| 421 |
+
codec_input_ids.append(prompt_codec.clone())
|
| 422 |
+
codec_input_ids.append(prompt_codec.clone())
|
| 423 |
+
|
| 424 |
+
codec_input_ids = torch.cat(codec_input_ids, dim=0).to(device)
|
| 425 |
+
|
| 426 |
+
tts_generation_kwargs = {
|
| 427 |
+
"codec_input_ids": codec_input_ids,
|
| 428 |
+
"cfg_scale": [2, 3],
|
| 429 |
+
"neg_input_size": 3,
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
return text_input, tts_generation_kwargs
|
| 433 |
+
|
| 434 |
+
def t2m_preprocess(batch_caption):
|
| 435 |
+
|
| 436 |
+
text_input = []
|
| 437 |
+
for caption in batch_caption:
|
| 438 |
+
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + "Low quality." + "<|MUSIC_END|>") + AUDIO_START)
|
| 439 |
+
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + caption + "<|MUSIC_END|>") + AUDIO_START)
|
| 440 |
+
|
| 441 |
+
t2m_generation_kwargs = {
|
| 442 |
+
"cfg_scale": 10,
|
| 443 |
+
"neg_input_size": 2,
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
return text_input, t2m_generation_kwargs
|
| 447 |
+
|
| 448 |
+
def v2m_preprocess(batch_caption, batch_video, fps=1):
|
| 449 |
+
|
| 450 |
+
def extract_images_from_video(video_path, fps=1, max_frames=1):
|
| 451 |
+
video = VideoFileClip(video_path)
|
| 452 |
+
duration = video.duration
|
| 453 |
+
|
| 454 |
+
# 提取图片
|
| 455 |
+
images = []
|
| 456 |
+
for i, t in enumerate(range(0, math.ceil(duration * fps))):
|
| 457 |
+
time_in_video = t / fps
|
| 458 |
+
frame = video.get_frame(time_in_video)
|
| 459 |
+
img = Image.fromarray(frame)
|
| 460 |
+
images.append(img)
|
| 461 |
+
|
| 462 |
+
if max_frames is not None and i >= max_frames - 1:
|
| 463 |
+
break
|
| 464 |
+
|
| 465 |
+
return images
|
| 466 |
+
|
| 467 |
+
text_input = []
|
| 468 |
+
video_inputs = []
|
| 469 |
+
fps_inputs = []
|
| 470 |
+
|
| 471 |
+
for caption, video in zip(batch_caption, batch_video):
|
| 472 |
+
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + "Low quality." + "<|MUSIC_END|>") + AUDIO_START)
|
| 473 |
+
text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + caption + "<|MUSIC_END|>") + AUDIO_START)
|
| 474 |
+
|
| 475 |
+
video_input = frame_process(
|
| 476 |
+
extract_images_from_video(video, fps),
|
| 477 |
+
fps = fps,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
video_inputs.append(video_input)
|
| 481 |
+
video_inputs.append(video_input)
|
| 482 |
+
|
| 483 |
+
fps_inputs.append(fps)
|
| 484 |
+
fps_inputs.append(fps)
|
| 485 |
+
|
| 486 |
+
t2m_generation_kwargs = {
|
| 487 |
+
"cfg_scale": 10,
|
| 488 |
+
"neg_input_size": 2,
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
return text_input, video_inputs, fps_inputs, t2m_generation_kwargs
|
video_preprocessor_config (1).json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": null,
|
| 3 |
+
"data_format": "channels_first",
|
| 4 |
+
"default_to_square": true,
|
| 5 |
+
"device": null,
|
| 6 |
+
"do_center_crop": null,
|
| 7 |
+
"do_convert_rgb": true,
|
| 8 |
+
"do_normalize": true,
|
| 9 |
+
"do_pad": null,
|
| 10 |
+
"do_rescale": true,
|
| 11 |
+
"do_resize": true,
|
| 12 |
+
"do_sample_frames": false,
|
| 13 |
+
"fps": null,
|
| 14 |
+
"image_mean": [
|
| 15 |
+
0.48145466,
|
| 16 |
+
0.4578275,
|
| 17 |
+
0.40821073
|
| 18 |
+
],
|
| 19 |
+
"image_std": [
|
| 20 |
+
0.26862954,
|
| 21 |
+
0.26130258,
|
| 22 |
+
0.27577711
|
| 23 |
+
],
|
| 24 |
+
"input_data_format": null,
|
| 25 |
+
"max_frames": 768,
|
| 26 |
+
"max_pixels": 12845056,
|
| 27 |
+
"merge_size": 2,
|
| 28 |
+
"min_frames": 4,
|
| 29 |
+
"min_pixels": 3136,
|
| 30 |
+
"num_frames": null,
|
| 31 |
+
"patch_size": 14,
|
| 32 |
+
"processor_class": "Qwen2_5_VLProcessor",
|
| 33 |
+
"resample": 3,
|
| 34 |
+
"rescale_factor": 0.00392156862745098,
|
| 35 |
+
"size": {
|
| 36 |
+
"longest_edge": 12845056,
|
| 37 |
+
"shortest_edge": 3136
|
| 38 |
+
},
|
| 39 |
+
"size_divisor": null,
|
| 40 |
+
"temporal_patch_size": 2,
|
| 41 |
+
"video_metadata": null,
|
| 42 |
+
"video_processor_type": "Qwen2VLVideoProcessor"
|
| 43 |
+
}
|
vocab.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|