XXXXRT666 commited on
Commit
3011ece
·
1 Parent(s): 4ae2215
AR/models/t2s_model_abc.py CHANGED
@@ -449,7 +449,9 @@ class CUDAGraphCacheABC(ABC):
449
  def assign_graph(self, session: Any):
450
  if self.graph is None:
451
  args, kwds = self.decoder.pre_forward(session)
452
- graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, *args, **kwds)
 
 
453
  self.graph = graph
454
 
455
  if self.assigned is False:
 
449
  def assign_graph(self, session: Any):
450
  if self.graph is None:
451
  args, kwds = self.decoder.pre_forward(session)
452
+ graph = self.decoder.capture(
453
+ self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds
454
+ )
455
  self.graph = graph
456
 
457
  if self.assigned is False:
AR/models/t2s_model_flash_attn.py CHANGED
@@ -239,7 +239,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
239
  session.input_pos = self.input_pos.clone().copy_(session.input_pos)
240
 
241
  args, kwds = self.decoder.pre_forward(session)
242
- graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, *args, **kwds)
243
  session.graph = graph
244
 
245
 
 
239
  session.input_pos = self.input_pos.clone().copy_(session.input_pos)
240
 
241
  args, kwds = self.decoder.pre_forward(session)
242
+ graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds)
243
  session.graph = graph
244
 
245
 
inference_webui.py CHANGED
@@ -38,6 +38,7 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
38
  logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
39
  logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
40
  logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
 
41
 
42
  os.makedirs("pretrained_models", exist_ok=True)
43
 
 
38
  logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
39
  logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
40
  logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
41
+ logging.getLogger("filelock").setLevel(logging.INFO)
42
 
43
  os.makedirs("pretrained_models", exist_ok=True)
44