foggyforest commited on
Commit
e576ca4
·
verified ·
1 Parent(s): 3ec670f

Upload 14 files

Browse files
DCMoE.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import Optional
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+ import deepspeed
8
+ from deepspeed import comm as dist
9
+ from deepspeed.utils import groups, log_dist
10
+ from deepspeed.utils.timer import SynchronizedWallClockTimer
11
+ from deepspeed.moe.sharded_moe import FIRST_ALLTOALL_TIMER, MOE_TIMER, SECOND_ALLTOALL_TIMER, _AllToAll, einsum, gumbel_rsample
12
+ from transformers.activations import ACT2FN
13
+
14
+ def compress_matrix(A: torch.Tensor, mask: torch.Tensor, force_dim: int = None, allow_larger_dim=None) -> torch.Tensor:
15
+ if A.shape[:2] != mask.shape:
16
+ raise ValueError("First two dimensions of A and mask must match.")
17
+ if mask.ndim != 2:
18
+ raise ValueError("mask must be a 2D tensor.")
19
+ if not ((mask == 0) | (mask == 1)).all():
20
+ raise ValueError(
21
+ f"mask must only contain 0s and 1s. dtype: {mask.dtype}. "
22
+ f"Invalid elements found at indices: {((mask != 0) & (mask != 1)).nonzero().tolist()} " # Get indices of elements not 0 AND not 1
23
+ f"with corresponding values: {mask[((mask != 0) & (mask != 1))].tolist()}. " # Get the values at those indices
24
+ f"\nOriginal mask (showing up to first 20 elements if large):\n{mask.flatten()[:20]}{'...' if mask.numel() > 20 else ''}"
25
+ )
26
+
27
+ S, E = mask.shape
28
+ trailing_dims_shape = A.shape[2:]
29
+ num_trailing_dims = len(trailing_dims_shape)
30
+ device = A.device
31
+
32
+ ones_per_column = mask.sum(dim=0)
33
+ X = ones_per_column.max().item() if force_dim is None else force_dim
34
+
35
+ if X == 0:
36
+ return torch.empty((0, E, *trailing_dims_shape), dtype=A.dtype, device=device)
37
+
38
+ sorted_row_indices_2d = torch.argsort(mask.float(), dim=0, descending=True)
39
+ view_shape_for_indices = (S, E, *((1,) * num_trailing_dims))
40
+ expanded_indices = sorted_row_indices_2d.view(view_shape_for_indices).expand_as(A)
41
+
42
+ A_gathered = torch.gather(A, 0, expanded_indices)
43
+
44
+ if X <= A_gathered.shape[0]:
45
+ B_candidate = A_gathered[:X, ...]
46
+ elif allow_larger_dim or allow_larger_dim is None:
47
+ if allow_larger_dim is None:
48
+ print(f"[Warning compress_matrix] Target dimension X ({X}) is larger than "
49
+ f"A's original row count S ({S}). Padding B_candidate with zeros.")
50
+ B_candidate = A_gathered
51
+ zeros_shape = [X - A_gathered.shape[0]] + list(B_candidate.shape[1:])
52
+ B_candidate = torch.cat((B_candidate, torch.zeros(zeros_shape, dtype=B_candidate.dtype, device=B_candidate.device)), dim=0) # Shape (X_target_dim, E, ...)
53
+ else:
54
+ raise AssertionError(
55
+ f"Target dimension X ({X}) is larger than A's original row count S ({S}) "
56
+ f"and allow_larger_dim is False. Padding is disallowed."
57
+ )
58
+ row_indices_for_B = torch.arange(X, device=device).unsqueeze(1)
59
+ b_mask_2d = row_indices_for_B < ones_per_column.unsqueeze(0)
60
+ view_shape_for_b_mask = (X, E, *((1,) * num_trailing_dims))
61
+ B = B_candidate * b_mask_2d.view(view_shape_for_b_mask).to(A.dtype)
62
+
63
+ return B
64
+
65
+
66
+ def decompress_matrix(B: torch.Tensor, mask: torch.Tensor, allow_larger_dim=None) -> torch.Tensor:
67
+ if B.shape[1] != mask.shape[1]:
68
+ raise ValueError("B's second dimension and mask's second dimension (E) must match.")
69
+ if mask.ndim != 2:
70
+ raise ValueError("mask must be a 2D tensor.")
71
+ if not ((mask == 0) | (mask == 1)).all():
72
+ raise ValueError("mask must only contain 0s and 1s.")
73
+
74
+ S, E = mask.shape
75
+ X = B.shape[0]
76
+ trailing_dims_shape = B.shape[2:]
77
+ num_trailing_dims = len(trailing_dims_shape)
78
+ device = B.device
79
+
80
+ if X == 0: return torch.zeros((S, E, *trailing_dims_shape), dtype=B.dtype, device=device)
81
+ if X <= S: pass
82
+ elif allow_larger_dim or allow_larger_dim is None:
83
+ if allow_larger_dim is None:
84
+ print(f"[Warning decompress_matrix] Input B.shape[0] ({X}) is larger than "
85
+ f"target A's row count S ({S}). Truncating B to its first {S} rows.")
86
+ B = B[:S, ...]
87
+ X = S
88
+ else:
89
+ raise AssertionError(
90
+ f"Input B.shape[0] ({X}) is larger than target A's row count S ({S}) "
91
+ f"and allow_larger_dim is False. Truncation is disallowed."
92
+ )
93
+
94
+ sorted_row_indices_2d = torch.argsort(mask.float(), dim=0, descending=True)
95
+ target_A_row_indices_2d = sorted_row_indices_2d[:X, :]
96
+ A_reconstructed = torch.zeros((S, E, *trailing_dims_shape), dtype=B.dtype, device=device)
97
+ view_shape_for_target_indices = (X, E, *((1,) * num_trailing_dims))
98
+ expanded_target_indices = target_A_row_indices_2d.view(view_shape_for_target_indices).expand_as(B)
99
+ A_reconstructed.scatter_(dim=0, index=expanded_target_indices, src=B)
100
+
101
+ return A_reconstructed
102
+
103
+
104
+
105
+ class AudioSharedExpertMLP(nn.Module):
106
+ """
107
+ Shared expert MLP for UniMoE-Audio model.
108
+ Handles common audio feature transformations across all tokens.
109
+ """
110
+ def __init__(self, config):
111
+ super().__init__()
112
+ self.hidden_size = config.hidden_size
113
+ self.intermediate_size = config.shared_intermediate_size
114
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
115
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
116
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
117
+ self.act_fn = ACT2FN[config.hidden_act]
118
+
119
+ def forward(self, hidden_state):
120
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
121
+
122
+
123
+ class AudioDynamicExpertMLP(nn.Module):
124
+ """
125
+ Dynamic expert MLP for UniMoE-Audio model.
126
+ Specialized for adaptive audio feature processing based on content.
127
+ """
128
+ def __init__(self, config):
129
+ super().__init__()
130
+ self.hidden_size = config.hidden_size
131
+ self.intermediate_size = config.dynamic_intermediate_size
132
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
133
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
134
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
135
+ self.act_fn = ACT2FN[config.hidden_act]
136
+
137
+ def forward(self, hidden_state):
138
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
139
+
140
+
141
+ class AudioNullExpertMLP(nn.Module):
142
+ """
143
+ Null expert MLP for UniMoE-Audio model.
144
+ Returns zero output for tokens that don't require expert processing.
145
+ """
146
+ def __init__(self, config):
147
+ super().__init__()
148
+
149
+ def forward(self, hidden_state):
150
+ return torch.zeros_like(hidden_state, dtype=hidden_state.dtype, device=hidden_state.device)
151
+
152
+
153
+ def audio_sparse_expert_mixer(scores, top_k, jitter_eps, training):
154
+ """
155
+ Sparse expert mixing function for UniMoE-Audio.
156
+ Implements adaptive expert selection with noise injection for training.
157
+ """
158
+ masked_scores = scores
159
+ multiplier_list = []
160
+ selected_experts_list = []
161
+
162
+ for _ in range(top_k):
163
+ with torch.no_grad():
164
+ mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
165
+ factor = scores.abs().clamp(min=mask_logits_threshold.abs())
166
+ mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
167
+
168
+ masked_gates = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
169
+
170
+ selected_experts = max_ind
171
+
172
+ masked_gates = torch.softmax(masked_gates, dim=-1)
173
+ multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
174
+
175
+ multiplier = multiplier_o
176
+
177
+ masked_scores = torch.scatter(
178
+ masked_scores,
179
+ -1,
180
+ selected_experts,
181
+ float("-inf"),
182
+ )
183
+
184
+ multiplier_list.append(multiplier)
185
+ selected_experts_list.append(selected_experts)
186
+
187
+ multiplier = torch.concat(multiplier_list, dim=-1)
188
+ selected_experts = torch.concat(selected_experts_list, dim=-1)
189
+ return (
190
+ multiplier,
191
+ selected_experts,
192
+ )
193
+
194
+
195
+ def audio_dynamic_expert_selection(logits, top_p):
196
+ """
197
+ Dynamic expert selection for UniMoE-Audio based on cumulative probability threshold.
198
+ Adapts the number of experts based on audio content complexity.
199
+ """
200
+ dynamic_scores = torch.softmax(logits, dim=-1)
201
+ dynamic_scores_sorted, _ = torch.sort(dynamic_scores, dim=-1, descending=True)
202
+ dynamic_scores_cumsum = dynamic_scores_sorted.cumsum(dim=-1)
203
+ dynamic_top_k = (~(dynamic_scores_cumsum >= top_p)).sum(dim=-1)
204
+ dynamic_top_k = dynamic_top_k + 1
205
+ return dynamic_top_k
206
+
207
+
208
+ def _audio_expert_capacity(num_tokens, num_experts, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
209
+ """Calculate expert capacity for UniMoE-Audio based on token distribution and capacity factor."""
210
+ capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
211
+ if capacity < min_capacity:
212
+ capacity = min_capacity.to(torch.int64)
213
+ return capacity
214
+
215
+
216
+ def calculate_audio_global_routing_weight(
217
+ expert_mask: torch.Tensor,
218
+ full_router_logits: torch.Tensor,
219
+ mlp_dynamic_expert_num: int,
220
+ routing_weights: torch.Tensor,
221
+ ):
222
+ """
223
+ Calculate global routing weights for UniMoE-Audio combining dynamic and fixed expert weights.
224
+ Optimized for audio generation tasks.
225
+ """
226
+ global_weight = torch.softmax(full_router_logits.masked_fill(expert_mask == 0, float("-inf")), dim=-1)
227
+ global_dynamic_weight = global_weight[:, :mlp_dynamic_expert_num]
228
+ global_fixed_weight = global_weight[:, mlp_dynamic_expert_num:]
229
+ global_dynamic_weight = routing_weights * global_dynamic_weight.sum(-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1])
230
+ global_weight = torch.cat((global_dynamic_weight, global_fixed_weight), dim=-1)
231
+ return global_weight
232
+
233
+
234
+ class UniMoEAudioSparseMoeBlock(nn.Module):
235
+ """
236
+ UniMoE-Audio Sparse Mixture of Experts block with dynamic routing and expert selection.
237
+ Optimized for audio generation tasks with efficient sparse operations and capacity management.
238
+ """
239
+
240
+ def __init__(self, config):
241
+ super().__init__()
242
+ self.hidden_dim = config.hidden_size
243
+ self.mlp_dynamic_expert_num = config.mlp_dynamic_expert_num + config.mlp_dynamic_null_expert_num
244
+ self.mlp_dynamic_real_expert_num = config.mlp_dynamic_expert_num
245
+ self.mlp_dynamic_null_expert_num = config.mlp_dynamic_null_expert_num
246
+ self.mlp_dynamic_top_p = config.mlp_dynamic_top_p
247
+ self.mlp_dynamic_top_k = config.mlp_dynamic_top_k
248
+ self.mlp_fixed_expert_num = config.mlp_fixed_expert_num
249
+ self.num_experts = self.mlp_dynamic_expert_num + self.mlp_fixed_expert_num
250
+
251
+ if self.mlp_dynamic_top_p == 0:
252
+ print(f"mlp_dynamic_top_p is 0, will use mlp_dynamic_top_k={self.mlp_dynamic_top_k} instead !!!")
253
+
254
+ self.ignore_differentiable_router = config.ignore_differentiable_router
255
+ if self.ignore_differentiable_router:
256
+ print("ignore_differentiable_router is True, will not use router_logits !!!")
257
+
258
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
259
+ self.fixed_real_moe = nn.ModuleList([AudioSharedExpertMLP(config) for _ in range(self.mlp_fixed_expert_num)])
260
+ self.dynamic_real_moe = UniMoEAudioMoE(config, AudioDynamicExpertMLP(config), self.mlp_dynamic_real_expert_num, config.ep_size)
261
+
262
+ self.router_jitter_noise = config.router_jitter_noise
263
+ self.input_jitter_noise = config.input_jitter_noise
264
+
265
+ self.min_capacity = config.min_capacity
266
+ self.capacity_factor = config.capacity_factor
267
+ self.token_drop = config.token_drop
268
+ self.drop_policy = config.drop_policy
269
+
270
+ self.avg_hidden_states_last = config.avg_hidden_states_last
271
+ self.drop_token_num_print = config.drop_token_num_print
272
+ self.fp32_gate = config.fp32_gate
273
+
274
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, aux_balance_weight: torch.Tensor=None):
275
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
276
+ original_hidden_states = hidden_states
277
+
278
+ if self.training and self.fp32_gate:
279
+ hidden_states = hidden_states.float()
280
+
281
+ if self.training and self.input_jitter_noise > 0:
282
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise)
283
+
284
+ hidden_states = hidden_states.view(-1, hidden_dim)
285
+
286
+ if self.training and self.fp32_gate:
287
+ full_router_logits = torch.nn.functional.linear(hidden_states, weight=self.gate.weight.float(), bias=None)
288
+ else:
289
+ full_router_logits = self.gate(hidden_states)
290
+ dynamic_router_logits = full_router_logits[:, : self.mlp_dynamic_expert_num]
291
+
292
+ if self.mlp_dynamic_top_p != 0:
293
+ dynamic_top_k = audio_dynamic_expert_selection(dynamic_router_logits, self.mlp_dynamic_top_p)
294
+ else:
295
+ dynamic_top_k = torch.full((dynamic_router_logits.shape[0],), self.mlp_dynamic_top_k, dtype=torch.int, device=dynamic_router_logits.device)
296
+
297
+ expert_mask = torch.zeros((batch_size * sequence_length, self.num_experts), dtype=torch.int, device=hidden_states.device)
298
+
299
+ routing_weights = torch.zeros((batch_size * sequence_length, self.mlp_dynamic_expert_num), dtype=hidden_states.dtype, device=hidden_states.device)
300
+ for top_k in range(1, self.mlp_dynamic_expert_num + 1):
301
+ group_idx = torch.nonzero(dynamic_top_k == top_k, as_tuple=True)[0]
302
+ if len(group_idx) == 0:
303
+ continue
304
+
305
+ dynamic_group_logits = dynamic_router_logits[group_idx]
306
+ group_routing_weights, group_selected_experts = audio_sparse_expert_mixer(
307
+ dynamic_group_logits,
308
+ top_k=top_k,
309
+ jitter_eps=self.router_jitter_noise,
310
+ training=self.training and not self.ignore_differentiable_router,
311
+ )
312
+
313
+ group_expert_mask = torch.nn.functional.one_hot(group_selected_experts, num_classes=self.num_experts)
314
+ group_expert_mask = group_expert_mask.sum(dim=1)
315
+
316
+ group_weight = torch.zeros((len(group_idx), self.mlp_dynamic_expert_num), dtype=hidden_states.dtype, device=hidden_states.device)
317
+ group_weight.scatter_(dim=-1, index=group_selected_experts, src=group_routing_weights)
318
+ routing_weights.index_add_(0, group_idx, group_weight)
319
+
320
+ expert_mask.index_add_(0, group_idx, group_expert_mask.to(expert_mask.dtype))
321
+
322
+ routing_weights = routing_weights / (routing_weights.sum(dim=-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1]) + 1e-6)
323
+
324
+ if attention_mask is not None:
325
+ attention_mask = attention_mask.to(expert_mask.dtype).view(-1).unsqueeze(-1).expand(-1, self.num_experts)
326
+ expert_mask = expert_mask * attention_mask
327
+
328
+ if self.mlp_dynamic_expert_num < self.num_experts:
329
+ expert_mask[:, self.mlp_dynamic_expert_num :] = 1
330
+
331
+ aux_loss = audio_load_balancing_loss_func(
332
+ expert_mask=expert_mask,
333
+ mlp_dynamic_expert_num=self.mlp_dynamic_expert_num,
334
+ global_weight=None,
335
+ full_router_logits=full_router_logits,
336
+ routing_weights=routing_weights,
337
+ aux_balance_weight=aux_balance_weight,
338
+ )
339
+
340
+ if self.token_drop:
341
+ expert_mask_dtype = expert_mask.dtype
342
+ capacity = _audio_expert_capacity(batch_size * sequence_length, self.mlp_dynamic_expert_num, torch.tensor(self.capacity_factor), torch.tensor(self.min_capacity))
343
+ if self.drop_policy == "probs":
344
+ if capacity > dynamic_router_logits.shape[0]:
345
+ print(f"[warning] token capacity({capacity}) > token num({dynamic_router_logits.shape[0]}), setting capacity=token num")
346
+ capacity = dynamic_router_logits.shape[0]
347
+ dynamic_expert_mask = expert_mask[:, : self.mlp_dynamic_expert_num].bool()
348
+ token_drop_router_logits = torch.masked_fill(dynamic_router_logits, ~dynamic_expert_mask, torch.finfo(dynamic_router_logits.dtype).min)
349
+ capacity_probs, capacity_indices = torch.topk(token_drop_router_logits, k=capacity, dim=0, sorted=False)
350
+ capacity_mask = torch.zeros_like(expert_mask).scatter(0, capacity_indices, 1)
351
+ capacity_mask[:, self.mlp_dynamic_expert_num :] = 1
352
+ expert_mask = torch.logical_and(expert_mask, capacity_mask)
353
+
354
+ ori_token_num = dynamic_expert_mask.sum().item()
355
+ cur_token_num = expert_mask[:, : self.mlp_dynamic_expert_num].sum().item()
356
+ if self.drop_token_num_print and ("RANK" not in os.environ or int(os.environ["RANK"]) == 0):
357
+ print(f"drop {ori_token_num - cur_token_num} tokens from total {ori_token_num} tokens")
358
+
359
+ elif self.drop_policy == "position":
360
+ locations = torch.cumsum(expert_mask, dim=0) - 1
361
+ expert_mask *= torch.lt(locations, capacity)
362
+ else:
363
+ raise ValueError(f"Invalid drop_policy: {self.drop_policy}")
364
+ expert_mask = expert_mask.to(expert_mask_dtype)
365
+
366
+ routing_weights = routing_weights.masked_fill(~(expert_mask[:, : self.mlp_dynamic_expert_num].bool()), 0.0)
367
+ routing_weights = routing_weights / (routing_weights.sum(dim=-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1]) + 1e-6)
368
+
369
+ if self.mlp_dynamic_expert_num < self.num_experts:
370
+ global_weight = calculate_audio_global_routing_weight(expert_mask, full_router_logits, self.mlp_dynamic_expert_num, routing_weights)
371
+ else:
372
+ global_weight = routing_weights
373
+
374
+ hidden_states = original_hidden_states.view(-1, hidden_dim)
375
+
376
+ final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
377
+ global_weight = global_weight.to(hidden_states.dtype)
378
+
379
+ current_hidden_states = self.dynamic_real_moe(hidden_states, expert_mask=expert_mask[:, : self.mlp_dynamic_real_expert_num], router_weight=global_weight[:, : self.mlp_dynamic_real_expert_num])
380
+ final_hidden_states = final_hidden_states + current_hidden_states
381
+
382
+ for expert_idx in range(self.mlp_fixed_expert_num):
383
+ expert_layer = self.fixed_real_moe[expert_idx]
384
+
385
+ current_state = hidden_states
386
+ current_global_weight = global_weight[:, self.mlp_dynamic_expert_num + expert_idx].unsqueeze(-1)
387
+ current_hidden_states = expert_layer(current_state) * current_global_weight
388
+
389
+ final_hidden_states = final_hidden_states + current_hidden_states
390
+
391
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
392
+
393
+ if not self.training and self.avg_hidden_states_last:
394
+ dist.all_reduce(final_hidden_states, op=dist.ReduceOp.AVG, group=self.dynamic_real_moe.deepspeed_moe.ep_group)
395
+
396
+ return final_hidden_states, full_router_logits, dynamic_top_k, expert_mask, global_weight, aux_loss
397
+
398
+
399
+ def audio_load_balancing_loss_func(
400
+ expert_mask: torch.Tensor,
401
+ mlp_dynamic_expert_num: int,
402
+ global_weight: Optional[torch.Tensor] = None,
403
+ full_router_logits: Optional[torch.Tensor] = None,
404
+ routing_weights: Optional[torch.Tensor] = None,
405
+ aux_balance_weight: Optional[torch.Tensor] = None,
406
+ ) -> float:
407
+ """Calculate load balancing loss for UniMoE-Audio expert routing to encourage balanced usage."""
408
+ min_dtype = torch.finfo(full_router_logits.dtype).min
409
+ global_weight = full_router_logits.masked_fill(expert_mask == 0, min_dtype)
410
+ global_weight = global_weight[:, :mlp_dynamic_expert_num]
411
+ global_weight = torch.softmax(global_weight, dim=-1)
412
+ expert_mask = expert_mask[:, :mlp_dynamic_expert_num]
413
+
414
+ num_experts = expert_mask.shape[-1]
415
+ if aux_balance_weight is None:
416
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
417
+ router_prob_per_expert = torch.mean(global_weight, dim=0)
418
+ else:
419
+ batch_size, sequence_length = aux_balance_weight.shape
420
+ num_hidden_layers = global_weight.shape[0] // (batch_size * sequence_length)
421
+ expert_attention_mask = aux_balance_weight[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(global_weight.device)
422
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
423
+ router_prob_per_expert = torch.sum(global_weight * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
424
+
425
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
426
+
427
+ return overall_loss * num_experts
428
+
429
+
430
+ class AudioExperts(deepspeed.moe.experts.Experts):
431
+ """Custom Audio experts class extending DeepSpeed MoE experts with additional functionality."""
432
+
433
+ def __init__(self, expert, num_local_experts=1, expert_group_name=None):
434
+ super(deepspeed.moe.experts.Experts, self).__init__()
435
+
436
+ self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
437
+ self.num_local_experts = num_local_experts
438
+
439
+ for expert in self.deepspeed_experts:
440
+ for name, param in expert.named_parameters():
441
+ param.allreduce = False
442
+ param.group_name = expert_group_name
443
+
444
+ def forward(self, inputs):
445
+ chunks = inputs.chunk(self.num_local_experts, dim=1)
446
+ expert_outputs = []
447
+ for chunk, expert in zip(chunks, self.deepspeed_experts):
448
+ out = expert(chunk)
449
+ if type(out) is tuple:
450
+ out = out[0]
451
+ expert_outputs += [out]
452
+
453
+ expert_output = torch.cat(expert_outputs, dim=1)
454
+ return expert_output
455
+
456
+
457
+ class AudioMOELayer(deepspeed.moe.sharded_moe.MOELayer):
458
+ """Custom Audio MoE layer extending DeepSpeed MOELayer with matrix compression optimization."""
459
+
460
+ def __init__(
461
+ self,
462
+ experts: nn.Module,
463
+ ep_group_name,
464
+ ep_size,
465
+ num_local_experts: int,
466
+ use_tutel: bool = False,
467
+ ) -> None:
468
+ super(deepspeed.moe.sharded_moe.MOELayer, self).__init__()
469
+
470
+ self.experts = experts
471
+ self.ep_group = None
472
+ self.ep_size = ep_size
473
+ self.ep_group_name = ep_group_name
474
+ self.num_local_experts = num_local_experts
475
+ self.time_falltoall = 0.0
476
+ self.time_salltoall = 0.0
477
+ self.time_moe = 0.0
478
+ self.timers = SynchronizedWallClockTimer()
479
+ self.wall_clock_breakdown = False
480
+
481
+ def _set_ep_group(self, ep_group):
482
+ self.ep_group = ep_group
483
+
484
+ def forward(self, hidden_states: Tensor, expert_mask: Tensor, router_weight: Tensor) -> Tensor:
485
+ router_weight = router_weight * expert_mask
486
+
487
+ if self.wall_clock_breakdown:
488
+ self.timers(MOE_TIMER).start()
489
+
490
+ d_model = hidden_states.shape[-1]
491
+ seq_len = hidden_states.shape[0]
492
+ expert_num = expert_mask.shape[-1]
493
+ capacity = expert_mask.sum(dim=0).max()
494
+ if self.ep_group is not None:
495
+ dist.all_reduce(capacity, op=dist.ReduceOp.MAX, group=self.ep_group)
496
+
497
+ compres_hidden_states = hidden_states.unsqueeze(1).expand(seq_len, expert_num, d_model)
498
+ compres_hidden_states = compress_matrix(compres_hidden_states, expert_mask, force_dim=capacity, allow_larger_dim=True) # [C, expert_num, d_model]
499
+ compres_expert_mask = compress_matrix(expert_mask, expert_mask, force_dim=capacity, allow_larger_dim=True)
500
+ dispatched_input = einsum("ce,cem->ecm", compres_expert_mask, compres_hidden_states)
501
+
502
+ if self.wall_clock_breakdown:
503
+ self.timers(FIRST_ALLTOALL_TIMER).start()
504
+
505
+ dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
506
+
507
+ if self.wall_clock_breakdown:
508
+ self.timers(FIRST_ALLTOALL_TIMER).stop()
509
+ self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
510
+
511
+ dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
512
+
513
+ expert_output = self.experts(dispatched_input)
514
+
515
+ if self.wall_clock_breakdown:
516
+ self.timers(SECOND_ALLTOALL_TIMER).start()
517
+
518
+ expert_output = _AllToAll.apply(self.ep_group, expert_output)
519
+
520
+ if self.wall_clock_breakdown:
521
+ self.timers(SECOND_ALLTOALL_TIMER).stop()
522
+ self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
523
+
524
+ expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
525
+ expert_output = decompress_matrix(expert_output.transpose(0, 1), expert_mask, allow_larger_dim=True)
526
+ combined_output = einsum("se,sem->sm", router_weight, expert_output)
527
+ if self.wall_clock_breakdown:
528
+ self.timers(MOE_TIMER).stop()
529
+ self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False)
530
+
531
+ return combined_output
532
+
533
+
534
+ class UniMoEAudioMoE(deepspeed.moe.layer.MoE):
535
+ """Custom Audio MoE class extending DeepSpeed MoE with configuration and parallelism setup."""
536
+
537
+ def __init__(self, config, expert, num_experts, ep_size, moe_name_prefix="ep_size"):
538
+ super(deepspeed.moe.layer.MoE, self).__init__()
539
+ self.enable_expert_tensor_parallelism = config.enable_expert_tensor_parallelism
540
+ self.ep_size = ep_size
541
+ self.num_experts = num_experts
542
+ self.expert_group_name = f"{moe_name_prefix}_{self.ep_size}"
543
+ self.num_local_experts = self.num_experts // self.ep_size
544
+ log_dist(f"Creating MoE layer with num_experts: {self.num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}", [0])
545
+ experts = AudioExperts(expert, self.num_local_experts, self.expert_group_name)
546
+ self.deepspeed_moe = AudioMOELayer(experts, self.expert_group_name, self.ep_size, self.num_local_experts)
547
+
548
+ def set_deepspeed_parallelism(self, use_data_before_expert_parallel_=False):
549
+ self._create_process_groups(use_data_before_expert_parallel_=use_data_before_expert_parallel_)
550
+
551
+ def _create_process_groups(self, use_data_before_expert_parallel_=False):
552
+ if self.expert_group_name not in groups._get_expert_parallel_group_dict():
553
+ print(f"No existing process group found, creating a new group named: {self.expert_group_name}")
554
+ if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):
555
+ groups._create_expert_and_data_parallel(self.ep_size, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
556
+ else:
557
+ groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
558
+ self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))
559
+
560
+ def forward(self, *input_args, **input_kwargs):
561
+ return self.deepspeed_moe(*input_args, **input_kwargs)
README (1).md ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ - zh
6
+ base_model:
7
+ - Qwen/Qwen2-0.5B
8
+ pipeline_tag: feature-extraction
9
+ library_name: sentence-transformers
10
+ tags:
11
+ - MoE
12
+ - Unified Generation
13
+ - Speech and Music
14
+ - Multi-modal
15
+ datasets:
16
+ ---
17
+
18
+ <h1 align="center">UniMoE-Audio</h1>
19
+
20
+ **UniMoE-Audio** is a unified framework that seamlessly combines speech and music generation. Powered by a novel dynamic-capacity Mixture-of-Experts design, it adapts intelligently to input complexity, enabling high-fidelity voice and expressive music within a single model.
21
+
22
+ ## Key Innovations
23
+
24
+ #### **Top-P Dynamic Routing Strategy**
25
+ We introduce a **Top-P routing strategy** that overcomes the limitations of conventional static Top-K routing:
26
+
27
+ - **Dynamic Expert Allocation**: Instead of assigning a fixed number of experts to every token, our approach dynamically determines the number of experts based on token complexity
28
+ - **Resource Efficiency**: Simple tokens don't consume unnecessary resources, while complex tokens receive sufficient processing power
29
+ - **Performance Optimization**: Results in improved overall efficiency and performance
30
+
31
+ #### **Three-Stage Training Curriculum**
32
+ We employ a comprehensive training approach to enable effective joint learning from imbalanced data:
33
+
34
+ 1. **Independent Specialist Training** - Initial expert specialization
35
+ 2. **Integration with Warm-up** - Gradual system integration
36
+ 3. **Synergistic Joint Training** - Collaborative optimization
37
+
38
+ ## Model Information
39
+ - **Base Model**: Qwen2.5-VL with MoE extensions
40
+ - **Audio Codec**: DAC (Descript Audio Codec) with 12 channels
41
+ - **Expert Configuration**: 8 dynamic experts + 2 shared experts
42
+ - **Audio Sampling Rate**: 16kHz
43
+ - Usage:
44
+ - Text-to-Speech (TTS)
45
+ - Speech-to-Text (STT)
46
+ - Music Generation
47
+ - GPU Requirements:
48
+ - Memory: 16GB+
49
+ - CUDA-enabled GPU
50
+
51
+ ## Open-source Plan
52
+ - [☑️] Model Checkpoint
53
+ - [☑️] [UniMoE-Audio-preview](https://huggingface.co/foggyforest/UniMoE-Audio-preview)
54
+ - [☑️] Inference Code: [HITsz-TMG/UniMoE-Audio](https://github.com/HITsz-TMG/UMOE-Scaling-Unified-Multimodal-LLMs/tree/master/UniMoE-Audio)
55
+ - [☑️] Technical Report: [UniMoE-Audio: Unified Speech and Music Generation with Dynamic-Capacity MoE]()
56
+
57
+ ## Evaluation
58
+ ### Speech Synthesis
59
+ ![Speech Synthesis](./imgs/Speech_Generation.png)
60
+ ### Text to Music Generation
61
+ ![Text to Music Generation](./imgs/T2M.png)
62
+ ### Video-Text to Music Generation
63
+ ![Video-Text to Music Generation](./imgs/VT2M.png)
64
+
65
+ ## Requirements
66
+ We recommend using conda to install the environment.
67
+ ```bash
68
+ conda env create -f configs/enviroment.yml # add -n for your name
69
+ conda activate unimoe-audio # default name
70
+ ```
71
+ then install the torch packages
72
+ ```bash
73
+ # Use the official index
74
+ pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121
75
+
76
+ # Use Tsinghua mirror source
77
+ pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 -i https://pypi.tuna.tsinghua.edu.cn/simple/ --extra-index-url https://download.pytorch.org/whl/cu121
78
+
79
+ # Use Alibaba Cloud mirror source
80
+ pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 -i https://mirrors.aliyun.com/pypi/simple/ --extra-index-url https://download.pytorch.org/whl/cu121
81
+ ```
82
+ A `dac model` is also required to be downloaded in '/path/to/UniMoE-Audio/utils/dac_model'.
83
+ It will be automatically downloaded when running the first time.
84
+
85
+
86
+ ## Usage
87
+ Please move to the `utils` folder to your working directory.
88
+ Then you can use the model like this:
89
+
90
+ ```python
91
+ from modeling import UniMoEAudio
92
+
93
+ MODEL_NAME= "HIT-TMG/UniMoE-Audio-Preview"
94
+
95
+ # Load model
96
+ unimoe_audio = UniMoEAudio.from_pretrained(
97
+ MODEL_NAME,
98
+ cache_dir='./cache',
99
+ torch_dtype='bfloat16',
100
+ device_id=0
101
+ )
102
+
103
+ ```
104
+
105
+ ### TTS Example:
106
+ ```python
107
+ # TTS/Voice Cloning
108
+ target_text = "Target Text"
109
+ prompt_audio = "/path/to/your/prompt_audio.wav"
110
+ prompt_text = "Prompt Text"
111
+
112
+ # Encode prompt audio
113
+ prompt_codec = unimoe_audio.dac.encode(prompt_audio)
114
+
115
+ prompt_codec_input_ids = unimoe_audio._preprocess_codec(
116
+ codec=prompt_codec,
117
+ codec_delay_pattern=unimoe_audio.model.config.codec_delay_pattern,
118
+ codec_channels=unimoe_audio.model.num_channels,
119
+ codec_bos_value=unimoe_audio.model.config.codec_bos_value,
120
+ codec_eos_value=unimoe_audio.model.config.codec_eos_value,
121
+ codec_pad_value=unimoe_audio.model.config.codec_pad_value
122
+ )
123
+
124
+ # Construct prompt text
125
+ text_input, _, _ = unimoe_audio._prepare_prompt(task="speech", caption=target_text, prompt_text=prompt_text, prompt_codec_input_ids=prompt_codec_input_ids)
126
+
127
+ # Tokenize input text
128
+ source_input = unimoe_audio.tokenizer(text_input, add_special_tokens=False, return_tensors="pt", padding=True)
129
+ prompt_codec_input_ids = prompt_codec_input_ids.unsqueeze(0).expand(len(text_input), -1, -1).reshape(-1, prompt_codec_input_ids.shape[1])
130
+
131
+ #Speech Generation
132
+ unimoe_audio._generate_core(
133
+ source_input,
134
+ prompt_codec_input_ids,
135
+ save_name = "speech",
136
+ output_dir = "./",
137
+ cfg_scale = 1.0,
138
+ temperature = 1.0,
139
+ top_p = 1.0,
140
+ cfg_filter_top_k = 45,
141
+ eos_prob_mul_factor = 1.0,
142
+ do_sample = True,
143
+ debug_guidance_step = -1,
144
+ use_cache = True
145
+ )
146
+ ```
147
+ ### T2M Example:
148
+ ```python
149
+ caption = "music deccription"
150
+
151
+ # Construct prompt text
152
+ text_input, _, _ = unimoe_audio._prepare_prompt(task="music", caption=caption)
153
+
154
+ # Tokenize input text
155
+ source_input = unimoe_audio.tokenizer(text_input, add_special_tokens=False, return_tensors="pt", padding=True)
156
+
157
+ #music generation with prompt text
158
+ unimoe_audio._generate_core(
159
+ source_input,
160
+ None,
161
+ save_name = "music",
162
+ output_dir = "./",
163
+ cfg_scale = 10.0,
164
+ temperature = 1.0,
165
+ top_p = 1.0,
166
+ cfg_filter_top_k = 45,
167
+ eos_prob_mul_factor = 0.6,
168
+ do_sample = True,
169
+ debug_guidance_step = -1,
170
+ use_cache = True
171
+ )
172
+ ```
173
+
174
+ ### VT2M Example:
175
+ ```python
176
+ # VT2M
177
+ caption = "music deccription"
178
+ prompt_video = "/path/to/your/video.mp4"
179
+
180
+ #prepare prompt
181
+ text_input, video_inputs, fps_inputs = unimoe_audio._prepare_prompt(task="music", caption=caption, video=prompt_video, fps=1, sampling_fps=1, max_frames=1)
182
+
183
+ #input processor
184
+ source_input = unimoe_audio.processor(
185
+ text=text_input,
186
+ images=None,
187
+ videos=video_inputs,
188
+ fps=fps_inputs,
189
+ padding=True,
190
+ return_tensors="pt",
191
+ do_resize=False
192
+ )
193
+
194
+ #music generation with prompt video
195
+ unimoe_audio._generate_core(
196
+ source_input,
197
+ None,
198
+ save_name = "video_music",
199
+ output_dir = "./",
200
+ rebuild_codec=None,
201
+ cfg_scale = 10.0,
202
+ temperature = 1.0,
203
+ top_p = 1.0,
204
+ cfg_filter_top_k = 45,
205
+ eos_prob_mul_factor = 0.6,
206
+ do_sample = True,
207
+ debug_guidance_step = -1,
208
+ use_cache = True
209
+ )
210
+ ```
211
+
212
+
213
+
214
+
215
+
216
+
config.json CHANGED
@@ -2,6 +2,10 @@
2
  "architectures": [
3
  "UniAudioRVQQwen2_5VLMoEForConditionalGeneration"
4
  ],
 
 
 
 
5
  "attention_dropout": 0.0,
6
  "bos_token_id": 151643,
7
  "codec_bos_value": 1026,
 
2
  "architectures": [
3
  "UniAudioRVQQwen2_5VLMoEForConditionalGeneration"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "modeling.UniMoEAudioConfig",
7
+ "AutoModelForCausalLM": "modeling.UniMoEAudio"
8
+ },
9
  "attention_dropout": 0.0,
10
  "bos_token_id": 151643,
11
  "codec_bos_value": 1026,
deepspeed_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
2
+
3
+ import deepspeed
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from deepspeed import comm as dist
7
+ from deepspeed.moe.sharded_moe import _capacity, _one_hot_to_float, einsum, gumbel_rsample
8
+ from torch import Tensor
9
+
10
+ try:
11
+ # To enable Tutel MoE optimizations:
12
+ # python3 -m pip install --user --upgrade git+https://github.com/microsoft/[email protected]
13
+ from tutel import moe as tutel_moe
14
+
15
+ TUTEL_INSTALLED = True
16
+ except:
17
+ # Fail silently so we don't spam logs unnecessarily if user isn't using tutel
18
+ TUTEL_INSTALLED = False
19
+ pass
20
+
21
+
22
+ # =============================================================================
23
+ # DeepSpeed MoE Inference Utilities
24
+ # =============================================================================
25
+
26
+ def _AllToAll_forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore
27
+ ctx.group = group
28
+ input = input.contiguous()
29
+ return input
30
+
31
+
32
+ def gate_forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
33
+ d_model = input[0].shape[-1]
34
+ reshaped_input = input[0].reshape(-1, d_model)
35
+
36
+ if self.use_tutel:
37
+ self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True)
38
+ S, M = reshaped_input.size(0), reshaped_input.size(1)
39
+
40
+ if not hasattr(self, "_tutel_dispatcher"):
41
+ self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype)
42
+ self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C)
43
+ dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
44
+ else:
45
+ self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
46
+ dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input)
47
+
48
+ dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
49
+ expert_output = self.experts(dispatched_input)
50
+ expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, dispatched_input.shape[2], -1)
51
+
52
+ if self.use_tutel:
53
+ combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M))
54
+ else:
55
+ combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
56
+
57
+ a = combined_output.reshape(input[0].size()[:-1] + (-1,))
58
+
59
+ return a
60
+
61
+
62
+ def top2gating(
63
+ logits: Tensor, capacity_factor: float, min_capacity: int, drop_tokens: bool = True, ep_group: Union[torch.distributed.ProcessGroup, None] = None, top2_2nd_expert_sampling: bool = True
64
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
65
+ """Implements Top2Gating on logits."""
66
+ gates = F.softmax(logits, dim=1)
67
+ indices1_s = torch.argmax(gates, dim=1)
68
+ num_experts = int(gates.shape[1])
69
+ mask1 = F.one_hot(indices1_s, num_classes=num_experts)
70
+
71
+ if top2_2nd_expert_sampling:
72
+ logits += gumbel_rsample(logits.shape, device=logits.device)
73
+
74
+ logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
75
+ indices2_s = torch.argmax(logits_except1, dim=1)
76
+ mask2 = F.one_hot(indices2_s, num_classes=num_experts)
77
+
78
+ locations1 = torch.cumsum(mask1, dim=0) - 1
79
+ locations2 = torch.cumsum(mask2, dim=0) - 1
80
+ locations2 += torch.sum(mask1, dim=0, keepdim=True)
81
+
82
+ me = torch.mean(gates, dim=0)
83
+ ce = torch.mean(mask1.float(), dim=0)
84
+ l_aux = torch.mean(me * ce) * num_experts * num_experts
85
+ exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device)
86
+
87
+ if drop_tokens:
88
+ capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
89
+ mask1 *= torch.lt(locations1, capacity)
90
+ mask2 *= torch.lt(locations2, capacity)
91
+ else:
92
+ new_capacity = torch.max(exp_counts)
93
+ capacity = new_capacity
94
+
95
+ locations1_s = torch.sum(locations1 * mask1, dim=1)
96
+ locations2_s = torch.sum(locations2 * mask2, dim=1)
97
+ mask1_float = mask1.float()
98
+ mask2_float = mask2.float()
99
+
100
+ gates1_s = einsum("se,se->s", gates, mask1_float)
101
+ gates2_s = einsum("se,se->s", gates, mask2_float)
102
+ denom_s = gates1_s + gates2_s
103
+
104
+ denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
105
+ gates1_s /= denom_s
106
+ gates2_s /= denom_s
107
+
108
+ gates1 = einsum("s,se->se", gates1_s, mask1_float)
109
+ gates2 = einsum("s,se->se", gates2_s, mask2_float)
110
+ locations1_sc = _one_hot_to_float(locations1_s, capacity)
111
+ locations2_sc = _one_hot_to_float(locations2_s, capacity)
112
+ combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
113
+ combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
114
+ combine_weights = combine1_sec + combine2_sec
115
+ dispatch_mask = combine_weights.bool()
116
+
117
+ return l_aux, combine_weights, dispatch_mask, exp_counts
118
+
119
+
120
+ # Apply the modifications to deepspeed
121
+ deepspeed.moe.sharded_moe.MOELayer.forward = gate_forward
122
+ deepspeed.moe.sharded_moe.top2gating = top2gating
123
+ deepspeed.moe.sharded_moe._AllToAll.forward = _AllToAll_forward
124
+
model-00001-of-00003.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d853dc5fdece11379a9ef43710c18f6f7fd55aaa7cf6257c183738edb6882100
3
- size 4999916992
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:254260c822c07d95dcd11f897c656eda8d08e5849832d4fd4f67c074c449b2fb
3
+ size 4999916960
modeling.py ADDED
@@ -0,0 +1,1182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch Qwen2-VL model."""
21
+
22
+ from dataclasses import dataclass
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch.nn import CrossEntropyLoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache
32
+ from transformers.generation import GenerationMixin
33
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput
37
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
+ from transformers.processing_utils import Unpack
40
+ from transformers.utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
41
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
42
+
43
+ from transformers import AutoConfig, AutoModelForCausalLM
44
+ from transformers.modeling_outputs import (
45
+ ModelOutput,
46
+ )
47
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
48
+ Qwen2_5_VLVisionConfig,
49
+ Qwen2_5_VLTextConfig,
50
+ Qwen2_5_VLConfig,
51
+ )
52
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
53
+ Qwen2_5_VLAttention,
54
+ Qwen2RMSNorm,
55
+ Qwen2_5_VLRotaryEmbedding,
56
+ )
57
+ from DCMoE import UniMoEAudioSparseMoeBlock
58
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel
59
+
60
+ logger = logging.get_logger(__name__)
61
+
62
+ FAST_INIT = True
63
+ if FAST_INIT:
64
+ logger.warning(f"using FAST initial for Grin Qwen2_vl !!!")
65
+
66
+ class Qwen2_5_VLMoETextConfig(Qwen2_5_VLTextConfig):
67
+ model_type = "qwen2_5_vl_moe_text"
68
+
69
+ def __init__(
70
+ self,
71
+ mlp_dynamic_expert_num=4,
72
+ mlp_dynamic_null_expert_num=0,
73
+ mlp_dynamic_top_p=0.7,
74
+ mlp_dynamic_top_k=2,
75
+ mlp_fixed_expert_num=2,
76
+ dynamic_intermediate_size=8960,
77
+ shared_intermediate_size=8960,
78
+ ignore_differentiable_router=False,
79
+ enable_expert_tensor_parallelism: bool = False,
80
+ ep_size=1,
81
+ fixed_ep_size=1,
82
+ router_jitter_noise=0.01,
83
+ input_jitter_noise=0.01,
84
+ token_drop=False,
85
+ drop_policy: str = "probs",
86
+ min_capacity: int = 8,
87
+ capacity_factor: float = 1.0,
88
+ fp32_gate=True,
89
+ avg_hidden_states_last=False,
90
+ drop_token_num_print=True,
91
+ **kwargs,
92
+ ):
93
+
94
+ super().__init__(**kwargs)
95
+ self.mlp_dynamic_expert_num = mlp_dynamic_expert_num
96
+ self.mlp_dynamic_top_p = mlp_dynamic_top_p
97
+ self.mlp_dynamic_top_k = mlp_dynamic_top_k
98
+ self.mlp_fixed_expert_num = mlp_fixed_expert_num
99
+ self.mlp_dynamic_null_expert_num = mlp_dynamic_null_expert_num
100
+ self.dynamic_intermediate_size = dynamic_intermediate_size
101
+ self.shared_intermediate_size = shared_intermediate_size
102
+ self.ignore_differentiable_router = ignore_differentiable_router
103
+ self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism
104
+ self.ep_size = ep_size
105
+ self.fixed_ep_size = fixed_ep_size
106
+ self.input_jitter_noise = input_jitter_noise
107
+ self.router_jitter_noise = router_jitter_noise
108
+ self.token_drop = token_drop
109
+ self.drop_policy = drop_policy
110
+ self.min_capacity = min_capacity
111
+ self.capacity_factor = capacity_factor
112
+ self.fp32_gate = fp32_gate
113
+ self.avg_hidden_states_last = avg_hidden_states_last
114
+ self.drop_token_num_print = drop_token_num_print
115
+
116
+ class UniMoEAudioConfig(PretrainedConfig):
117
+ model_type = "uni_audio_rvq_qwen2_5vl_moe"
118
+ sub_configs = {"vision_config": Qwen2_5_VLVisionConfig, "text_config": Qwen2_5_VLMoETextConfig}
119
+ keys_to_ignore_at_inference = ["past_key_values"]
120
+
121
+ def __init__(
122
+ self,
123
+ text_config=None,
124
+ vision_config=None,
125
+ image_token_id=151655,
126
+ video_token_id=151656,
127
+ codec_vocab_size=1028,
128
+ codec_delay_pattern=[0, 8, 9, 10, 11, 12, 13, 14, 15],
129
+ codec_channels=9,
130
+ codec_eos_value=1024,
131
+ codec_pad_value=1025,
132
+ codec_bos_value=1026,
133
+ codec_placeholder_value=None,
134
+ **kwargs,
135
+ ):
136
+ if isinstance(vision_config, dict):
137
+ self.vision_config = self.sub_configs["vision_config"](**vision_config)
138
+ elif vision_config is None:
139
+ self.vision_config = self.sub_configs["vision_config"]()
140
+
141
+ if isinstance(text_config, dict):
142
+ self.text_config = self.sub_configs["text_config"](**text_config)
143
+ elif text_config is None:
144
+ self.text_config = self.sub_configs["text_config"](**kwargs)
145
+
146
+ self.image_token_id = image_token_id
147
+ self.video_token_id = video_token_id
148
+ self.codec_vocab_size = codec_vocab_size
149
+ self.codec_delay_pattern = codec_delay_pattern
150
+ self.codec_channels = codec_channels
151
+ self.codec_eos_value = codec_eos_value
152
+ self.codec_pad_value = codec_pad_value
153
+ self.codec_bos_value = codec_bos_value
154
+ self.codec_placeholder_value = codec_placeholder_value
155
+
156
+ super().__init__(**kwargs)
157
+
158
+ @dataclass
159
+ class MoEQwen2_5VLCausalLMOutputWithPast(ModelOutput):
160
+ loss: Optional[torch.FloatTensor] = None
161
+ logits: torch.FloatTensor = None
162
+ past_key_values: Optional[List[torch.FloatTensor]] = None
163
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
164
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
165
+ rope_deltas: Optional[torch.LongTensor] = None
166
+ all_router_logits: Tuple = None
167
+ all_router_top_k: Tuple = None
168
+ all_router_expert_mask: Tuple = None
169
+ all_router_weight: Tuple = None
170
+ aux_balance_loss: torch.FloatTensor = None
171
+
172
+
173
+ @dataclass
174
+ class BaseModelOutputWithPast(ModelOutput):
175
+ last_hidden_state: torch.FloatTensor = None
176
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
177
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
178
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
179
+ all_router_logits: Tuple = None
180
+ all_router_top_k: Tuple = None
181
+ all_router_weight: Tuple = None
182
+ all_router_expert_mask: Tuple = None
183
+ all_aux_loss: Tuple = None
184
+
185
+
186
+ class Qwen2_5_VLMoEDecoderLayer(GradientCheckpointingLayer):
187
+ def __init__(self, config: Qwen2_5_VLMoETextConfig, layer_idx: int):
188
+ super().__init__()
189
+ self.hidden_size = config.hidden_size
190
+
191
+ if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
192
+ logger.warning_once(
193
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
194
+ "unexpected results may be encountered."
195
+ )
196
+
197
+ self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
198
+ self.mlp = UniMoEAudioSparseMoeBlock(config)
199
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
200
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
201
+ self.attention_type = config.layer_types[layer_idx]
202
+
203
+ def forward(
204
+ self,
205
+ hidden_states: torch.Tensor,
206
+ attention_mask: Optional[torch.Tensor] = None,
207
+ padding_token_mask: Optional[torch.Tensor] = None,
208
+ position_ids: Optional[torch.LongTensor] = None,
209
+ past_key_value: Optional[tuple[torch.Tensor]] = None,
210
+ output_attentions: Optional[bool] = False,
211
+ output_router_logits_and_topk: Optional[bool] = False,
212
+ use_cache: Optional[bool] = False,
213
+ cache_position: Optional[torch.LongTensor] = None,
214
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
215
+ **kwargs: Unpack[FlashAttentionKwargs],
216
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
217
+
218
+ residual = hidden_states
219
+ hidden_states = self.input_layernorm(hidden_states)
220
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
221
+ hidden_states=hidden_states,
222
+ attention_mask=attention_mask,
223
+ position_ids=position_ids,
224
+ past_key_value=past_key_value,
225
+ output_attentions=output_attentions,
226
+ use_cache=use_cache,
227
+ cache_position=cache_position,
228
+ position_embeddings=position_embeddings,
229
+ )
230
+ hidden_states = residual + hidden_states
231
+ residual = hidden_states
232
+ hidden_states = self.post_attention_layernorm(hidden_states)
233
+ hidden_states, router_logits, router_top_k, router_expert_mask, router_weight, aux_loss = self.mlp(hidden_states, padding_token_mask)
234
+ hidden_states = residual + hidden_states
235
+
236
+ outputs = (hidden_states,)
237
+
238
+ if output_attentions:
239
+ outputs += (self_attn_weights,)
240
+
241
+ if output_router_logits_and_topk:
242
+ outputs += (router_logits,)
243
+ outputs += (router_top_k,)
244
+ outputs += (router_expert_mask,)
245
+ outputs += (router_weight,)
246
+ outputs += (aux_loss,)
247
+
248
+ return outputs
249
+
250
+
251
+ class Qwen2_5_VLMoEPreTrainedModel(PreTrainedModel):
252
+ config_class = UniMoEAudioConfig
253
+ base_model_prefix = "model"
254
+ supports_gradient_checkpointing = True
255
+ _no_split_modules = ["Qwen2_5_VLMoEDecoderLayer", "Qwen2_5_VLVisionBlock"]
256
+ _skip_keys_device_placement = "past_key_values"
257
+ _supports_flash_attn_2 = True
258
+ _supports_flash_attn_3 = True
259
+ _supports_sdpa = True
260
+ _supports_cache_class = True
261
+ _supports_static_cache = True
262
+ _supports_attention_backend = True
263
+
264
+ def _init_weights(self, module):
265
+ std = self.config.initializer_range
266
+ if FAST_INIT:
267
+ if isinstance(module, UniMoEAudioSparseMoeBlock):
268
+ module.gate.weight.data.normal_(mean=0.0, std=std)
269
+ if module.gate.bias is not None:
270
+ module.gate.bias.data.zero_()
271
+ elif isinstance(module, nn.Embedding):
272
+ module.weight.data.normal_(mean=0.0, std=std)
273
+ if module.padding_idx is not None:
274
+ module.weight.data[module.padding_idx].zero_()
275
+ else:
276
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
277
+ module.weight.data.normal_(mean=0.0, std=std)
278
+ if module.bias is not None:
279
+ module.bias.data.zero_()
280
+ elif isinstance(module, nn.Embedding):
281
+ module.weight.data.normal_(mean=0.0, std=std)
282
+ if module.padding_idx is not None:
283
+ module.weight.data[module.padding_idx].zero_()
284
+ elif isinstance(module, Qwen2RMSNorm):
285
+ module.weight.data.fill_(1.0)
286
+
287
+
288
+ class Qwen2_5_VLMoETextModel(Qwen2_5_VLMoEPreTrainedModel):
289
+ config_class = Qwen2_5_VLMoETextConfig
290
+ def __init__(self, config: Qwen2_5_VLMoETextConfig):
291
+ super().__init__(config)
292
+ self.padding_idx = config.pad_token_id
293
+ self.vocab_size = config.vocab_size
294
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
295
+ self.layers = nn.ModuleList(
296
+ [Qwen2_5_VLMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
297
+ )
298
+ self._attn_implementation = config._attn_implementation
299
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
300
+ self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
301
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
302
+ self.gradient_checkpointing = False
303
+ self.post_init()
304
+
305
+ def get_input_embeddings(self):
306
+ return self.embed_tokens
307
+
308
+ def set_input_embeddings(self, value):
309
+ self.embed_tokens = value
310
+
311
+ def forward(
312
+ self,
313
+ input_ids: Optional[torch.LongTensor] = None,
314
+ attention_mask: Optional[torch.Tensor] = None,
315
+ padding_token_mask: Optional[torch.Tensor] = None,
316
+ position_ids: Optional[torch.LongTensor] = None,
317
+ past_key_values: Optional[Cache] = None,
318
+ inputs_embeds: Optional[torch.FloatTensor] = None,
319
+ use_cache: Optional[bool] = None,
320
+ output_attentions: Optional[bool] = None,
321
+ output_hidden_states: Optional[bool] = None,
322
+ output_router_logits_and_topk: Optional[bool] = None,
323
+ return_dict: Optional[bool] = None,
324
+ cache_position: Optional[torch.LongTensor] = None,
325
+ **kwargs: Unpack[FlashAttentionKwargs],
326
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
327
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
328
+ output_hidden_states = (
329
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
330
+ )
331
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
332
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
333
+
334
+ if (input_ids is None) ^ (inputs_embeds is not None):
335
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
336
+
337
+ if self.gradient_checkpointing and self.training:
338
+ if use_cache:
339
+ logger.warning_once(
340
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
341
+ )
342
+ use_cache = False
343
+
344
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
345
+ past_key_values = DynamicCache()
346
+
347
+ if inputs_embeds is None:
348
+ inputs_embeds = self.embed_tokens(input_ids)
349
+
350
+ if cache_position is None:
351
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
352
+ cache_position = torch.arange(
353
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
354
+ )
355
+
356
+ if position_ids is None:
357
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
358
+ elif position_ids.dim() == 2:
359
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
360
+
361
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
362
+ mask_kwargs = {
363
+ "config": self.config,
364
+ "input_embeds": inputs_embeds,
365
+ "attention_mask": attention_mask,
366
+ "cache_position": cache_position,
367
+ "past_key_values": past_key_values,
368
+ "position_ids": position_ids,
369
+ }
370
+ causal_mask_mapping = {
371
+ "full_attention": create_causal_mask(**mask_kwargs),
372
+ }
373
+ if self.has_sliding_layers:
374
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
375
+
376
+ hidden_states = inputs_embeds
377
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
378
+
379
+ all_hidden_states = () if output_hidden_states else None
380
+ all_self_attns = () if output_attentions else None
381
+ all_router_logits = () if output_router_logits_and_topk else None
382
+ all_router_top_k = () if output_router_logits_and_topk else None
383
+ all_router_expert_mask = ()
384
+ all_router_weight = ()
385
+ all_aux_loss = ()
386
+ next_decoder_cache = None
387
+
388
+ for decoder_layer in self.layers:
389
+ if output_hidden_states:
390
+ all_hidden_states += (hidden_states,)
391
+
392
+ layer_outputs = decoder_layer(
393
+ hidden_states,
394
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
395
+ padding_token_mask=padding_token_mask,
396
+ position_ids=position_ids,
397
+ past_key_value=past_key_values,
398
+ output_attentions=output_attentions,
399
+ output_router_logits_and_topk=output_router_logits_and_topk,
400
+ use_cache=use_cache,
401
+ cache_position=cache_position,
402
+ position_embeddings=position_embeddings,
403
+ **kwargs,
404
+ )
405
+
406
+ hidden_states = layer_outputs[0]
407
+
408
+ if output_attentions:
409
+ all_self_attns += (layer_outputs[1],)
410
+
411
+ if output_router_logits_and_topk:
412
+ all_router_logits += (layer_outputs[-5],)
413
+ all_router_top_k += (layer_outputs[-4],)
414
+ all_router_expert_mask += (layer_outputs[-3],)
415
+ all_router_weight += (layer_outputs[-2],)
416
+ all_aux_loss += (layer_outputs[-1],)
417
+
418
+ hidden_states = self.norm(hidden_states)
419
+
420
+ if output_hidden_states:
421
+ all_hidden_states += (hidden_states,)
422
+
423
+ if not return_dict:
424
+ return tuple(
425
+ v for v in [
426
+ hidden_states,
427
+ past_key_values,
428
+ all_hidden_states,
429
+ all_self_attns,
430
+ all_router_logits,
431
+ all_router_top_k,
432
+ all_router_expert_mask,
433
+ all_router_weight,
434
+ all_aux_loss]
435
+ if v is not None
436
+ )
437
+ return BaseModelOutputWithPast(
438
+ last_hidden_state=hidden_states,
439
+ past_key_values=past_key_values,
440
+ hidden_states=all_hidden_states,
441
+ attentions=all_self_attns,
442
+ all_router_logits=all_router_logits,
443
+ all_router_top_k=all_router_top_k,
444
+ all_router_expert_mask=all_router_expert_mask,
445
+ all_router_weight=all_router_weight,
446
+ all_aux_loss=all_aux_loss,
447
+ )
448
+
449
+
450
+ class UniMoEAudio(Qwen2_5_VLMoEPreTrainedModel):
451
+ base_model_prefix = ""
452
+ _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"]
453
+ config_class = UniMoEAudioConfig
454
+ _checkpoint_conversion_mapping = {
455
+ "^visual": "visual",
456
+ r"^model(?!\.(language_model|visual))": "language_model",
457
+ }
458
+ _tied_weights_keys = ["lm_head.weight"]
459
+
460
+ def __init__(self, config):
461
+ super().__init__(config)
462
+ self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config, attn_implementation=config._attn_implementation)
463
+ self.language_model = Qwen2_5_VLMoETextModel._from_config(config.text_config)
464
+ self.rope_deltas = None
465
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
466
+ self.num_channels = config.codec_channels
467
+ self.codec_vocab_size = config.codec_vocab_size
468
+ self.codec_embed_tokens = nn.ModuleList(
469
+ [nn.Embedding(self.codec_vocab_size, config.text_config.hidden_size) for embed_idx in range(self.num_channels)])
470
+ self.codec_placeholder_value = config.codec_placeholder_value
471
+ self.codec_head = nn.Linear(config.text_config.hidden_size, self.num_channels * self.codec_vocab_size, bias=False)
472
+ self.post_init()
473
+
474
+ @property
475
+ def cur_aux_weight(self):
476
+ if self.training_steps >= self.l_aux_weight_decay_steps:
477
+ return self.min_l_aux_weight
478
+ return self.l_aux_weight - (self.l_aux_weight - self.min_l_aux_weight) / self.l_aux_weight_decay_steps * self.training_steps
479
+
480
+ def get_input_embeddings(self):
481
+ return self.language_model.get_input_embeddings()
482
+
483
+ def set_input_embeddings(self, value):
484
+ self.language_model.set_input_embeddings(value)
485
+
486
+ def get_output_embeddings(self):
487
+ return self.lm_head
488
+
489
+ def set_output_embeddings(self, new_embeddings):
490
+ self.lm_head = new_embeddings
491
+
492
+ def set_decoder(self, decoder):
493
+ self.language_model = decoder
494
+
495
+ def get_decoder(self):
496
+ return self.language_model
497
+
498
+ def get_rope_index(
499
+ self,
500
+ input_ids: Optional[torch.LongTensor] = None,
501
+ image_grid_thw: Optional[torch.LongTensor] = None,
502
+ video_grid_thw: Optional[torch.LongTensor] = None,
503
+ second_per_grid_ts: Optional[torch.Tensor] = None,
504
+ attention_mask: Optional[torch.Tensor] = None,
505
+ ) -> tuple[torch.Tensor, torch.Tensor]:
506
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
507
+ image_token_id = self.config.image_token_id
508
+ video_token_id = self.config.video_token_id
509
+ vision_start_token_id = self.config.vision_start_token_id
510
+ mrope_position_deltas = []
511
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
512
+ total_input_ids = input_ids
513
+ if attention_mask is None:
514
+ attention_mask = torch.ones_like(total_input_ids)
515
+ position_ids = torch.ones(
516
+ 3,
517
+ input_ids.shape[0],
518
+ input_ids.shape[1],
519
+ dtype=input_ids.dtype,
520
+ device=input_ids.device,
521
+ )
522
+ image_index, video_index = 0, 0
523
+ attention_mask = attention_mask.to(total_input_ids.device)
524
+ for i, input_ids in enumerate(total_input_ids):
525
+ input_ids = input_ids[attention_mask[i] == 1]
526
+ image_nums, video_nums = 0, 0
527
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
528
+ vision_tokens = input_ids[vision_start_indices + 1]
529
+ image_nums = (vision_tokens == image_token_id).sum()
530
+ video_nums = (vision_tokens == video_token_id).sum()
531
+ input_tokens = input_ids.tolist()
532
+ llm_pos_ids_list: list = []
533
+ st = 0
534
+ remain_images, remain_videos = image_nums, video_nums
535
+ for _ in range(image_nums + video_nums):
536
+ if image_token_id in input_tokens and remain_images > 0:
537
+ ed_image = input_tokens.index(image_token_id, st)
538
+ else:
539
+ ed_image = len(input_tokens) + 1
540
+ if video_token_id in input_tokens and remain_videos > 0:
541
+ ed_video = input_tokens.index(video_token_id, st)
542
+ else:
543
+ ed_video = len(input_tokens) + 1
544
+ if ed_image < ed_video:
545
+ t, h, w = (
546
+ image_grid_thw[image_index][0],
547
+ image_grid_thw[image_index][1],
548
+ image_grid_thw[image_index][2],
549
+ )
550
+ second_per_grid_t = 0
551
+ image_index += 1
552
+ remain_images -= 1
553
+ ed = ed_image
554
+
555
+ else:
556
+ t, h, w = (
557
+ video_grid_thw[video_index][0],
558
+ video_grid_thw[video_index][1],
559
+ video_grid_thw[video_index][2],
560
+ )
561
+ if second_per_grid_ts is not None:
562
+ second_per_grid_t = second_per_grid_ts[video_index]
563
+ else:
564
+ second_per_grid_t = 1.0
565
+ video_index += 1
566
+ remain_videos -= 1
567
+ ed = ed_video
568
+ llm_grid_t, llm_grid_h, llm_grid_w = (
569
+ t.item(),
570
+ h.item() // spatial_merge_size,
571
+ w.item() // spatial_merge_size,
572
+ )
573
+ text_len = ed - st
574
+
575
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
576
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
577
+
578
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
579
+ expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
580
+ second_per_grid_t = torch.as_tensor(
581
+ second_per_grid_t, dtype=range_tensor.dtype, device=range_tensor.device
582
+ )
583
+
584
+ time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
585
+
586
+ time_tensor_long = time_tensor.long()
587
+ t_index = time_tensor_long.flatten()
588
+
589
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
590
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
591
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
592
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
593
+
594
+ if st < len(input_tokens):
595
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
596
+ text_len = len(input_tokens) - st
597
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
598
+
599
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
600
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
601
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
602
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
603
+ return position_ids, mrope_position_deltas
604
+ else:
605
+ if attention_mask is not None:
606
+ position_ids = attention_mask.long().cumsum(-1) - 1
607
+ position_ids.masked_fill_(attention_mask == 0, 1)
608
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
609
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
610
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
611
+ else:
612
+ position_ids = (
613
+ torch.arange(input_ids.shape[1], device=input_ids.device)
614
+ .view(1, 1, -1)
615
+ .expand(3, input_ids.shape[0], -1)
616
+ )
617
+ mrope_position_deltas = torch.zeros(
618
+ [input_ids.shape[0], 1],
619
+ device=input_ids.device,
620
+ dtype=input_ids.dtype,
621
+ )
622
+
623
+ return position_ids, mrope_position_deltas
624
+
625
+ def get_video_features(self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None):
626
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
627
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
628
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
629
+ video_embeds = torch.split(video_embeds, split_sizes)
630
+ return video_embeds
631
+
632
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
633
+ pixel_values = pixel_values.type(self.visual.dtype)
634
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
635
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
636
+ image_embeds = torch.split(image_embeds, split_sizes)
637
+ return image_embeds
638
+
639
+
640
+ def codec_embedding(self, codec_input_ids):
641
+ x = None
642
+ for i in range(self.num_channels):
643
+ channel_tokens = codec_input_ids[..., i]
644
+ channel_embed = self.codec_embed_tokens[i](channel_tokens)
645
+ x = channel_embed if x is None else x + channel_embed
646
+ return x
647
+
648
+ def calculate_input_embedding(self, input_ids, codec_input_ids):
649
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
650
+ if codec_input_ids is not None:
651
+ codec_input_embeds = self.codec_embedding(codec_input_ids)
652
+
653
+ codec_mask = (input_ids == self.codec_placeholder_value).unsqueeze(-1).expand_as(inputs_embeds)
654
+ inputs_embeds = inputs_embeds.masked_scatter(codec_mask, codec_input_embeds)
655
+ return inputs_embeds
656
+
657
+ @can_return_tuple
658
+ def forward(
659
+ self,
660
+ input_ids: torch.LongTensor = None,
661
+ codec_input_ids: torch.LongTensor = None,
662
+ attention_mask: Optional[torch.Tensor] = None,
663
+ position_ids: Optional[torch.LongTensor] = None,
664
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
665
+ inputs_embeds: Optional[torch.FloatTensor] = None,
666
+ labels: Optional[torch.LongTensor] = None,
667
+ codec_labels: Optional[torch.LongTensor] = None,
668
+ padding_token_mask: Optional[torch.Tensor] = None,
669
+ use_cache: Optional[bool] = None,
670
+ output_attentions: Optional[bool] = None,
671
+ output_hidden_states: Optional[bool] = None,
672
+ output_router_logits_and_topk: Optional[bool] = None,
673
+ pixel_values: Optional[torch.Tensor] = None,
674
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
675
+ image_grid_thw: Optional[torch.LongTensor] = None,
676
+ video_grid_thw: Optional[torch.LongTensor] = None,
677
+ rope_deltas: Optional[torch.LongTensor] = None,
678
+ cache_position: Optional[torch.LongTensor] = None,
679
+ second_per_grid_ts: Optional[torch.Tensor] = None,
680
+ **kwargs,
681
+
682
+ ) -> Union[Tuple, MoEQwen2_5VLCausalLMOutputWithPast]:
683
+ return_dict = True
684
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
685
+ output_hidden_states = (
686
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
687
+ )
688
+
689
+ if inputs_embeds is None:
690
+ inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids)
691
+
692
+ if pixel_values is not None:
693
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw)
694
+ image_embeds = torch.cat(image_embeds, dim=0)
695
+
696
+ if input_ids is None:
697
+ image_mask = inputs_embeds == self.get_input_embeddings()(
698
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
699
+ )
700
+ image_mask = image_mask.all(-1)
701
+ else:
702
+ image_mask = input_ids == self.config.image_token_id
703
+
704
+ n_image_tokens = (image_mask).sum()
705
+ image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
706
+ n_image_features = image_embeds.shape[0]
707
+ if not is_torchdynamo_compiling() and n_image_tokens != n_image_features:
708
+ raise ValueError(
709
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
710
+ )
711
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
712
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
713
+
714
+ if pixel_values_videos is not None:
715
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw)
716
+ video_embeds = torch.cat(video_embeds, dim=0)
717
+
718
+ if input_ids is None:
719
+ video_mask = inputs_embeds == self.get_input_embeddings()(
720
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
721
+ )
722
+ video_mask = video_mask.all(-1)
723
+ else:
724
+ video_mask = input_ids == self.config.video_token_id
725
+
726
+ n_video_tokens = (video_mask).sum()
727
+ n_video_features = video_embeds.shape[0]
728
+ video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
729
+ if not is_torchdynamo_compiling() and n_video_tokens != n_video_features:
730
+ raise ValueError(
731
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
732
+ )
733
+
734
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
735
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
736
+
737
+ if position_ids is None:
738
+ attention_mask_tensor = (
739
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
740
+ )
741
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
742
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
743
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
744
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
745
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
746
+ (input_ids is not None and input_ids.shape[1] != 1)
747
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
748
+ )
749
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
750
+ (cache_position is not None and cache_position[0] == 0)
751
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
752
+ )
753
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
754
+ position_ids, rope_deltas = self.get_rope_index(
755
+ input_ids,
756
+ image_grid_thw,
757
+ video_grid_thw,
758
+ second_per_grid_ts=second_per_grid_ts,
759
+ attention_mask=attention_mask_tensor,
760
+ )
761
+ self.rope_deltas = rope_deltas
762
+
763
+ else:
764
+ batch_size, seq_length, _ = inputs_embeds.shape
765
+ delta = (
766
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
767
+ if cache_position is not None
768
+ else 0
769
+ )
770
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
771
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
772
+ if cache_position is not None:
773
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
774
+ position_ids = position_ids.add(delta)
775
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
776
+
777
+ if padding_token_mask is None:
778
+ padding_token_mask = attention_mask.bool()
779
+
780
+ outputs = self.language_model(
781
+ input_ids=None,
782
+ position_ids=position_ids,
783
+ attention_mask=attention_mask,
784
+ padding_token_mask=padding_token_mask,
785
+ past_key_values=past_key_values,
786
+ inputs_embeds=inputs_embeds,
787
+ use_cache=use_cache,
788
+ output_attentions=output_attentions,
789
+ output_hidden_states=output_hidden_states,
790
+ output_router_logits_and_topk=output_router_logits_and_topk,
791
+ return_dict=return_dict,
792
+ cache_position=cache_position,
793
+ **kwargs,
794
+ )
795
+
796
+ hidden_states = outputs[0]
797
+ logits = self.lm_head(hidden_states).float()
798
+ codec_logits = self.codec_head(hidden_states).float()
799
+ codec_logits = codec_logits.view((logits.shape[0], logits.shape[1], self.num_channels, self.codec_vocab_size))
800
+
801
+ loss = None
802
+ if labels is not None:
803
+
804
+ all_aux_loss = outputs.all_aux_loss if return_dict else outputs[-1]
805
+ all_aux_loss = torch.mean(torch.cat([l.unsqueeze(0) for l in all_aux_loss], dim=0))
806
+ aux_loss = self.cur_aux_weight * all_aux_loss
807
+ self.training_steps += 1
808
+ codec_loss = None
809
+
810
+ if codec_labels is not None:
811
+ for i in range(self.num_channels):
812
+ channel_logits = codec_logits[:, :, i].float()
813
+ channel_labels = codec_labels[:, :, i]
814
+ shift_channel_logits = channel_logits[..., :-1, :].contiguous()
815
+ shift_channel_labels = channel_labels[..., 1:].contiguous()
816
+
817
+ if i!= 0 and (shift_channel_labels != -100).sum() == 0:
818
+ continue
819
+
820
+ loss_fct = CrossEntropyLoss()
821
+ shift_channel_logits = shift_channel_logits.view(-1, self.codec_vocab_size)
822
+ shift_channel_labels = shift_channel_labels.view(-1)
823
+ shift_channel_labels = shift_channel_labels.to(shift_channel_logits.device)
824
+ channel_loss = loss_fct(shift_channel_logits, shift_channel_labels)
825
+ codec_loss = channel_loss if codec_loss is None else codec_loss + channel_loss
826
+
827
+ loss = codec_loss + aux_loss
828
+
829
+
830
+ if not return_dict:
831
+ output = (logits,) + outputs[1:]
832
+ return (loss,) + output if loss is not None else output
833
+
834
+ return MoEQwen2_5VLCausalLMOutputWithPast(
835
+ loss=loss,
836
+ logits=logits,
837
+ past_key_values=outputs.past_key_values,
838
+ hidden_states=outputs.hidden_states,
839
+ attentions=outputs.attentions,
840
+ all_router_logits=outputs.all_router_logits,
841
+ all_router_top_k=outputs.all_router_top_k,
842
+ all_router_expert_mask=outputs.all_router_expert_mask,
843
+ all_router_weight=outputs.all_router_weight,
844
+ aux_balance_loss=all_aux_loss,
845
+ )
846
+
847
+ @staticmethod
848
+ def _sample_next_token(
849
+ logits_BCxV: torch.Tensor,
850
+ temperature: float,
851
+ top_p: float,
852
+ top_k: int,
853
+ audio_eos_value: int,
854
+ ) -> torch.Tensor:
855
+ if temperature == 0.0:
856
+ return torch.argmax(logits_BCxV, dim=-1)
857
+
858
+ logits_BCxV = logits_BCxV / temperature
859
+
860
+ if audio_eos_value is not None and audio_eos_value >= 0:
861
+ top_logit_indices_BC = torch.argmax(logits_BCxV, dim=-1)
862
+ eos_not_highest_mask_BC = top_logit_indices_BC != audio_eos_value
863
+ mask_eos_unless_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool)
864
+ mask_eos_unless_highest_BCxV[eos_not_highest_mask_BC, audio_eos_value] = True
865
+ logits_BCxV = logits_BCxV.masked_fill(mask_eos_unless_highest_BCxV, -torch.inf)
866
+
867
+ if top_k is not None:
868
+ _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=top_k, dim=-1)
869
+ mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
870
+ mask = mask.scatter(dim=-1, index=top_k_indices_BCxV, value=False)
871
+ logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
872
+
873
+ if top_p < 1.0:
874
+ probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
875
+ sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True)
876
+ cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
877
+
878
+ sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
879
+ sorted_indices_to_remove_BCxV = torch.roll(sorted_indices_to_remove_BCxV, shifts=1, dims=-1)
880
+ sorted_indices_to_remove_BCxV[..., 0] = torch.zeros_like(sorted_indices_to_remove_BCxV[..., 0])
881
+
882
+ indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
883
+ indices_to_remove_BCxV = indices_to_remove_BCxV.scatter(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV)
884
+ logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
885
+
886
+ final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
887
+
888
+ sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
889
+ sampled_indices_C = sampled_indices_BC.squeeze(-1)
890
+ return sampled_indices_C
891
+
892
+ def _decoder_step(
893
+ self,
894
+ tokens_Bx1xC: torch.Tensor,
895
+ model_kwargs,
896
+ cfg_scale: float,
897
+ neg_input_size: int,
898
+ temperature: float,
899
+ top_p: float,
900
+ top_k: int,
901
+ do_sample=True,
902
+ eos_prob_mul_factor=1.0,
903
+ labels_Bx1xC=None,
904
+ use_cache=True,
905
+ enable_eos=True,
906
+ ) -> torch.Tensor:
907
+ B = tokens_Bx1xC.shape[0]
908
+ audio_eos_value = self.config.codec_eos_value
909
+ attention_mask = model_kwargs["attention_mask"]
910
+ cache_position = model_kwargs["cache_position"]
911
+ past_key_values = model_kwargs["past_key_values"]
912
+ input_ids = model_kwargs["input_ids"]
913
+ codec_input_ids = model_kwargs["codec_input_ids"]
914
+ position_ids = attention_mask.long().cumsum(-1) - 1
915
+ position_ids.masked_fill_(attention_mask == 0, 1)
916
+ if past_key_values:
917
+ position_ids = position_ids[:, -tokens_Bx1xC.shape[1] :]
918
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
919
+
920
+ tokens_Bx1xC = tokens_Bx1xC.repeat_interleave(neg_input_size, dim=0)
921
+ codec_input_ids = torch.cat((codec_input_ids, tokens_Bx1xC), dim=1) if codec_input_ids is not None else tokens_Bx1xC.clone()
922
+ input_ids = torch.cat((input_ids, torch.ones(input_ids.shape[0], 1).to(input_ids) * self.codec_placeholder_value), dim=-1)
923
+
924
+ if use_cache:
925
+ codec_input_embeds = self.codec_embedding(tokens_Bx1xC)
926
+ outputs = self.language_model(
927
+ input_ids=None,
928
+ attention_mask=attention_mask,
929
+ position_ids=position_ids,
930
+ past_key_values=past_key_values,
931
+ inputs_embeds=codec_input_embeds,
932
+ use_cache=True,
933
+ output_attentions=False,
934
+ output_hidden_states=False,
935
+ return_dict=True,
936
+ cache_position=cache_position,
937
+ )
938
+
939
+ else:
940
+ batch_codec_input_ids = codec_input_ids.contiguous().view(-1, self.num_channels)
941
+
942
+ inputs_embeds = self.calculate_input_embedding(input_ids, batch_codec_input_ids)
943
+ outputs = self.language_model(
944
+ input_ids=None,
945
+ attention_mask=attention_mask,
946
+ position_ids=attention_mask.long().cumsum(-1) - 1,
947
+ past_key_values=None,
948
+ inputs_embeds=inputs_embeds,
949
+ use_cache=True,
950
+ output_attentions=False,
951
+ output_hidden_states=False,
952
+ return_dict=True,
953
+ cache_position=None,
954
+ )
955
+
956
+ last_hidden_state = outputs.last_hidden_state
957
+ codec_logits = self.codec_head(last_hidden_state).float()
958
+ codec_logits = codec_logits.view((codec_logits.shape[0], codec_logits.shape[1], self.num_channels, self.codec_vocab_size))
959
+ model_kwargs["past_key_values"] = outputs.past_key_values
960
+ attention_mask = model_kwargs["attention_mask"]
961
+ model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
962
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
963
+ model_kwargs["input_ids"] = input_ids
964
+ model_kwargs["codec_input_ids"] = codec_input_ids
965
+
966
+ logits_Bx1xCxV = codec_logits[: , -1:].clone()
967
+ logits_last_2BxCxV = logits_Bx1xCxV[:, -1]
968
+ logits_last_Bx2xCxV = logits_last_2BxCxV.view(B, neg_input_size, *logits_last_2BxCxV.shape[1:])
969
+ if cfg_scale is not None:
970
+ cond_logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] # Shape [B, C, V]
971
+ logits_BxCxV = cond_logits_BxCxV
972
+ for ni in range(neg_input_size - 1):
973
+ uncond_logits_BxCxV = logits_last_Bx2xCxV[:, ni, :, :] # Shape [B, C, V]
974
+ cfg_weight = cfg_scale[ni] if isinstance(cfg_scale, List) else cfg_scale
975
+ logits_BxCxV = logits_BxCxV + cfg_weight * (cond_logits_BxCxV - uncond_logits_BxCxV)
976
+ else:
977
+ logits_BxCxV = logits_last_Bx2xCxV[:, -1, :, :] # Shape [B, C, V]
978
+
979
+ if enable_eos:
980
+ logits_BxCxV[:, :, audio_eos_value + 1 :] = torch.full_like(
981
+ logits_BxCxV[:, :, audio_eos_value + 1 :],
982
+ fill_value=-torch.inf,
983
+ )
984
+ logits_BxCxV[:, 1:, audio_eos_value:] = torch.full_like(
985
+ logits_BxCxV[:, 1:, audio_eos_value:],
986
+ fill_value=-torch.inf,
987
+ )
988
+ logits_BxCxV[:, 0, audio_eos_value] *= torch.tensor(eos_prob_mul_factor, device=self.device)
989
+
990
+ else:
991
+ logits_BxCxV[:, :, audio_eos_value:] = torch.full_like(
992
+ logits_BxCxV[:, :, audio_eos_value:],
993
+ fill_value=-torch.inf,
994
+ )
995
+
996
+
997
+ flat_logits_BCxV = logits_BxCxV.reshape(B * self.num_channels, -1)
998
+ if do_sample:
999
+ pred_BC = self._sample_next_token(
1000
+ flat_logits_BCxV.float(),
1001
+ temperature=temperature,
1002
+ top_p=top_p,
1003
+ top_k=top_k,
1004
+ audio_eos_value=audio_eos_value,
1005
+ )
1006
+ else:
1007
+ pred_BC = torch.argmax(flat_logits_BCxV, dim=1)
1008
+
1009
+ pred_BxC = pred_BC.view(B, self.num_channels)
1010
+
1011
+ return pred_BxC, model_kwargs
1012
+
1013
+ def generate(
1014
+ self,
1015
+ input_ids,
1016
+ attention_mask,
1017
+ dec_output,
1018
+ max_tokens,
1019
+ min_tokens=None,
1020
+ codec_input_ids: Optional[torch.Tensor] = None,
1021
+ pixel_values: Optional[torch.Tensor] = None,
1022
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
1023
+ image_grid_thw: Optional[torch.LongTensor] = None,
1024
+ video_grid_thw: Optional[torch.LongTensor] = None,
1025
+ second_per_grid_ts: Optional[torch.Tensor] = None,
1026
+ neg_input_size = 2,
1027
+ cfg_scale = 3.0,
1028
+ temperature: float = 1.2,
1029
+ top_p: float = 0.95,
1030
+ cfg_filter_top_k: int = 45,
1031
+ eos_prob_mul_factor: float = 0.8,
1032
+ do_sample: bool = True,
1033
+ debug_guidance_step: int = 0,
1034
+ use_cache=True,
1035
+ ):
1036
+ if codec_input_ids is not None:
1037
+ assert use_cache
1038
+ batch_size = input_ids.shape[0] // neg_input_size
1039
+ audio_eos_value = self.config.codec_eos_value
1040
+ audio_pad_value = self.config.codec_pad_value
1041
+ delay_pattern = self.config.codec_delay_pattern
1042
+ max_delay_pattern = max(delay_pattern)
1043
+ delay_pattern_Cx = torch.tensor(delay_pattern, device=self.device, dtype=torch.long)
1044
+
1045
+ dec_step = min(dec_output.prefill_steps) - 1
1046
+
1047
+ eos_detected_Bx = torch.zeros((batch_size,), dtype=torch.bool, device=self.device)
1048
+ eos_countdown_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device)
1049
+ finished_step_Bx = torch.full((batch_size,), -1, dtype=torch.long, device=self.device)
1050
+
1051
+ bos_over = False
1052
+ model_kwargs = dict(attention_mask=attention_mask, use_cache=True)
1053
+ model_kwargs["past_key_values"] = DynamicCache()
1054
+ model_kwargs["cache_position"] = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
1055
+ attention_mask = model_kwargs["attention_mask"]
1056
+ past_key_values = model_kwargs["past_key_values"]
1057
+ position_ids = attention_mask.long().cumsum(-1) - 1
1058
+ position_ids.masked_fill_(attention_mask == 0, 1)
1059
+ cache_position = torch.arange(0, input_ids.shape[-1], device=input_ids.device)
1060
+ inputs_embeds = self.calculate_input_embedding(input_ids, codec_input_ids)
1061
+ outputs = self.language_model(
1062
+ input_ids=None,
1063
+ attention_mask=attention_mask,
1064
+ position_ids=position_ids,
1065
+ past_key_values=past_key_values,
1066
+ inputs_embeds=inputs_embeds,
1067
+ pixel_values=pixel_values,
1068
+ pixel_values_videos=pixel_values_videos,
1069
+ image_grid_thw=image_grid_thw,
1070
+ video_grid_thw=video_grid_thw,
1071
+ second_per_grid_ts=second_per_grid_ts,
1072
+ use_cache=True,
1073
+ output_attentions=False,
1074
+ output_hidden_states=False,
1075
+ return_dict=True,
1076
+ cache_position=cache_position,
1077
+ )
1078
+
1079
+ model_kwargs["input_ids"] = input_ids
1080
+ model_kwargs["codec_input_ids"] = None
1081
+ model_kwargs["labels"] = torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100
1082
+ labels_Bx1xC = dec_output.get_labels_at(0)
1083
+ if labels_Bx1xC is not None:
1084
+ model_kwargs["codec_labels"] = (torch.ones_like(input_ids[neg_input_size-1::neg_input_size]) * -100).unsqueeze(-1).expand(-1, -1, self.num_channels)
1085
+ assert (labels_Bx1xC != self.config.codec_bos_value).sum() == 0
1086
+ labels_Bx1xC = torch.full_like(labels_Bx1xC, -100)
1087
+ model_kwargs["codec_labels"] = torch.cat((model_kwargs["codec_labels"], labels_Bx1xC), dim=1)
1088
+ model_kwargs["past_key_values"] = outputs.past_key_values
1089
+ attention_mask = model_kwargs["attention_mask"]
1090
+ model_kwargs["attention_mask"] = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
1091
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
1092
+
1093
+ while dec_step < max_tokens:
1094
+ if (eos_countdown_Bx == 0).all():
1095
+ break
1096
+
1097
+ current_step_idx = dec_step + 1
1098
+ tokens_Bx1xC = dec_output.get_tokens_at(dec_step)
1099
+ labels_Bx1xC = dec_output.get_labels_at(dec_step + 1)
1100
+
1101
+ pred_BxC, model_kwargs = self._decoder_step(
1102
+ tokens_Bx1xC=tokens_Bx1xC,
1103
+ model_kwargs=model_kwargs,
1104
+ cfg_scale=cfg_scale,
1105
+ neg_input_size=neg_input_size,
1106
+ temperature=temperature,
1107
+ top_p=top_p,
1108
+ top_k=cfg_filter_top_k,
1109
+ do_sample=do_sample,
1110
+ eos_prob_mul_factor=eos_prob_mul_factor,
1111
+ labels_Bx1xC=labels_Bx1xC,
1112
+ use_cache=use_cache,
1113
+ enable_eos=(min_tokens is None or dec_step >= min_tokens),
1114
+ )
1115
+ if labels_Bx1xC is not None and (dec_step < debug_guidance_step or debug_guidance_step==-1):
1116
+ pred_BxC = labels_Bx1xC[:, 0]
1117
+
1118
+ active_mask_Bx = eos_countdown_Bx != 0
1119
+ eos_trigger_Bx = torch.zeros_like(active_mask_Bx)
1120
+ if active_mask_Bx.any():
1121
+ is_eos_token = (~eos_detected_Bx[active_mask_Bx]) & (pred_BxC[active_mask_Bx, 0] == audio_eos_value)
1122
+ is_max_len = current_step_idx >= max_tokens - max_delay_pattern
1123
+ eos_trigger_Bx[active_mask_Bx] = is_eos_token | is_max_len
1124
+ eos_detected_Bx |= eos_trigger_Bx
1125
+ start_countdown_mask_Bx = eos_trigger_Bx & (eos_countdown_Bx < 0)
1126
+ if start_countdown_mask_Bx.any():
1127
+ eos_countdown_Bx[start_countdown_mask_Bx] = max_delay_pattern
1128
+ finished_step_Bx[start_countdown_mask_Bx] = current_step_idx
1129
+
1130
+ padding_mask_Bx = eos_countdown_Bx > 0
1131
+ if padding_mask_Bx.any():
1132
+ pred_active_BxC = pred_BxC[padding_mask_Bx].clone()
1133
+ countdown_active_Bx = eos_countdown_Bx[padding_mask_Bx]
1134
+ step_after_eos_Bx = max_delay_pattern - countdown_active_Bx
1135
+ step_after_eos_Bx_ = step_after_eos_Bx.unsqueeze(1)
1136
+ delay_pattern_Cx_ = delay_pattern_Cx.unsqueeze(0)
1137
+ eos_mask_NxC = step_after_eos_Bx_ == delay_pattern_Cx_
1138
+ pad_mask_NxC = step_after_eos_Bx_ > delay_pattern_Cx_
1139
+ pred_active_BxC[eos_mask_NxC] = audio_eos_value
1140
+ pred_active_BxC[pad_mask_NxC] = audio_pad_value
1141
+ pred_BxC[padding_mask_Bx] = pred_active_BxC
1142
+ eos_countdown_Bx[padding_mask_Bx] -= 1
1143
+
1144
+ if not bos_over:
1145
+ bos_over = all(current_step_idx - prefill_step >= max_delay_pattern for prefill_step in dec_output.prefill_steps)
1146
+
1147
+ dec_output.update_one(pred_BxC, current_step_idx, not bos_over)
1148
+ dec_step += 1
1149
+
1150
+ final_step = dec_step + 1
1151
+ finished_step_Bx[finished_step_Bx == -1] = final_step - max_delay_pattern
1152
+ prefill_steps_tensor = torch.tensor(dec_output.prefill_steps, device=self.device)
1153
+ lengths_Bx = finished_step_Bx - prefill_steps_tensor
1154
+ lengths_Bx = torch.clamp(lengths_Bx, min=0)
1155
+ max_len = lengths_Bx.max().item() + max_delay_pattern
1156
+
1157
+ if max_len > 0:
1158
+ num_channels = self.num_channels
1159
+ generated_codes = torch.full(
1160
+ (batch_size, max_len, num_channels),
1161
+ fill_value=audio_pad_value,
1162
+ dtype=torch.long,
1163
+ device=self.device,
1164
+ )
1165
+
1166
+ for i in range(batch_size):
1167
+ start_step = dec_output.prefill_steps[i]
1168
+ actual_len = lengths_Bx[i].item() + max_delay_pattern
1169
+ if actual_len > 0:
1170
+ tokens_to_copy = dec_output.generated_tokens[i, start_step : start_step + actual_len, :]
1171
+ generated_codes[i, :actual_len, :] = tokens_to_copy
1172
+
1173
+ return generated_codes, lengths_Bx
1174
+ else:
1175
+ print("Warning: Nothing generated for any sequence in the batch.")
1176
+ return None, None
1177
+
1178
+ # AutoConfig.register("qwen2_5_vl_moe_text", Qwen2_5_VLMoETextConfig)
1179
+ # AutoModelForCausalLM.register(Qwen2_5_VLMoETextConfig, Qwen2_5_VLMoETextModel)
1180
+
1181
+ # AutoConfig.register("uni_audio_rvq_qwen2_5vl_moe", UniMoEAudioConfig)
1182
+ # AutoModelForCausalLM.register(UniMoEAudioConfig, UniMoEAudio)
special_tokens_map.json CHANGED
@@ -12,7 +12,84 @@
12
  "<|vision_end|>",
13
  "<|vision_pad|>",
14
  "<|image_pad|>",
15
- "<|video_pad|>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  ],
17
  "eos_token": {
18
  "content": "<|im_end|>",
 
12
  "<|vision_end|>",
13
  "<|vision_pad|>",
14
  "<|image_pad|>",
15
+ "<|video_pad|>",
16
+ {
17
+ "content": "<|AUDIO_PLACEHOLDER|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ {
24
+ "content": "<|AUDIO_START|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ {
31
+ "content": "<|AUDIO_END|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ {
38
+ "content": "<|SPEECH_START|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ {
45
+ "content": "<|SPEECH_END|>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ },
51
+ {
52
+ "content": "<|VOICE_PROMPT_START|>",
53
+ "lstrip": false,
54
+ "normalized": false,
55
+ "rstrip": false,
56
+ "single_word": false
57
+ },
58
+ {
59
+ "content": "<|VOICE_PROMPT_END|>",
60
+ "lstrip": false,
61
+ "normalized": false,
62
+ "rstrip": false,
63
+ "single_word": false
64
+ },
65
+ {
66
+ "content": "<|SPEECH_PROMPT_START|>",
67
+ "lstrip": false,
68
+ "normalized": false,
69
+ "rstrip": false,
70
+ "single_word": false
71
+ },
72
+ {
73
+ "content": "<|SPEECH_PROMPT_END|>",
74
+ "lstrip": false,
75
+ "normalized": false,
76
+ "rstrip": false,
77
+ "single_word": false
78
+ },
79
+ {
80
+ "content": "<|MUSIC_START|>",
81
+ "lstrip": false,
82
+ "normalized": false,
83
+ "rstrip": false,
84
+ "single_word": false
85
+ },
86
+ {
87
+ "content": "<|MUSIC_END|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": false,
91
+ "single_word": false
92
+ }
93
  ],
94
  "eos_token": {
95
  "content": "<|im_end|>",
tokenizer_config.json CHANGED
@@ -177,6 +177,94 @@
177
  "rstrip": false,
178
  "single_word": false,
179
  "special": false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  }
181
  },
182
  "additional_special_tokens": [
@@ -192,15 +280,27 @@
192
  "<|vision_end|>",
193
  "<|vision_pad|>",
194
  "<|image_pad|>",
195
- "<|video_pad|>"
 
 
 
 
 
 
 
 
 
 
 
196
  ],
197
  "bos_token": null,
198
  "clean_up_tokenization_spaces": false,
199
  "eos_token": "<|im_end|>",
200
  "errors": "replace",
201
  "extra_special_tokens": {},
202
- "model_max_length": 131072,
203
  "pad_token": "<|endoftext|>",
 
204
  "processor_class": "Qwen2_5_VLProcessor",
205
  "split_special_tokens": false,
206
  "tokenizer_class": "Qwen2Tokenizer",
 
177
  "rstrip": false,
178
  "single_word": false,
179
  "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<|AUDIO_PLACEHOLDER|>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ },
189
+ "151666": {
190
+ "content": "<|AUDIO_START|>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": true
196
+ },
197
+ "151667": {
198
+ "content": "<|AUDIO_END|>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": true
204
+ },
205
+ "151668": {
206
+ "content": "<|SPEECH_START|>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": true
212
+ },
213
+ "151669": {
214
+ "content": "<|SPEECH_END|>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ },
221
+ "151670": {
222
+ "content": "<|VOICE_PROMPT_START|>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": true
228
+ },
229
+ "151671": {
230
+ "content": "<|VOICE_PROMPT_END|>",
231
+ "lstrip": false,
232
+ "normalized": false,
233
+ "rstrip": false,
234
+ "single_word": false,
235
+ "special": true
236
+ },
237
+ "151672": {
238
+ "content": "<|SPEECH_PROMPT_START|>",
239
+ "lstrip": false,
240
+ "normalized": false,
241
+ "rstrip": false,
242
+ "single_word": false,
243
+ "special": true
244
+ },
245
+ "151673": {
246
+ "content": "<|SPEECH_PROMPT_END|>",
247
+ "lstrip": false,
248
+ "normalized": false,
249
+ "rstrip": false,
250
+ "single_word": false,
251
+ "special": true
252
+ },
253
+ "151674": {
254
+ "content": "<|MUSIC_START|>",
255
+ "lstrip": false,
256
+ "normalized": false,
257
+ "rstrip": false,
258
+ "single_word": false,
259
+ "special": true
260
+ },
261
+ "151675": {
262
+ "content": "<|MUSIC_END|>",
263
+ "lstrip": false,
264
+ "normalized": false,
265
+ "rstrip": false,
266
+ "single_word": false,
267
+ "special": true
268
  }
269
  },
270
  "additional_special_tokens": [
 
280
  "<|vision_end|>",
281
  "<|vision_pad|>",
282
  "<|image_pad|>",
283
+ "<|video_pad|>",
284
+ "<|AUDIO_PLACEHOLDER|>",
285
+ "<|AUDIO_START|>",
286
+ "<|AUDIO_END|>",
287
+ "<|SPEECH_START|>",
288
+ "<|SPEECH_END|>",
289
+ "<|VOICE_PROMPT_START|>",
290
+ "<|VOICE_PROMPT_END|>",
291
+ "<|SPEECH_PROMPT_START|>",
292
+ "<|SPEECH_PROMPT_END|>",
293
+ "<|MUSIC_START|>",
294
+ "<|MUSIC_END|>"
295
  ],
296
  "bos_token": null,
297
  "clean_up_tokenization_spaces": false,
298
  "eos_token": "<|im_end|>",
299
  "errors": "replace",
300
  "extra_special_tokens": {},
301
+ "model_max_length": 4096,
302
  "pad_token": "<|endoftext|>",
303
+ "padding_side": "right",
304
  "processor_class": "Qwen2_5_VLProcessor",
305
  "split_special_tokens": false,
306
  "tokenizer_class": "Qwen2Tokenizer",
utils.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ UniMoE Audio Utilities Module
4
+ Author: UniMoE Audio Team
5
+ """
6
+
7
+ import copy
8
+ import glob
9
+ import json
10
+ import math
11
+ import os
12
+ import re
13
+ import shutil
14
+ import sys
15
+ import time
16
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, TYPE_CHECKING, Callable
17
+
18
+ import dac
19
+ import datasets
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchaudio
24
+ import transformers
25
+ from audiotools import AudioSignal
26
+ from safetensors import safe_open
27
+ from tqdm import tqdm
28
+ from transformers import AutoProcessor, AutoTokenizer, LogitsProcessor, LogitsProcessorList
29
+ from moviepy.video.io.VideoFileClip import VideoFileClip
30
+ from PIL import Image
31
+ from torchvision import io, transforms
32
+ from torchvision.transforms import InterpolationMode
33
+ import torchvision
34
+
35
+ from qwen_vl_utils import smart_resize, process_vision_info
36
+
37
+ import deepspeed
38
+ from deepspeed import comm as dist
39
+ from deepspeed.moe.sharded_moe import _capacity, _one_hot_to_float, einsum, gumbel_rsample
40
+ from torch import Tensor
41
+
42
+ try:
43
+ import torch_npu
44
+ IS_CUDA = False
45
+ except:
46
+ IS_CUDA = True
47
+
48
+ try:
49
+ # To enable Tutel MoE optimizations:
50
+ # python3 -m pip install --user --upgrade git+https://github.com/microsoft/[email protected]
51
+ from tutel import moe as tutel_moe
52
+ TUTEL_INSTALLED = True
53
+ except:
54
+ # Fail silently so we don't spam logs unnecessarily if user isn't using tutel
55
+ TUTEL_INSTALLED = False
56
+ pass
57
+
58
+
59
+ SYSTEM_MESSAGE = """<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"""
60
+ INPUT_FORMAT = """<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"""
61
+ AUDIO_START = "<|AUDIO_START|>"
62
+
63
+ DEFAULT_VIDEO_PROMPT = "<|vision_start|><|video_pad|><|vision_end|>{}"
64
+ IMAGE_FACTOR = 28
65
+ MIN_PIXELS = 4 * 28 * 28
66
+ MAX_PIXELS = 16384 * 28 * 28
67
+ MAX_RATIO = 200
68
+ VIDEO_TOTAL_PIXELS = 16 * 28 * 28
69
+ VIDEO_MIN_PIXELS = 16 * 28 * 28
70
+ VIDEO_MAX_PIXELS = 64 * 28 * 28
71
+ FRAME_FACTOR = 2
72
+
73
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
74
+ IMAGENET_STD = (0.229, 0.224, 0.225)
75
+
76
+ IMG_START_TOKEN='<img>'
77
+ IMG_END_TOKEN='</img>'
78
+ IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'
79
+ IMG_PREFIX_FORMAT = "<|IMAGE_PLACE_HOLDER|>"
80
+
81
+ # =============================================================================
82
+ # DAC Utilities
83
+ # =============================================================================
84
+
85
+ class Dac:
86
+ def __init__(self):
87
+ base_dir = os.path.dirname(__file__)
88
+ dac_model_dir = os.path.join(base_dir, "dac_model")
89
+ model_path = os.path.join(dac_model_dir, "weights_16khz.pth")
90
+
91
+ if not os.path.isfile(model_path):
92
+ print(f"DAC model not found at {model_path}, downloading...")
93
+ os.makedirs(dac_model_dir, exist_ok=True)
94
+ downloaded_path = dac.utils.download(model_type="16khz")
95
+ shutil.move(downloaded_path, model_path)
96
+ print(f"DAC model downloaded and saved to {model_path}")
97
+
98
+ env_path = os.environ.get("DAC_WEIGHTS")
99
+ candidates = []
100
+ if env_path:
101
+ candidates.append(env_path)
102
+
103
+ candidates.extend([
104
+ model_path,
105
+ os.path.join(base_dir, "weights_16khz.pth"),
106
+ os.path.join(os.getcwd(), "utils", "dac_model", "weights_16khz.pth"),
107
+ os.path.join(os.getcwd(), "dac_model", "weights_16khz.pth"),
108
+ ])
109
+
110
+ final_model_path = next((p for p in candidates if p and os.path.isfile(p)), None)
111
+ if not final_model_path:
112
+ searched = "\n - " + "\n - ".join(candidates)
113
+ raise FileNotFoundError(
114
+ "DAC weights not found. Please place weights_16khz.pth in one of the following locations or set DAC_WEIGHTS to an absolute path:" + searched
115
+ )
116
+
117
+ self.model = dac.DAC.load(final_model_path)
118
+ self.resampler = dict()
119
+ if IS_CUDA:
120
+ self.model = self.model.to("cuda")
121
+ else:
122
+ self.model = self.model.to("npu")
123
+
124
+ def encode(self, audio_path):
125
+ signal = AudioSignal(audio_path)
126
+ if signal.audio_data.shape[1] == 2:
127
+ signal.audio_data = 0.5 * (signal.audio_data[:, :1, :] + signal.audio_data[:, 1:, :])
128
+ signal.to(self.model.device)
129
+
130
+ if signal.sample_rate != 16000:
131
+ if not str(signal.sample_rate) in self.resampler:
132
+ self.resampler[str(signal.sample_rate)] = torchaudio.transforms.Resample(signal.sample_rate, 16000)
133
+ if IS_CUDA:
134
+ self.resampler[str(signal.sample_rate)] = self.resampler[str(signal.sample_rate)].cuda()
135
+ else:
136
+ self.resampler[str(signal.sample_rate)] = self.resampler[str(signal.sample_rate)].npu()
137
+
138
+ signal.audio_data = self.resampler[str(signal.sample_rate)](signal.audio_data)
139
+ signal.sample_rate = 16000
140
+
141
+ x = self.model.preprocess(signal.audio_data.to(self.model.device), signal.sample_rate)
142
+ z, codes, latents, _, _ = self.model.encode(x)
143
+
144
+ codes = codes[0].clone().detach().transpose(0, 1)
145
+ assert codes.shape[1] == 12 and len(codes.shape) == 2
146
+ codes = codes.tolist()
147
+
148
+ return codes
149
+
150
+ def decode(self, codes, save_path, min_duration=None):
151
+ assert codes.shape[0] == 1 and codes.shape[1] == 12
152
+ z, _, _ = self.model.quantizer.from_codes(codes.to(self.model.device))
153
+ audio_out = self.model.decode(z)[0].detach().cpu()
154
+
155
+ sample_rate = 16000
156
+ duration = audio_out.size(1) / sample_rate
157
+ if min_duration is not None and duration < min_duration:
158
+ padding_duration = min_duration - duration
159
+ padding_samples = int(padding_duration * sample_rate)
160
+ padding = torch.zeros((audio_out.size(0), padding_samples), dtype=audio_out.dtype, device=audio_out.device)
161
+ audio_out = torch.cat((audio_out, padding), dim=1)
162
+
163
+ torchaudio.save(save_path, audio_out.detach().cpu(), sample_rate=16000, encoding="PCM_S", bits_per_sample=16)
164
+
165
+
166
+ def build_delay_indices(B: int, T: int, C: int, delay_pattern: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
167
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
168
+
169
+ t_idx_BxT = torch.broadcast_to(
170
+ torch.arange(T, dtype=torch.int32)[None, :],
171
+ [B, T],
172
+ )
173
+ t_idx_BxTx1 = t_idx_BxT[..., None]
174
+ t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
175
+
176
+ b_idx_BxTxC = torch.broadcast_to(
177
+ torch.arange(B, dtype=torch.int32).view(B, 1, 1),
178
+ [B, T, C],
179
+ )
180
+ c_idx_BxTxC = torch.broadcast_to(
181
+ torch.arange(C, dtype=torch.int32).view(1, 1, C),
182
+ [B, T, C],
183
+ )
184
+ t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
185
+ indices_BTCx3 = torch.stack(
186
+ [
187
+ b_idx_BxTxC.reshape(-1),
188
+ t_clamped_BxTxC.reshape(-1),
189
+ c_idx_BxTxC.reshape(-1),
190
+ ],
191
+ dim=1,
192
+ ).long()
193
+
194
+ return t_idx_BxTxC, indices_BTCx3
195
+
196
+
197
+ def apply_audio_delay(audio_BxTxC: torch.Tensor, pad_value: int, bos_value: int, precomp: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
198
+ device = audio_BxTxC.device
199
+ t_idx_BxTxC, indices_BTCx3 = precomp
200
+ t_idx_BxTxC = t_idx_BxTxC.to(device)
201
+ indices_BTCx3 = indices_BTCx3.to(device)
202
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
203
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
204
+ mask_bos = t_idx_BxTxC < 0
205
+ mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1]
206
+
207
+ bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
208
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
209
+
210
+ result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
211
+
212
+ return result_BxTxC
213
+
214
+
215
+ def build_revert_indices(B: int, T: int, C: int, delay_pattern: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
216
+ device = None
217
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
218
+ t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
219
+ t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
220
+ t_idx_BxTxC = torch.minimum(
221
+ t_idx_BT1 + delay_arr.view(1, 1, C),
222
+ torch.tensor(T - 1, device=device),
223
+ )
224
+ b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
225
+ c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
226
+ indices_BTCx3 = torch.stack(
227
+ [
228
+ b_idx_BxTxC.reshape(-1),
229
+ t_idx_BxTxC.reshape(-1),
230
+ c_idx_BxTxC.reshape(-1),
231
+ ],
232
+ axis=1,
233
+ ).long()
234
+
235
+ return t_idx_BxTxC, indices_BTCx3
236
+
237
+
238
+ def revert_audio_delay(
239
+ audio_BxTxC: torch.Tensor,
240
+ pad_value: int,
241
+ precomp: Tuple[torch.Tensor, torch.Tensor],
242
+ T: int,
243
+ ) -> torch.Tensor:
244
+ t_idx_BxTxC, indices_BTCx3 = precomp
245
+ device = audio_BxTxC.device
246
+ t_idx_BxTxC = t_idx_BxTxC.to(device)
247
+ indices_BTCx3 = indices_BTCx3.to(device)
248
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
249
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.size())
250
+
251
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
252
+ T_tensor = torch.tensor(T, device=device)
253
+
254
+ result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC)
255
+
256
+ return result_BxTxC
257
+
258
+
259
+ def prepare_audio_prompt(model, audio_prompts: list[torch.Tensor]):
260
+ num_channels = model.config.codec_channels
261
+ audio_bos_value = model.config.codec_bos_value
262
+ delay_pattern = model.config.codec_delay_pattern
263
+ max_delay_pattern = max(delay_pattern)
264
+ batch_size = len(audio_prompts)
265
+ max_len = max(p.shape[0] if p is not None else 0 for p in audio_prompts) + max_delay_pattern + 1
266
+ prefill_steps = []
267
+ prefill = torch.full(
268
+ (batch_size, max_len, num_channels),
269
+ fill_value=-1,
270
+ dtype=torch.int,
271
+ device=model.device,
272
+ )
273
+ prefill[:, 0, :] = audio_bos_value
274
+ for i in range(batch_size):
275
+ prompt = audio_prompts[i]
276
+ if prompt is not None:
277
+ prompt = prompt.to(device=model.device, dtype=torch.int)
278
+ prefill[i, 1 : prompt.shape[0] + 1, :] = prompt
279
+ prefill_steps.append(prompt.shape[0] + 1)
280
+ else:
281
+ prefill_steps.append(1)
282
+
283
+ delay_precomp = build_delay_indices(
284
+ B=batch_size,
285
+ T=max_len,
286
+ C=num_channels,
287
+ delay_pattern=delay_pattern,
288
+ )
289
+
290
+ delayed_batch = apply_audio_delay(
291
+ audio_BxTxC=prefill,
292
+ pad_value=-1,
293
+ bos_value=audio_bos_value,
294
+ precomp=delay_precomp,
295
+ )
296
+
297
+ return delayed_batch, prefill_steps
298
+
299
+
300
+ class DecoderOutput:
301
+ def __init__(self, prefill, prefill_steps, device: torch.device, labels_prefill=None):
302
+ self.generated_tokens = prefill
303
+ self.prefill_steps = prefill_steps
304
+ self.labels_prefill = labels_prefill
305
+ self.device = device
306
+
307
+ def get_tokens_at(self, step_from: int, step_to: int = None) -> torch.Tensor:
308
+ if step_to is None:
309
+ step_to = step_from + 1
310
+ return self.generated_tokens[:, step_from:step_to, :].to(self.device)
311
+
312
+ def get_labels_at(self, step_from: int, step_to: int = None) -> torch.Tensor:
313
+ if step_to is None:
314
+ step_to = step_from + 1
315
+ if self.labels_prefill is None:
316
+ return None
317
+ return self.labels_prefill[:, step_from:step_to, :].to(self.device)
318
+
319
+ def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
320
+ dec_out = dec_out.to(self.generated_tokens.dtype).to(self.generated_tokens.device)
321
+ if apply_mask:
322
+ assert step < self.generated_tokens.shape[1]
323
+ mask = self.generated_tokens[:, step, :] == -1
324
+ self.generated_tokens[:, step, :] = torch.where(mask, dec_out, self.generated_tokens[:, step, :])
325
+ else:
326
+ assert step == self.generated_tokens.shape[1]
327
+ self.generated_tokens = torch.cat((self.generated_tokens, dec_out[:, None, :]), dim=1)
328
+
329
+
330
+ def generate_output(model, generated_codes: torch.Tensor, lengths_Bx: torch.Tensor) -> list[np.ndarray]:
331
+ num_channels = model.config.codec_channels
332
+ batch_size = generated_codes.shape[0]
333
+ seq_length = generated_codes.shape[1]
334
+ delay_pattern = model.config.codec_delay_pattern
335
+ audio_pad_value = model.config.codec_pad_value
336
+ max_delay_pattern = max(delay_pattern)
337
+ revert_precomp = build_revert_indices(
338
+ B=batch_size,
339
+ T=seq_length,
340
+ C=num_channels,
341
+ delay_pattern=delay_pattern,
342
+ )
343
+ codebook = revert_audio_delay(
344
+ audio_BxTxC=generated_codes,
345
+ pad_value=audio_pad_value,
346
+ precomp=revert_precomp,
347
+ T=seq_length,
348
+ )[:, :-max_delay_pattern, :]
349
+
350
+ audios = []
351
+ for i in range(batch_size):
352
+ audios.append(codebook[i, : lengths_Bx[i], :].cpu())
353
+
354
+ return audios
355
+
356
+ def frame_process(images, **ele):
357
+ images = [torchvision.transforms.functional.pil_to_tensor(img) for img in images]
358
+ video = torch.stack(images, dim=0)
359
+
360
+ # copy from fetch_video
361
+ nframes, _, height, width = video.shape
362
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
363
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
364
+ max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
365
+ max_pixels_supposed = ele.get("max_pixels", max_pixels)
366
+ if max_pixels_supposed > max_pixels:
367
+ print(f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}].")
368
+ max_pixels = min(max_pixels_supposed, max_pixels)
369
+ if "resized_height" in ele and "resized_width" in ele:
370
+ resized_height, resized_width = smart_resize(
371
+ ele["resized_height"],
372
+ ele["resized_width"],
373
+ factor=IMAGE_FACTOR,
374
+ )
375
+ else:
376
+ resized_height, resized_width = smart_resize(
377
+ height,
378
+ width,
379
+ factor=IMAGE_FACTOR,
380
+ min_pixels=min_pixels,
381
+ max_pixels=max_pixels,
382
+ )
383
+ video = transforms.functional.resize(
384
+ video,
385
+ [resized_height, resized_width],
386
+ interpolation=InterpolationMode.BICUBIC,
387
+ antialias=True,
388
+ ).float()
389
+ return video
390
+
391
+ def preprocess_codec(model, codec):
392
+ """Preprocess codec tokens"""
393
+ codec_token = torch.tensor(codec, dtype=torch.long)
394
+ codec_token_len = codec_token.shape[0]
395
+ max_delay_pattern = max(model.config.codec_delay_pattern)
396
+ codec_input_ids = torch.zeros((codec_token_len + max_delay_pattern + 1, model.num_channels), dtype=torch.long)
397
+
398
+ for c in range(model.num_channels):
399
+ start = model.config.codec_delay_pattern[c] + 1
400
+ codec_input_ids[:start, c] = model.config.codec_bos_value
401
+ codec_input_ids[start : start + codec_token_len, c] = codec_token[:, c]
402
+ codec_input_ids[start + codec_token_len :, c] = model.config.codec_pad_value
403
+ if start + codec_token_len < codec_input_ids.shape[0]:
404
+ codec_input_ids[start + codec_token_len, c] = model.config.codec_eos_value
405
+
406
+ return codec_input_ids
407
+
408
+
409
+ def tts_preprocess(batch_caption, prompt_codec, prompt_text, device):
410
+
411
+ text_input = []
412
+ codec_input_ids = []
413
+ for caption in batch_caption:
414
+ prompt_caption = "<|SPEECH_PROMPT_START|>" + prompt_text + "<|SPEECH_PROMPT_END|>"
415
+ prompt_caption += "<|VOICE_PROMPT_START|>" + "<|AUDIO_PLACEHOLDER|>" * prompt_codec.shape[0] + "<|VOICE_PROMPT_END|>"
416
+ prompt_caption_fn = lambda x: prompt_caption + "<|SPEECH_START|>" + x + "<|SPEECH_END|>"
417
+
418
+ text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format(f"<|SPEECH_PROMPT_START|>{prompt_text}<|SPEECH_PROMPT_END|><|VOICE_PROMPT_START|><|VOICE_PROMPT_END|><|SPEECH_START|>{caption}<|SPEECH_END|>") + AUDIO_START)
419
+ text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format(prompt_caption_fn("")) + AUDIO_START)
420
+ text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format(prompt_caption_fn(caption)) + AUDIO_START)
421
+ codec_input_ids.append(prompt_codec.clone())
422
+ codec_input_ids.append(prompt_codec.clone())
423
+
424
+ codec_input_ids = torch.cat(codec_input_ids, dim=0).to(device)
425
+
426
+ tts_generation_kwargs = {
427
+ "codec_input_ids": codec_input_ids,
428
+ "cfg_scale": [2, 3],
429
+ "neg_input_size": 3,
430
+ }
431
+
432
+ return text_input, tts_generation_kwargs
433
+
434
+ def t2m_preprocess(batch_caption):
435
+
436
+ text_input = []
437
+ for caption in batch_caption:
438
+ text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + "Low quality." + "<|MUSIC_END|>") + AUDIO_START)
439
+ text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + caption + "<|MUSIC_END|>") + AUDIO_START)
440
+
441
+ t2m_generation_kwargs = {
442
+ "cfg_scale": 10,
443
+ "neg_input_size": 2,
444
+ }
445
+
446
+ return text_input, t2m_generation_kwargs
447
+
448
+ def v2m_preprocess(batch_caption, batch_video, fps=1):
449
+
450
+ def extract_images_from_video(video_path, fps=1, max_frames=1):
451
+ video = VideoFileClip(video_path)
452
+ duration = video.duration
453
+
454
+ # 提取图片
455
+ images = []
456
+ for i, t in enumerate(range(0, math.ceil(duration * fps))):
457
+ time_in_video = t / fps
458
+ frame = video.get_frame(time_in_video)
459
+ img = Image.fromarray(frame)
460
+ images.append(img)
461
+
462
+ if max_frames is not None and i >= max_frames - 1:
463
+ break
464
+
465
+ return images
466
+
467
+ text_input = []
468
+ video_inputs = []
469
+ fps_inputs = []
470
+
471
+ for caption, video in zip(batch_caption, batch_video):
472
+ text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + "Low quality." + "<|MUSIC_END|>") + AUDIO_START)
473
+ text_input.append(SYSTEM_MESSAGE + INPUT_FORMAT.format("<|MUSIC_START|>" + caption + "<|MUSIC_END|>") + AUDIO_START)
474
+
475
+ video_input = frame_process(
476
+ extract_images_from_video(video, fps),
477
+ fps = fps,
478
+ )
479
+
480
+ video_inputs.append(video_input)
481
+ video_inputs.append(video_input)
482
+
483
+ fps_inputs.append(fps)
484
+ fps_inputs.append(fps)
485
+
486
+ t2m_generation_kwargs = {
487
+ "cfg_scale": 10,
488
+ "neg_input_size": 2,
489
+ }
490
+
491
+ return text_input, video_inputs, fps_inputs, t2m_generation_kwargs
video_preprocessor_config (1).json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": null,
3
+ "data_format": "channels_first",
4
+ "default_to_square": true,
5
+ "device": null,
6
+ "do_center_crop": null,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_pad": null,
10
+ "do_rescale": true,
11
+ "do_resize": true,
12
+ "do_sample_frames": false,
13
+ "fps": null,
14
+ "image_mean": [
15
+ 0.48145466,
16
+ 0.4578275,
17
+ 0.40821073
18
+ ],
19
+ "image_std": [
20
+ 0.26862954,
21
+ 0.26130258,
22
+ 0.27577711
23
+ ],
24
+ "input_data_format": null,
25
+ "max_frames": 768,
26
+ "max_pixels": 12845056,
27
+ "merge_size": 2,
28
+ "min_frames": 4,
29
+ "min_pixels": 3136,
30
+ "num_frames": null,
31
+ "patch_size": 14,
32
+ "processor_class": "Qwen2_5_VLProcessor",
33
+ "resample": 3,
34
+ "rescale_factor": 0.00392156862745098,
35
+ "size": {
36
+ "longest_edge": 12845056,
37
+ "shortest_edge": 3136
38
+ },
39
+ "size_divisor": null,
40
+ "temporal_patch_size": 2,
41
+ "video_metadata": null,
42
+ "video_processor_type": "Qwen2VLVideoProcessor"
43
+ }
vocab.json CHANGED
The diff for this file is too large to render. See raw diff