XXXXRT666 commited on
Commit
5cfeca6
·
1 Parent(s): 7bdf3c3

Cache CUDA Graph

Browse files
AR/models/structs.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass
@@ -48,7 +52,6 @@ class T2SSession:
48
  self.y_len = y_len
49
 
50
  # Cache
51
- self.kv_cache = decoder.init_cache(bsz)
52
  self.sampler = Sampler(bsz, decoder.vocab_size)
53
 
54
  # Forward args
@@ -62,11 +65,6 @@ class T2SSession:
62
  self.input_pos = torch.zeros_like(self.prefill_len)
63
  self.input_pos.add_(self.prefill_len)
64
 
65
- # CUDA Graph
66
- self.graph: Optional[torch.cuda.CUDAGraph] = None
67
- self.xy_pos_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
68
- self.xy_dec_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
69
-
70
  # EOS
71
  self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
72
  self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
 
1
+ """
2
+ Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
+ """
4
+
5
  from __future__ import annotations
6
 
7
  from dataclasses import dataclass
 
52
  self.y_len = y_len
53
 
54
  # Cache
 
55
  self.sampler = Sampler(bsz, decoder.vocab_size)
56
 
57
  # Forward args
 
65
  self.input_pos = torch.zeros_like(self.prefill_len)
66
  self.input_pos.add_(self.prefill_len)
67
 
 
 
 
 
 
68
  # EOS
69
  self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
70
  self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
AR/models/t2s_model_abc.py CHANGED
@@ -1,9 +1,14 @@
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import os
4
  from abc import ABC, abstractmethod
5
  from contextlib import nullcontext
6
  from typing import Any, Dict, List, MutableSequence, Optional, Tuple, Type
 
7
 
8
  import torch
9
  import torch._inductor.config
@@ -31,6 +36,7 @@ class Sampler(nn.Module):
31
  self.register_buffer("samples", torch.zeros((batch_size,), dtype=torch.int32), persistent=False)
32
 
33
  self.__CUDAGraph: Optional[CUDAGraph] = None
 
34
 
35
  def empty_cache(self):
36
  self.logits.zero_()
@@ -139,6 +145,7 @@ class Sampler(nn.Module):
139
  return idx_next
140
 
141
  def capture(self, temperature: float, top_k: int, top_p: float):
 
142
  s = torch.cuda.Stream()
143
  s.wait_stream(torch.cuda.current_stream())
144
 
@@ -153,7 +160,9 @@ class Sampler(nn.Module):
153
  with torch.cuda.graph(self.__CUDAGraph):
154
  self.samples = self.__sample_cuda_graph(logits, temperature, top_k, top_p)
155
  torch.cuda.synchronize()
 
156
 
 
157
  def sample(
158
  self,
159
  logits: Tensor,
@@ -162,21 +171,32 @@ class Sampler(nn.Module):
162
  top_k: int,
163
  top_p: float,
164
  repetition_penalty: float,
165
- use_cuda_graph=False,
166
- idx=-1,
167
  ) -> Tensor:
168
- if use_cuda_graph and torch.cuda.is_available() and self.__CUDAGraph is None and idx > 0:
169
- self.logits.copy_(logits)
170
- self.capture(temperature, top_k, top_p)
171
- if self.__CUDAGraph is not None:
172
- self.logits.copy_(logits)
173
- self.apply_repetition_penalty(self.logits, previous_tokens, repetition_penalty)
174
- self.__CUDAGraph.replay()
175
- samples = self.samples.clone()
176
- else:
177
- samples = self.__sample(logits, previous_tokens, temperature, top_k, top_p, repetition_penalty)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
- return samples
180
 
181
 
182
  class KVCacheABC(ABC, nn.Module):
 
1
+ """
2
+ Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
+ """
4
+
5
  from __future__ import annotations
6
 
7
  import os
8
  from abc import ABC, abstractmethod
9
  from contextlib import nullcontext
10
  from typing import Any, Dict, List, MutableSequence, Optional, Tuple, Type
11
+ import time
12
 
13
  import torch
14
  import torch._inductor.config
 
36
  self.register_buffer("samples", torch.zeros((batch_size,), dtype=torch.int32), persistent=False)
37
 
38
  self.__CUDAGraph: Optional[CUDAGraph] = None
39
+
40
 
41
  def empty_cache(self):
42
  self.logits.zero_()
 
145
  return idx_next
146
 
147
  def capture(self, temperature: float, top_k: int, top_p: float):
148
+ t1=time.perf_counter()
149
  s = torch.cuda.Stream()
150
  s.wait_stream(torch.cuda.current_stream())
151
 
 
160
  with torch.cuda.graph(self.__CUDAGraph):
161
  self.samples = self.__sample_cuda_graph(logits, temperature, top_k, top_p)
162
  torch.cuda.synchronize()
163
+ print("Sample",time.perf_counter()-t1)
164
 
165
+ # @torch.jit.script
166
  def sample(
167
  self,
168
  logits: Tensor,
 
171
  top_k: int,
172
  top_p: float,
173
  repetition_penalty: float,
 
 
174
  ) -> Tensor:
175
+
176
+ previous_tokens = previous_tokens.long()
177
+ score = torch.gather(logits, dim=1, index=previous_tokens)
178
+ score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
179
+ logits.scatter_(dim=1, index=previous_tokens, src=score)
180
+
181
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
182
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
183
+ sorted_indices_to_remove = cum_probs > top_p
184
+ sorted_indices_to_remove[:, 0] = False # keep at least one option
185
+ indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
186
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
187
+
188
+ logits = logits / max(temperature, 1e-5)
189
+
190
+ v, _ = torch.topk(logits, top_k)
191
+ pivot = v[:, -1].unsqueeze(-1)
192
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
193
+
194
+ probs = torch.nn.functional.softmax(logits, dim=-1)
195
+ q = torch.empty_like(probs).exponential_(1.0)
196
+ idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
197
+
198
+ return idx_next
199
 
 
200
 
201
 
202
  class KVCacheABC(ABC, nn.Module):
AR/models/t2s_model_flash_attn.py CHANGED
@@ -1,8 +1,12 @@
1
- import gc
 
 
 
2
  import os
3
  import time
4
  import traceback
5
- from typing import Dict, List, Tuple
 
6
 
7
  import flash_attn # type: ignore
8
  import torch
@@ -50,7 +54,7 @@ class Attention(AttentionABC):
50
 
51
  attn: Tensor = flash_attn.flash_attn_with_kvcache(
52
  q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
53
- )
54
 
55
  attn = self.dropout.forward(attn)
56
 
@@ -215,57 +219,66 @@ class CUDAGraphRunner:
215
 
216
  self.decoder_path: os.PathLike
217
  self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
 
 
 
 
 
 
218
 
219
  def _handle_request(self, request: T2SRequest) -> List[torch.Tensor]:
220
  with self.device:
 
 
 
221
  decoder = self.decoder_model
222
  session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
223
-
224
- y = session.y
225
- bsz = y.size(0)
226
  t1 = 0.0
227
-
 
228
  torch_profiler = TorchProfiler(request.debug)
229
-
230
  with torch_profiler.profiler():
231
  for idx in tqdm(range(1500)):
232
  if idx == 0:
233
- xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested, session.kv_cache)
234
  xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
235
  else:
236
- if request.use_cuda_graph and session.graph is None and torch.cuda.is_available():
237
- session.xy_pos_.copy_(session.xy_pos)
238
  args, kwds = decoder.pre_forward(session)
239
- session.graph = decoder.capture(
240
- session.input_pos,
241
- session.xy_pos_,
242
- session.xy_dec_,
243
- kv_caches=session.kv_cache,
244
  *args,
245
  **kwds,
246
  )
247
 
248
  with torch_profiler.record("AR"):
249
- if session.graph:
250
- session.xy_pos_.copy_(session.xy_pos)
251
- session.graph.replay()
252
- xy_dec = session.xy_dec_.clone()
253
  else:
254
  args, kwds = decoder.pre_forward(session)
255
  xy_dec = decoder.h.forward(
256
- session.input_pos,
257
  session.xy_pos,
258
- session.kv_cache,
259
  *args,
260
  **kwds,
261
  )
 
262
  decoder.post_forward(idx, session)
263
  logits = decoder.ar_predict_layer(xy_dec[:, -1])
264
- session.input_pos.add_(1)
265
 
266
  if idx == 0:
267
- logits = logits[:, :-1]
268
-
269
  with torch_profiler.record("Sampling"):
270
  samples = session.sampler.sample(
271
  logits=logits,
@@ -274,27 +287,26 @@ class CUDAGraphRunner:
274
  top_p=request.top_p,
275
  repetition_penalty=request.repetition_penalty,
276
  temperature=request.temperature,
277
- use_cuda_graph=request.use_cuda_graph,
278
- idx=idx,
279
  )
280
 
281
  session.y = torch.cat([session.y, samples], dim=1)
282
 
 
283
  with torch_profiler.record("EOS"):
284
  argmax_token = torch.argmax(logits, dim=-1)
285
  sample_token = samples.squeeze(1)
286
  EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
287
- with torch_profiler.record("EOS1"):
288
  newly_done_mask = EOS_mask & (~session.completed)
289
- with torch_profiler.record("EOS2"):
290
  newly_done_indices = newly_done_mask.nonzero()
291
- with torch_profiler.record("EOS3"):
 
292
  if newly_done_indices.numel() > 0:
293
  session.y_results[newly_done_indices[0]] = session.y[
294
  newly_done_indices[0], session.y_len : -1
295
  ].squeeze(0)
296
  session.completed[newly_done_indices] = True
297
- with torch_profiler.record("EOS4"):
298
  if torch.all(session.completed).item():
299
  if session.y.size(1) == 0:
300
  session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
@@ -304,11 +316,12 @@ class CUDAGraphRunner:
304
  f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
305
  )
306
  tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
 
307
  break
308
-
309
  if (
310
- request.early_stop_num != -1
311
- and (session.y.size(1) - session.y_len) > request.early_stop_num
312
  ):
313
  for i in range(bsz):
314
  if not session.completed[i].item():
@@ -318,14 +331,25 @@ class CUDAGraphRunner:
318
 
319
  with torch_profiler.record("NextPos"):
320
  y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
321
- session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
322
 
323
  if idx == 2:
324
  torch_profiler.start()
325
  t1 = time.perf_counter()
326
 
327
- # if idx == 51:
328
- # torch_profiler.end()
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  match session.device.type:
331
  case "cuda":
@@ -336,7 +360,7 @@ class CUDAGraphRunner:
336
  torch.xpu.empty_cache()
337
  case "mtia":
338
  torch.mtia.empty_cache()
339
- gc.collect()
340
  torch_profiler.end()
341
  return session.y_results[: request.valid_length]
342
 
 
1
+ """
2
+ Modified From https://github.com/XXXXRT666/GPT-SoVITS
3
+ """
4
+
5
  import os
6
  import time
7
  import traceback
8
+ from typing import Dict, List, Tuple,Optional
9
+ import gradio as gr
10
 
11
  import flash_attn # type: ignore
12
  import torch
 
54
 
55
  attn: Tensor = flash_attn.flash_attn_with_kvcache(
56
  q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
57
+ ) # type: ignore
58
 
59
  attn = self.dropout.forward(attn)
60
 
 
219
 
220
  self.decoder_path: os.PathLike
221
  self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
222
+
223
+ self.graph: Optional[torch.cuda.CUDAGraph]= None
224
+ self.xy_pos_ = torch.rand((1, 1, decoder_model.embedding_dim),device=device).to(dtype)
225
+ self.xy_dec_ = torch.rand((1, 1, decoder_model.embedding_dim),device=device).to(dtype)
226
+ self.kv_cache = decoder_model.init_cache(1)
227
+ self.input_pos = torch.tensor([10]).int().cuda()
228
 
229
  def _handle_request(self, request: T2SRequest) -> List[torch.Tensor]:
230
  with self.device:
231
+ for i in self.kv_cache:
232
+ i.empty()
233
+
234
  decoder = self.decoder_model
235
  session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
236
+ self.input_pos.copy_(session.input_pos)
237
+
 
238
  t1 = 0.0
239
+ y = session.y
240
+ bsz = y.size(0)
241
  torch_profiler = TorchProfiler(request.debug)
 
242
  with torch_profiler.profiler():
243
  for idx in tqdm(range(1500)):
244
  if idx == 0:
245
+ xy_dec = decoder.h.prefill(session.xy_pos, session.attn_mask_nested, self.kv_cache)
246
  xy_dec = torch.stack([t[[-1]] for t in xy_dec.unbind()])
247
  else:
248
+ if request.use_cuda_graph and self.graph is None and torch.cuda.is_available():
249
+ self.xy_pos_.copy_(session.xy_pos)
250
  args, kwds = decoder.pre_forward(session)
251
+ self.graph = decoder.capture(
252
+ self.input_pos,
253
+ self.xy_pos_,
254
+ self.xy_dec_,
255
+ kv_caches=self.kv_cache,
256
  *args,
257
  **kwds,
258
  )
259
 
260
  with torch_profiler.record("AR"):
261
+ if self.graph:
262
+ self.xy_pos_.copy_(session.xy_pos)
263
+ self.graph.replay()
264
+ xy_dec = self.xy_dec_.clone()
265
  else:
266
  args, kwds = decoder.pre_forward(session)
267
  xy_dec = decoder.h.forward(
268
+ self.input_pos,
269
  session.xy_pos,
270
+ self.kv_cache,
271
  *args,
272
  **kwds,
273
  )
274
+
275
  decoder.post_forward(idx, session)
276
  logits = decoder.ar_predict_layer(xy_dec[:, -1])
277
+ self.input_pos.add_(1)
278
 
279
  if idx == 0:
280
+ logits[:, -1] = float("-inf")
281
+
282
  with torch_profiler.record("Sampling"):
283
  samples = session.sampler.sample(
284
  logits=logits,
 
287
  top_p=request.top_p,
288
  repetition_penalty=request.repetition_penalty,
289
  temperature=request.temperature,
 
 
290
  )
291
 
292
  session.y = torch.cat([session.y, samples], dim=1)
293
 
294
+
295
  with torch_profiler.record("EOS"):
296
  argmax_token = torch.argmax(logits, dim=-1)
297
  sample_token = samples.squeeze(1)
298
  EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
299
+
300
  newly_done_mask = EOS_mask & (~session.completed)
 
301
  newly_done_indices = newly_done_mask.nonzero()
302
+
303
+
304
  if newly_done_indices.numel() > 0:
305
  session.y_results[newly_done_indices[0]] = session.y[
306
  newly_done_indices[0], session.y_len : -1
307
  ].squeeze(0)
308
  session.completed[newly_done_indices] = True
309
+
310
  if torch.all(session.completed).item():
311
  if session.y.size(1) == 0:
312
  session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
 
316
  f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
317
  )
318
  tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
319
+ gr.Info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s",duration=0.75)
320
  break
321
+
322
  if (
323
+ (request.early_stop_num != -1
324
+ and (session.y.size(1) - session.y_len) > request.early_stop_num )or idx ==1499
325
  ):
326
  for i in range(bsz):
327
  if not session.completed[i].item():
 
331
 
332
  with torch_profiler.record("NextPos"):
333
  y_emb = decoder.ar_audio_embedding(session.y[:, -1:])
334
+ session.xy_pos = decoder.ar_audio_position.forward(self.input_pos - session.x_lens, y_emb)
335
 
336
  if idx == 2:
337
  torch_profiler.start()
338
  t1 = time.perf_counter()
339
 
340
+ if idx == 51:
341
+ torch_profiler.end()
342
+
343
+ if idx % 100 == 0:
344
+ match session.device.type:
345
+ case "cuda":
346
+ torch.cuda.empty_cache()
347
+ case "mps":
348
+ torch.mps.empty_cache()
349
+ case "xpu":
350
+ torch.xpu.empty_cache()
351
+ case "mtia":
352
+ torch.mtia.empty_cache()
353
 
354
  match session.device.type:
355
  case "cuda":
 
360
  torch.xpu.empty_cache()
361
  case "mtia":
362
  torch.mtia.empty_cache()
363
+
364
  torch_profiler.end()
365
  return session.y_results[: request.valid_length]
366
 
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🤗
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.44.1
8
  app_file: inference_webui.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.20.0
8
  app_file: inference_webui.py
9
  pinned: false
10
  license: mit
inference_webui.py CHANGED
@@ -57,6 +57,10 @@ import LangSegment
57
  import spaces
58
  import torch
59
 
 
 
 
 
60
  version = "v2" # os.environ.get("version","v2")
61
  cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base")
62
  bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large")
@@ -540,7 +544,7 @@ def get_tts_wav(
540
  if i_text in cache and if_freeze == True:
541
  pred_semantic = cache[i_text]
542
  else:
543
- with torch.no_grad():
544
  t2s_request = T2SRequest(
545
  [all_phoneme_ids.squeeze(0)],
546
  all_phoneme_len,
@@ -552,7 +556,7 @@ def get_tts_wav(
552
  temperature=temperature,
553
  early_stop_num=1500,
554
  use_cuda_graph=True,
555
- debug=True,
556
  )
557
  t2s_result = t2s_model.generate(t2s_request)
558
  pred_semantic = t2s_result.result
@@ -836,5 +840,4 @@ if __name__ == "__main__":
836
  server_name="0.0.0.0",
837
  inbrowser=True,
838
  show_api=False,
839
- server_port=1111,
840
  )
 
57
  import spaces
58
  import torch
59
 
60
+ import threading
61
+
62
+ lock = threading.Lock()
63
+
64
  version = "v2" # os.environ.get("version","v2")
65
  cnhubert_base_path = os.environ.get("cnhubert_base_path", "pretrained_models/chinese-hubert-base")
66
  bert_path = os.environ.get("bert_path", "pretrained_models/chinese-roberta-wwm-ext-large")
 
544
  if i_text in cache and if_freeze == True:
545
  pred_semantic = cache[i_text]
546
  else:
547
+ with torch.no_grad(),lock:
548
  t2s_request = T2SRequest(
549
  [all_phoneme_ids.squeeze(0)],
550
  all_phoneme_len,
 
556
  temperature=temperature,
557
  early_stop_num=1500,
558
  use_cuda_graph=True,
559
+ # debug=True,
560
  )
561
  t2s_result = t2s_model.generate(t2s_request)
562
  pred_semantic = t2s_result.result
 
840
  server_name="0.0.0.0",
841
  inbrowser=True,
842
  show_api=False,
 
843
  )