XXXXRT666 commited on
Commit
d7f22c4
·
1 Parent(s): 8a5b90d

Add CUDA Graph

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .*cache
2
+ __pycache__
3
+ pretrained_models
AR/models/structs.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Literal, Optional
5
+
6
+ import torch
7
+
8
+ from AR.models.t2s_model_abc import Sampler, T2SDecoderABC
9
+
10
+ Tensor = torch.Tensor
11
+
12
+
13
+ @dataclass
14
+ class T2SResult:
15
+ result: List[Tensor] | None = None
16
+ status: Literal["Success", "Error"] = "Success"
17
+ exception: Optional[Exception] = None
18
+ traceback: Optional[str] = None
19
+
20
+
21
+ @dataclass
22
+ class T2SRequest:
23
+ x: List[torch.Tensor]
24
+ x_lens: Tensor
25
+ prompts: torch.Tensor
26
+ bert_feature: List[Tensor]
27
+ valid_length: int
28
+ top_k: int = 5
29
+ top_p: float = 1
30
+ early_stop_num: int = -1
31
+ temperature: float = 1.0
32
+ repetition_penalty: float = 1.35
33
+ use_cuda_graph: bool = False
34
+ debug: bool = False
35
+
36
+
37
+ class T2SSession:
38
+ def __init__(self, decoder: T2SDecoderABC, request: T2SRequest, device: torch.device, dtype: torch.dtype):
39
+ with device:
40
+ self.decoder = decoder
41
+ self.request = request
42
+ self.device = device
43
+ self.dtype = dtype
44
+
45
+ bsz = len(request.x)
46
+ y_len = request.prompts.size(-1)
47
+ self.bsz = bsz
48
+ self.y_len = y_len
49
+
50
+ # Cache
51
+ self.kv_cache = decoder.init_cache(bsz)
52
+ self.sampler = Sampler(bsz, decoder.vocab_size)
53
+
54
+ # Forward args
55
+ self.x = request.x
56
+ self.x_lens = request.x_lens.to(torch.int32)
57
+ self.y = request.prompts
58
+ self.bert_feature = request.bert_feature
59
+
60
+ self.prefill_len = self.x_lens + self.y.size(1)
61
+
62
+ self.input_pos = torch.zeros_like(self.prefill_len)
63
+ self.input_pos.add_(self.prefill_len)
64
+
65
+ # CUDA Graph
66
+ self.graph: Optional[torch.cuda.CUDAGraph] = None
67
+ self.xy_pos_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
68
+ self.xy_dec_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
69
+
70
+ # EOS
71
+ self.completed = [False] * len(self.x)
72
+ self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
73
+
74
+ self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
75
+
76
+ attn_mask = []
77
+ for bs in range(bsz):
78
+ pos = int(self.x_lens[bs].item())
79
+ mask = torch.zeros(pos + y_len, pos + y_len).bool()
80
+ mask[:, :pos].fill_(True)
81
+ mask[-y_len:, -y_len:] = ~torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1)
82
+ attn_mask.append(mask)
83
+ self.attn_mask_nested = torch.nested.nested_tensor(attn_mask)
AR/models/t2s_model_abc.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+ from contextlib import nullcontext
6
+ from typing import Any, Dict, List, MutableSequence, Optional, Tuple, Type
7
+
8
+ import torch
9
+ import torch._inductor.config
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch.cuda.graphs import CUDAGraph
13
+ from torch.profiler import ProfilerAction, tensorboard_trace_handler
14
+
15
+ from AR.modules.embedding import (
16
+ SinePositionalEmbeddingNested as SinePositionalEmbedding,
17
+ )
18
+ from AR.modules.embedding import TokenEmbedding
19
+
20
+ Tensor = torch.Tensor
21
+
22
+
23
+ class Sampler(nn.Module):
24
+ def __init__(self, batch_size: int, vocab_size: int) -> None:
25
+ super().__init__()
26
+ self.batch_size = batch_size
27
+
28
+ self.logits: Tensor
29
+ self.samples: Tensor
30
+ self.register_buffer("logits", torch.zeros((batch_size, vocab_size)), persistent=False)
31
+ self.register_buffer("samples", torch.zeros((batch_size,), dtype=torch.int32), persistent=False)
32
+
33
+ self.__CUDAGraph: Optional[CUDAGraph] = None
34
+
35
+ def empty_cache(self):
36
+ self.logits.zero_()
37
+ self.__CUDAGraph = None
38
+
39
+ @staticmethod
40
+ def multinomial_sample_one_no_sync(probs_sort: Tensor): # Does multinomial sampling without a cuda synchronization
41
+ q = torch.empty_like(probs_sort).exponential_(1)
42
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int32)
43
+
44
+ @staticmethod
45
+ def logits_to_probs(
46
+ logits: Tensor,
47
+ previous_tokens: Tensor,
48
+ temperature: float,
49
+ top_k: int,
50
+ top_p: float,
51
+ repetition_penalty: float,
52
+ ):
53
+ previous_tokens = previous_tokens.long()
54
+ score = torch.gather(logits, dim=1, index=previous_tokens)
55
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
56
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
57
+
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
59
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
60
+ sorted_indices_to_remove = cum_probs > top_p
61
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
62
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
63
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
64
+
65
+ logits = logits / max(temperature, 1e-5)
66
+
67
+ v, _ = torch.topk(logits, top_k)
68
+ pivot = v[:, -1].unsqueeze(-1)
69
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
70
+
71
+ probs = torch.nn.functional.softmax(logits, dim=-1)
72
+ return probs
73
+
74
+ @staticmethod
75
+ def apply_repetition_penalty(logits: Tensor, previous_tokens: Tensor, repetition_penalty: float):
76
+ previous_tokens = previous_tokens.long()
77
+ score = torch.gather(logits, dim=1, index=previous_tokens)
78
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
79
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
80
+ return logits
81
+
82
+ @staticmethod
83
+ def logits_to_probs_cuda_graph(
84
+ logits: Tensor,
85
+ temperature: float,
86
+ top_k: int,
87
+ top_p: float,
88
+ ):
89
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
90
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
91
+ sorted_indices_to_remove = cum_probs > top_p
92
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
93
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
94
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
95
+
96
+ logits = logits / max(temperature, 1e-5)
97
+
98
+ v, _ = torch.topk(logits, top_k)
99
+ pivot = v[:, -1].unsqueeze(-1)
100
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
101
+
102
+ probs = torch.nn.functional.softmax(logits, dim=-1)
103
+ return probs
104
+
105
+ def __sample(
106
+ self,
107
+ logits: Tensor,
108
+ previous_tokens: Tensor,
109
+ temperature: float,
110
+ top_k: int,
111
+ top_p: float,
112
+ repetition_penalty: float,
113
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ probs = self.logits_to_probs(
115
+ logits=logits,
116
+ previous_tokens=previous_tokens,
117
+ temperature=temperature,
118
+ top_k=top_k,
119
+ top_p=top_p,
120
+ repetition_penalty=repetition_penalty,
121
+ )
122
+ idx_next = self.multinomial_sample_one_no_sync(probs)
123
+ return idx_next, probs
124
+
125
+ def __sample_cuda_graph(
126
+ self,
127
+ logits: Tensor,
128
+ temperature: float,
129
+ top_k: int,
130
+ top_p: float,
131
+ ):
132
+ probs = self.logits_to_probs_cuda_graph(
133
+ logits=logits,
134
+ temperature=temperature,
135
+ top_k=top_k,
136
+ top_p=top_p,
137
+ )
138
+ idx_next = self.multinomial_sample_one_no_sync(probs)
139
+ return idx_next
140
+
141
+ def capture(self, temperature: float, top_k: int, top_p: float):
142
+ s = torch.cuda.Stream()
143
+ s.wait_stream(torch.cuda.current_stream())
144
+
145
+ logits = self.logits
146
+
147
+ with torch.cuda.stream(s): # type: ignore
148
+ for _ in range(5):
149
+ self.__sample_cuda_graph(logits, temperature, top_k, top_p)
150
+ torch.cuda.current_stream().wait_stream(s)
151
+
152
+ self.__CUDAGraph = torch.cuda.CUDAGraph()
153
+ with torch.cuda.graph(self.__CUDAGraph):
154
+ self.samples = self.__sample_cuda_graph(logits, temperature, top_k, top_p)
155
+ torch.cuda.synchronize()
156
+
157
+ def sample(
158
+ self,
159
+ logits: Tensor,
160
+ previous_tokens: Tensor,
161
+ temperature: float,
162
+ top_k: int,
163
+ top_p: float,
164
+ repetition_penalty: float,
165
+ use_cuda_graph=False,
166
+ idx=-1,
167
+ ) -> Tensor:
168
+ if use_cuda_graph and torch.cuda.is_available() and self.__CUDAGraph is None and idx > 0:
169
+ self.logits.copy_(logits)
170
+ self.capture(temperature, top_k, top_p)
171
+ if self.__CUDAGraph is not None:
172
+ self.logits.copy_(logits)
173
+ self.apply_repetition_penalty(self.logits, previous_tokens, repetition_penalty)
174
+ self.__CUDAGraph.replay()
175
+ samples = self.samples.clone()
176
+ else:
177
+ samples = self.__sample(logits, previous_tokens, temperature, top_k, top_p, repetition_penalty)[0]
178
+
179
+ return samples
180
+
181
+
182
+ class KVCacheABC(ABC, nn.Module):
183
+ def __init__(self, *args, **kwds) -> None:
184
+ super().__init__()
185
+ self.k_cache: Tensor
186
+ self.v_cache: Tensor
187
+ self.n_head: int
188
+ self.head_dim: int
189
+ self.batch_size: int
190
+ self.max_seq_length: int
191
+
192
+ def empty(self):
193
+ self.k_cache.zero_()
194
+ self.v_cache.zero_()
195
+
196
+ @abstractmethod
197
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> Tuple[Tensor, Tensor]: ...
198
+
199
+ @abstractmethod
200
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int) -> None: ...
201
+
202
+ def forward(self):
203
+ raise NotImplementedError()
204
+
205
+
206
+ class KVCacheNHD(KVCacheABC):
207
+ def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
208
+ super().__init__()
209
+ assert batch_size > 0
210
+ cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
211
+ self.n_head = n_heads
212
+ self.head_dim = head_dim
213
+ self.batch_size = batch_size
214
+ self.max_seq_length = max_seq_length
215
+
216
+ self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
217
+ self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
218
+
219
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
220
+ # input_pos: [B, ], k_val: [B, 1, H, D]
221
+
222
+ index = (
223
+ (input_pos - 1)
224
+ .unsqueeze(-1)
225
+ .unsqueeze(-1)
226
+ .unsqueeze(-1)
227
+ .expand(
228
+ -1,
229
+ -1,
230
+ self.n_head,
231
+ self.head_dim,
232
+ )
233
+ .to(torch.int64)
234
+ ) # (bs, 1, num_head, head_dim)
235
+
236
+ k_out = self.k_cache
237
+ v_out = self.v_cache
238
+ k_out.scatter_(1, index, k_val)
239
+ v_out.scatter_(1, index, v_val)
240
+
241
+ return k_out, v_out
242
+
243
+ def empty(self):
244
+ self.k_cache.zero_()
245
+ self.v_cache.zero_()
246
+
247
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int):
248
+ # input_pos: int, k_val: [B, S, H, D]
249
+
250
+ self.k_cache[[bs], : k_val.shape[1]] = k_val
251
+ self.v_cache[[bs], : v_val.shape[1]] = v_val
252
+
253
+
254
+ class KVCacheHND(KVCacheABC):
255
+ def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
256
+ super().__init__()
257
+ assert batch_size > 0
258
+ cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
259
+ self.n_head = n_heads
260
+ self.head_dim = head_dim
261
+ self.batch_size = batch_size
262
+ self.max_seq_length = max_seq_length
263
+
264
+ self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
265
+ self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
266
+
267
+ def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
268
+ # input_pos: [B, ], k_val: [B, H, 1, D]
269
+
270
+ index = (
271
+ (input_pos - 1)
272
+ .unsqueeze(-1)
273
+ .unsqueeze(-1)
274
+ .unsqueeze(-1)
275
+ .expand(
276
+ -1,
277
+ self.n_head,
278
+ -1,
279
+ self.head_dim,
280
+ )
281
+ .to(torch.int64)
282
+ ) # (bs, num_head, 1, head_dim)
283
+
284
+ k_out = self.k_cache
285
+ v_out = self.v_cache
286
+ k_out.scatter_(2, index, k_val)
287
+ v_out.scatter_(2, index, v_val)
288
+
289
+ return k_out, v_out
290
+
291
+ def empty(self):
292
+ self.k_cache.zero_()
293
+ self.v_cache.zero_()
294
+
295
+ def prefill_kv(self, k_val: Tensor, v_val: Tensor, bs: int):
296
+ # input_pos: int, k_val: [B, S, H, D]
297
+
298
+ self.k_cache[[bs], :, : k_val.shape[1]] = k_val.transpose(1, 2)
299
+ self.v_cache[[bs], :, : v_val.shape[1]] = v_val.transpose(1, 2)
300
+
301
+
302
+ class AttentionABC(ABC, nn.Module):
303
+ def __init__(self):
304
+ super().__init__()
305
+ self.n_head: int
306
+ self.hidden_dim: int
307
+ self.head_dim: int
308
+
309
+ # key, query, value projections for all heads, but in a batch
310
+ self.in_proj: nn.Linear
311
+ self.out_proj: nn.Linear
312
+
313
+ self.dropout = nn.Dropout(0.1)
314
+
315
+ self._register_load_state_dict_pre_hook(self.load_hook)
316
+
317
+ def load_hook(self, state_dict: dict, prefix, *args):
318
+ keys_to_modify = [key for key in state_dict if "in_proj_" in key]
319
+ for key in keys_to_modify:
320
+ new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
321
+ state_dict[new_key] = state_dict.pop(key)
322
+
323
+ @abstractmethod
324
+ def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor: ...
325
+
326
+ def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheABC) -> Tensor:
327
+ bsz = x.size(0)
328
+
329
+ outputs = []
330
+
331
+ for bs in range(bsz):
332
+ x_b = x[bs].unsqueeze(0)
333
+
334
+ q, k, v = self.in_proj.forward(x_b.unsqueeze(0)).chunk(3, dim=-1)
335
+
336
+ q = q.contiguous().view(1, -1, self.n_head, self.head_dim)
337
+ k = k.contiguous().view(1, -1, self.n_head, self.head_dim)
338
+ v = v.contiguous().view(1, -1, self.n_head, self.head_dim)
339
+
340
+ kv_cache.prefill_kv(k, v, bs)
341
+
342
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
343
+
344
+ attn_mask = mask[bs].unsqueeze(0).unsqueeze(0).expand(1, self.n_head, -1, -1)
345
+
346
+ attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
347
+
348
+ attn = self.dropout.forward(attn)
349
+
350
+ attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
351
+
352
+ output = self.out_proj.forward(attn)
353
+
354
+ outputs.append(output.squeeze(0))
355
+
356
+ return torch.nested.nested_tensor(outputs)
357
+
358
+
359
+ class FeedForward(nn.Module):
360
+ def __init__(self, dim: int, hidden_dim: int) -> None:
361
+ super().__init__()
362
+ self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
363
+ self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
364
+ self.dropout = nn.Dropout(0.1)
365
+
366
+ def forward(self, x: Tensor) -> Tensor:
367
+ return self.dropout.forward(self.linear2(self.dropout.forward(F.relu(self.linear1(x)))))
368
+
369
+
370
+ class TransformerBlockABC(ABC, nn.Module):
371
+ def __init__(self) -> None:
372
+ super().__init__()
373
+ self.hidden_dim: int
374
+ self.attention: AttentionABC
375
+ self.feed_forward: FeedForward
376
+ self.attention_norm: nn.LayerNorm
377
+ self.ffn_norm: nn.LayerNorm
378
+ self.dropout = nn.Dropout(0.1)
379
+
380
+ self._register_load_state_dict_pre_hook(self.load_hook)
381
+
382
+ def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
383
+ for key in list(state_dict.keys()):
384
+ new_key = (
385
+ key.replace("self_attn", "attention")
386
+ .replace("linear", "feed_forward.linear")
387
+ .replace("norm1", "attention_norm")
388
+ .replace("norm2", "ffn_norm")
389
+ )
390
+ state_dict[new_key] = state_dict.pop(key)
391
+
392
+ def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor:
393
+ h = self.attention_norm.forward(
394
+ x
395
+ + self.dropout.forward(
396
+ self.attention.forward(
397
+ x,
398
+ input_pos,
399
+ kv_cache,
400
+ *args,
401
+ **kwds,
402
+ )
403
+ )
404
+ )
405
+ out = self.ffn_norm.forward(h + self.feed_forward.forward(h))
406
+ return out
407
+
408
+ def prefill(self, x: Tensor, mask: Tensor, kv_cache: KVCacheABC) -> Tensor:
409
+ h = self.attention_norm.forward(
410
+ x
411
+ + self.dropout.forward(
412
+ self.attention.prefill(
413
+ x,
414
+ mask,
415
+ kv_cache,
416
+ )
417
+ )
418
+ )
419
+ out = self.ffn_norm.forward(h + self.feed_forward.forward(h))
420
+ return out
421
+
422
+
423
+ class TransformerDecoderABC(ABC, nn.Module):
424
+ def __init__(self) -> None:
425
+ super().__init__()
426
+
427
+ self.hidden_dim: int
428
+ self.n_head: int
429
+ self.head_dim: int
430
+ self.vocab_size: int
431
+ self.n_layer: int
432
+
433
+ self.layers: MutableSequence[TransformerBlockABC]
434
+
435
+ self.max_seq_length: int
436
+ self.max_batch_size: int
437
+
438
+ self.input_pos: Tensor
439
+ self.xy_pos: Tensor
440
+ self.xy_dec: Tensor
441
+
442
+ def forward(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheABC], *args, **kwds):
443
+ for layer, kv_cache in zip(self.layers, kv_caches):
444
+ x = layer.forward(x, input_pos, kv_cache, *args, **kwds)
445
+ return x
446
+
447
+ def prefill(self, x: Tensor, mask: Tensor, kv_caches: MutableSequence[KVCacheABC]):
448
+ for layer, kv_cache in zip(self.layers, kv_caches):
449
+ x = layer.prefill(x, mask, kv_cache)
450
+ return x
451
+
452
+
453
+ class T2SDecoderABC(ABC, nn.Module):
454
+ def __init__(self) -> None:
455
+ super().__init__()
456
+
457
+ self.n_layer: int
458
+ self.hidden_dim: int
459
+ self.n_head: int
460
+
461
+ self.head_dim: int
462
+ self.embedding_dim: int
463
+ self.vocab_size: int
464
+ self.phoneme_vocab_size: int
465
+ self.p_dropout: float
466
+ self.max_seq_length: int
467
+ self.max_batch_size: int
468
+ self.EOS: int
469
+
470
+ self.bert_proj: nn.Linear
471
+ self.ar_text_embedding: TokenEmbedding
472
+ self.ar_text_position: SinePositionalEmbedding
473
+ self.ar_audio_embedding: TokenEmbedding
474
+ self.ar_audio_position: SinePositionalEmbedding
475
+ self.ar_predict_layer: nn.Linear
476
+ self.h: TransformerDecoderABC
477
+
478
+ self.kv_class: Type[KVCacheNHD] | Type[KVCacheHND]
479
+
480
+ self._register_load_state_dict_pre_hook(self.load_hook)
481
+
482
+ def load_hook(self, state_dict, prefix, *args):
483
+ model_keys = [key for key in state_dict if key.startswith("model.")]
484
+ for key in model_keys:
485
+ new_key = key[len("model.") :]
486
+ state_dict[new_key] = state_dict.pop(key)
487
+
488
+ def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheABC]:
489
+ bsz = bsz or self.h.max_batch_size
490
+ assert bsz <= self.h.max_batch_size
491
+ seq_lens = self.h.max_seq_length
492
+ device = self.bert_proj.bias.device
493
+ dtype = self.bert_proj.bias.dtype
494
+ kvclass = self.kv_class
495
+ return nn.ModuleList(
496
+ [kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
497
+ ).to(device, dtype) # type: ignore
498
+
499
+ @abstractmethod
500
+ def embed(self, x: List[torch.Tensor], y: torch.Tensor, bert_features: List[Tensor]) -> Tensor: ...
501
+
502
+ def compile(self, *args, **kwds):
503
+ torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
504
+ torch._inductor.config.coordinate_descent_tuning = True
505
+ torch._inductor.config.triton.unique_kernel_names = True
506
+ # Experimental features to reduce compilation times, will be on by default in future
507
+ torch._inductor.config.fx_graph_cache = True
508
+ torch._inductor.config.triton.cudagraph_trees = True
509
+ torch._inductor.config.triton.cudagraph_support_input_mutation = True
510
+ self.h.compile(fullgraph=True, mode="reduce-overhead")
511
+
512
+ def capture(self, input_pos: Tensor, x: Tensor, x_dec: Tensor, *args, **kwds) -> CUDAGraph:
513
+ s = torch.cuda.Stream()
514
+ s.wait_stream(torch.cuda.current_stream())
515
+
516
+ graph = torch.cuda.CUDAGraph()
517
+
518
+ with torch.cuda.stream(s): # type: ignore
519
+ for _ in range(5):
520
+ self.h.forward(input_pos, x, *args, **kwds)
521
+ torch.cuda.current_stream().wait_stream(s)
522
+
523
+ with torch.cuda.graph(graph):
524
+ x_dec.copy_(self.h.forward(input_pos, x, *args, **kwds))
525
+ torch.cuda.synchronize()
526
+
527
+ return graph
528
+
529
+ @abstractmethod
530
+ def pre_forward(self, session: Any) -> Tuple[List, Dict]: ...
531
+
532
+ @abstractmethod
533
+ def post_forward(self, idx: int, session: Any) -> None: ...
534
+
535
+
536
+ class TorchProfiler:
537
+ def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
538
+ self.debug = debug
539
+ self.log_dir = log_dir
540
+ self.__profiler: torch.profiler.profile
541
+
542
+ if self.debug and not os.path.exists(self.log_dir):
543
+ os.makedirs(self.log_dir)
544
+
545
+ self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
546
+
547
+ def profiler_callback(self, prof: torch.profiler.profile):
548
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
549
+ print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
550
+ self.tensorboard_handler(prof)
551
+
552
+ @staticmethod
553
+ def three_step_schedule(step: int) -> ProfilerAction:
554
+ if step == 0:
555
+ return ProfilerAction.NONE
556
+ elif step == 1:
557
+ return ProfilerAction.RECORD
558
+ elif step == 2:
559
+ return ProfilerAction.RECORD_AND_SAVE
560
+ else:
561
+ return ProfilerAction.NONE
562
+
563
+ def start(self):
564
+ if not self.debug:
565
+ return
566
+ assert self.__profiler is not None
567
+ self.__profiler.step()
568
+
569
+ def end(self):
570
+ if not self.debug:
571
+ return
572
+ assert self.__profiler is not None
573
+ self.__profiler.step()
574
+
575
+ def profiler(self):
576
+ if self.debug:
577
+ activities_list = [torch.profiler.ProfilerActivity.CPU]
578
+ if torch.cuda.is_available():
579
+ activities_list.append(torch.profiler.ProfilerActivity.CUDA)
580
+
581
+ self.__profiler = torch.profiler.profile(
582
+ activities=activities_list,
583
+ record_shapes=True,
584
+ with_stack=True,
585
+ with_modules=True,
586
+ profile_memory=True,
587
+ schedule=self.three_step_schedule,
588
+ on_trace_ready=self.profiler_callback,
589
+ )
590
+ return self.__profiler
591
+ else:
592
+ return nullcontext()
593
+
594
+ def record(self, func_name: str):
595
+ if self.debug:
596
+ return torch.profiler.record_function(func_name)
597
+ else:
598
+ return nullcontext()
AR/models/t2s_model_flash_attn.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import time
4
+ import traceback
5
+ from typing import Dict, List, Tuple
6
+
7
+ import flash_attn # type: ignore
8
+ import torch
9
+ import torch.nn as nn
10
+ from tqdm import tqdm
11
+
12
+ from AR.models.structs import T2SRequest, T2SResult, T2SSession
13
+ from AR.models.t2s_model_abc import (
14
+ AttentionABC,
15
+ FeedForward,
16
+ KVCacheABC,
17
+ KVCacheNHD,
18
+ T2SDecoderABC,
19
+ TorchProfiler,
20
+ TransformerBlockABC,
21
+ TransformerDecoderABC,
22
+ )
23
+ from AR.modules.embedding import (
24
+ SinePositionalEmbeddingNested as SinePositionalEmbedding,
25
+ )
26
+ from AR.modules.embedding import TokenEmbedding
27
+
28
+ Tensor = torch.Tensor
29
+
30
+
31
+ class Attention(AttentionABC):
32
+ def __init__(self, n_head: int, hidden_dim: int):
33
+ super().__init__()
34
+ self.n_head = n_head
35
+ self.hidden_dim = hidden_dim
36
+ assert hidden_dim % n_head == 0
37
+ self.head_dim = hidden_dim // n_head
38
+
39
+ self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
40
+ self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
41
+
42
+ def forward(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheABC, *args, **kwds) -> Tensor:
43
+ bsz, seqlen, _ = x.shape
44
+
45
+ q, k, v = self.in_proj.forward(x).chunk(3, dim=-1)
46
+
47
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
48
+ k = k.view(bsz, seqlen, self.n_head, self.head_dim)
49
+ v = v.view(bsz, seqlen, self.n_head, self.head_dim)
50
+
51
+ attn: Tensor = flash_attn.flash_attn_with_kvcache(
52
+ q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
53
+ )
54
+
55
+ attn = self.dropout.forward(attn)
56
+
57
+ attn = attn.view(bsz, seqlen, self.hidden_dim)
58
+
59
+ attn = self.out_proj.forward(attn)
60
+
61
+ return attn
62
+
63
+
64
+ class TransformerBlock(TransformerBlockABC):
65
+ def __init__(self, n_head, ffn_dim, hidden_dim) -> None:
66
+ super().__init__()
67
+ self.hidden_dim = hidden_dim
68
+ self.attention = Attention(n_head, hidden_dim)
69
+ self.feed_forward = FeedForward(hidden_dim, ffn_dim)
70
+ self.attention_norm = nn.LayerNorm([self.hidden_dim])
71
+ self.ffn_norm = nn.LayerNorm([self.hidden_dim])
72
+
73
+
74
+ class TransformerDecoder(TransformerDecoderABC):
75
+ def __init__(
76
+ self,
77
+ hidden_dim,
78
+ n_layer,
79
+ n_head,
80
+ ffn_dim,
81
+ vocab_size,
82
+ max_seq_length,
83
+ max_batch_size,
84
+ ) -> None:
85
+ super().__init__()
86
+
87
+ self.hidden_dim = hidden_dim
88
+ self.n_head = n_head
89
+ assert hidden_dim % n_head == 0
90
+
91
+ self.head_dim = hidden_dim // n_head
92
+ self.vocab_size = vocab_size
93
+
94
+ self.n_layer = n_layer
95
+
96
+ self.layers = nn.ModuleList( # type: ignore
97
+ TransformerBlock(n_head, ffn_dim, hidden_dim) for _ in range(n_layer)
98
+ )
99
+
100
+ self.max_seq_length: int = max_seq_length
101
+ self.max_batch_size: int = max_batch_size
102
+
103
+ self.setup_caches(self.max_batch_size, self.max_seq_length)
104
+
105
+ def setup_caches(self, max_batch_size=10, max_seq_length=2500):
106
+ self.max_seq_length = max_seq_length
107
+ self.max_batch_size = max_batch_size
108
+
109
+
110
+ class T2SDecoder(T2SDecoderABC):
111
+ def __init__(
112
+ self,
113
+ config,
114
+ *args,
115
+ norm_first=False,
116
+ max_seq_length=2500,
117
+ max_batch_size=10,
118
+ **kwds,
119
+ ) -> None:
120
+ super().__init__()
121
+
122
+ hidden_dim = config["model"]["hidden_dim"]
123
+ embedding_dim = config["model"]["embedding_dim"]
124
+ n_head = config["model"]["head"]
125
+ n_layer = config["model"]["n_layer"]
126
+ vocab_size = config["model"]["vocab_size"]
127
+ phoneme_vocab_size = config["model"]["phoneme_vocab_size"]
128
+ p_dropout = config["model"]["dropout"]
129
+ EOS = config["model"]["EOS"]
130
+ ffn_dim = hidden_dim * 4
131
+ self.norm_first = norm_first
132
+
133
+ self.n_layer = n_layer
134
+ self.hidden_dim = hidden_dim
135
+ self.n_head = n_head
136
+ assert hidden_dim % n_head == 0
137
+
138
+ self.head_dim = hidden_dim // n_head
139
+ self.embedding_dim = embedding_dim
140
+ self.vocab_size = vocab_size
141
+ self.phoneme_vocab_size = phoneme_vocab_size
142
+ self.p_dropout = p_dropout
143
+ self.max_seq_length = max_seq_length
144
+ self.max_batch_size = max_batch_size
145
+ self.EOS = EOS
146
+ assert self.EOS == self.vocab_size - 1
147
+
148
+ self.bert_proj = nn.Linear(1024, self.embedding_dim)
149
+ self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
150
+ self.ar_text_position = SinePositionalEmbedding(
151
+ self.embedding_dim,
152
+ dropout=0.1,
153
+ scale=False,
154
+ alpha=True,
155
+ max_batch_size=max_batch_size,
156
+ max_seq_len=max_seq_length,
157
+ )
158
+ self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
159
+ self.ar_audio_position = SinePositionalEmbedding(
160
+ self.embedding_dim,
161
+ dropout=0.1,
162
+ scale=False,
163
+ alpha=True,
164
+ max_batch_size=max_batch_size,
165
+ max_seq_len=max_seq_length,
166
+ )
167
+ self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
168
+ self.h: TransformerDecoderABC = TransformerDecoder(
169
+ hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size
170
+ )
171
+
172
+ self.kv_class = KVCacheNHD
173
+ self._register_load_state_dict_pre_hook(self.load_hook)
174
+
175
+ def embed(
176
+ self,
177
+ x: List[torch.Tensor],
178
+ y: torch.Tensor,
179
+ bert_features: List[torch.Tensor],
180
+ ):
181
+ x_nested = torch.nested.nested_tensor(x)
182
+ assert x_nested.size(0) <= self.max_batch_size
183
+ bert_features_nested = torch.nested.nested_tensor(list(map(lambda x: x.transpose(0, 1), bert_features)))
184
+
185
+ x_emb = self.ar_text_embedding.forward(x_nested)
186
+ bert = self.bert_proj.forward(bert_features_nested)
187
+ x_emb = x_emb + bert
188
+ x_pos = self.ar_text_position.prefill(x_emb)
189
+
190
+ y_nested = torch.nested.nested_tensor(list(y.unbind(0)))
191
+ y_emb = self.ar_audio_embedding.forward(y_nested)
192
+ y_pos = self.ar_audio_position.prefill(y_emb)
193
+
194
+ xy_pos = torch.nested.nested_tensor([torch.cat([x_pos[i], y_pos[i]]) for i in range(len(x))])
195
+ return xy_pos
196
+
197
+ def post_forward(self, idx: int, session: T2SSession) -> None:
198
+ pass
199
+
200
+ def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]:
201
+ return list(), dict()
202
+
203
+
204
+ class CUDAGraphRunner:
205
+ def __init__(
206
+ self,
207
+ decoder_model: T2SDecoderABC,
208
+ device: torch.device = torch.device("cpu"),
209
+ dtype: torch.dtype = torch.float32,
210
+ ) -> None:
211
+ assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
212
+ assert dtype in {torch.float16, torch.bfloat16, torch.float32}
213
+ self.device = device
214
+ self.dtype = dtype
215
+
216
+ self.decoder_path: os.PathLike
217
+ self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
218
+
219
+ def _handle_request(self, request: T2SRequest) -> List[torch.Tensor]:
220
+ with self.device:
221
+ decoder = self.decoder_model
222
+ session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
223
+
224
+ y = session.y
225
+ bsz = y.size(0)
226
+ t1 = 0.0
227
+
228
+ torch_profiler = TorchProfiler(request.debug)
229
+
230
+ with torch_profiler.profiler():
231
+ for idx in tqdm(range(1500)):
232
+ if idx == 0:
233
+ xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested, session.kv_cache)
234
+ xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
235
+ else:
236
+ if request.use_cuda_graph and session.graph is None and torch.cuda.is_available():
237
+ session.xy_pos_.copy_(session.xy_pos)
238
+ args, kwds = decoder.pre_forward(session)
239
+ session.graph = decoder.capture(
240
+ session.input_pos,
241
+ session.xy_pos_,
242
+ session.xy_dec_,
243
+ kv_caches=session.kv_cache,
244
+ *args,
245
+ **kwds,
246
+ )
247
+
248
+ torch_profiler.start()
249
+ with torch_profiler.record("AR"):
250
+ if session.graph:
251
+ session.xy_pos_.copy_(session.xy_pos)
252
+ session.graph.replay()
253
+ xy_dec = session.xy_dec_.clone()
254
+ else:
255
+ args, kwds = decoder.pre_forward(session)
256
+ xy_dec = decoder.h.forward(
257
+ session.input_pos,
258
+ session.xy_pos,
259
+ session.kv_cache,
260
+ *args,
261
+ **kwds,
262
+ )
263
+ decoder.post_forward(idx, session)
264
+ logits = decoder.ar_predict_layer(xy_dec[:, -1])
265
+ session.input_pos.add_(1)
266
+
267
+ if idx == 0:
268
+ logits = logits[:, :-1]
269
+
270
+ with torch_profiler.record("Sampling"):
271
+ samples = session.sampler.sample(
272
+ logits=logits,
273
+ previous_tokens=session.y,
274
+ top_k=request.top_k,
275
+ top_p=request.top_p,
276
+ repetition_penalty=request.repetition_penalty,
277
+ temperature=request.temperature,
278
+ use_cuda_graph=False,
279
+ idx=idx,
280
+ )
281
+
282
+ session.y = torch.cat([session.y, samples], dim=1)
283
+
284
+ with torch_profiler.record("EOS"):
285
+ EOS_mask = (samples[:, 0] == decoder.EOS) | (torch.argmax(logits, dim=-1) == decoder.EOS)
286
+ EOS_indices: List[int] = torch.where(EOS_mask)[0].tolist()
287
+
288
+ for i in EOS_indices:
289
+ if not session.completed[i]:
290
+ session.y_results[i] = session.y[i, session.y_len : -1]
291
+ session.completed[i] = True
292
+
293
+ if all(session.completed):
294
+ if session.y.size(1) == 0:
295
+ session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
296
+ tqdm.write("Bad Zero Prediction")
297
+ else:
298
+ tqdm.write(
299
+ f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
300
+ )
301
+ tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
302
+ break
303
+
304
+ if (
305
+ request.early_stop_num != -1
306
+ and (session.y.size(1) - session.y_len) > request.early_stop_num
307
+ ):
308
+ for i in range(bsz):
309
+ if not session.completed[i]:
310
+ session.y_results[i] = session.y[i, session.y_len :]
311
+ session.completed[i] = True
312
+ break
313
+
314
+ with torch_profiler.record("NextPos"):
315
+ y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
316
+ session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
317
+
318
+ if idx == 2:
319
+ t1 = time.perf_counter()
320
+
321
+ if idx == 51:
322
+ torch_profiler.end()
323
+
324
+ match session.device.type:
325
+ case "cuda":
326
+ torch.cuda.empty_cache()
327
+ case "mps":
328
+ torch.mps.empty_cache()
329
+ case "xpu":
330
+ torch.xpu.empty_cache()
331
+ case "mtia":
332
+ torch.mtia.empty_cache()
333
+ gc.collect()
334
+
335
+ return session.y_results[: request.valid_length]
336
+
337
+ def generate(self, request: T2SRequest):
338
+ try:
339
+ result = self._handle_request(request)
340
+ t2s_result = T2SResult(result=result, status="Success")
341
+ except Exception as e:
342
+ t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
343
+ return t2s_result
344
+
345
+ @staticmethod
346
+ def load_decoder(weights_path: os.PathLike, implement: str = "flash_attn"):
347
+ print(f"Loading Text2Semantic Weights from {weights_path} with {implement.replace('_', ' ').title()} Implement")
348
+ module_path = f"AR.models.t2s_model_{implement.lower()}"
349
+ cls_name = "T2SDecoder"
350
+ mod = __import__(module_path, fromlist=[cls_name])
351
+ decoder_cls: T2SDecoderABC = getattr(mod, cls_name)
352
+ dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
353
+ config = dict_s1["config"]
354
+ decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=1)
355
+ state_dict = dict_s1["weight"]
356
+ decoder.load_state_dict(state_dict)
357
+ return decoder.eval()
AR/modules/embedding.py CHANGED
@@ -60,14 +60,11 @@ class SinePositionalEmbedding(nn.Module):
60
  return
61
  pe = torch.zeros(x.size(1), self.embedding_dim)
62
  if self.reverse:
63
- position = torch.arange(
64
- x.size(1) - 1, -1, -1.0, dtype=torch.float32
65
- ).unsqueeze(1)
66
  else:
67
  position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
68
  div_term = torch.exp(
69
- torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
70
- * -(math.log(10000.0) / self.embedding_dim)
71
  )
72
  pe[:, 0::2] = torch.sin(position * div_term)
73
  pe[:, 1::2] = torch.cos(position * div_term)
@@ -79,3 +76,68 @@ class SinePositionalEmbedding(nn.Module):
79
  output = x.unsqueeze(-1) if x.ndim == 2 else x
80
  output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
81
  return self.dropout(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  return
61
  pe = torch.zeros(x.size(1), self.embedding_dim)
62
  if self.reverse:
63
+ position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
 
 
64
  else:
65
  position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
66
  div_term = torch.exp(
67
+ torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
 
68
  )
69
  pe[:, 0::2] = torch.sin(position * div_term)
70
  pe[:, 1::2] = torch.cos(position * div_term)
 
76
  output = x.unsqueeze(-1) if x.ndim == 2 else x
77
  output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
78
  return self.dropout(output)
79
+
80
+
81
+ class SinePositionalEmbeddingNested(nn.Module):
82
+ def __init__(
83
+ self,
84
+ embedding_dim: int,
85
+ dropout: float = 0.0,
86
+ scale: bool = False,
87
+ alpha: bool = False,
88
+ max_batch_size: int = 20,
89
+ max_seq_len: int = 2500,
90
+ ):
91
+ super().__init__()
92
+ self.embedding_dim = embedding_dim
93
+ self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
94
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
95
+ self.dropout = torch.nn.Dropout(p=dropout)
96
+ self.max_batch_size = max_batch_size
97
+ self.max_seq_len = max_seq_len
98
+
99
+ self.reverse = False
100
+ self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
101
+ self.pe: torch.Tensor
102
+ self.compute_pe()
103
+
104
+ def compute_pe(self):
105
+ """Reset the positional encodings."""
106
+ if self.reverse:
107
+ position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
108
+ else:
109
+ position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
110
+ div_term = torch.exp(
111
+ torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
112
+ )
113
+ pe = self.pe
114
+ pe[:, :, 0::2] = torch.sin(position * div_term)
115
+ pe[:, :, 1::2] = torch.cos(position * div_term)
116
+
117
+ def forward(self, input_pos: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
118
+ """
119
+ Args:
120
+ input_pos (Tensor): [batch_size, ]
121
+ x (Tensor): [batch_size, 1, embed_dim]
122
+
123
+ Returns:
124
+ embedded_x (Tensor): [batch_size, 1, embed_dim]
125
+ """
126
+
127
+ batch_size = x.shape[0]
128
+ pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
129
+
130
+ return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
131
+
132
+ def prefill(self, x: torch.Tensor) -> torch.Tensor:
133
+ """
134
+ Args:
135
+ x (Tensor): Nested Seqlen [batch_size, seq_len, embed_dim]
136
+
137
+ Returns:
138
+ embedded_x (Tensor): Nested Seqlen [batch_size, seq_len, embed_dim]
139
+ """
140
+
141
+ input_pos: torch.Tensor = torch.tensor([i.shape[0] for i in x.unbind()])
142
+ pe_values = torch.nested.nested_tensor([self.pe[i, : input_pos[i], :] for i in range(input_pos.size(0))])
143
+ return x * self.x_scale + self.alpha.item() * pe_values
inference_webui.py CHANGED
@@ -1,12 +1,35 @@
1
  import os
2
- os.makedirs("pretrained_models",exist_ok=True)
 
3
  from huggingface_hub import snapshot_download
4
- snapshot_download(repo_id="lj1995/GPT-SoVITS",repo_type="model",allow_patterns="chinese*",local_dir="pretrained_models",)
5
- snapshot_download(repo_id="lj1995/GPT-SoVITS",repo_type="model",allow_patterns="s1v3.ckpt",local_dir="pretrained_models",)
6
- snapshot_download(repo_id="lj1995/GPT-SoVITS",repo_type="model",allow_patterns="sv*",local_dir="pretrained_models",)
7
- snapshot_download(repo_id="lj1995/GPT-SoVITS",repo_type="model",allow_patterns="v2Pro/s2Gv2ProPlus.pth",local_dir="pretrained_models",)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import logging
9
  import traceback
 
10
  logging.getLogger("markdown_it").setLevel(logging.ERROR)
11
  logging.getLogger("urllib3").setLevel(logging.ERROR)
12
  logging.getLogger("httpcore").setLevel(logging.ERROR)
@@ -17,42 +40,47 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
17
  logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
18
  logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
19
  logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
 
 
 
 
20
  from text.LangSegmenter import LangSegmenter
21
- import gradio.analytics as analytics
22
- analytics.version_check = lambda:None
23
- analytics.get_local_ip_address= lambda :"127.0.0.1"##不干掉本地联不通亚马逊的get_local_ip服务器
24
- import nltk,torchaudio
25
- nltk.download('averaged_perceptron_tagger_eng')
26
- import LangSegment, os, re, sys, json
27
  import pdb
 
 
 
 
28
  import spaces
29
  import torch
30
 
31
- version="v2"#os.environ.get("version","v2")
32
- cnhubert_base_path = os.environ.get(
33
- "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
34
- )
35
- bert_path = os.environ.get(
36
- "bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
37
- )
38
 
39
- punctuation = set(['!', '?', '', ',', '.', '-'," "])
40
  import gradio as gr
41
- from transformers import AutoModelForMaskedLM, AutoTokenizer
42
- import numpy as np
43
  import librosa
 
 
 
44
  from feature_extractor import cnhubert
45
 
46
  cnhubert.cnhubert_base_path = cnhubert_base_path
47
 
 
 
 
 
 
48
  from module.models import SynthesizerTrn
49
- from AR.models.t2s_lightning_module import Text2SemanticLightningModule
50
  from text import cleaned_text_to_sequence
51
  from text.cleaner import clean_text
52
- from time import time as ttime
53
- from module.mel_processing import spectrogram_torch
54
- from tools.my_utils import load_audio
55
  from tools.i18n.i18n import I18nAuto, scan_language_list
 
56
 
57
  # language=os.environ.get("language","Auto")
58
  # language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
@@ -65,30 +93,30 @@ if torch.cuda.is_available():
65
  is_half = True # eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
66
  else:
67
  device = "cpu"
68
- is_half=False
69
 
70
  dict_language_v1 = {
71
- i18n("中文"): "all_zh",#全部按中文识别
72
- i18n("英文"): "en",#全部按英文识别#######不变
73
- i18n("日文"): "all_ja",#全部按日文识别
74
- i18n("中英混合"): "zh",#按中英混合识别####不变
75
- i18n("日英混合"): "ja",#按日英混合识别####不变
76
- i18n("多语种混合"): "auto",#多语种启动切分识别语种
77
  }
78
  dict_language_v2 = {
79
- i18n("中文"): "all_zh",#全部按中文识别
80
- i18n("英文"): "en",#全部按英文识别#######不变
81
- i18n("日���"): "all_ja",#全部按日文识别
82
- i18n("粤语"): "all_yue",#全部按中文识别
83
- i18n("韩文"): "all_ko",#全部按韩文识别
84
- i18n("中英混合"): "zh",#按中英混合识别####不变
85
- i18n("日英混合"): "ja",#按日英混合识别####不变
86
- i18n("粤英混合"): "yue",#按粤英混合识别####不变
87
- i18n("韩英混合"): "ko",#按韩英混合识别####不变
88
- i18n("多语种混合"): "auto",#多语种启动切分识别语种
89
- i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
90
  }
91
- dict_language = dict_language_v1 if version =='v1' else dict_language_v2
92
 
93
  tokenizer = AutoTokenizer.from_pretrained(bert_path)
94
  bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
@@ -149,13 +177,13 @@ else:
149
  ssl_model = ssl_model.to(device)
150
 
151
 
152
- def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
153
  global vq_model, hps, version, dict_language
154
  dict_s2 = torch.load(sovits_path, map_location="cpu")
155
  hps = dict_s2["config"]
156
  hps = DictToAttrRecursive(hps)
157
  hps.model.semantic_frame_rate = "25hz"
158
- if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
159
  hps.model.version = "v1"
160
  else:
161
  hps.model.version = "v2"
@@ -165,9 +193,9 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
165
  hps.data.filter_length // 2 + 1,
166
  hps.train.segment_size // hps.data.hop_length,
167
  n_speakers=hps.data.n_speakers,
168
- **hps.model
169
  )
170
- if ("pretrained" not in sovits_path):
171
  del vq_model.enc_q
172
  if is_half == True:
173
  vq_model = vq_model.half().to(device)
@@ -175,43 +203,48 @@ def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
175
  vq_model = vq_model.to(device)
176
  vq_model.eval()
177
  print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
178
- dict_language = dict_language_v1 if version =='v1' else dict_language_v2
179
  if prompt_language is not None and text_language is not None:
180
  if prompt_language in list(dict_language.keys()):
181
- prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
 
 
 
182
  else:
183
- prompt_text_update = {'__type__':'update', 'value':''}
184
- prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
185
  if text_language in list(dict_language.keys()):
186
- text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
187
  else:
188
- text_update = {'__type__':'update', 'value':''}
189
- text_language_update = {'__type__':'update', 'value':i18n("中文")}
190
- return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
191
-
 
 
 
 
 
 
192
 
193
 
194
  change_sovits_weights("pretrained_models/v2Pro/s2Gv2ProPlus.pth")
195
 
196
 
197
  def change_gpt_weights(gpt_path):
198
- global hz, max_sec, t2s_model, config
199
- hz = 50
200
  dict_s1 = torch.load(gpt_path, map_location="cpu")
201
  config = dict_s1["config"]
202
- max_sec = config["data"]["max_sec"]
203
- t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
204
- t2s_model.load_state_dict(dict_s1["weight"])
205
- if is_half == True:
206
- t2s_model = t2s_model.half()
207
- t2s_model = t2s_model.to(device)
208
- t2s_model.eval()
209
- total = sum([param.nelement() for param in t2s_model.parameters()])
210
  print("Number of parameter: %.2fM" % (total / 1e6))
211
 
212
 
213
  change_gpt_weights("pretrained_models/s1v3.ckpt")
214
  from sv import SV
 
215
  sv_cn_model = SV(device, is_half)
216
 
217
  resample_transform_dict = {}
@@ -261,11 +294,14 @@ def clean_text_inf(text, language, version):
261
  phones = cleaned_text_to_sequence(phones, version)
262
  return phones, word2ph, norm_text
263
 
264
- dtype=torch.float16 if is_half == True else torch.float32
 
 
 
265
  def get_bert_inf(phones, word2ph, norm_text, language):
266
- language=language.replace("all_","")
267
  if language == "zh":
268
- bert = get_bert_feature(norm_text, word2ph).to(device)#.to(dtype)
269
  else:
270
  bert = torch.zeros(
271
  (1024, len(phones)),
@@ -275,7 +311,21 @@ def get_bert_inf(phones, word2ph, norm_text, language):
275
  return bert
276
 
277
 
278
- splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
 
281
  def get_first(text):
@@ -283,8 +333,10 @@ def get_first(text):
283
  text = re.split(pattern, text)[0].strip()
284
  return text
285
 
 
286
  from text import chinese
287
 
 
288
  def get_phones_and_bert(text, language, version, final=False):
289
  if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
290
  formattext = text
@@ -361,24 +413,44 @@ def merge_short_text_in_array(texts, threshold):
361
  if len(text) >= threshold:
362
  result.append(text)
363
  text = ""
364
- if (len(text) > 0):
365
  if len(result) == 0:
366
  result.append(text)
367
  else:
368
  result[len(result) - 1] += text
369
  return result
370
 
 
371
  ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
372
  # cache_tokens={}#暂未实现清理机制
373
- cache= {}
374
- @torch.inference_mode()
 
375
  @spaces.GPU
376
- def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6, ref_free = False,speed=1,if_freeze=False,inp_refs=123):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  global cache
378
- if ref_wav_path:pass
379
- else:gr.Warning(i18n('请上传参考音频'))
380
- if text:pass
381
- else:gr.Warning(i18n('请填入推理文本'))
 
 
 
 
382
  t = []
383
  if prompt_text is None or len(prompt_text) == 0:
384
  ref_free = True
@@ -386,13 +458,14 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
386
  prompt_language = dict_language[prompt_language]
387
  text_language = dict_language[text_language]
388
 
389
-
390
  if not ref_free:
391
  prompt_text = prompt_text.strip("\n")
392
- if (prompt_text[-1] not in splits): prompt_text += "。" if prompt_language != "en" else "."
 
393
  print(i18n("实际输入的参考文本:"), prompt_text)
394
  text = text.strip("\n")
395
- if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
 
396
 
397
  print(i18n("实际输入的目标文本:"), text)
398
  zero_wav = np.zeros(
@@ -402,7 +475,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
402
  if not ref_free:
403
  with torch.no_grad():
404
  wav16k, sr = librosa.load(ref_wav_path, sr=16000)
405
- if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
406
  gr.Warning(i18n("参考音���在3~10秒范围外,请更换!"))
407
  raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
408
  wav16k = torch.from_numpy(wav16k)
@@ -414,27 +487,23 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
414
  wav16k = wav16k.to(device)
415
  zero_wav_torch = zero_wav_torch.to(device)
416
  wav16k = torch.cat([wav16k, zero_wav_torch])
417
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
418
- "last_hidden_state"
419
- ].transpose(
420
- 1, 2
421
- ) # .float()
422
  codes = vq_model.extract_latent(ssl_content)
423
  prompt_semantic = codes[0, 0]
424
  prompt = prompt_semantic.unsqueeze(0).to(device)
425
 
426
  t1 = ttime()
427
- t.append(t1-t0)
428
 
429
- if (how_to_cut == i18n("凑四句一切")):
430
  text = cut1(text)
431
- elif (how_to_cut == i18n("凑50字一切")):
432
  text = cut2(text)
433
- elif (how_to_cut == i18n("按中文句号。切")):
434
  text = cut3(text)
435
- elif (how_to_cut == i18n("按英文句号.切")):
436
  text = cut4(text)
437
- elif (how_to_cut == i18n("按标点符号切")):
438
  text = cut5(text)
439
  while "\n\n" in text:
440
  text = text.replace("\n\n", "\n")
@@ -444,19 +513,20 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
444
  texts = merge_short_text_in_array(texts, 5)
445
  audio_opt = []
446
  if not ref_free:
447
- phones1,bert1,norm_text1=get_phones_and_bert(prompt_text, prompt_language, version)
448
 
449
- for i_text,text in enumerate(texts):
450
  # 解决输入目标文本的空行导致报错的问题
451
- if (len(text.strip()) == 0):
452
  continue
453
- if (text[-1] not in splits): text += "。" if text_language != "en" else "."
 
454
  print(i18n("实际输入的目标文本(每句):"), text)
455
- phones2,bert2,norm_text2=get_phones_and_bert(text, text_language, version)
456
  print(i18n("前端处理后的文本(每句):"), norm_text2)
457
  if not ref_free:
458
  bert = torch.cat([bert1, bert2], 1)
459
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
460
  else:
461
  bert = bert2
462
  all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
@@ -467,26 +537,33 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
467
  t2 = ttime()
468
  # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
469
  # print(cache.keys(),if_freeze)
470
- if(i_text in cache and if_freeze==True):pred_semantic=cache[i_text]
 
471
  else:
472
  with torch.no_grad():
473
- pred_semantic, idx = t2s_model.model.infer_panel(
474
- all_phoneme_ids,
475
  all_phoneme_len,
476
- None if ref_free else prompt,
477
- bert,
478
- # prompt_phone_len=ph_offset,
479
  top_k=top_k,
480
  top_p=top_p,
481
  temperature=temperature,
482
- early_stop_num=hz * max_sec,
 
483
  )
484
- pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
485
- cache[i_text]=pred_semantic
 
 
 
 
 
486
  t3 = ttime()
487
- refers=[]
488
  sv_emb = []
489
- if(inp_refs):
490
  for path in inp_refs:
491
  try:
492
  refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro=True)
@@ -498,22 +575,28 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
498
  refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro=True)
499
  refers = [refers]
500
  sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
501
- audio = vq_model.decode(
502
- pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb
503
- ).detach().cpu().numpy()[0][0]
504
- max_audio=np.abs(audio).max()#简单防止16bit爆音
505
- if max_audio>1:audio/=max_audio
 
 
 
 
 
 
 
 
 
 
506
  audio_opt.append(audio)
507
  audio_opt.append(zero_wav)
508
  t4 = ttime()
509
- t.extend([t2 - t1,t3 - t2, t4 - t3])
510
  t1 = ttime()
511
- print("%.3f\t%.3f\t%.3f\t%.3f" %
512
- (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))
513
- )
514
- yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
515
- np.int16
516
- )
517
 
518
 
519
  def split(todo_text):
@@ -543,7 +626,7 @@ def cut1(inp):
543
  if len(split_idx) > 1:
544
  opts = []
545
  for idx in range(len(split_idx) - 1):
546
- opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
547
  else:
548
  opts = [inp]
549
  opts = [item for item in opts if not set(item).issubset(punctuation)]
@@ -579,7 +662,8 @@ def cut3(inp):
579
  inp = inp.strip("\n")
580
  opts = ["%s" % item for item in inp.strip("。").split("。")]
581
  opts = [item for item in opts if not set(item).issubset(punctuation)]
582
- return "\n".join(opts)
 
583
 
584
  def cut4(inp):
585
  inp = inp.strip("\n")
@@ -591,13 +675,13 @@ def cut4(inp):
591
  # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
592
  def cut5(inp):
593
  inp = inp.strip("\n")
594
- punds = {',', '.', ';', '?', '!', '', '', '', '', '', ';', '', ''}
595
  mergeitems = []
596
  items = []
597
 
598
  for i, char in enumerate(inp):
599
  if char in punds:
600
- if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
601
  items.append(char)
602
  else:
603
  items.append(char)
@@ -615,35 +699,37 @@ def cut5(inp):
615
 
616
  def custom_sort_key(s):
617
  # 使用正则表达式提取字符串中的数字部分和非数字部分
618
- parts = re.split('(\d+)', s)
619
  # 将数字部分转换为整数,非数字部分保持不变
620
  parts = [int(part) if part.isdigit() else part for part in parts]
621
  return parts
622
 
 
623
  def process_text(texts):
624
- _text=[]
625
- if all(text in [None, " ", "\n",""] for text in texts):
626
  raise ValueError(i18n("请输入有效文本"))
627
  for text in texts:
628
- if text in [None, " ", ""]:
629
  pass
630
  else:
631
  _text.append(text)
632
  return _text
633
 
634
 
635
- def html_center(text, label='p'):
636
  return f"""<div style="text-align: center; margin: 100; padding: 50;">
637
  <{label} style="margin: 0; padding: 0;">{text}</{label}>
638
  </div>"""
639
 
640
- def html_left(text, label='p'):
 
641
  return f"""<div style="text-align: left; margin: 0; padding: 0;">
642
  <{label} style="margin: 0; padding: 0;">{text}</{label}>
643
  </div>"""
644
 
645
 
646
- with gr.Blocks(title="GPT-SoVITS WebUI") as app:
647
  gr.Markdown(
648
  value="""# GPT-SoVITS-ProPlus Zero-shot TTS demo
649
  ## https://github.com/RVC-Boss/GPT-SoVITS
@@ -656,49 +742,95 @@ This demo is open source under the MIT license. The author does not have any con
656
  """
657
  )
658
  with gr.Group():
659
- gr.Markdown(html_center(i18n("*请上传并填写参考信息"),'h3'))
660
  with gr.Row():
661
  inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath")
662
  with gr.Column():
663
- ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
664
- gr.Markdown(html_left(i18n("使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。<br>开启后无视填写的参考文本。")))
 
 
 
 
 
 
 
 
 
 
 
665
  prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=3, max_lines=3)
666
  prompt_language = gr.Dropdown(
667
  label=i18n("参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
668
  )
669
- inp_refs = gr.File(label=i18n("可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。"),file_count="multiple")
670
- gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"),'h3'))
 
 
 
 
 
671
  with gr.Row():
672
  with gr.Column():
673
  text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26)
674
  with gr.Column():
675
  text_language = gr.Dropdown(
676
- label=i18n("需要合成的语种")+i18n(".限制范围越小判别效果越好。"), choices=list(dict_language.keys()), value=i18n("中文")
677
- )
 
 
678
  how_to_cut = gr.Dropdown(
679
- label=i18n("怎么切"),
680
- choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
681
- value=i18n("凑四句一切"),
682
- interactive=True
683
- )
 
 
 
 
 
 
 
684
  gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
685
- if_freeze=gr.Checkbox(label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"), value=False, interactive=True,show_label=True)
686
- speed = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label=i18n("语速"),value=1,interactive=True)
 
 
 
 
 
687
  gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):")))
688
- top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=15,interactive=True)
689
- top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
690
- temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
 
 
691
  with gr.Row():
692
- inference_button = gr.Button(i18n("合成语音"), variant="primary", size='lg')
693
  output = gr.Audio(label=i18n("输出的语音"))
694
 
695
  inference_button.click(
696
  get_tts_wav,
697
- [inp_ref, prompt_text, prompt_language, text, text_language, how_to_cut, top_k, top_p, temperature, ref_text_free,speed,if_freeze,inp_refs],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
  [output],
699
  )
700
 
701
- if __name__ == '__main__':
702
  app.queue().launch(
703
  server_name="0.0.0.0",
704
  inbrowser=True,
 
1
  import os
2
+
3
+ os.makedirs("pretrained_models", exist_ok=True)
4
  from huggingface_hub import snapshot_download
5
+
6
+ snapshot_download(
7
+ repo_id="lj1995/GPT-SoVITS",
8
+ repo_type="model",
9
+ allow_patterns="chinese*",
10
+ local_dir="pretrained_models",
11
+ )
12
+ snapshot_download(
13
+ repo_id="lj1995/GPT-SoVITS",
14
+ repo_type="model",
15
+ allow_patterns="s1v3.ckpt",
16
+ local_dir="pretrained_models",
17
+ )
18
+ snapshot_download(
19
+ repo_id="lj1995/GPT-SoVITS",
20
+ repo_type="model",
21
+ allow_patterns="sv*",
22
+ local_dir="pretrained_models",
23
+ )
24
+ snapshot_download(
25
+ repo_id="lj1995/GPT-SoVITS",
26
+ repo_type="model",
27
+ allow_patterns="v2Pro/s2Gv2ProPlus.pth",
28
+ local_dir="pretrained_models",
29
+ )
30
  import logging
31
  import traceback
32
+
33
  logging.getLogger("markdown_it").setLevel(logging.ERROR)
34
  logging.getLogger("urllib3").setLevel(logging.ERROR)
35
  logging.getLogger("httpcore").setLevel(logging.ERROR)
 
40
  logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
41
  logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
42
  logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
43
+
44
+ import nltk
45
+ import torchaudio
46
+
47
  from text.LangSegmenter import LangSegmenter
48
+
49
+ nltk.download("averaged_perceptron_tagger_eng")
50
+ import json
51
+ import os
 
 
52
  import pdb
53
+ import re
54
+ import sys
55
+
56
+ import LangSegment
57
  import spaces
58
  import torch
59
 
60
+ version = "v2" # os.environ.get("version","v2")
61
+ cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base")
62
+ bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large")
 
 
 
 
63
 
64
+ punctuation = set(["!", "?", "", ",", ".", "-", " "])
65
  import gradio as gr
 
 
66
  import librosa
67
+ import numpy as np
68
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
69
+
70
  from feature_extractor import cnhubert
71
 
72
  cnhubert.cnhubert_base_path = cnhubert_base_path
73
 
74
+ from time import time as ttime
75
+
76
+ from AR.models.structs import T2SRequest
77
+ from AR.models.t2s_model_flash_attn import CUDAGraphRunner
78
+ from module.mel_processing import spectrogram_torch
79
  from module.models import SynthesizerTrn
 
80
  from text import cleaned_text_to_sequence
81
  from text.cleaner import clean_text
 
 
 
82
  from tools.i18n.i18n import I18nAuto, scan_language_list
83
+ from tools.my_utils import load_audio
84
 
85
  # language=os.environ.get("language","Auto")
86
  # language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
 
93
  is_half = True # eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
94
  else:
95
  device = "cpu"
96
+ is_half = False
97
 
98
  dict_language_v1 = {
99
+ i18n("中文"): "all_zh", # 全部按中文识别
100
+ i18n("英文"): "en", # 全部按英文识别#######不变
101
+ i18n("日文"): "all_ja", # 全部按日文识别
102
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
103
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
104
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
105
  }
106
  dict_language_v2 = {
107
+ i18n("中文"): "all_zh", # 全部按中文识别
108
+ i18n("英文"): "en", # 全部按英文识别#######不变
109
+ i18n("日文"): "all_ja", # 全部按日文识别
110
+ i18n("粤语"): "all_yue", # 全部按中文识别
111
+ i18n("韩文"): "all_ko", # 全部按韩文识别
112
+ i18n("中英混合"): "zh", # 按中英混合识别####不变
113
+ i18n("日英混合"): "ja", # 按日英混合识别####不变
114
+ i18n("粤英混合"): "yue", # 按粤英混合识别####不变
115
+ i18n("韩英混合"): "ko", # 按韩英混合识别####不变
116
+ i18n("多语种混合"): "auto", # 多语种启动切分识别语种
117
+ i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
118
  }
119
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
120
 
121
  tokenizer = AutoTokenizer.from_pretrained(bert_path)
122
  bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
 
177
  ssl_model = ssl_model.to(device)
178
 
179
 
180
+ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
181
  global vq_model, hps, version, dict_language
182
  dict_s2 = torch.load(sovits_path, map_location="cpu")
183
  hps = dict_s2["config"]
184
  hps = DictToAttrRecursive(hps)
185
  hps.model.semantic_frame_rate = "25hz"
186
+ if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
187
  hps.model.version = "v1"
188
  else:
189
  hps.model.version = "v2"
 
193
  hps.data.filter_length // 2 + 1,
194
  hps.train.segment_size // hps.data.hop_length,
195
  n_speakers=hps.data.n_speakers,
196
+ **hps.model,
197
  )
198
+ if "pretrained" not in sovits_path:
199
  del vq_model.enc_q
200
  if is_half == True:
201
  vq_model = vq_model.half().to(device)
 
203
  vq_model = vq_model.to(device)
204
  vq_model.eval()
205
  print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
206
+ dict_language = dict_language_v1 if version == "v1" else dict_language_v2
207
  if prompt_language is not None and text_language is not None:
208
  if prompt_language in list(dict_language.keys()):
209
+ prompt_text_update, prompt_language_update = (
210
+ {"__type__": "update"},
211
+ {"__type__": "update", "value": prompt_language},
212
+ )
213
  else:
214
+ prompt_text_update = {"__type__": "update", "value": ""}
215
+ prompt_language_update = {"__type__": "update", "value": i18n("中文")}
216
  if text_language in list(dict_language.keys()):
217
+ text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
218
  else:
219
+ text_update = {"__type__": "update", "value": ""}
220
+ text_language_update = {"__type__": "update", "value": i18n("中文")}
221
+ return (
222
+ {"__type__": "update", "choices": list(dict_language.keys())},
223
+ {"__type__": "update", "choices": list(dict_language.keys())},
224
+ prompt_text_update,
225
+ prompt_language_update,
226
+ text_update,
227
+ text_language_update,
228
+ )
229
 
230
 
231
  change_sovits_weights("pretrained_models/v2Pro/s2Gv2ProPlus.pth")
232
 
233
 
234
  def change_gpt_weights(gpt_path):
235
+ global t2s_model, config
 
236
  dict_s1 = torch.load(gpt_path, map_location="cpu")
237
  config = dict_s1["config"]
238
+ t2s_model = CUDAGraphRunner(
239
+ CUDAGraphRunner.load_decoder(gpt_path), torch.device(device), torch.float16 if is_half else torch.float32
240
+ )
241
+ total = sum(p.numel() for p in t2s_model.decoder_model.parameters())
 
 
 
 
242
  print("Number of parameter: %.2fM" % (total / 1e6))
243
 
244
 
245
  change_gpt_weights("pretrained_models/s1v3.ckpt")
246
  from sv import SV
247
+
248
  sv_cn_model = SV(device, is_half)
249
 
250
  resample_transform_dict = {}
 
294
  phones = cleaned_text_to_sequence(phones, version)
295
  return phones, word2ph, norm_text
296
 
297
+
298
+ dtype = torch.float16 if is_half == True else torch.float32
299
+
300
+
301
  def get_bert_inf(phones, word2ph, norm_text, language):
302
+ language = language.replace("all_", "")
303
  if language == "zh":
304
+ bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
305
  else:
306
  bert = torch.zeros(
307
  (1024, len(phones)),
 
311
  return bert
312
 
313
 
314
+ splits = {
315
+ ",",
316
+ "。",
317
+ "?",
318
+ "!",
319
+ ",",
320
+ ".",
321
+ "?",
322
+ "!",
323
+ "~",
324
+ ":",
325
+ ":",
326
+ "—",
327
+ "…",
328
+ }
329
 
330
 
331
  def get_first(text):
 
333
  text = re.split(pattern, text)[0].strip()
334
  return text
335
 
336
+
337
  from text import chinese
338
 
339
+
340
  def get_phones_and_bert(text, language, version, final=False):
341
  if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
342
  formattext = text
 
413
  if len(text) >= threshold:
414
  result.append(text)
415
  text = ""
416
+ if len(text) > 0:
417
  if len(result) == 0:
418
  result.append(text)
419
  else:
420
  result[len(result) - 1] += text
421
  return result
422
 
423
+
424
  ##ref_wav_path+prompt_text+prompt_language+text(单个)+text_language+top_k+top_p+temperature
425
  # cache_tokens={}#暂未实现清理机制
426
+ cache = {}
427
+
428
+
429
  @spaces.GPU
430
+ def get_tts_wav(
431
+ ref_wav_path,
432
+ prompt_text,
433
+ prompt_language,
434
+ text,
435
+ text_language,
436
+ how_to_cut=i18n("不切"),
437
+ top_k=20,
438
+ top_p=0.6,
439
+ temperature=0.6,
440
+ ref_free=False,
441
+ speed=1,
442
+ if_freeze=False,
443
+ inp_refs=123,
444
+ ):
445
  global cache
446
+ if ref_wav_path:
447
+ pass
448
+ else:
449
+ gr.Warning(i18n("请上传参考音频"))
450
+ if text:
451
+ pass
452
+ else:
453
+ gr.Warning(i18n("请填入推理文本"))
454
  t = []
455
  if prompt_text is None or len(prompt_text) == 0:
456
  ref_free = True
 
458
  prompt_language = dict_language[prompt_language]
459
  text_language = dict_language[text_language]
460
 
 
461
  if not ref_free:
462
  prompt_text = prompt_text.strip("\n")
463
+ if prompt_text[-1] not in splits:
464
+ prompt_text += "。" if prompt_language != "en" else "."
465
  print(i18n("实际输入的参考文本:"), prompt_text)
466
  text = text.strip("\n")
467
+ if text[0] not in splits and len(get_first(text)) < 4:
468
+ text = "。" + text if text_language != "en" else "." + text
469
 
470
  print(i18n("实际输入的目标文本:"), text)
471
  zero_wav = np.zeros(
 
475
  if not ref_free:
476
  with torch.no_grad():
477
  wav16k, sr = librosa.load(ref_wav_path, sr=16000)
478
+ if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
479
  gr.Warning(i18n("参考音���在3~10秒范围外,请更换!"))
480
  raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
481
  wav16k = torch.from_numpy(wav16k)
 
487
  wav16k = wav16k.to(device)
488
  zero_wav_torch = zero_wav_torch.to(device)
489
  wav16k = torch.cat([wav16k, zero_wav_torch])
490
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
 
 
 
 
491
  codes = vq_model.extract_latent(ssl_content)
492
  prompt_semantic = codes[0, 0]
493
  prompt = prompt_semantic.unsqueeze(0).to(device)
494
 
495
  t1 = ttime()
496
+ t.append(t1 - t0)
497
 
498
+ if how_to_cut == i18n("凑四句一切"):
499
  text = cut1(text)
500
+ elif how_to_cut == i18n("凑50字一切"):
501
  text = cut2(text)
502
+ elif how_to_cut == i18n("按中文句号。切"):
503
  text = cut3(text)
504
+ elif how_to_cut == i18n("按英文句号.切"):
505
  text = cut4(text)
506
+ elif how_to_cut == i18n("按标点符号切"):
507
  text = cut5(text)
508
  while "\n\n" in text:
509
  text = text.replace("\n\n", "\n")
 
513
  texts = merge_short_text_in_array(texts, 5)
514
  audio_opt = []
515
  if not ref_free:
516
+ phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
517
 
518
+ for i_text, text in enumerate(texts):
519
  # 解决输入目标文本的空行导致报错的问题
520
+ if len(text.strip()) == 0:
521
  continue
522
+ if text[-1] not in splits:
523
+ text += "。" if text_language != "en" else "."
524
  print(i18n("实际输入的目标文本(每句):"), text)
525
+ phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
526
  print(i18n("前端处理后的文本(每句):"), norm_text2)
527
  if not ref_free:
528
  bert = torch.cat([bert1, bert2], 1)
529
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
530
  else:
531
  bert = bert2
532
  all_phoneme_ids = torch.LongTensor(phones2).to(device).unsqueeze(0)
 
537
  t2 = ttime()
538
  # cache_key="%s-%s-%s-%s-%s-%s-%s-%s"%(ref_wav_path,prompt_text,prompt_language,text,text_language,top_k,top_p,temperature)
539
  # print(cache.keys(),if_freeze)
540
+ if i_text in cache and if_freeze == True:
541
+ pred_semantic = cache[i_text]
542
  else:
543
  with torch.no_grad():
544
+ t2s_request = T2SRequest(
545
+ [all_phoneme_ids.squeeze(0)],
546
  all_phoneme_len,
547
+ torch.zeros((1, 0)) if ref_free else prompt,
548
+ [bert.squeeze(0)],
549
+ valid_length=1,
550
  top_k=top_k,
551
  top_p=top_p,
552
  temperature=temperature,
553
+ early_stop_num=1500,
554
+ use_cuda_graph=True,
555
  )
556
+ t2s_result = t2s_model.generate(t2s_request)
557
+ pred_semantic = t2s_result.result
558
+ if pred_semantic is None:
559
+ print(t2s_result.exception)
560
+ print(t2s_result.traceback)
561
+ raise RuntimeError("")
562
+ cache[i_text] = pred_semantic
563
  t3 = ttime()
564
+ refers = []
565
  sv_emb = []
566
+ if inp_refs:
567
  for path in inp_refs:
568
  try:
569
  refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro=True)
 
575
  refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro=True)
576
  refers = [refers]
577
  sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
578
+ audio = (
579
+ vq_model.decode(
580
+ pred_semantic[0].unsqueeze(0).unsqueeze(0),
581
+ torch.LongTensor(phones2).to(device).unsqueeze(0),
582
+ refers,
583
+ speed=speed,
584
+ sv_emb=sv_emb,
585
+ )
586
+ .detach()
587
+ .cpu()
588
+ .numpy()[0][0]
589
+ )
590
+ max_audio = np.abs(audio).max() # 简单防止16bit爆音
591
+ if max_audio > 1:
592
+ audio /= max_audio
593
  audio_opt.append(audio)
594
  audio_opt.append(zero_wav)
595
  t4 = ttime()
596
+ t.extend([t2 - t1, t3 - t2, t4 - t3])
597
  t1 = ttime()
598
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
599
+ yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
 
 
 
 
600
 
601
 
602
  def split(todo_text):
 
626
  if len(split_idx) > 1:
627
  opts = []
628
  for idx in range(len(split_idx) - 1):
629
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
630
  else:
631
  opts = [inp]
632
  opts = [item for item in opts if not set(item).issubset(punctuation)]
 
662
  inp = inp.strip("\n")
663
  opts = ["%s" % item for item in inp.strip("。").split("。")]
664
  opts = [item for item in opts if not set(item).issubset(punctuation)]
665
+ return "\n".join(opts)
666
+
667
 
668
  def cut4(inp):
669
  inp = inp.strip("\n")
 
675
  # contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
676
  def cut5(inp):
677
  inp = inp.strip("\n")
678
+ punds = {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}
679
  mergeitems = []
680
  items = []
681
 
682
  for i, char in enumerate(inp):
683
  if char in punds:
684
+ if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
685
  items.append(char)
686
  else:
687
  items.append(char)
 
699
 
700
  def custom_sort_key(s):
701
  # 使用正则表达式提取字符串中的数字部分和非数字部分
702
+ parts = re.split("(\d+)", s)
703
  # 将数字部分转换为整数,非数字部分保持不变
704
  parts = [int(part) if part.isdigit() else part for part in parts]
705
  return parts
706
 
707
+
708
  def process_text(texts):
709
+ _text = []
710
+ if all(text in [None, " ", "\n", ""] for text in texts):
711
  raise ValueError(i18n("请输入有效文本"))
712
  for text in texts:
713
+ if text in [None, " ", ""]:
714
  pass
715
  else:
716
  _text.append(text)
717
  return _text
718
 
719
 
720
+ def html_center(text, label="p"):
721
  return f"""<div style="text-align: center; margin: 100; padding: 50;">
722
  <{label} style="margin: 0; padding: 0;">{text}</{label}>
723
  </div>"""
724
 
725
+
726
+ def html_left(text, label="p"):
727
  return f"""<div style="text-align: left; margin: 0; padding: 0;">
728
  <{label} style="margin: 0; padding: 0;">{text}</{label}>
729
  </div>"""
730
 
731
 
732
+ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False) as app:
733
  gr.Markdown(
734
  value="""# GPT-SoVITS-ProPlus Zero-shot TTS demo
735
  ## https://github.com/RVC-Boss/GPT-SoVITS
 
742
  """
743
  )
744
  with gr.Group():
745
+ gr.Markdown(html_center(i18n("*请上传并填写参考信息"), "h3"))
746
  with gr.Row():
747
  inp_ref = gr.Audio(label=i18n("请上传3~10秒内参考音频,超过会报错!"), type="filepath")
748
  with gr.Column():
749
+ ref_text_free = gr.Checkbox(
750
+ label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
751
+ value=False,
752
+ interactive=True,
753
+ show_label=True,
754
+ )
755
+ gr.Markdown(
756
+ html_left(
757
+ i18n(
758
+ "使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开。<br>开启后无视填写的参考文本。"
759
+ )
760
+ )
761
+ )
762
  prompt_text = gr.Textbox(label=i18n("参考音频的文本"), value="", lines=3, max_lines=3)
763
  prompt_language = gr.Dropdown(
764
  label=i18n("参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
765
  )
766
+ inp_refs = gr.File(
767
+ label=i18n(
768
+ "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。"
769
+ ),
770
+ file_count="multiple",
771
+ )
772
+ gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
773
  with gr.Row():
774
  with gr.Column():
775
  text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=26, max_lines=26)
776
  with gr.Column():
777
  text_language = gr.Dropdown(
778
+ label=i18n("需要合成的语种") + i18n(".限制范围越小判别效果越好。"),
779
+ choices=list(dict_language.keys()),
780
+ value=i18n("中文"),
781
+ )
782
  how_to_cut = gr.Dropdown(
783
+ label=i18n("怎么切"),
784
+ choices=[
785
+ i18n("不切"),
786
+ i18n("凑四句一切"),
787
+ i18n("凑50字一切"),
788
+ i18n("按中文句号。切"),
789
+ i18n("按英文句号.切"),
790
+ i18n("按标点符号切"),
791
+ ],
792
+ value=i18n("凑四句一切"),
793
+ interactive=True,
794
+ )
795
  gr.Markdown(value=html_center(i18n("语速调整,高为更快")))
796
+ if_freeze = gr.Checkbox(
797
+ label=i18n("是否直接对上次合成结果调整语速和音色。防止随机性。"),
798
+ value=False,
799
+ interactive=True,
800
+ show_label=True,
801
+ )
802
+ speed = gr.Slider(minimum=0.6, maximum=1.65, step=0.05, label=i18n("语速"), value=1, interactive=True)
803
  gr.Markdown(html_center(i18n("GPT采样参数(无参考文本时不要太低。不懂就用默认):")))
804
+ top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True)
805
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
806
+ temperature = gr.Slider(
807
+ minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
808
+ )
809
  with gr.Row():
810
+ inference_button = gr.Button(i18n("合成语音"), variant="primary", size="lg")
811
  output = gr.Audio(label=i18n("输出的语音"))
812
 
813
  inference_button.click(
814
  get_tts_wav,
815
+ [
816
+ inp_ref,
817
+ prompt_text,
818
+ prompt_language,
819
+ text,
820
+ text_language,
821
+ how_to_cut,
822
+ top_k,
823
+ top_p,
824
+ temperature,
825
+ ref_text_free,
826
+ speed,
827
+ if_freeze,
828
+ inp_refs,
829
+ ],
830
  [output],
831
  )
832
 
833
+ if __name__ == "__main__":
834
  app.queue().launch(
835
  server_name="0.0.0.0",
836
  inbrowser=True,
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch==2.5.1
requirements.txt CHANGED
@@ -30,10 +30,10 @@ g2pk2==0.0.3
30
  ko_pron==1.3
31
  opencc==1.1.0
32
  python_mecab_ko==1.3.7
33
- torch==2.5.1
34
  pydantic==2.8.2
35
  torchmetrics<=1.5
36
  nltk==3.8.1
37
  fast_langdetect==0.3.1
38
  split_lang==2.1.0
39
- ToJyutping==3.2.0
 
 
30
  ko_pron==1.3
31
  opencc==1.1.0
32
  python_mecab_ko==1.3.7
 
33
  pydantic==2.8.2
34
  torchmetrics<=1.5
35
  nltk==3.8.1
36
  fast_langdetect==0.3.1
37
  split_lang==2.1.0
38
+ ToJyutping==3.2.0
39
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp310-cp310-linux_x86_64.whl