Nymbo commited on
Commit
487b5c4
·
verified ·
1 Parent(s): 4442589

Update helper.py

Browse files
Files changed (1) hide show
  1. helper.py +351 -349
helper.py CHANGED
@@ -1,349 +1,351 @@
1
- import json
2
- import os
3
- import time
4
- from contextlib import contextmanager
5
- from typing import Optional
6
- from unicodedata import normalize
7
- import re
8
-
9
- import numpy as np
10
- import onnxruntime as ort
11
-
12
-
13
- class UnicodeProcessor:
14
- def __init__(self, unicode_indexer_path: str):
15
- with open(unicode_indexer_path, "r") as f:
16
- self.indexer = json.load(f)
17
-
18
- def _preprocess_text(self, text: str) -> str:
19
- # TODO: add more preprocessing
20
- text = normalize("NFKD", text)
21
- return text
22
-
23
- def _get_text_mask(self, text_ids_lengths: np.ndarray) -> np.ndarray:
24
- text_mask = length_to_mask(text_ids_lengths)
25
- return text_mask
26
-
27
- def _text_to_unicode_values(self, text: str) -> np.ndarray:
28
- unicode_values = np.array(
29
- [ord(char) for char in text], dtype=np.uint16
30
- ) # 2 bytes
31
- return unicode_values
32
-
33
- def __call__(self, text_list: list[str]) -> tuple[np.ndarray, np.ndarray]:
34
- text_list = [self._preprocess_text(t) for t in text_list]
35
- text_ids_lengths = np.array([len(text) for text in text_list], dtype=np.int64)
36
- text_ids = np.zeros((len(text_list), text_ids_lengths.max()), dtype=np.int64)
37
- for i, text in enumerate(text_list):
38
- unicode_vals = self._text_to_unicode_values(text)
39
- text_ids[i, : len(unicode_vals)] = np.array(
40
- [self.indexer[val] for val in unicode_vals], dtype=np.int64
41
- )
42
- text_mask = self._get_text_mask(text_ids_lengths)
43
- return text_ids, text_mask
44
-
45
-
46
- class Style:
47
- def __init__(self, style_ttl_onnx: np.ndarray, style_dp_onnx: np.ndarray):
48
- self.ttl = style_ttl_onnx
49
- self.dp = style_dp_onnx
50
-
51
-
52
- class TextToSpeech:
53
- def __init__(
54
- self,
55
- cfgs: dict,
56
- text_processor: UnicodeProcessor,
57
- dp_ort: ort.InferenceSession,
58
- text_enc_ort: ort.InferenceSession,
59
- vector_est_ort: ort.InferenceSession,
60
- vocoder_ort: ort.InferenceSession,
61
- ):
62
- self.cfgs = cfgs
63
- self.text_processor = text_processor
64
- self.dp_ort = dp_ort
65
- self.text_enc_ort = text_enc_ort
66
- self.vector_est_ort = vector_est_ort
67
- self.vocoder_ort = vocoder_ort
68
- self.sample_rate = cfgs["ae"]["sample_rate"]
69
- self.base_chunk_size = cfgs["ae"]["base_chunk_size"]
70
- self.chunk_compress_factor = cfgs["ttl"]["chunk_compress_factor"]
71
- self.ldim = cfgs["ttl"]["latent_dim"]
72
-
73
- def sample_noisy_latent(
74
- self, duration: np.ndarray
75
- ) -> tuple[np.ndarray, np.ndarray]:
76
- bsz = len(duration)
77
- wav_len_max = duration.max() * self.sample_rate
78
- wav_lengths = (duration * self.sample_rate).astype(np.int64)
79
- chunk_size = self.base_chunk_size * self.chunk_compress_factor
80
- latent_len = ((wav_len_max + chunk_size - 1) / chunk_size).astype(np.int32)
81
- latent_dim = self.ldim * self.chunk_compress_factor
82
- noisy_latent = np.random.randn(bsz, latent_dim, latent_len).astype(np.float32)
83
- latent_mask = get_latent_mask(
84
- wav_lengths, self.base_chunk_size, self.chunk_compress_factor
85
- )
86
-
87
- noisy_latent = noisy_latent * latent_mask
88
- return noisy_latent, latent_mask
89
-
90
- def _infer(
91
- self, text_list: list[str], style: Style, total_step: int, speed: float = 1.05
92
- ) -> tuple[np.ndarray, np.ndarray]:
93
- assert (
94
- len(text_list) == style.ttl.shape[0]
95
- ), "Number of texts must match number of style vectors"
96
- bsz = len(text_list)
97
- text_ids, text_mask = self.text_processor(text_list)
98
- dur_onnx, *_ = self.dp_ort.run(
99
- None, {"text_ids": text_ids, "style_dp": style.dp, "text_mask": text_mask}
100
- )
101
- dur_onnx = dur_onnx / speed
102
- text_emb_onnx, *_ = self.text_enc_ort.run(
103
- None,
104
- {"text_ids": text_ids, "style_ttl": style.ttl, "text_mask": text_mask},
105
- ) # dur_onnx: [bsz]
106
- xt, latent_mask = self.sample_noisy_latent(dur_onnx)
107
- total_step_np = np.array([total_step] * bsz, dtype=np.float32)
108
- for step in range(total_step):
109
- current_step = np.array([step] * bsz, dtype=np.float32)
110
- xt, *_ = self.vector_est_ort.run(
111
- None,
112
- {
113
- "noisy_latent": xt,
114
- "text_emb": text_emb_onnx,
115
- "style_ttl": style.ttl,
116
- "text_mask": text_mask,
117
- "latent_mask": latent_mask,
118
- "current_step": current_step,
119
- "total_step": total_step_np,
120
- },
121
- )
122
- wav, *_ = self.vocoder_ort.run(None, {"latent": xt})
123
- return wav, dur_onnx
124
-
125
- def __call__(
126
- self,
127
- text: str,
128
- style: Style,
129
- total_step: int,
130
- speed: float = 1.05,
131
- silence_duration: float = 0.3,
132
- ) -> tuple[np.ndarray, np.ndarray]:
133
- assert (
134
- style.ttl.shape[0] == 1
135
- ), "Single speaker text to speech only supports single style"
136
- text_list = chunk_text(text)
137
- wav_cat = None
138
- dur_cat = None
139
- for text in text_list:
140
- wav, dur_onnx = self._infer([text], style, total_step, speed)
141
- if wav_cat is None:
142
- wav_cat = wav
143
- dur_cat = dur_onnx
144
- else:
145
- silence = np.zeros(
146
- (1, int(silence_duration * self.sample_rate)), dtype=np.float32
147
- )
148
- wav_cat = np.concatenate([wav_cat, silence, wav], axis=1)
149
- dur_cat += dur_onnx + silence_duration
150
- return wav_cat, dur_cat
151
-
152
- def stream(
153
- self,
154
- text: str,
155
- style: Style,
156
- total_step: int,
157
- speed: float = 1.05,
158
- silence_duration: float = 0.3,
159
- ):
160
- assert (
161
- style.ttl.shape[0] == 1
162
- ), "Single speaker text to speech only supports single style"
163
- text_list = chunk_text(text)
164
-
165
- for i, text in enumerate(text_list):
166
- wav, _ = self._infer([text], style, total_step, speed)
167
- yield wav.flatten()
168
-
169
- if i < len(text_list) - 1:
170
- silence = np.zeros(
171
- (int(silence_duration * self.sample_rate),), dtype=np.float32
172
- )
173
- yield silence
174
-
175
- def batch(
176
- self, text_list: list[str], style: Style, total_step: int, speed: float = 1.05
177
- ) -> tuple[np.ndarray, np.ndarray]:
178
- return self._infer(text_list, style, total_step, speed)
179
-
180
-
181
- def length_to_mask(lengths: np.ndarray, max_len: Optional[int] = None) -> np.ndarray:
182
- """
183
- Convert lengths to binary mask.
184
-
185
- Args:
186
- lengths: (B,)
187
- max_len: int
188
-
189
- Returns:
190
- mask: (B, 1, max_len)
191
- """
192
- max_len = max_len or lengths.max()
193
- ids = np.arange(0, max_len)
194
- mask = (ids < np.expand_dims(lengths, axis=1)).astype(np.float32)
195
- return mask.reshape(-1, 1, max_len)
196
-
197
-
198
- def get_latent_mask(
199
- wav_lengths: np.ndarray, base_chunk_size: int, chunk_compress_factor: int
200
- ) -> np.ndarray:
201
- latent_size = base_chunk_size * chunk_compress_factor
202
- latent_lengths = (wav_lengths + latent_size - 1) // latent_size
203
- latent_mask = length_to_mask(latent_lengths)
204
- return latent_mask
205
-
206
-
207
- def load_onnx(
208
- onnx_path: str, opts: ort.SessionOptions, providers: list[str]
209
- ) -> ort.InferenceSession:
210
- return ort.InferenceSession(onnx_path, sess_options=opts, providers=providers)
211
-
212
-
213
- def load_onnx_all(
214
- onnx_dir: str, opts: ort.SessionOptions, providers: list[str]
215
- ) -> tuple[
216
- ort.InferenceSession,
217
- ort.InferenceSession,
218
- ort.InferenceSession,
219
- ort.InferenceSession,
220
- ]:
221
- dp_onnx_path = os.path.join(onnx_dir, "duration_predictor.onnx")
222
- text_enc_onnx_path = os.path.join(onnx_dir, "text_encoder.onnx")
223
- vector_est_onnx_path = os.path.join(onnx_dir, "vector_estimator.onnx")
224
- vocoder_onnx_path = os.path.join(onnx_dir, "vocoder.onnx")
225
-
226
- dp_ort = load_onnx(dp_onnx_path, opts, providers)
227
- text_enc_ort = load_onnx(text_enc_onnx_path, opts, providers)
228
- vector_est_ort = load_onnx(vector_est_onnx_path, opts, providers)
229
- vocoder_ort = load_onnx(vocoder_onnx_path, opts, providers)
230
- return dp_ort, text_enc_ort, vector_est_ort, vocoder_ort
231
-
232
-
233
- def load_cfgs(onnx_dir: str) -> dict:
234
- cfg_path = os.path.join(onnx_dir, "tts.json")
235
- with open(cfg_path, "r") as f:
236
- cfgs = json.load(f)
237
- return cfgs
238
-
239
-
240
- def load_text_processor(onnx_dir: str) -> UnicodeProcessor:
241
- unicode_indexer_path = os.path.join(onnx_dir, "unicode_indexer.json")
242
- text_processor = UnicodeProcessor(unicode_indexer_path)
243
- return text_processor
244
-
245
-
246
- def load_text_to_speech(onnx_dir: str, use_gpu: bool = False) -> TextToSpeech:
247
- opts = ort.SessionOptions()
248
- if use_gpu:
249
- raise NotImplementedError("GPU mode is not fully tested")
250
- else:
251
- providers = ["CPUExecutionProvider"]
252
- print("Using CPU for inference")
253
- cfgs = load_cfgs(onnx_dir)
254
- dp_ort, text_enc_ort, vector_est_ort, vocoder_ort = load_onnx_all(
255
- onnx_dir, opts, providers
256
- )
257
- text_processor = load_text_processor(onnx_dir)
258
- return TextToSpeech(
259
- cfgs, text_processor, dp_ort, text_enc_ort, vector_est_ort, vocoder_ort
260
- )
261
-
262
-
263
- def load_voice_style(voice_style_paths: list[str], verbose: bool = False) -> Style:
264
- bsz = len(voice_style_paths)
265
-
266
- # Read first file to get dimensions
267
- with open(voice_style_paths[0], "r") as f:
268
- first_style = json.load(f)
269
- ttl_dims = first_style["style_ttl"]["dims"]
270
- dp_dims = first_style["style_dp"]["dims"]
271
-
272
- # Pre-allocate arrays with full batch size
273
- ttl_style = np.zeros([bsz, ttl_dims[1], ttl_dims[2]], dtype=np.float32)
274
- dp_style = np.zeros([bsz, dp_dims[1], dp_dims[2]], dtype=np.float32)
275
-
276
- # Fill in the data
277
- for i, voice_style_path in enumerate(voice_style_paths):
278
- with open(voice_style_path, "r") as f:
279
- voice_style = json.load(f)
280
-
281
- ttl_data = np.array(
282
- voice_style["style_ttl"]["data"], dtype=np.float32
283
- ).flatten()
284
- ttl_style[i] = ttl_data.reshape(ttl_dims[1], ttl_dims[2])
285
-
286
- dp_data = np.array(
287
- voice_style["style_dp"]["data"], dtype=np.float32
288
- ).flatten()
289
- dp_style[i] = dp_data.reshape(dp_dims[1], dp_dims[2])
290
-
291
- if verbose:
292
- print(f"Loaded {bsz} voice styles")
293
- return Style(ttl_style, dp_style)
294
-
295
-
296
- @contextmanager
297
- def timer(name: str):
298
- start = time.time()
299
- print(f"{name}...")
300
- yield
301
- print(f" -> {name} completed in {time.time() - start:.2f} sec")
302
-
303
-
304
- def sanitize_filename(text: str, max_len: int) -> str:
305
- """Sanitize filename by replacing non-alphanumeric characters with underscores"""
306
- prefix = text[:max_len]
307
- return re.sub(r"[^a-zA-Z0-9]", "_", prefix)
308
-
309
-
310
- def chunk_text(text: str, max_len: int = 300) -> list[str]:
311
- """
312
- Split text into chunks by paragraphs and sentences.
313
-
314
- Args:
315
- text: Input text to chunk
316
- max_len: Maximum length of each chunk (default: 300)
317
-
318
- Returns:
319
- List of text chunks
320
- """
321
- # Split by paragraph (two or more newlines)
322
- paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text.strip()) if p.strip()]
323
-
324
- chunks = []
325
-
326
- for paragraph in paragraphs:
327
- paragraph = paragraph.strip()
328
- if not paragraph:
329
- continue
330
-
331
- # Split by sentence boundaries (period, question mark, exclamation mark followed by space)
332
- # But exclude common abbreviations like Mr., Mrs., Dr., etc. and single capital letters like F.
333
- pattern = r"(?<!Mr\.)(?<!Mrs\.)(?<!Ms\.)(?<!Dr\.)(?<!Prof\.)(?<!Sr\.)(?<!Jr\.)(?<!Ph\.D\.)(?<!etc\.)(?<!e\.g\.)(?<!i\.e\.)(?<!vs\.)(?<!Inc\.)(?<!Ltd\.)(?<!Co\.)(?<!Corp\.)(?<!St\.)(?<!Ave\.)(?<!Blvd\.)(?<!\b[A-Z]\.)(?<=[.!?])\s+"
334
- sentences = re.split(pattern, paragraph)
335
-
336
- current_chunk = ""
337
-
338
- for sentence in sentences:
339
- if len(current_chunk) + len(sentence) + 1 <= max_len:
340
- current_chunk += (" " if current_chunk else "") + sentence
341
- else:
342
- if current_chunk:
343
- chunks.append(current_chunk.strip())
344
- current_chunk = sentence
345
-
346
- if current_chunk:
347
- chunks.append(current_chunk.strip())
348
-
349
- return chunks
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ from contextlib import contextmanager
5
+ from typing import Optional
6
+ from unicodedata import normalize
7
+ import re
8
+
9
+ import numpy as np
10
+ import onnxruntime as ort
11
+
12
+
13
+ class UnicodeProcessor:
14
+ def __init__(self, unicode_indexer_path: str):
15
+ with open(unicode_indexer_path, "r") as f:
16
+ self.indexer = json.load(f)
17
+
18
+ def _preprocess_text(self, text: str) -> str:
19
+ # TODO: add more preprocessing
20
+ text = normalize("NFKD", text)
21
+ return text
22
+
23
+ def _get_text_mask(self, text_ids_lengths: np.ndarray) -> np.ndarray:
24
+ text_mask = length_to_mask(text_ids_lengths)
25
+ return text_mask
26
+
27
+ def _text_to_unicode_values(self, text: str) -> np.ndarray:
28
+ unicode_values = np.array(
29
+ [ord(char) for char in text], dtype=np.uint16
30
+ ) # 2 bytes
31
+ return unicode_values
32
+
33
+ def __call__(self, text_list: list[str]) -> tuple[np.ndarray, np.ndarray]:
34
+ text_list = [self._preprocess_text(t) for t in text_list]
35
+ text_ids_lengths = np.array([len(text) for text in text_list], dtype=np.int64)
36
+ text_ids = np.zeros((len(text_list), text_ids_lengths.max()), dtype=np.int64)
37
+ for i, text in enumerate(text_list):
38
+ unicode_vals = self._text_to_unicode_values(text)
39
+ text_ids[i, : len(unicode_vals)] = np.array(
40
+ [self.indexer[val] for val in unicode_vals], dtype=np.int64
41
+ )
42
+ text_mask = self._get_text_mask(text_ids_lengths)
43
+ return text_ids, text_mask
44
+
45
+
46
+ class Style:
47
+ def __init__(self, style_ttl_onnx: np.ndarray, style_dp_onnx: np.ndarray):
48
+ self.ttl = style_ttl_onnx
49
+ self.dp = style_dp_onnx
50
+
51
+
52
+ class TextToSpeech:
53
+ def __init__(
54
+ self,
55
+ cfgs: dict,
56
+ text_processor: UnicodeProcessor,
57
+ dp_ort: ort.InferenceSession,
58
+ text_enc_ort: ort.InferenceSession,
59
+ vector_est_ort: ort.InferenceSession,
60
+ vocoder_ort: ort.InferenceSession,
61
+ ):
62
+ self.cfgs = cfgs
63
+ self.text_processor = text_processor
64
+ self.dp_ort = dp_ort
65
+ self.text_enc_ort = text_enc_ort
66
+ self.vector_est_ort = vector_est_ort
67
+ self.vocoder_ort = vocoder_ort
68
+ self.sample_rate = cfgs["ae"]["sample_rate"]
69
+ self.base_chunk_size = cfgs["ae"]["base_chunk_size"]
70
+ self.chunk_compress_factor = cfgs["ttl"]["chunk_compress_factor"]
71
+ self.ldim = cfgs["ttl"]["latent_dim"]
72
+
73
+ def sample_noisy_latent(
74
+ self, duration: np.ndarray
75
+ ) -> tuple[np.ndarray, np.ndarray]:
76
+ bsz = len(duration)
77
+ wav_len_max = duration.max() * self.sample_rate
78
+ wav_lengths = (duration * self.sample_rate).astype(np.int64)
79
+ chunk_size = self.base_chunk_size * self.chunk_compress_factor
80
+ latent_len = ((wav_len_max + chunk_size - 1) / chunk_size).astype(np.int32)
81
+ latent_dim = self.ldim * self.chunk_compress_factor
82
+ noisy_latent = np.random.randn(bsz, latent_dim, latent_len).astype(np.float32)
83
+ latent_mask = get_latent_mask(
84
+ wav_lengths, self.base_chunk_size, self.chunk_compress_factor
85
+ )
86
+
87
+ noisy_latent = noisy_latent * latent_mask
88
+ return noisy_latent, latent_mask
89
+
90
+ def _infer(
91
+ self, text_list: list[str], style: Style, total_step: int, speed: float = 1.05
92
+ ) -> tuple[np.ndarray, np.ndarray]:
93
+ assert (
94
+ len(text_list) == style.ttl.shape[0]
95
+ ), "Number of texts must match number of style vectors"
96
+ bsz = len(text_list)
97
+ text_ids, text_mask = self.text_processor(text_list)
98
+ dur_onnx, *_ = self.dp_ort.run(
99
+ None, {"text_ids": text_ids, "style_dp": style.dp, "text_mask": text_mask}
100
+ )
101
+ dur_onnx = dur_onnx / speed
102
+ text_emb_onnx, *_ = self.text_enc_ort.run(
103
+ None,
104
+ {"text_ids": text_ids, "style_ttl": style.ttl, "text_mask": text_mask},
105
+ ) # dur_onnx: [bsz]
106
+ xt, latent_mask = self.sample_noisy_latent(dur_onnx)
107
+ total_step_np = np.array([total_step] * bsz, dtype=np.float32)
108
+ for step in range(total_step):
109
+ current_step = np.array([step] * bsz, dtype=np.float32)
110
+ xt, *_ = self.vector_est_ort.run(
111
+ None,
112
+ {
113
+ "noisy_latent": xt,
114
+ "text_emb": text_emb_onnx,
115
+ "style_ttl": style.ttl,
116
+ "text_mask": text_mask,
117
+ "latent_mask": latent_mask,
118
+ "current_step": current_step,
119
+ "total_step": total_step_np,
120
+ },
121
+ )
122
+ wav, *_ = self.vocoder_ort.run(None, {"latent": xt})
123
+ return wav, dur_onnx
124
+
125
+ def __call__(
126
+ self,
127
+ text: str,
128
+ style: Style,
129
+ total_step: int,
130
+ speed: float = 1.05,
131
+ silence_duration: float = 0.3,
132
+ max_len: int = 300,
133
+ ) -> tuple[np.ndarray, np.ndarray]:
134
+ assert (
135
+ style.ttl.shape[0] == 1
136
+ ), "Single speaker text to speech only supports single style"
137
+ text_list = chunk_text(text, max_len=max_len)
138
+ wav_cat = None
139
+ dur_cat = None
140
+ for text in text_list:
141
+ wav, dur_onnx = self._infer([text], style, total_step, speed)
142
+ if wav_cat is None:
143
+ wav_cat = wav
144
+ dur_cat = dur_onnx
145
+ else:
146
+ silence = np.zeros(
147
+ (1, int(silence_duration * self.sample_rate)), dtype=np.float32
148
+ )
149
+ wav_cat = np.concatenate([wav_cat, silence, wav], axis=1)
150
+ dur_cat += dur_onnx + silence_duration
151
+ return wav_cat, dur_cat
152
+
153
+ def stream(
154
+ self,
155
+ text: str,
156
+ style: Style,
157
+ total_step: int,
158
+ speed: float = 1.05,
159
+ silence_duration: float = 0.3,
160
+ max_len: int = 300,
161
+ ):
162
+ assert (
163
+ style.ttl.shape[0] == 1
164
+ ), "Single speaker text to speech only supports single style"
165
+ text_list = chunk_text(text, max_len=max_len)
166
+
167
+ for i, text in enumerate(text_list):
168
+ wav, _ = self._infer([text], style, total_step, speed)
169
+ yield wav.flatten()
170
+
171
+ if i < len(text_list) - 1:
172
+ silence = np.zeros(
173
+ (int(silence_duration * self.sample_rate),), dtype=np.float32
174
+ )
175
+ yield silence
176
+
177
+ def batch(
178
+ self, text_list: list[str], style: Style, total_step: int, speed: float = 1.05
179
+ ) -> tuple[np.ndarray, np.ndarray]:
180
+ return self._infer(text_list, style, total_step, speed)
181
+
182
+
183
+ def length_to_mask(lengths: np.ndarray, max_len: Optional[int] = None) -> np.ndarray:
184
+ """
185
+ Convert lengths to binary mask.
186
+
187
+ Args:
188
+ lengths: (B,)
189
+ max_len: int
190
+
191
+ Returns:
192
+ mask: (B, 1, max_len)
193
+ """
194
+ max_len = max_len or lengths.max()
195
+ ids = np.arange(0, max_len)
196
+ mask = (ids < np.expand_dims(lengths, axis=1)).astype(np.float32)
197
+ return mask.reshape(-1, 1, max_len)
198
+
199
+
200
+ def get_latent_mask(
201
+ wav_lengths: np.ndarray, base_chunk_size: int, chunk_compress_factor: int
202
+ ) -> np.ndarray:
203
+ latent_size = base_chunk_size * chunk_compress_factor
204
+ latent_lengths = (wav_lengths + latent_size - 1) // latent_size
205
+ latent_mask = length_to_mask(latent_lengths)
206
+ return latent_mask
207
+
208
+
209
+ def load_onnx(
210
+ onnx_path: str, opts: ort.SessionOptions, providers: list[str]
211
+ ) -> ort.InferenceSession:
212
+ return ort.InferenceSession(onnx_path, sess_options=opts, providers=providers)
213
+
214
+
215
+ def load_onnx_all(
216
+ onnx_dir: str, opts: ort.SessionOptions, providers: list[str]
217
+ ) -> tuple[
218
+ ort.InferenceSession,
219
+ ort.InferenceSession,
220
+ ort.InferenceSession,
221
+ ort.InferenceSession,
222
+ ]:
223
+ dp_onnx_path = os.path.join(onnx_dir, "duration_predictor.onnx")
224
+ text_enc_onnx_path = os.path.join(onnx_dir, "text_encoder.onnx")
225
+ vector_est_onnx_path = os.path.join(onnx_dir, "vector_estimator.onnx")
226
+ vocoder_onnx_path = os.path.join(onnx_dir, "vocoder.onnx")
227
+
228
+ dp_ort = load_onnx(dp_onnx_path, opts, providers)
229
+ text_enc_ort = load_onnx(text_enc_onnx_path, opts, providers)
230
+ vector_est_ort = load_onnx(vector_est_onnx_path, opts, providers)
231
+ vocoder_ort = load_onnx(vocoder_onnx_path, opts, providers)
232
+ return dp_ort, text_enc_ort, vector_est_ort, vocoder_ort
233
+
234
+
235
+ def load_cfgs(onnx_dir: str) -> dict:
236
+ cfg_path = os.path.join(onnx_dir, "tts.json")
237
+ with open(cfg_path, "r") as f:
238
+ cfgs = json.load(f)
239
+ return cfgs
240
+
241
+
242
+ def load_text_processor(onnx_dir: str) -> UnicodeProcessor:
243
+ unicode_indexer_path = os.path.join(onnx_dir, "unicode_indexer.json")
244
+ text_processor = UnicodeProcessor(unicode_indexer_path)
245
+ return text_processor
246
+
247
+
248
+ def load_text_to_speech(onnx_dir: str, use_gpu: bool = False) -> TextToSpeech:
249
+ opts = ort.SessionOptions()
250
+ if use_gpu:
251
+ raise NotImplementedError("GPU mode is not fully tested")
252
+ else:
253
+ providers = ["CPUExecutionProvider"]
254
+ print("Using CPU for inference")
255
+ cfgs = load_cfgs(onnx_dir)
256
+ dp_ort, text_enc_ort, vector_est_ort, vocoder_ort = load_onnx_all(
257
+ onnx_dir, opts, providers
258
+ )
259
+ text_processor = load_text_processor(onnx_dir)
260
+ return TextToSpeech(
261
+ cfgs, text_processor, dp_ort, text_enc_ort, vector_est_ort, vocoder_ort
262
+ )
263
+
264
+
265
+ def load_voice_style(voice_style_paths: list[str], verbose: bool = False) -> Style:
266
+ bsz = len(voice_style_paths)
267
+
268
+ # Read first file to get dimensions
269
+ with open(voice_style_paths[0], "r") as f:
270
+ first_style = json.load(f)
271
+ ttl_dims = first_style["style_ttl"]["dims"]
272
+ dp_dims = first_style["style_dp"]["dims"]
273
+
274
+ # Pre-allocate arrays with full batch size
275
+ ttl_style = np.zeros([bsz, ttl_dims[1], ttl_dims[2]], dtype=np.float32)
276
+ dp_style = np.zeros([bsz, dp_dims[1], dp_dims[2]], dtype=np.float32)
277
+
278
+ # Fill in the data
279
+ for i, voice_style_path in enumerate(voice_style_paths):
280
+ with open(voice_style_path, "r") as f:
281
+ voice_style = json.load(f)
282
+
283
+ ttl_data = np.array(
284
+ voice_style["style_ttl"]["data"], dtype=np.float32
285
+ ).flatten()
286
+ ttl_style[i] = ttl_data.reshape(ttl_dims[1], ttl_dims[2])
287
+
288
+ dp_data = np.array(
289
+ voice_style["style_dp"]["data"], dtype=np.float32
290
+ ).flatten()
291
+ dp_style[i] = dp_data.reshape(dp_dims[1], dp_dims[2])
292
+
293
+ if verbose:
294
+ print(f"Loaded {bsz} voice styles")
295
+ return Style(ttl_style, dp_style)
296
+
297
+
298
+ @contextmanager
299
+ def timer(name: str):
300
+ start = time.time()
301
+ print(f"{name}...")
302
+ yield
303
+ print(f" -> {name} completed in {time.time() - start:.2f} sec")
304
+
305
+
306
+ def sanitize_filename(text: str, max_len: int) -> str:
307
+ """Sanitize filename by replacing non-alphanumeric characters with underscores"""
308
+ prefix = text[:max_len]
309
+ return re.sub(r"[^a-zA-Z0-9]", "_", prefix)
310
+
311
+
312
+ def chunk_text(text: str, max_len: int = 300) -> list[str]:
313
+ """
314
+ Split text into chunks by paragraphs and sentences.
315
+
316
+ Args:
317
+ text: Input text to chunk
318
+ max_len: Maximum length of each chunk (default: 300)
319
+
320
+ Returns:
321
+ List of text chunks
322
+ """
323
+ # Split by paragraph (two or more newlines)
324
+ paragraphs = [p.strip() for p in re.split(r"\n\s*\n+", text.strip()) if p.strip()]
325
+
326
+ chunks = []
327
+
328
+ for paragraph in paragraphs:
329
+ paragraph = paragraph.strip()
330
+ if not paragraph:
331
+ continue
332
+
333
+ # Split by sentence boundaries (period, question mark, exclamation mark followed by space)
334
+ # But exclude common abbreviations like Mr., Mrs., Dr., etc. and single capital letters like F.
335
+ pattern = r"(?<!Mr\.)(?<!Mrs\.)(?<!Ms\.)(?<!Dr\.)(?<!Prof\.)(?<!Sr\.)(?<!Jr\.)(?<!Ph\.D\.)(?<!etc\.)(?<!e\.g\.)(?<!i\.e\.)(?<!vs\.)(?<!Inc\.)(?<!Ltd\.)(?<!Co\.)(?<!Corp\.)(?<!St\.)(?<!Ave\.)(?<!Blvd\.)(?<!\b[A-Z]\.)(?<=[.!?])\s+"
336
+ sentences = re.split(pattern, paragraph)
337
+
338
+ current_chunk = ""
339
+
340
+ for sentence in sentences:
341
+ if len(current_chunk) + len(sentence) + 1 <= max_len:
342
+ current_chunk += (" " if current_chunk else "") + sentence
343
+ else:
344
+ if current_chunk:
345
+ chunks.append(current_chunk.strip())
346
+ current_chunk = sentence
347
+
348
+ if current_chunk:
349
+ chunks.append(current_chunk.strip())
350
+
351
+ return chunks