Spaces:
Running
on
Zero
Running
on
Zero
XXXXRT666
commited on
Commit
·
3011ece
1
Parent(s):
4ae2215
- AR/models/t2s_model_abc.py +3 -1
- AR/models/t2s_model_flash_attn.py +1 -1
- inference_webui.py +1 -0
AR/models/t2s_model_abc.py
CHANGED
@@ -449,7 +449,9 @@ class CUDAGraphCacheABC(ABC):
|
|
449 |
def assign_graph(self, session: Any):
|
450 |
if self.graph is None:
|
451 |
args, kwds = self.decoder.pre_forward(session)
|
452 |
-
graph = self.decoder.capture(
|
|
|
|
|
453 |
self.graph = graph
|
454 |
|
455 |
if self.assigned is False:
|
|
|
449 |
def assign_graph(self, session: Any):
|
450 |
if self.graph is None:
|
451 |
args, kwds = self.decoder.pre_forward(session)
|
452 |
+
graph = self.decoder.capture(
|
453 |
+
self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds
|
454 |
+
)
|
455 |
self.graph = graph
|
456 |
|
457 |
if self.assigned is False:
|
AR/models/t2s_model_flash_attn.py
CHANGED
@@ -239,7 +239,7 @@ class CUDAGraphCache(CUDAGraphCacheABC):
|
|
239 |
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
240 |
|
241 |
args, kwds = self.decoder.pre_forward(session)
|
242 |
-
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, *args, **kwds)
|
243 |
session.graph = graph
|
244 |
|
245 |
|
|
|
239 |
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
240 |
|
241 |
args, kwds = self.decoder.pre_forward(session)
|
242 |
+
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, kv_caches=self.kv_cache, *args, **kwds)
|
243 |
session.graph = graph
|
244 |
|
245 |
|
inference_webui.py
CHANGED
@@ -38,6 +38,7 @@ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
|
38 |
logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
|
39 |
logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
|
40 |
logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
|
|
|
41 |
|
42 |
os.makedirs("pretrained_models", exist_ok=True)
|
43 |
|
|
|
38 |
logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
|
39 |
logging.getLogger("python_multipart.multipart").setLevel(logging.ERROR)
|
40 |
logging.getLogger("split_lang.split.splitter").setLevel(logging.ERROR)
|
41 |
+
logging.getLogger("filelock").setLevel(logging.INFO)
|
42 |
|
43 |
os.makedirs("pretrained_models", exist_ok=True)
|
44 |
|