MCplayer commited on
Commit
29c0409
·
1 Parent(s): fe8f545

speech similarity model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.trt filter=lfs diff=lfs merge=lfs -text
37
+ *.pt.*.partial filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .venv/
3
+ *.pyc
4
+ .DS_Store
5
+ outputs/
6
+ logs/
7
+ eval_res/
README.md CHANGED
@@ -1,3 +1,137 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ # Speaker Similarity Evaluator - 新格式使用说明
6
+
7
+ ## 概述
8
+
9
+ 音色相似度评估器现在支持从JSONL文件读取输入数据,并且支持两种prompt音频输入模式:
10
+ 1. **预分割模式**:直接提供S1和S2的分别音频文件
11
+ 2. **自动分割模式**:提供combined prompt音频,程序自动按说话人标签分割
12
+
13
+ ## 输入格式
14
+
15
+ ### JSONL文件格式
16
+
17
+ 每行是一个JSON对象,必须包含以下字段:
18
+
19
+ #### 必需字段
20
+ - `text`: 待评估的文本,包含说话人标签[S1][S2]
21
+ - `output_audio`: 待评估的音频文件路径
22
+
23
+ #### prompt音频字段(两种模式择一)
24
+
25
+ **模式1:预分割模式**
26
+ - `prompt_audio_speaker1`: S1说话人的音频文件
27
+ - `prompt_text_speaker1`: S1说话人的文本
28
+ - `prompt_audio_speaker2`: S2说话人的音频文件
29
+ - `prompt_text_speaker2`: S2说话人的文本
30
+
31
+ **模式2:自动分割模式**
32
+ - `prompt_audio`: 包含两个说话人的combined音频文件
33
+ - `prompt_text`: 包含说话人标签的文本,如"[S1]文本1[S2]文本2"
34
+
35
+ ### 示例
36
+
37
+ #### 预分割模式示例
38
+ ```json
39
+ {
40
+ "text": "[S1]是我对不住你。[S2]没有没有!燕子幸亏咱俩没领证!",
41
+ "prompt_audio_speaker1": "/path/to/speaker1.wav",
42
+ "prompt_text_speaker1": "一共二十万我都记着呢。我一赚到钱就马上还给你。",
43
+ "prompt_audio_speaker2": "/path/to/speaker2.wav",
44
+ "prompt_text_speaker2": "没关系,我不缺钱。",
45
+ "output_audio": "/path/to/output.wav"
46
+ }
47
+ ```
48
+
49
+ #### 自动分割模式示例
50
+ ```json
51
+ {
52
+ "text": "[S1]今天天气真好啊。[S2]是的,阳光明媚。",
53
+ "prompt_audio": "/path/to/combined_prompt.wav",
54
+ "prompt_text": "[S1]早上好,今天怎么样?[S2]很好,谢谢你的关心。",
55
+ "output_audio": "/path/to/output.wav"
56
+ }
57
+ ```
58
+
59
+ #### 混合模式示例(同时提供两种模式,优先使用预分割)
60
+ ```json
61
+ {
62
+ "text": "[S1]是我对不住你。[S2]没有没有!",
63
+ "prompt_audio": "/path/to/combined.wav",
64
+ "prompt_text": "[S1]一共二十万我都记着呢。[S2]没关系,我不缺钱。",
65
+ "prompt_audio_speaker1": "/path/to/speaker1.wav",
66
+ "prompt_text_speaker1": "一共二十万我都记着呢。我一赚到钱就马上还给你。",
67
+ "prompt_audio_speaker2": "/path/to/speaker2.wav",
68
+ "prompt_text_speaker2": "没关系,我不缺钱。",
69
+ "output_audio": "/path/to/output.wav"
70
+ }
71
+ ```
72
+
73
+ ## 使用方法
74
+
75
+ ### 命令行运行
76
+
77
+ ```bash
78
+ # 使用JSONL文件输入
79
+ python test.py --jsonl_path /path/to/your/input.jsonl --output_dir /path/to/results
80
+
81
+ # 使用默认示例数据(向后兼容)
82
+ python test.py --output_dir /path/to/results
83
+ ```
84
+
85
+ ### 程序调用
86
+
87
+ ```python
88
+ from test import SpeakerSimilarityEvaluator
89
+
90
+ # 创建评估器
91
+ evaluator = SpeakerSimilarityEvaluator(output_dir="/path/to/results")
92
+
93
+ # 从JSONL文件处理
94
+ evaluator.process_batch_from_jsonl("/path/to/input.jsonl")
95
+
96
+ # 或者直接传入数据列表(旧接口,向后兼容)
97
+ input_data = [
98
+ {
99
+ 'prompt_audio': "/path/to/prompt.wav",
100
+ 'prompt_text': "[S1]文本1[S2]文本2",
101
+ 'text': "[S1]输出文本1[S2]输出文本2",
102
+ 'output_audio': "/path/to/output.wav"
103
+ }
104
+ ]
105
+ evaluator.process_batch(input_data)
106
+ ```
107
+
108
+ ## 优势
109
+
110
+ ### 预分割模式的优势
111
+ 1. **更高精度**:避免了自动分割可能带来的误差
112
+ 2. **更快速度**:跳过音频分割步骤
113
+ 3. **更稳定**:不依赖词对齐模型的准确性
114
+
115
+ ### 自动分割模式的优势
116
+ 1. **便利性**:只需要提供一个combined音频文件
117
+ 2. **向后兼容**:与现有数据格式兼容
118
+
119
+ ## 输出文件结构
120
+
121
+ ```
122
+ results_YYYYMMDD_HHMMSS/
123
+ ├── segments/ # 分割后的音频片段
124
+ ├── prompts/ # prompt音频的S1和S2片段(仅自动分割模式)
125
+ ├── temp/ # 临时文件(运行结束后清空)
126
+ └── results/ # 评估结果
127
+ ├── speaker_similarity_results_YYYYMMDD_HHMMSS.jsonl
128
+ └── evaluation_summary_YYYYMMDD_HHMMSS.json
129
+ ```
130
+
131
+ ## 注意事项
132
+
133
+ 1. 确保所有音频文件路径正确且文件存在
134
+ 2. 文本中的说话人标签格式必须为`[S1]`和`[S2]`
135
+ 3. 如果同时提供两种模式的数据,程序优先使用预分割模式
136
+ 4. JSONL文件中的每行必须是有效的JSON格式
137
+ 5. 程序会自动验证输入数据的完整性,跳过有问题的行并继续处理
alignment.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torchaudio.functional as F
4
+ import torchaudio
5
+ import uroman as ur
6
+ import logging
7
+ import traceback
8
+
9
+ def convert_to_list_with_punctuation_mixed(text):
10
+ """处理中文文本(可能包含英文单词) - 中文按字符分割,英文单词保持完整"""
11
+ result = []
12
+ text = text.strip()
13
+
14
+ if not text:
15
+ return result
16
+
17
+ def is_chinese(char):
18
+ """检查是否是汉字"""
19
+ return '\u4e00' <= char <= '\u9fff'
20
+
21
+ # 使用更精确的正则表达式来分割文本
22
+ # 匹配:英文单词(含数字)、单个汉字、标点符号
23
+ pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
24
+ tokens = re.findall(pattern, text)
25
+
26
+ for token in tokens:
27
+ if not token.strip(): # 跳过空字符
28
+ continue
29
+
30
+ if re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词(可能包含数字)
31
+ result.append(token)
32
+ elif is_chinese(token): # 单个汉字
33
+ result.append(token)
34
+ else: # 标点符号等其他字符
35
+ # 标点符号加到前一个词后面
36
+ if result:
37
+ result[-1] += token
38
+ else:
39
+ # 如果是文本开头的标点,单独作为一项
40
+ result.append(token)
41
+
42
+ return result
43
+
44
+ def split_and_merge_punctuation(text):
45
+ """处理英文 - 按单词分割,保持单词完整性"""
46
+ # 先按空格拆分文本
47
+ elements = text.split()
48
+
49
+ # 用于保存最终的结果
50
+ result = []
51
+
52
+ # 遍历每个拆分后的元素
53
+ for ele in elements:
54
+ # 使用正则表达式提取连续字母、数字和标点
55
+ parts = re.findall(r'[a-zA-Z0-9]+|[^\w\s]+', ele)
56
+
57
+ # 用于保存拆分后的部分
58
+ merged_parts = []
59
+
60
+ for i in range(len(parts)):
61
+ if i % 2 == 0: # 如果是字母或数字部分
62
+ # 将字母或数字部分添加到结果中
63
+ merged_parts.append(parts[i])
64
+ else: # 如果是标点或其他符号部分
65
+ # 将标点部分与前面的字母或数字部分合并
66
+ if merged_parts:
67
+ merged_parts[-1] += parts[i]
68
+ else:
69
+ merged_parts.append(parts[i])
70
+
71
+ # 将合并后的部分加入最终结果
72
+ result.extend(merged_parts)
73
+
74
+ return result
75
+
76
+
77
+ def get_aligned_result_text_with_punctuation(alignment_result, text, language):
78
+ """
79
+ 将对齐结果转换为正确的文本tokens,英文保持单词级别,中文保持字符级别(但英文单词完整)
80
+ """
81
+ logging.info("start change text to text_tokens")
82
+
83
+ if language == "EN":
84
+ text_tokens = split_and_merge_punctuation(text) # 英文按单词分词
85
+ elif language == "ZH":
86
+ text_tokens = convert_to_list_with_punctuation_mixed(text) # 中文按字符分割,但英文单词保持完整
87
+ else:
88
+ raise ValueError(f"Unsupported language: {language}")
89
+
90
+ logging.info(f"Text tokens count: {len(text_tokens)}, Alignment result count: {len(alignment_result)}")
91
+
92
+ punctuations = set(',.!?;:()[]<>\'\"…·,。;:!?()【】《》''""\、')
93
+
94
+ logging.info("start get align result text with punctuation")
95
+ updated_alignment_result = []
96
+ token_idx = 0
97
+
98
+ for index, align_item in enumerate(alignment_result):
99
+ if token_idx >= len(text_tokens):
100
+ # 如果text_tokens用完了但还有对齐结果,跳出循环
101
+ logging.warning(f"Text tokens exhausted at index {token_idx}, but alignment has more items")
102
+ break
103
+
104
+ start = align_item["start"]
105
+ end = align_item["end"]
106
+ text_token = text_tokens[token_idx]
107
+
108
+ # 检查该 token 后是否有连续标点(仅对中文)
109
+ if language == "ZH":
110
+ while token_idx + 1 < len(text_tokens) and text_tokens[token_idx + 1] in punctuations:
111
+ assert False, "???" # 这里理论上应该进不去??
112
+ text_token += text_tokens[token_idx + 1] # 将标点加入
113
+ token_idx += 1
114
+ else:
115
+ # 英文不需要特殊的标点处理,因为标点已经在split_and_merge_punctuation中处理了
116
+ pass
117
+
118
+ # 更新对齐结果
119
+ updated_item = {
120
+ "start": start,
121
+ "end": end,
122
+ "transcript": text_token
123
+ }
124
+ updated_item.update({key: align_item[key] for key in align_item if key not in ["start", "end", "transcript"]})
125
+
126
+ updated_alignment_result.append(updated_item)
127
+ token_idx += 1
128
+
129
+ logging.info("end get align result text with punctuation")
130
+ return updated_alignment_result
131
+
132
+
133
+ class AlignmentModel:
134
+ def __init__(self, device, model_dir='/data-mnt/data/wy/X-Codec-2.0/checkpoints'):
135
+ """
136
+ 初始化对齐模型并加载必要的资源
137
+ :param device: 设备类型 ("cuda" 或 "cpu")
138
+ :param model_dir: 模型目录路径
139
+ """
140
+ self.device = torch.device(device)
141
+ self.bundle = torchaudio.pipelines.MMS_FA
142
+ self.align_model = self.bundle.get_model(with_star=False, dl_kwargs={'model_dir': model_dir}).to(self.device)
143
+ self.uroman = ur.Uroman()
144
+ self.DICTIONARY = self.bundle.get_dict()
145
+
146
+ def align(self, emission, tokens):
147
+ """
148
+ 执行强对齐
149
+ :param emission: 模型的输出
150
+ :param tokens: 目标 tokens
151
+ :return: 对齐的 tokens 和分数
152
+ """
153
+ alignments, scores = F.forced_align(
154
+ log_probs=emission,
155
+ targets=tokens,
156
+ blank=0
157
+ )
158
+ alignments, scores = alignments[0], scores[0]
159
+ scores = scores.exp()
160
+ return alignments, scores
161
+
162
+ def unflatten(self, list_, lengths):
163
+ """
164
+ 将一个长列表按照长度拆分成子列表
165
+ :param list_: 长列表
166
+ :param lengths: 各子列表的长度
167
+ :return: 拆分后的子列表
168
+ """
169
+ assert len(list_) == sum(lengths)
170
+ i = 0
171
+ ret = []
172
+ for l in lengths:
173
+ ret.append(list_[i:i + l])
174
+ i += l
175
+ return ret
176
+
177
+ def preview_word(self, waveform, spans, num_frames, transcript, sample_rate):
178
+ """
179
+ 预览每个单词的开始时间和结束时间
180
+ :param waveform: 音频波形
181
+ :param spans: 单词的跨度
182
+ :param num_frames: 帧数
183
+ :param transcript: 转录文本
184
+ :param sample_rate: 采样率
185
+ :return: 单词的对齐信息
186
+ """
187
+ end = 0
188
+ alignment_result = []
189
+ for span, trans in zip(spans, transcript):
190
+ ratio = waveform.size(1) / num_frames
191
+ x0 = int(ratio * span[0].start)
192
+ x1 = int(ratio * span[-1].end)
193
+ align_info = {
194
+ "transcript": trans,
195
+ "start": round(x0 / sample_rate, 3),
196
+ "end": round(x1 / sample_rate, 3)
197
+ }
198
+ align_info["pause"] = round(align_info["start"] - end, 3)
199
+ align_info["duration"] = round(align_info["end"] - align_info["start"], 3)
200
+ end = align_info["end"]
201
+ alignment_result.append(align_info)
202
+ return alignment_result
203
+
204
+ def make_wav_batch(self, wav_list):
205
+ """
206
+ 将 wav_list 中的每个 wav 张量填充为相同的长度,返回填充后的张量和每个张量的原始长度。
207
+ :param wav_list: wav 文件列表
208
+ :return: 填充后的音频张量和原始长度
209
+ """
210
+ wav_lengths = torch.tensor([wav.size(0) for wav in wav_list], dtype=torch.long)
211
+ max_length = max(wav_lengths)
212
+ # 确保张量在正确的设备上
213
+ wavs_tensors = torch.zeros(len(wav_list), max_length, device=self.device)
214
+ for i, wav in enumerate(wav_list):
215
+ wav = wav.to(self.device) # 确保wav在正确的设备上
216
+ wavs_tensors[i, :wav_lengths[i]] = wav
217
+ return wavs_tensors, wav_lengths.to(self.device)
218
+
219
+ def get_target(self, transcript, language):
220
+ """
221
+ 获取给定转录文本的目标 tokens - 修正版本,保持英文单词完整性
222
+ """
223
+ original_transcript = transcript # 保存原始文本用于调试
224
+
225
+ if language == "ZH":
226
+ # 中文处理:保持英文单词完整,只对中文字符进行romanization
227
+ # 使用相同的分词逻辑
228
+ pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
229
+ tokens = re.findall(pattern, transcript)
230
+
231
+ # 分别处理中文字符和英文单词
232
+ processed_parts = []
233
+ for token in tokens:
234
+ if not token.strip():
235
+ continue
236
+ elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词
237
+ # 英文单词保持原样,不进行romanization
238
+ processed_parts.append(token.lower())
239
+ elif '\u4e00' <= token <= '\u9fff': # 中文字符
240
+ # 只对中文字符进行romanization
241
+ romanized = self.uroman.romanize_string(token)
242
+ processed_parts.append(romanized)
243
+ else: # 标点符号等
244
+ # 标点符号直接添加,但会在后续步骤中被过滤掉
245
+ processed_parts.append(token)
246
+
247
+ # 用空格连接所有部分
248
+ transcript = ' '.join(processed_parts)
249
+
250
+ elif language == "EN":
251
+ # 英文处理:保持单词结构,只是清理标点
252
+ pass
253
+ else:
254
+ assert False, f"Unsupported language: {language}"
255
+
256
+ # 清理标点符号
257
+ transcript = re.sub(r'[^\w\s]', r' ', transcript)
258
+ TRANSCRIPT = transcript.lower().split()
259
+
260
+ # 提前获取字典中的特殊符号 token
261
+ star_token = self.DICTIONARY['*']
262
+ tokenized_transcript = []
263
+
264
+ # 统一的tokenization逻辑
265
+ for word in TRANSCRIPT:
266
+ # 对每个word中的字符进行token化
267
+ word_tokens = []
268
+ for c in word:
269
+ if c in self.DICTIONARY and c != '-':
270
+ word_tokens.append(self.DICTIONARY[c])
271
+ else:
272
+ word_tokens.append(star_token)
273
+ tokenized_transcript.extend(word_tokens)
274
+
275
+ logging.info(f"Original transcript: {original_transcript}")
276
+ logging.info(f"Processed transcript: {transcript}")
277
+ logging.info(f"Final TRANSCRIPT: {TRANSCRIPT}")
278
+
279
+ return torch.tensor([tokenized_transcript], dtype=torch.int32, device=self.device)
280
+
281
+ def get_alignment_result(self, emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language):
282
+ """
283
+ 根据给定的 emission 和对齐信息生成对齐结果 - 修正版本
284
+ """
285
+ original_transcript = transcript # 保存原始文本
286
+
287
+ if language == "ZH":
288
+ # 使用与get_target相同的处理逻辑
289
+ pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
290
+ tokens = re.findall(pattern, transcript)
291
+
292
+ processed_parts = []
293
+ for token in tokens:
294
+ if not token.strip():
295
+ continue
296
+ elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词
297
+ processed_parts.append(token.lower())
298
+ elif '\u4e00' <= token <= '\u9fff': # 中文字符
299
+ romanized = self.uroman.romanize_string(token)
300
+ processed_parts.append(romanized)
301
+ else: # 标点符号等
302
+ processed_parts.append(token)
303
+
304
+ transcript = ' '.join(processed_parts)
305
+ elif language == "EN":
306
+ pass
307
+ else:
308
+ assert False, f"Unsupported language: {language}"
309
+
310
+ transcript = re.sub(r'[^\w\s]', r' ', transcript)
311
+ emission = emission_padded[:emission_length, :].unsqueeze(0)
312
+ TRANSCRIPT = transcript.lower().split()
313
+
314
+ token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
315
+
316
+ # 统一的分组逻辑
317
+ word_spans = self.unflatten(token_spans, [len(word) for word in TRANSCRIPT])
318
+
319
+ num_frames = emission.size(1)
320
+
321
+ logging.info(f"Original transcript for alignment: {original_transcript}")
322
+ logging.info(f"Processed TRANSCRIPT: {TRANSCRIPT}")
323
+
324
+ return self.preview_word(waveform.unsqueeze(0), word_spans, num_frames, TRANSCRIPT, self.bundle.sample_rate)
325
+
326
+ def batch_alignment(self, wav_list, transcript_list, language_list):
327
+ """
328
+ 批量对齐
329
+ :param wav_list: wav 文件列表
330
+ :param transcript_list: 转录文本列表
331
+ :param language_list: 语言类型列表
332
+ :return: 对齐结果列表
333
+ """
334
+ wavs_tensors, wavs_lengths_tensor = self.make_wav_batch(wav_list)
335
+ logging.info("start alignment model forward")
336
+ with torch.inference_mode():
337
+ emission, emission_lengths = self.align_model(wavs_tensors.to(self.device), wavs_lengths_tensor)
338
+ star_dim = torch.zeros((emission.shape[0], emission.size(1), 1), dtype=emission.dtype, device=self.device)
339
+ emission = torch.cat((emission, star_dim), dim=-1)
340
+
341
+ logging.info("end alignment model forward")
342
+
343
+ target_list = [self.get_target(transcript, language) for transcript, language in zip(transcript_list, language_list)]
344
+
345
+ logging.info("align success")
346
+ align_results = [
347
+ self.align(emission_padded[:emission_length, :].unsqueeze(0), target)
348
+ for emission_padded, emission_length, target in zip(emission, emission_lengths, target_list)
349
+ ]
350
+
351
+ logging.info("get align result")
352
+ batch_aligned_tokens = [align_result[0] for align_result in align_results]
353
+ batch_alignment_scores = [align_result[1] for align_result in align_results]
354
+
355
+ alignment_result_list = [
356
+ self.get_alignment_result(emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language)
357
+ for emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language
358
+ in zip(emission, emission_lengths, batch_aligned_tokens, batch_alignment_scores, transcript_list, wav_list, language_list)
359
+ ]
360
+ logging.info("get align result success")
361
+ return alignment_result_list
362
+
363
+
364
+ def batch_get_alignment_result(alignment_model, wav_list, transcript_list, language_list):
365
+ """
366
+ 批量获取对齐结果的便捷函数
367
+ """
368
+ alignment_results = alignment_model.batch_alignment(
369
+ wav_list=wav_list,
370
+ transcript_list=transcript_list,
371
+ language_list=language_list
372
+ )
373
+
374
+ alignments_results_with_text_and_punctuation = []
375
+ for alignment_result, transcript, language in zip(alignment_results, transcript_list, language_list):
376
+ try:
377
+ result = get_aligned_result_text_with_punctuation(alignment_result, transcript, language)
378
+ alignments_results_with_text_and_punctuation.append(result)
379
+ except:
380
+ logger = logging.getLogger("tokenize")
381
+ logger.error(f"Error in processing {alignment_result}")
382
+ traceback.print_exc()
383
+ alignments_results_with_text_and_punctuation.append(alignment_result)
384
+ return alignments_results_with_text_and_punctuation
alignment_online.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import httpx
3
+ import re
4
+ import requests
5
+ import torch
6
+ import torchaudio.functional as F
7
+ import torchaudio
8
+ import uroman as ur
9
+ import logging
10
+ import traceback
11
+
12
+
13
+ def convert_to_list_with_punctuation_mixed(text):
14
+ """处理中文文本(可能包含英文单词) - 中文按字符分割,英文单词保持完整"""
15
+ result = []
16
+ text = text.strip()
17
+
18
+ if not text:
19
+ return result
20
+
21
+ def is_chinese(char):
22
+ """检查是否是汉字"""
23
+ return '\u4e00' <= char <= '\u9fff'
24
+
25
+ # 使用更精确的正则表达式来分割文本
26
+ # 匹配:英文单词(含数字)、单个汉字、标点符号
27
+ pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
28
+ tokens = re.findall(pattern, text)
29
+
30
+ for token in tokens:
31
+ if not token.strip(): # 跳过空字符
32
+ continue
33
+
34
+ if re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词(可能包含数字)
35
+ result.append(token)
36
+ elif is_chinese(token): # 单个汉字
37
+ result.append(token)
38
+ else: # 标点符号等其他字符
39
+ # 标点符号加到前一个词后面
40
+ if result:
41
+ result[-1] += token
42
+ else:
43
+ # 如果是文本开头的标点,单独作为一项
44
+ result.append(token)
45
+
46
+ return result
47
+
48
+ def split_and_merge_punctuation(text):
49
+ """处理英文 - 按单词分割,保持单词完整性"""
50
+ # 先按空格拆分文本
51
+ elements = text.split()
52
+
53
+ # 用于保存最终的结果
54
+ result = []
55
+
56
+ # 遍历每个拆分后的元素
57
+ for ele in elements:
58
+ # 使用正则表达式提取连续字母、数字和标点
59
+ parts = re.findall(r'[a-zA-Z0-9]+|[^\w\s]+', ele)
60
+
61
+ # 用于保存拆分后的部分
62
+ merged_parts = []
63
+
64
+ for i in range(len(parts)):
65
+ if i % 2 == 0: # 如果是字母或数字部分
66
+ # 将字母或数字部分添加到结果中
67
+ merged_parts.append(parts[i])
68
+ else: # 如果是标点或其他符号部分
69
+ # 将标点部分与前面的字母或数字部分合并
70
+ if merged_parts:
71
+ merged_parts[-1] += parts[i]
72
+ else:
73
+ merged_parts.append(parts[i])
74
+
75
+ # 将合并后的部分加入最终结果
76
+ result.extend(merged_parts)
77
+
78
+ return result
79
+
80
+
81
+ def get_aligned_result_text_with_punctuation(alignment_result, text, language):
82
+ """
83
+ 将对齐结果转换为正确的文本tokens,英文保持单词级别,中文保持字符级别(但英文单词完整)
84
+ """
85
+ logging.info("start change text to text_tokens")
86
+
87
+ if language == "EN":
88
+ text_tokens = split_and_merge_punctuation(text) # 英文按单词分词
89
+ elif language == "ZH":
90
+ text_tokens = convert_to_list_with_punctuation_mixed(text) # 中文按字符分割,但英文单词保持完整
91
+ else:
92
+ raise ValueError(f"Unsupported language: {language}")
93
+
94
+ logging.info(f"Text tokens count: {len(text_tokens)}, Alignment result count: {len(alignment_result)}")
95
+
96
+ punctuations = set(',.!?;:()[]<>\'\"…·,。;:!?()【】《》''""\、')
97
+
98
+ logging.info("start get align result text with punctuation")
99
+ updated_alignment_result = []
100
+ token_idx = 0
101
+
102
+ for index, align_item in enumerate(alignment_result):
103
+ if token_idx >= len(text_tokens):
104
+ # 如果text_tokens用完了但还有对齐结果,跳出循环
105
+ logging.warning(f"Text tokens exhausted at index {token_idx}, but alignment has more items")
106
+ break
107
+
108
+ start = align_item["start"]
109
+ end = align_item["end"]
110
+ text_token = text_tokens[token_idx]
111
+
112
+ # 检查该 token 后是否有连续标点(仅对中文)
113
+ if language == "ZH":
114
+ while token_idx + 1 < len(text_tokens) and text_tokens[token_idx + 1] in punctuations:
115
+ assert False, "???" # 这里理论上应该进不去??
116
+ text_token += text_tokens[token_idx + 1] # 将标点加入
117
+ token_idx += 1
118
+ else:
119
+ # 英文不需要特殊的标点处理,因为标点已经在split_and_merge_punctuation中处理了
120
+ pass
121
+
122
+ # 更新对齐结果
123
+ updated_item = {
124
+ "start": start,
125
+ "end": end,
126
+ "transcript": text_token
127
+ }
128
+ updated_item.update({key: align_item[key] for key in align_item if key not in ["start", "end", "transcript"]})
129
+
130
+ updated_alignment_result.append(updated_item)
131
+ token_idx += 1
132
+
133
+ logging.info("end get align result text with punctuation")
134
+ return updated_alignment_result
135
+
136
+
137
+ class AlignmentModel:
138
+ def __init__(self, device, model_dir='/data-mnt/data/wy/X-Codec-2.0/checkpoints'):
139
+ """
140
+ 初始化对齐模型并加载必要的资源
141
+ """
142
+ self.device = torch.device(device)
143
+ self.bundle = torchaudio.pipelines.MMS_FA
144
+ model = self.bundle.get_model(with_star=False, dl_kwargs={'model_dir': model_dir}).to(self.device)
145
+
146
+ # --- 核心优化 ---
147
+ # 使用 torch.compile 对模型进行 JIT 编译
148
+ # mode="max-autotune" 会花费更长时间编译,但能达到最佳性能
149
+ print("Compiling the model... This may take a moment.")
150
+ self.align_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
151
+ print("Model compiled successfully.")
152
+
153
+ self.uroman = ur.Uroman()
154
+ self.DICTIONARY = self.bundle.get_dict()
155
+
156
+ def align(self, emission, tokens):
157
+ """
158
+ 执行强对齐
159
+ :param emission: 模型的输出
160
+ :param tokens: 目标 tokens
161
+ :return: 对齐的 tokens 和分数
162
+ """
163
+ alignments, scores = F.forced_align(
164
+ log_probs=emission,
165
+ targets=tokens,
166
+ blank=0
167
+ )
168
+ alignments, scores = alignments[0], scores[0]
169
+ scores = scores.exp()
170
+ return alignments, scores
171
+
172
+ def unflatten(self, list_, lengths):
173
+ """
174
+ 将一个长列表按照长度拆分成子列表
175
+ :param list_: 长列表
176
+ :param lengths: 各子列表的长度
177
+ :return: 拆分后的子列表
178
+ """
179
+ assert len(list_) == sum(lengths)
180
+ i = 0
181
+ ret = []
182
+ for l in lengths:
183
+ ret.append(list_[i:i + l])
184
+ i += l
185
+ return ret
186
+
187
+ def preview_word(self, waveform, spans, num_frames, transcript, sample_rate):
188
+ """
189
+ 预览每个单词的开始时间和结束时间
190
+ :param waveform: 音频波形
191
+ :param spans: 单词的跨度
192
+ :param num_frames: 帧数
193
+ :param transcript: 转录文本
194
+ :param sample_rate: 采样率
195
+ :return: 单词的对齐信息
196
+ """
197
+ end = 0
198
+ alignment_result = []
199
+ for span, trans in zip(spans, transcript):
200
+ ratio = waveform.size(1) / num_frames
201
+ x0 = int(ratio * span[0].start)
202
+ x1 = int(ratio * span[-1].end)
203
+ align_info = {
204
+ "transcript": trans,
205
+ "start": round(x0 / sample_rate, 3),
206
+ "end": round(x1 / sample_rate, 3)
207
+ }
208
+ align_info["pause"] = round(align_info["start"] - end, 3)
209
+ align_info["duration"] = round(align_info["end"] - align_info["start"], 3)
210
+ end = align_info["end"]
211
+ alignment_result.append(align_info)
212
+ return alignment_result
213
+
214
+ def make_wav_batch(self, wav_list):
215
+ """
216
+ 将 wav_list 中的每个 wav 张量填充为相同的长度,返回填充后的张量和每个张量的原始长度。
217
+ :param wav_list: wav 文件列表
218
+ :return: 填充后的音频张量和原始长度
219
+ """
220
+ wav_lengths = torch.tensor([wav.size(0) for wav in wav_list], dtype=torch.long)
221
+ max_length = max(wav_lengths)
222
+ # 确保张量在正确的设备上
223
+ wavs_tensors = torch.zeros(len(wav_list), max_length, device=self.device)
224
+ for i, wav in enumerate(wav_list):
225
+ wav = wav.to(self.device) # 确保wav在正确的设备上
226
+ wavs_tensors[i, :wav_lengths[i]] = wav
227
+ return wavs_tensors, wav_lengths.to(self.device)
228
+
229
+ def get_target(self, transcript, language):
230
+ """
231
+ 获取给定转录文本的目标 tokens - 修正版本,保持英文单词完整性
232
+ """
233
+ original_transcript = transcript # 保存原始文本用于调试
234
+
235
+ if language == "ZH":
236
+ # 中文处理:保持英文单词完整,只对中文字符进行romanization
237
+ # 使用相同的分词逻辑
238
+ pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
239
+ tokens = re.findall(pattern, transcript)
240
+
241
+ # 分别处理中文字符和英文单词
242
+ processed_parts = []
243
+ for token in tokens:
244
+ if not token.strip():
245
+ continue
246
+ elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词
247
+ # 英文单词保持原样,不进行romanization
248
+ processed_parts.append(token.lower())
249
+ elif '\u4e00' <= token <= '\u9fff': # 中文字符
250
+ # 只对中文字符进行romanization
251
+ romanized = self.uroman.romanize_string(token)
252
+ processed_parts.append(romanized)
253
+ else: # 标点符号等
254
+ # 标点符号直接添加,但会在后续步骤中被过滤掉
255
+ processed_parts.append(token)
256
+
257
+ # 用空格连接所有部分
258
+ transcript = ' '.join(processed_parts)
259
+
260
+ elif language == "EN":
261
+ # 英文处理:保持单词结构,只是清理标点
262
+ pass
263
+ else:
264
+ assert False, f"Unsupported language: {language}"
265
+
266
+ # 清理标点符号
267
+ transcript = re.sub(r'[^\w\s]', r' ', transcript)
268
+ TRANSCRIPT = transcript.lower().split()
269
+
270
+ # 提前获取字典中的特殊符号 token
271
+ star_token = self.DICTIONARY['*']
272
+ tokenized_transcript = []
273
+
274
+ # 统一的tokenization逻辑
275
+ for word in TRANSCRIPT:
276
+ # 对每个word中的字符进行token化
277
+ word_tokens = []
278
+ for c in word:
279
+ if c in self.DICTIONARY and c != '-':
280
+ word_tokens.append(self.DICTIONARY[c])
281
+ else:
282
+ word_tokens.append(star_token)
283
+ tokenized_transcript.extend(word_tokens)
284
+
285
+ logging.info(f"Original transcript: {original_transcript}")
286
+ logging.info(f"Processed transcript: {transcript}")
287
+ logging.info(f"Final TRANSCRIPT: {TRANSCRIPT}")
288
+
289
+ return torch.tensor([tokenized_transcript], dtype=torch.int32, device=self.device)
290
+
291
+ def get_alignment_result(self, emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language):
292
+ """
293
+ 根据给定的 emission 和对齐信息生成对齐结果 - 修正版本
294
+ """
295
+ original_transcript = transcript # 保存原始文本
296
+
297
+ if language == "ZH":
298
+ # 使用与get_target相同的处理逻辑
299
+ pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
300
+ tokens = re.findall(pattern, transcript)
301
+
302
+ processed_parts = []
303
+ for token in tokens:
304
+ if not token.strip():
305
+ continue
306
+ elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词
307
+ processed_parts.append(token.lower())
308
+ elif '\u4e00' <= token <= '\u9fff': # 中文字符
309
+ romanized = self.uroman.romanize_string(token)
310
+ processed_parts.append(romanized)
311
+ else: # 标点符号等
312
+ processed_parts.append(token)
313
+
314
+ transcript = ' '.join(processed_parts)
315
+ elif language == "EN":
316
+ pass
317
+ else:
318
+ assert False, f"Unsupported language: {language}"
319
+
320
+ transcript = re.sub(r'[^\w\s]', r' ', transcript)
321
+ emission = emission_padded[:emission_length, :].unsqueeze(0)
322
+ TRANSCRIPT = transcript.lower().split()
323
+
324
+ token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
325
+
326
+ # 统一的分组逻辑
327
+ word_spans = self.unflatten(token_spans, [len(word) for word in TRANSCRIPT])
328
+
329
+ num_frames = emission.size(1)
330
+
331
+ logging.info(f"Original transcript for alignment: {original_transcript}")
332
+ logging.info(f"Processed TRANSCRIPT: {TRANSCRIPT}")
333
+
334
+ return self.preview_word(waveform.unsqueeze(0), word_spans, num_frames, TRANSCRIPT, self.bundle.sample_rate)
335
+
336
+ def batch_alignment(self, wav_list, transcript_list, language_list):
337
+ """
338
+ 批量对齐
339
+ :param wav_list: wav 文件列表
340
+ :param transcript_list: 转录文本列表
341
+ :param language_list: 语言类型列表
342
+ :return: 对齐结果列表
343
+ """
344
+ wavs_tensors, wavs_lengths_tensor = self.make_wav_batch(wav_list)
345
+ logging.info("start alignment model forward")
346
+ with torch.inference_mode():
347
+ emission, emission_lengths = self.align_model(wavs_tensors.to(self.device), wavs_lengths_tensor)
348
+ star_dim = torch.zeros((emission.shape[0], emission.size(1), 1), dtype=emission.dtype, device=self.device)
349
+ emission = torch.cat((emission, star_dim), dim=-1)
350
+
351
+ logging.info("end alignment model forward")
352
+
353
+ target_list = [self.get_target(transcript, language) for transcript, language in zip(transcript_list, language_list)]
354
+
355
+ logging.info("align success")
356
+ align_results = [
357
+ self.align(emission_padded[:emission_length, :].unsqueeze(0), target)
358
+ for emission_padded, emission_length, target in zip(emission, emission_lengths, target_list)
359
+ ]
360
+
361
+ logging.info("get align result")
362
+ batch_aligned_tokens = [align_result[0] for align_result in align_results]
363
+ batch_alignment_scores = [align_result[1] for align_result in align_results]
364
+
365
+ alignment_result_list = [
366
+ self.get_alignment_result(emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language)
367
+ for emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language
368
+ in zip(emission, emission_lengths, batch_aligned_tokens, batch_alignment_scores, transcript_list, wav_list, language_list)
369
+ ]
370
+ logging.info("get align result success")
371
+ return alignment_result_list
372
+
373
+
374
+ async def batch_get_alignment_result_remote(alignment_url, audio_path, transcript, language):
375
+ """
376
+ 通过调用远程对齐服务来批量获取对齐结果。
377
+ """
378
+ payload = {
379
+ "audio_path": audio_path,
380
+ "transcript": transcript,
381
+ "language": language,
382
+ }
383
+
384
+ try:
385
+ async with httpx.AsyncClient() as client:
386
+ response = await client.post(alignment_url, json=payload, timeout=300) # 设置较长的超时
387
+ response.raise_for_status() # 如果状态码不是 2xx,则抛出异常
388
+ data = response.json()
389
+ return data['results']
390
+
391
+ except requests.exceptions.RequestException as e:
392
+ logging.error(f"Failed to connect to alignment service: {e}")
393
+ traceback.print_exc()
394
+ # 根据需求可以返回空列表或抛出异常
395
+ except Exception as e:
396
+ logging.error(f"An error occurred in remote alignment: {e}")
397
+ traceback.print_exc()
398
+
docker/Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################################################################################################
2
+ #
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Redistribution and use in source and binary forms, with or without modification, are permitted
6
+ # provided that the following conditions are met:
7
+ # * Redistributions of source code must retain the above copyright notice, this list of
8
+ # conditions and the following disclaimer.
9
+ # * Redistributions in binary form must reproduce the above copyright notice, this list of
10
+ # conditions and the following disclaimer in the documentation and/or other materials
11
+ # provided with the distribution.
12
+ # * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
13
+ # to endorse or promote products derived from this software without specific prior written
14
+ # permission.
15
+ #
16
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
17
+ # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
18
+ # FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
19
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
20
+ # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
21
+ # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
22
+ # STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
+ #
25
+ ###################################################################################################
26
+ FROM nvcr.io/nvidia/tritonserver:25.08-py3
27
+ LABEL maintainer="NVIDIA"
28
+ LABEL repository="tritonserver"
29
+
30
+ RUN apt-get update && apt-get -y install swig && apt-get -y install python3-dev && apt-get install -y cmake && apt-get install -y libsndfile1
31
+ RUN pip3 install kaldiio
32
+ RUN pip3 install torch torchvision torchaudio -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
33
+ RUN pip3 install -v kaldifeat
34
+ RUN python3 -m pip install cupy
35
+ RUN python3 -m pip install soundfile
36
+ RUN pip3 install --upgrade pip
37
+ RUN pip install --extra-index-url https://pypi.nvidia.com cudf_cu12
38
+ RUN pip install --extra-index-url https://pypi.nvidia.com cuml_cu12
39
+ RUN pip install --extra-index-url https://pypi.nvidia.com cugraph_cu12
40
+ WORKDIR /workspace
download_mms_model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import os
3
+ from pathlib import Path
4
+
5
+ def download_mms_model(download_dir="/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation/models/mms_fa"):
6
+ """下载MMS-FA模型到指定目录"""
7
+
8
+ # 创建下载目录
9
+ download_path = Path(download_dir)
10
+ download_path.mkdir(parents=True, exist_ok=True)
11
+
12
+ print(f"开始下载MMS-FA模型到: {download_path}")
13
+
14
+ try:
15
+ # 获取MMS-FA bundle
16
+ bundle = torchaudio.pipelines.MMS_FA
17
+
18
+ # 下载模型
19
+ model = bundle.get_model(with_star=False, dl_kwargs={'model_dir': str(download_path)})
20
+
21
+ print(f"✅ 模型下载成功!保存在: {download_path}")
22
+ print(f"模型文件: {list(download_path.glob('*'))}")
23
+
24
+ return str(download_path)
25
+
26
+ except Exception as e:
27
+ print(f"❌ 下载失败: {e}")
28
+ return None
29
+
30
+ if __name__ == "__main__":
31
+ # 下载模型
32
+ model_path = download_mms_model()
33
+ if model_path:
34
+ print(f"\n使用方法:")
35
+ print(f"evaluator = SpeakerSimilarityEvaluator(alignment_model_dir='{model_path}')")
example_input.jsonl ADDED
@@ -0,0 +1 @@
 
 
1
+ {"text": "[S1] Hey, do you know the AI world has been super lively lately?[S2] Oh, yeah, new news every day. It feels like, um, a lot of big companies are just pushing really hard to get ahead.[S1] Right, right, exactly. Like, big news popping up every other day. Recently, I saw something about Anthropic. Didn't they release Claude 4?[S2] Oh, Claude 4, yeah, I saw some reports. They said it's really powerful, their latest model.[S1] Mhm, they're calling it the world's best programming model,sounds super impressive.[S2] Mm.[S1] Hey, really? World's best? That title alone is pretty catchy.[S2] Yeah, that really makes you curious, actually.[S1] Right? And it claims that for long tasks requiring extreme focus and thousands of steps, it can maintain stable performance.[S2] Mm.[S1] Meaning, it doesn't crash easily.[S2] Wow, that's amazing. So, it doesn't crash easily, huh?[S1] Exactly. They said, like, the Japanese e-commerce giant Rakuten, you know them, right? They actually verified Claude Opus 4's capability. In a demanding open-source refactoring task, it ran independently for seven hours.[S2] Seven hours?[S1] And throughout that time, its performance remained completely stable.[S2] Wow, my goodness. It runs on its own for seven hours without a break? That's incredible.[S1] Yeah, for those tasks that need focused effort and thousands of steps, it can handle them steadily.[S2] Mm, that's really something.[S1] Uh, so it's especially suitable for complex coding and problem-solving scenarios.[S2] Oh, I see. So, how's its performance in programming, really? Is it actually much better than before?[S1] Yeah, they mentioned the SWE-bench evaluation, which is a benchmark test for software engineering tasks.[S2] Oh, I know that test, it's quite professional.[S1] Mm, their Claude Sonnet 4 achieved an accuracy of 72.7 percent.[S2] Mm, 72.7 percent, that's high.[S1] Right, and they also compared it to the previous Sonnet 3.7 version.[S2] Mm.[S1] The 3.7 version got 62.3 percent.[S2] Oh, that's about a ten-point difference, then.[S1] Exactly, so Sonnet 4 improved significantly.[S2] Hmm, so it seems like this upgrade is substantial, not just hype.[S1] Indeed. And they also released Claude Code, which is a dedicated programming tool.[S2] Hmm, like, for developers to use?[S1] Yes, they said Claude Code is officially launched and supported by both Claude 4 models.[S2] Oh, I see. So, not only are the models powerful, but they've also improved the tools, like a complete package.[S1] That's right. And they also said that Claude Code isn't just for programmers.[S2] Huh? If it's not for programmers, then who's it for?[S1] They said, even for people who aren't really good at programming,[S2] Mm.[S1] Like product managers, if they want to create a prototype for an idea, they can just ask Claude to do it.[S2] Wow, that's really interesting. So, you don't have to write the code yourself, you just let the AI help you realize your ideas, right?[S1] Yeah, they're saying that in the future, if you have an idea, you might not need to write a document; you can just have it help you create the prototype.[S2] Hmm, that sounds a bit like, uh, will programmers' jobs become less common in the future?[S1] Hmm, it might be more like, Scott White, who's their product lead, he said that Claude is transforming from a tool that provides answers into a truly capable collaborative partner.[S2] Oh, I understand. So, it helps you with, uh, more basic or repetitive tasks, allowing you to focus more on creative things.[S1] Yes, exactly. And the models they released this time are called Opus 4 and Sonnet 4.[S2] Mm.[S1] Opus 4, they say, is their most powerful model to date, and also the world's best programming model.[S2] Definitely the flagship model.[S1] And Sonnet 4 is a major upgrade to Sonnet 3.7.[S2] Oh, so what are the specific differences between the two?[S1] Hmm, Opus 4 is better at high-end tasks like coding, research, writing, and scientific discovery.[S2] Hmm, Opus sounds more all-around capable.[S1] Right, and Sonnet 4 is more suitable for everyday use cases; it offers cutting-edge performance for daily tasks.[S2] Oh, I see. So, one is super high-end, and the other is also super strong for everyday use.[S1] Yes, and both models use a hybrid mode design.[S2] Hybrid mode? What does that mean?[S1] It means it can provide almost instant responses, but also perform deeper reasoning and thought.[S2] Oh.[S1] Like, uh, expansive thinking.[S2] Oh, I see. So, sometimes it needs to be fast, and other times it needs to be slow and think deeply.[S1] Exactly.[S2] Hmm, so what about the pricing? Is it very expensive?[S1] The pricing is the same as the previous Opus and Sonnet models.[S2] Oh.[S1] For Opus 4, it's fifteen dollars per million input tokens and seventy-five dollars for output tokens.[S2] Wow, output is much more expensive![S1] Right. And for Sonnet 4, input is three dollars and output is fifteen dollars.[S2] Hmm, Sonnet is much more affordable then.[S1] Yes, and Sonnet 4 is also available for free users.[S2] Oh, that's good, everyone can try it out.[S1] Mm, exactly.[S2] Hey, so how does it compare to other AI giants? Where does it stand now?[S1] This release of theirs has intensified the competition with giants like OpenAI and Google in the top-tier model space.[S2] Yeah, it really feels like everyone's pushing hard lately.[S1] Right? Like, Microsoft also announced new coding agents, didn't they? And they partnered with Elon Musk's xAI.[S2] Mm.[S1] Google, meanwhile, is accelerating the integration of AI agents into their services.[S2] Right.[S1] And OpenAI is even more impressive; they just made a six-point-five-billion-dollar deal to acquire an AI hardware startup founded by the father of iPhone, former Apple design chief Jony Ive.[S2] Wow, six-point-five billion, that's a huge move. It feels like AI competition is really heating up.[S1] Exactly, so for investors, it means re-evaluating the competitive landscape in the AI sector.[S2] Hmm, makes sense. So, does this Claude 4 also bring a lot of opportunities for Anthropic?[S1] Yeah, its strong performance in coding, reasoning, and agent tasks will definitely help it capture more market share and enterprise clients.[S2] Hmm, sounds like it has huge potential indeed.[S1] Mm, it just feels like the AI competition now is all about who can push the technology to new heights.[S2] Exactly, and also who can really, uh, implement these technologies into practical applications.[S1] Right, right, exactly like that.[S2] Okay, well, this news about Claude 4 today really makes you feel like AI has taken a huge leap forward.[S1] Yeah, looking forward to it bringing more surprises in the future.", "prompt_audio_speaker1": "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/moon-en1/en_spk1_moon.wav", "prompt_text_speaker1": "OK. I'm starting to see how this multi-headed approach could lead to some pretty impressive results.", "prompt_audio_speaker2": "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/moon-en1/en_spk2-moon.wav", "prompt_text_speaker2": "It's not just crunching data. It's starting to develop a more sophisticated understanding of how language actually works.", "output_audio": "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step40000/test_en/gpu0/output_0.wav"}
model_repo/speaker_model/1/model.trt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17a3a3b794fa886c9b8341a08ad5e22f3bb385dd994ff891ba62d591503673f5
3
+ size 104729100
model_repo/speaker_model/config.pbtxt ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ name: "speaker_model"
16
+ backend: "tensorrt"
17
+ default_model_filename: "model.trt"
18
+
19
+ max_batch_size: 16
20
+ input [
21
+ {
22
+ name: "feats"
23
+ data_type: TYPE_FP32
24
+ dims: [ -1, 80 ] # num_mel_bins
25
+ }
26
+ ]
27
+
28
+ output [
29
+ {
30
+ name: "embs"
31
+ data_type: TYPE_FP32
32
+ dims: [ 256 ] # [embedding_size]
33
+ }
34
+ ]
35
+ dynamic_batching {
36
+ preferred_batch_size: [ 4, 8 ]
37
+ max_queue_delay_microseconds: 1000
38
+ }
39
+ instance_group [
40
+ {
41
+ count: 1
42
+ kind: KIND_GPU
43
+ }
44
+ ]
models/mms_fa/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20ef12963ab4924bef49ac4fc7f58ad5da2ee43b2c11bc8c853c9b90ecdbc680
3
+ size 1262047414
models/mms_fa/model.pt.2c7cc4fedf8e4a089a0095148cc9201b.partial ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0cf233f857de07296254c36332b4b984045cdc0964ec1fef6a0c6cc5aae00b7
3
+ size 1056964608
models/mms_fa/model.pt.5c5fe9893a2c462e9132dcd6a3fba337.partial ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51258936b4a1a51762ef849ec0f404920f38d03c4e018550d75ea4e1e82a451a
3
+ size 486539264
models/voxblink2_samresnet100_ft/avg_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d92ee34668d8eb24a02df4e7869fd4bde661220a137e045f29e4a0c85eb4004
3
+ size 201115747
models/voxblink2_samresnet100_ft/avg_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5aeee438ca23c0ca6e341bab6c6bf7f465497e1dc323bb1bc1074d6a0c778b11
3
+ size 201318407
models/voxblink2_samresnet100_ft/config.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: shard
2
+ dataloader_args:
3
+ batch_size: 128
4
+ drop_last: true
5
+ num_workers: 16
6
+ pin_memory: false
7
+ prefetch_factor: 8
8
+ dataset_args:
9
+ aug_prob: 0.6
10
+ fbank_args:
11
+ dither: 1.0
12
+ frame_length: 25
13
+ frame_shift: 10
14
+ num_mel_bins: 80
15
+ filter: true
16
+ filter_args:
17
+ max_num_frames: 800
18
+ min_num_frames: 100
19
+ num_frms: 200
20
+ resample_rate: 16000
21
+ sample_num_per_epoch: 0
22
+ shuffle: true
23
+ shuffle_args:
24
+ shuffle_size: 2500
25
+ spec_aug: false
26
+ spec_aug_args:
27
+ max_f: 8
28
+ max_t: 10
29
+ num_f_mask: 1
30
+ num_t_mask: 1
31
+ prob: 0.6
32
+ speed_perturb: true
33
+ enable_amp: false
34
+ exp_dir: exp/samresnet100/
35
+ gpus:
36
+ - 0
37
+ - 1
38
+ log_batch_interval: 100
39
+ loss: CrossEntropyLoss
40
+ loss_args: {}
41
+ margin_scheduler: MarginScheduler
42
+ margin_update:
43
+ epoch_iter: 4265
44
+ final_margin: 0.2
45
+ fix_start_epoch: 40
46
+ increase_start_epoch: 20
47
+ increase_type: exp
48
+ initial_margin: 0.0
49
+ update_margin: true
50
+ model: SimAM_ResNet100_ASP
51
+ model_args:
52
+ embed_dim: 256
53
+ model_init: null
54
+ noise_data: data/musan/lmdb
55
+ num_avg: 1
56
+ num_epochs: 150
57
+ optimizer: SGD
58
+ optimizer_args:
59
+ lr: 0.1
60
+ momentum: 0.9
61
+ nesterov: true
62
+ weight_decay: 0.0001
63
+ projection_args:
64
+ do_lm: false
65
+ easy_margin: false
66
+ embed_dim: 256
67
+ num_class: 17982
68
+ project_type: arc_margin
69
+ scale: 32.0
70
+ reverb_data: data/rirs/lmdb
71
+ save_epoch_interval: 5
72
+ scheduler: ExponentialDecrease
73
+ scheduler_args:
74
+ epoch_iter: 4265
75
+ final_lr: 5.0e-05
76
+ initial_lr: 0.1
77
+ num_epochs: 150
78
+ scale_ratio: 4.0
79
+ warm_from_zero: true
80
+ warm_up_epoch: 6
81
+ seed: 42
82
+ train_data: data/vox2_dev/shard.list
83
+ train_label: data/vox2_dev/utt2spk
models/wespeaker/chinese/config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ model: cnceleb_resnet34_LM
2
+ task: speaker_verification
3
+ domain: speech
4
+ framework: onnxruntime
5
+ dataset: cnceleb
6
+ language: chinese
7
+ sample_rate: 16000
models/wespeaker/chinese/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7584940aeac8d5512d875e58ce6c09ba4ddad65d8128e1dac0d93aadd087ebb
3
+ size 26530309
python_backend/similarity_model/1/model.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ import torchaudio.compliance.kaldi as kaldi
7
+ import traceback
8
+ from torch.utils.dlpack import from_dlpack
9
+ import triton_python_backend_utils as pb_utils
10
+
11
+
12
+ class TritonPythonModel:
13
+ def initialize(self, args):
14
+ self.sample_rate = 16000
15
+ self.feature_dim = 80
16
+ self.vad_enabled = True # This variable is declared but not used.
17
+ self.min_duration = 0.1
18
+
19
+ # This seems correct for BLS (Business Logic Scripting)
20
+ self.speaker_model_name = "speaker_model"
21
+
22
+ def execute(self, requests):
23
+ responses = []
24
+ for request in requests:
25
+ try:
26
+ # 1. Get the input audio BYTES, not a file path string.
27
+ # The input tensor is of type TYPE_STRING, which holds bytes.
28
+ # .as_numpy()[0] gives you the raw bytes object.
29
+ audio1_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_1").as_numpy()[0][0]
30
+ audio2_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_2").as_numpy()[0][0]
31
+
32
+ # 2. Preprocess audio from bytes
33
+ feats1 = self.preprocess(audio1_bytes)
34
+ feats2 = self.preprocess(audio2_bytes)
35
+
36
+ # 3. Call the speaker_model to compute similarity
37
+ similarity = self.compute_similarity(feats1, feats2)
38
+
39
+ # Prepare output
40
+ output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity], dtype=np.float32))
41
+ response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
42
+ responses.append(response)
43
+
44
+ except pb_utils.TritonModelException as e:
45
+ # If a Triton-specific error occurs, create an error response
46
+ error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(str(e)))
47
+ pb_utils.Logger.log_error(error_response)
48
+ responses.append(error_response)
49
+ except Exception as e:
50
+ # For any other unexpected error, log it and return an error response
51
+ error_message = f"Unexpected error: {e}\n{traceback.format_exc()}"
52
+ pb_utils.Logger.log_error(error_message)
53
+ error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_message))
54
+ responses.append(error_response)
55
+
56
+ return responses
57
+
58
+ def preprocess(self, audio_bytes: bytes):
59
+ """
60
+ Processes audio data from an in-memory byte buffer.
61
+ If the audio is too short, it's padded by repetition to meet the minimum length.
62
+ """
63
+ try:
64
+ # Wrap the raw bytes in a file-like object for torchaudio
65
+ # buffer = io.BytesIO(audio_bytes)
66
+ buffer = audio_bytes.decode('utf-8')
67
+ waveform, sample_rate = torchaudio.load(buffer)
68
+
69
+ # You might want to resample if the client's sample rate differs
70
+ if sample_rate != self.sample_rate:
71
+ # Note: This requires the 'torchaudio.transforms' module.
72
+ # Make sure torchaudio is fully installed in your Triton environment.
73
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
74
+ waveform = resampler(waveform)
75
+
76
+ duration = waveform.shape[1] / self.sample_rate
77
+
78
+ if duration < self.min_duration:
79
+ # Audio is too short, repeat it to meet the minimum duration
80
+ repeat_times = math.ceil(self.min_duration / duration)
81
+ waveform = waveform.repeat(1, repeat_times)
82
+
83
+ # --- THIS IS THE NEW, CRITICAL PART ---
84
+ # Calculate 80-dimensional Fbank features, which is what the speaker_model expects.
85
+ # The waveform needs to be shape [batch, time], so we squeeze it.
86
+ features = kaldi.fbank(
87
+ waveform.squeeze(0).unsqueeze(0), # Needs shape [1, T]
88
+ num_mel_bins=self.feature_dim, # This is 80
89
+ sample_frequency=self.sample_rate,
90
+ frame_length=25,
91
+ frame_shift=10
92
+ )
93
+ # The output of fbank is [1, num_frames, num_bins], e.g., [1, 150, 80]
94
+ # We need [num_frames, num_bins] for the speaker model
95
+ return features.squeeze(0) # Returns shape [num_frames, 80]
96
+
97
+ except Exception as e:
98
+ # Raise a specific exception that can be caught in execute()
99
+ raise pb_utils.TritonModelException(f"Failed during audio preprocessing: {e}")
100
+
101
+ def compute_similarity(self, waveform1, waveform2):
102
+ # Call speaker_model to get embeddings
103
+ # Assuming speaker_model takes a waveform and outputs an embedding
104
+ e1 = torch.from_numpy(self.call_speaker_model(waveform1)).to("cuda")
105
+ e2 = torch.from_numpy(self.call_speaker_model(waveform2)).to("cuda")
106
+
107
+ # Flatten the tensors
108
+ e1 = e1.flatten()
109
+ e2 = e2.flatten()
110
+
111
+ # Calculate cosine similarity
112
+ dot_product = torch.dot(e1, e2)
113
+ norm_e1 = torch.norm(e1)
114
+ norm_e2 = torch.norm(e2)
115
+
116
+ # Handle zero norms
117
+ if norm_e1 == 0 or norm_e2 == 0:
118
+ return 0.0
119
+
120
+ similarity = (dot_product / (norm_e1 * norm_e2)).item()
121
+
122
+ # Normalize from [-1, 1] to [0, 1]
123
+ return (similarity + 1) / 2
124
+
125
+ def call_speaker_model(self, waveform):
126
+ """Calls the speaker_model to get an embedding vector."""
127
+ # Create the input tensor for the speaker_model.
128
+ # The name 'feats' here must match the input name in speaker_model's config.pbtxt
129
+ if waveform.dim() == 2:
130
+ waveform = waveform.unsqueeze(0)
131
+ input_tensor = pb_utils.Tensor("feats", waveform.cpu().numpy().astype(np.float32))
132
+
133
+ inference_request = pb_utils.InferenceRequest(
134
+ model_name=self.speaker_model_name,
135
+ requested_output_names=["embs"], # Must match output name in speaker_model's config
136
+ inputs=[input_tensor]
137
+ )
138
+
139
+ inference_response = inference_request.exec()
140
+
141
+ if inference_response.has_error():
142
+ raise pb_utils.TritonModelException(f"Error from speaker_model: {inference_response.error().message()}")
143
+
144
+ output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs")
145
+ if output_tensor.is_cpu():
146
+ output_tensor = output_tensor.as_numpy()
147
+ else:
148
+ output_tensor = from_dlpack(output_tensor.to_dlpack()).detach().cpu().numpy()
149
+ return output_tensor
python_backend/similarity_model/1/model_old.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torchaudio
4
+ import traceback
5
+ import triton_python_backend_utils as pb_utils
6
+
7
+
8
+ class TritonPythonModel:
9
+ def initialize(self, args):
10
+ self.sample_rate = 16000
11
+ self.feature_dim = 80
12
+ self.vad_enabled = True
13
+ self.min_duration = 0.1
14
+
15
+ # 创建与speaker_model通信的客户端
16
+ self.speaker_model_name = "speaker_model"
17
+
18
+ def execute(self, requests):
19
+ responses = []
20
+ for request in requests:
21
+ # 获取输入音频
22
+ audio1 = pb_utils.get_input_tensor_by_name(request, "AUDIO1").as_numpy()[0].decode('utf-8')
23
+ audio2 = pb_utils.get_input_tensor_by_name(request, "AUDIO2").as_numpy()[0].decode('utf-8')
24
+
25
+ # 预处理音频
26
+ feats1 = self.preprocess(audio1)
27
+ feats2 = self.preprocess(audio2)
28
+
29
+ # 调用speaker_model计算相似度
30
+ similarity = self.compute_similarity(feats1, feats2)
31
+
32
+ # 准备输出
33
+ output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity]), dtype=np.float32)
34
+ response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
35
+ responses.append(response)
36
+
37
+ return responses
38
+
39
+ def preprocess(self, audio_path):
40
+ """
41
+ 处理音频文件,如果过短则复制到满足最小长度要求
42
+ 返回处理后的音频路径和是否为临时文件的标志
43
+ """
44
+ try:
45
+ waveform, sample_rate = torchaudio.load(audio_path)
46
+ duration = waveform.shape[1] / sample_rate
47
+
48
+ if duration >= self.min_duration:
49
+ # 音频长度足够,直接返回原路径
50
+ return waveform
51
+
52
+ # 音频过短,需要复制
53
+ repeat_times = math.ceil(self.min_duration / duration)
54
+
55
+ # 复制音频
56
+ return waveform.repeat(1, repeat_times)
57
+
58
+ except Exception:
59
+ traceback.format_exc()
60
+ return None
61
+
62
+ def compute_similarity(self, feats1, feats2):
63
+ # 调用speaker_model获取嵌入向量
64
+ e1 = self.call_speaker_model(feats1)
65
+ e2 = self.call_speaker_model(feats2)
66
+
67
+ # 计算余弦相似度
68
+ dot_product = np.dot(e1, e2)
69
+ norm_e1 = np.linalg.norm(e1)
70
+ norm_e2 = np.linalg.norm(e2)
71
+ similarity = dot_product / (norm_e1 * norm_e2)
72
+
73
+ # 归一化到[0, 1]
74
+ return (similarity + 1) / 2
75
+
76
+ def call_speaker_model(self, features):
77
+ """调用speaker_model获取嵌入向量"""
78
+ # 创建输入张量
79
+ input_tensor = pb_utils.Tensor("feats", features.astype(np.float32))
80
+
81
+ # 创建推理请求
82
+ inference_request = pb_utils.InferenceRequest(
83
+ model_name=self.speaker_model_name,
84
+ requested_output_names=["embs"],
85
+ inputs=[input_tensor]
86
+ )
87
+
88
+ # 发送请求
89
+ inference_response = inference_request.exec()
90
+
91
+ # 处理响应
92
+ if inference_response.has_error():
93
+ raise pb_utils.TritonModelException(inference_response.error().message())
94
+
95
+ # 获取嵌入向量
96
+ output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs")
97
+ return output_tensor.as_numpy()
python_backend/similarity_model/1/model_runnable.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ import torchaudio.compliance.kaldi as kaldi
7
+ import traceback
8
+ from torch.utils.dlpack import from_dlpack
9
+ import triton_python_backend_utils as pb_utils
10
+
11
+
12
+ class TritonPythonModel:
13
+ def initialize(self, args):
14
+ self.sample_rate = 16000
15
+ self.feature_dim = 80
16
+ self.vad_enabled = True # This variable is declared but not used.
17
+ self.min_duration = 0.1
18
+
19
+ # This seems correct for BLS (Business Logic Scripting)
20
+ self.speaker_model_name = "speaker_model"
21
+
22
+ def execute(self, requests):
23
+ responses = []
24
+ for request in requests:
25
+ try:
26
+ # 1. Get the input audio BYTES, not a file path string.
27
+ # The input tensor is of type TYPE_STRING, which holds bytes.
28
+ # .as_numpy()[0] gives you the raw bytes object.
29
+ audio1_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_1").as_numpy()[0][0]
30
+ audio2_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_2").as_numpy()[0][0]
31
+
32
+ # 2. Preprocess audio from bytes
33
+ feats1 = self.preprocess(audio1_bytes)
34
+ feats2 = self.preprocess(audio2_bytes)
35
+
36
+ # 3. Call the speaker_model to compute similarity
37
+ similarity = self.compute_similarity(feats1, feats2)
38
+
39
+ pb_utils.Logger.log_info(similarity)
40
+ # Prepare output
41
+ output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity], dtype=np.float32))
42
+ response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
43
+ responses.append(response)
44
+
45
+ except pb_utils.TritonModelException as e:
46
+ # If a Triton-specific error occurs, create an error response
47
+ error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(str(e)))
48
+ responses.append(error_response)
49
+ except Exception as e:
50
+ # For any other unexpected error, log it and return an error response
51
+ error_message = f"Unexpected error: {e}\n{traceback.format_exc()}"
52
+ pb_utils.Logger.log_error(error_message)
53
+ error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_message))
54
+ responses.append(error_response)
55
+
56
+ return responses
57
+
58
+ def preprocess(self, audio_bytes: bytes):
59
+ """
60
+ Processes audio data from an in-memory byte buffer.
61
+ If the audio is too short, it's padded by repetition to meet the minimum length.
62
+ """
63
+ try:
64
+ # Wrap the raw bytes in a file-like object for torchaudio
65
+ # buffer = io.BytesIO(audio_bytes)
66
+ buffer = audio_bytes.decode('utf-8')
67
+ waveform, sample_rate = torchaudio.load(buffer)
68
+
69
+ # You might want to resample if the client's sample rate differs
70
+ if sample_rate != self.sample_rate:
71
+ # Note: This requires the 'torchaudio.transforms' module.
72
+ # Make sure torchaudio is fully installed in your Triton environment.
73
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
74
+ waveform = resampler(waveform)
75
+
76
+ duration = waveform.shape[1] / self.sample_rate
77
+
78
+ if duration < self.min_duration:
79
+ # Audio is too short, repeat it to meet the minimum duration
80
+ repeat_times = math.ceil(self.min_duration / duration)
81
+ waveform = waveform.repeat(1, repeat_times)
82
+
83
+ # --- THIS IS THE NEW, CRITICAL PART ---
84
+ # Calculate 80-dimensional Fbank features, which is what the speaker_model expects.
85
+ # The waveform needs to be shape [batch, time], so we squeeze it.
86
+ features = kaldi.fbank(
87
+ waveform.squeeze(0).unsqueeze(0), # Needs shape [1, T]
88
+ num_mel_bins=self.feature_dim, # This is 80
89
+ sample_frequency=self.sample_rate,
90
+ frame_length=25,
91
+ frame_shift=10
92
+ )
93
+ # The output of fbank is [1, num_frames, num_bins], e.g., [1, 150, 80]
94
+ # We need [num_frames, num_bins] for the speaker model
95
+ return features.squeeze(0) # Returns shape [num_frames, 80]
96
+
97
+ except Exception as e:
98
+ # Raise a specific exception that can be caught in execute()
99
+ raise pb_utils.TritonModelException(f"Failed during audio preprocessing: {e}")
100
+
101
+ def compute_similarity(self, waveform1, waveform2):
102
+ # Call speaker_model to get embeddings
103
+ # Assuming speaker_model takes a waveform and outputs an embedding
104
+ e1 = torch.from_numpy(self.call_speaker_model(waveform1)).to("cuda")
105
+ e2 = torch.from_numpy(self.call_speaker_model(waveform2)).to("cuda")
106
+
107
+ # Flatten the tensors
108
+ e1 = e1.flatten()
109
+ e2 = e2.flatten()
110
+
111
+ # Calculate cosine similarity
112
+ dot_product = torch.dot(e1, e2)
113
+ norm_e1 = torch.norm(e1)
114
+ norm_e2 = torch.norm(e2)
115
+
116
+ # Handle zero norms
117
+ if norm_e1 == 0 or norm_e2 == 0:
118
+ return 0.0
119
+
120
+ similarity = (dot_product / (norm_e1 * norm_e2)).item()
121
+
122
+ # Normalize from [-1, 1] to [0, 1]
123
+ return (similarity + 1) / 2
124
+
125
+ def call_speaker_model(self, waveform):
126
+ """Calls the speaker_model to get an embedding vector."""
127
+ # Create the input tensor for the speaker_model.
128
+ # The name 'feats' here must match the input name in speaker_model's config.pbtxt
129
+ if waveform.dim() == 2:
130
+ waveform = waveform.unsqueeze(0)
131
+ input_tensor = pb_utils.Tensor("feats", waveform.cpu().numpy().astype(np.float32))
132
+
133
+ inference_request = pb_utils.InferenceRequest(
134
+ model_name=self.speaker_model_name,
135
+ requested_output_names=["embs"], # Must match output name in speaker_model's config
136
+ inputs=[input_tensor]
137
+ )
138
+
139
+ inference_response = inference_request.exec()
140
+
141
+ if inference_response.has_error():
142
+ raise pb_utils.TritonModelException(f"Error from speaker_model: {inference_response.error().message()}")
143
+
144
+ output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs")
145
+ if output_tensor.is_cpu():
146
+ output_tensor = output_tensor.as_numpy()
147
+ else:
148
+ output_tensor = from_dlpack(output_tensor.to_dlpack()).detach().cpu().numpy()
149
+ return output_tensor
python_backend/similarity_model/config.pbtxt.back ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "similarity_model"
2
+ backend: "python"
3
+ max_batch_size: 128
4
+
5
+ parameters: {
6
+ key: "EXECUTION_ENV_PATH",
7
+ value: {
8
+ # string_value: "/inspire/hdd/project/embodied-multimodality/public/cchang/env/audio.tar.gz"
9
+ # string_value: "/inspire/hdd/project/embodied-multimodality/public/cchang/env/audio_clean.tar.gz"
10
+ # string_value: "/inspire/hdd/project/embodied-multimodality/public/cchang/env_tar/audio_env.tar.gz"
11
+ string_value: "/inspire/hdd/project/embodied-multimodality/public/cchang/env/mooncast/bin/python"
12
+ }
13
+ }
14
+
15
+ input [
16
+ {
17
+ name: "AUDIO1"
18
+ data_type: TYPE_STRING
19
+ dims: [ 1 ] # 音频路径
20
+ },
21
+ {
22
+ name: "AUDIO2"
23
+ data_type: TYPE_STRING
24
+ dims: [ 1 ] # 音频路径
25
+ }
26
+ ]
27
+
28
+ output [
29
+ {
30
+ name: "SIMILARITY"
31
+ data_type: TYPE_FP32
32
+ dims: [ 1 ] # 相似度分数
33
+ }
34
+ ]
35
+
36
+ dynamic_batching {
37
+ preferred_batch_size: [ 16, 32 ]
38
+ }
39
+
40
+ instance_group [
41
+ {
42
+ count: 1
43
+ kind: KIND_GPU
44
+ }
45
+ ]
46
+
python_backend/similarity_model/config.pbtxt.disabled ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "similarity_model" # Or whatever you call this model
2
+ backend: "python"
3
+ max_batch_size: 8
4
+
5
+ # Input tensors are now raw audio bytes
6
+ input [
7
+ {
8
+ name: "AUDIO_BYTES_1"
9
+ data_type: TYPE_STRING # TYPE_STRING is used for variable-length binary data
10
+ dims: [ 1 ]
11
+ },
12
+ {
13
+ name: "AUDIO_BYTES_2"
14
+ data_type: TYPE_STRING
15
+ dims: [ 1 ]
16
+ }
17
+ ]
18
+
19
+ # Output is a single similarity score
20
+ output [
21
+ {
22
+ name: "SIMILARITY"
23
+ data_type: TYPE_FP32
24
+ dims: [ 1 ]
25
+ }
26
+ ]
similarity.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Binbin Zhang ([email protected])
2
+ # Shuai Wang ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import sys
18
+ from typing import List, Tuple
19
+
20
+ import numpy as np
21
+ from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
22
+ import torch
23
+ import torchaudio
24
+ import torchaudio.compliance.kaldi as kaldi
25
+ import yaml
26
+ import kaldiio
27
+ from tqdm import tqdm
28
+
29
+ from wespeaker.cli.hub import Hub
30
+ from wespeaker.cli.utils import get_args
31
+ from wespeaker.models.speaker_model import get_speaker_model
32
+ from wespeaker.utils.checkpoint import load_checkpoint
33
+ from wespeaker.diar.umap_clusterer import cluster
34
+ from wespeaker.diar.extract_emb import subsegment
35
+ from wespeaker.diar.make_rttm import merge_segments
36
+ from wespeaker.utils.utils import set_seed
37
+
38
+
39
+ class Speaker:
40
+
41
+ def __init__(self, model_dir: str):
42
+ set_seed()
43
+
44
+ config_path = os.path.join(model_dir, 'config.yaml')
45
+ model_path = os.path.join(model_dir, 'avg_model.pt')
46
+ with open(config_path, 'r') as fin:
47
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
48
+ self.model = get_speaker_model(
49
+ configs['model'])(**configs['model_args'])
50
+ load_checkpoint(self.model, model_path)
51
+ self.model.eval()
52
+ self.vad = load_silero_vad()
53
+ self.table = {}
54
+ self.resample_rate = 16000
55
+ self.apply_vad = False
56
+ self.device = torch.device('cpu')
57
+ self.wavform_norm = False
58
+
59
+ # diarization parmas
60
+ self.diar_min_duration = 0.255
61
+ self.diar_window_secs = 1.5
62
+ self.diar_period_secs = 0.75
63
+ self.diar_frame_shift = 10
64
+ self.diar_batch_size = 32
65
+ self.diar_subseg_cmn = True
66
+
67
+ def set_wavform_norm(self, wavform_norm: bool):
68
+ self.wavform_norm = wavform_norm
69
+
70
+ def set_resample_rate(self, resample_rate: int):
71
+ self.resample_rate = resample_rate
72
+
73
+ def set_vad(self, apply_vad: bool):
74
+ self.apply_vad = apply_vad
75
+
76
+ def set_device(self, device: str):
77
+ self.device = torch.device(device)
78
+ self.model = self.model.to(self.device)
79
+
80
+ def set_diarization_params(self,
81
+ min_duration: float = 0.255,
82
+ window_secs: float = 1.5,
83
+ period_secs: float = 0.75,
84
+ frame_shift: int = 10,
85
+ batch_size: int = 32,
86
+ subseg_cmn: bool = True):
87
+ self.diar_min_duration = min_duration
88
+ self.diar_window_secs = window_secs
89
+ self.diar_period_secs = period_secs
90
+ self.diar_frame_shift = frame_shift
91
+ self.diar_batch_size = batch_size
92
+ self.diar_subseg_cmn = subseg_cmn
93
+
94
+ def compute_fbank(self,
95
+ wavform,
96
+ sample_rate=16000,
97
+ num_mel_bins=80,
98
+ frame_length=25,
99
+ frame_shift=10,
100
+ cmn=True):
101
+ feat = kaldi.fbank(wavform,
102
+ num_mel_bins=num_mel_bins,
103
+ frame_length=frame_length,
104
+ frame_shift=frame_shift,
105
+ sample_frequency=sample_rate,
106
+ window_type='hamming')
107
+ if cmn:
108
+ feat = feat - torch.mean(feat, 0)
109
+ return feat
110
+
111
+ def extract_embedding_feats(self, fbanks, batch_size, subseg_cmn):
112
+ fbanks_array = np.stack(fbanks)
113
+ if subseg_cmn:
114
+ fbanks_array = fbanks_array - np.mean(
115
+ fbanks_array, axis=1, keepdims=True)
116
+ embeddings = []
117
+ fbanks_array = torch.from_numpy(fbanks_array).to(self.device)
118
+ for i in tqdm(range(0, fbanks_array.shape[0], batch_size)):
119
+ batch_feats = fbanks_array[i:i + batch_size]
120
+ with torch.no_grad():
121
+ batch_embs = self.model(batch_feats)
122
+ batch_embs = batch_embs[-1] if isinstance(
123
+ batch_embs, tuple) else batch_embs
124
+ embeddings.append(batch_embs.detach().cpu().numpy())
125
+ embeddings = np.vstack(embeddings)
126
+ return embeddings
127
+
128
+ def extract_embedding(self, audio_path: str):
129
+ pcm, sample_rate = torchaudio.load(audio_path,
130
+ normalize=self.wavform_norm)
131
+ return self.extract_embedding_from_pcm(pcm, sample_rate)
132
+
133
+ def extract_embedding_from_pcm(self, pcm: torch.Tensor, sample_rate: int):
134
+ if self.apply_vad:
135
+ # TODO(Binbin Zhang): Refine the segments logic, here we just
136
+ # suppose there is only silence at the start/end of the speech
137
+ vad_sample_rate = 16000
138
+ wav = pcm
139
+ if wav.size(0) > 1:
140
+ wav = wav.mean(dim=0, keepdim=True)
141
+
142
+ if sample_rate != vad_sample_rate:
143
+ transform = torchaudio.transforms.Resample(
144
+ orig_freq=sample_rate, new_freq=vad_sample_rate)
145
+ wav = transform(wav)
146
+ segments = get_speech_timestamps(wav,
147
+ self.vad,
148
+ return_seconds=True)
149
+ pcmTotal = torch.Tensor()
150
+ if len(segments) > 0: # remove all the silence
151
+ for segment in segments:
152
+ start = int(segment['start'] * sample_rate)
153
+ end = int(segment['end'] * sample_rate)
154
+ pcmTemp = pcm[0, start:end]
155
+ pcmTotal = torch.cat([pcmTotal, pcmTemp], 0)
156
+ pcm = pcmTotal.unsqueeze(0)
157
+ else: # all silence, nospeech
158
+ return None
159
+ pcm = pcm.to(torch.float)
160
+ if sample_rate != self.resample_rate:
161
+ pcm = torchaudio.transforms.Resample(
162
+ orig_freq=sample_rate, new_freq=self.resample_rate)(pcm)
163
+ feats = self.compute_fbank(pcm,
164
+ sample_rate=self.resample_rate,
165
+ cmn=True)
166
+ feats = feats.unsqueeze(0)
167
+ feats = feats.to(self.device)
168
+
169
+ with torch.no_grad():
170
+ outputs = self.model(feats)
171
+ outputs = outputs[-1] if isinstance(outputs, tuple) else outputs
172
+ embedding = outputs[0].to(torch.device('cpu'))
173
+ return embedding
174
+
175
+ def extract_embedding_list(self, scp_path: str):
176
+ names = []
177
+ embeddings = []
178
+ with open(scp_path, 'r') as read_scp:
179
+ for line in tqdm(read_scp):
180
+ name, wav_path = line.strip().split()
181
+ names.append(name)
182
+ embedding = self.extract_embedding(wav_path)
183
+ embeddings.append(embedding.detach().numpy())
184
+ return names, embeddings
185
+
186
+ def compute_similarity(self, audio_path1: str, audio_path2: str) -> float:
187
+ e1 = self.extract_embedding(audio_path1)
188
+ e2 = self.extract_embedding(audio_path2)
189
+ if e1 is None or e2 is None:
190
+ return 0.0
191
+ else:
192
+ return self.cosine_similarity(e1, e2)
193
+
194
+ def compute_similarity_batch(
195
+ self, audio_pairs: List[Tuple[str, str]]) -> List[float]:
196
+ """
197
+ Computes cosine similarity for a batch of audio file pairs.
198
+ This method is optimized to extract embedding for each unique audio file
199
+ only once.
200
+
201
+ Args:
202
+ audio_pairs (List[Tuple[str, str]]): A list of tuples, where each
203
+ tuple contains two audio paths.
204
+ e.g., [('audio1.wav', 'audio2.wav'),
205
+ ('audio1.wav', 'audio3.wav')]
206
+
207
+ Returns:
208
+ List[float]: A list of similarity scores, corresponding to the
209
+ input pairs.
210
+ """
211
+ # 1. Collect all unique audio paths to avoid redundant computations
212
+ unique_audio_paths = set()
213
+ for path1, path2 in audio_pairs:
214
+ unique_audio_paths.add(path1)
215
+ unique_audio_paths.add(path2)
216
+
217
+ # 2. Extract embeddings for all unique files and store them in a cache
218
+ embedding_cache = {}
219
+ print(f"Extracting embeddings for {len(unique_audio_paths)} "
220
+ "unique audio files...")
221
+ for path in tqdm(list(unique_audio_paths)):
222
+ embedding_cache[path] = self.extract_embedding(path)
223
+
224
+ # 3. Compute similarity for each pair using the cached embeddings
225
+ scores = []
226
+ for path1, path2 in audio_pairs:
227
+ e1 = embedding_cache.get(path1)
228
+ e2 = embedding_cache.get(path2)
229
+
230
+ if e1 is None or e2 is None:
231
+ # Handle cases where embedding extraction failed (e.g., all
232
+ # silence)
233
+ scores.append(0.0)
234
+ else:
235
+ score = self.cosine_similarity(e1, e2)
236
+ scores.append(score)
237
+
238
+ return scores
239
+
240
+ def cosine_similarity(self, e1, e2):
241
+ cosine_score = torch.dot(e1, e2) / (torch.norm(e1) * torch.norm(e2))
242
+ cosine_score = cosine_score.item()
243
+ return (cosine_score + 1.0) / 2 # normalize: [-1, 1] => [0, 1]
244
+
245
+ def register(self, name: str, audio_path: str):
246
+ if name in self.table:
247
+ print('Speaker {} already registered, ignore'.format(name))
248
+ else:
249
+ self.table[name] = self.extract_embedding(audio_path)
250
+
251
+ def recognize(self, audio_path: str):
252
+ q = self.extract_embedding(audio_path)
253
+ best_score = 0.0
254
+ best_name = ''
255
+ for name, e in self.table.items():
256
+ score = self.cosine_similarity(q, e)
257
+ if best_score < score:
258
+ best_score = score
259
+ best_name = name
260
+ result = {}
261
+ result['name'] = best_name
262
+ result['confidence'] = best_score
263
+ return result
264
+
265
+ def diarize(self, audio_path: str, utt: str = "unk"):
266
+
267
+ pcm, sample_rate = torchaudio.load(audio_path, normalize=False)
268
+ # 1. vad
269
+ wav = read_audio(audio_path)
270
+ vad_segments = get_speech_timestamps(wav,
271
+ self.vad,
272
+ return_seconds=True)
273
+ if not vad_segments:
274
+ return []
275
+ # 2. extact fbanks
276
+ subsegs, subseg_fbanks = [], []
277
+ window_fs = int(self.diar_window_secs * 1000) // self.diar_frame_shift
278
+ period_fs = int(self.diar_period_secs * 1000) // self.diar_frame_shift
279
+ for item in vad_segments:
280
+ begin, end = item['start'], item['end']
281
+ if end - begin >= self.diar_min_duration:
282
+ begin_idx = int(begin * sample_rate)
283
+ end_idx = int(end * sample_rate)
284
+ tmp_wavform = pcm[0, begin_idx:end_idx].unsqueeze(0).to(
285
+ torch.float)
286
+ fbank = self.compute_fbank(tmp_wavform,
287
+ sample_rate=sample_rate,
288
+ cmn=False)
289
+ tmp_subsegs, tmp_subseg_fbanks = subsegment(
290
+ fbank=fbank,
291
+ seg_id="{:08d}-{:08d}".format(int(begin * 1000),
292
+ int(end * 1000)),
293
+ window_fs=window_fs,
294
+ period_fs=period_fs,
295
+ frame_shift=self.diar_frame_shift)
296
+ subsegs.extend(tmp_subsegs)
297
+ subseg_fbanks.extend(tmp_subseg_fbanks)
298
+
299
+ # 3. extract embedding
300
+ embeddings = self.extract_embedding_feats(subseg_fbanks,
301
+ self.diar_batch_size,
302
+ self.diar_subseg_cmn)
303
+
304
+ # 4. cluster
305
+ subseg2label = []
306
+ labels = cluster(embeddings)
307
+ for (_subseg, _label) in zip(subsegs, labels):
308
+ # b, e = process_seg_id(_subseg, frame_shift=self.diar_frame_shift)
309
+ # subseg2label.append([b, e, _label])
310
+ begin_ms, end_ms, begin_frames, end_frames = _subseg.split('-')
311
+ begin = (int(begin_ms) +
312
+ int(begin_frames) * self.diar_frame_shift) / 1000.0
313
+ end = (int(begin_ms) +
314
+ int(end_frames) * self.diar_frame_shift) / 1000.0
315
+ subseg2label.append([begin, end, _label])
316
+
317
+ # 5. merged segments
318
+ # [[utt, ([begin, end, label], [])], [utt, ([], [])]]
319
+ merged_segment_to_labels = merge_segments({utt: subseg2label})
320
+
321
+ return merged_segment_to_labels
322
+
323
+ def diarize_list(self, scp_path: str):
324
+ utts = []
325
+ segment2labels = []
326
+ with open(scp_path, 'r', encoding='utf-8') as read_scp:
327
+ for line in tqdm(read_scp):
328
+ utt, wav_path = line.strip().split()
329
+ utts.append(utt)
330
+ segment2label = self.diarize(wav_path, utt)
331
+ segment2labels.append(segment2label)
332
+ return utts, segment2labels
333
+
334
+ def make_rttm(self, merged_segment_to_labels, outfile):
335
+ with open(outfile, 'w', encoding='utf-8') as fin:
336
+ for (utt, begin, end, label) in merged_segment_to_labels:
337
+ fin.write(
338
+ "SPEAKER {} {} {:.3f} {:.3f} <NA> <NA> {} <NA> <NA>\n".
339
+ format(utt, 1, float(begin),
340
+ float(end) - float(begin), label))
341
+
342
+
343
+ def load_model(language: str) -> Speaker:
344
+ model_path = Hub.get_model(language)
345
+ return Speaker(model_path)
346
+
347
+
348
+ def load_model_local(model_dir: str) -> Speaker:
349
+ return Speaker(model_dir)
350
+
351
+
352
+ def main():
353
+ args = get_args()
354
+ if args.pretrain == "":
355
+ if args.campplus:
356
+ model = load_model("campplus")
357
+ model.set_wavform_norm(True)
358
+ elif args.eres2net:
359
+ model = load_model("eres2net")
360
+ model.set_wavform_norm(True)
361
+ elif args.vblinkp:
362
+ model = load_model("vblinkp")
363
+ elif args.vblinkf:
364
+ model = load_model("vblinkf")
365
+ else:
366
+ model = load_model(args.language)
367
+ else:
368
+ model = load_model_local(args.pretrain)
369
+ model.set_resample_rate(args.resample_rate)
370
+ model.set_vad(args.vad)
371
+ model.set_device(args.device)
372
+ model.set_diarization_params(min_duration=args.diar_min_duration,
373
+ window_secs=args.diar_window_secs,
374
+ period_secs=args.diar_period_secs,
375
+ frame_shift=args.diar_frame_shift,
376
+ batch_size=args.diar_emb_bs,
377
+ subseg_cmn=args.diar_subseg_cmn)
378
+ if args.task == 'embedding':
379
+ embedding = model.extract_embedding(args.audio_file)
380
+ if embedding is not None:
381
+ np.savetxt(args.output_file, embedding.detach().numpy())
382
+ print('Succeed, see {}'.format(args.output_file))
383
+ else:
384
+ print('Fails to extract embedding')
385
+ elif args.task == 'embedding_kaldi':
386
+ names, embeddings = model.extract_embedding_list(args.wav_scp)
387
+ embed_ark = args.output_file + ".ark"
388
+ embed_scp = args.output_file + ".scp"
389
+ with kaldiio.WriteHelper('ark,scp:' + embed_ark + "," +
390
+ embed_scp) as writer:
391
+ for name, embedding in zip(names, embeddings):
392
+ writer(name, embedding)
393
+ elif args.task == 'similarity':
394
+ print(model.compute_similarity(args.audio_file, args.audio_file2))
395
+ elif args.task == 'diarization':
396
+ diar_result = model.diarize(args.audio_file)
397
+ if args.output_file is None:
398
+ for (_, start, end, spkid) in diar_result:
399
+ print("{:.3f}\t{:.3f}\t{:d}".format(start, end, spkid))
400
+ else:
401
+ model.make_rttm(diar_result, args.output_file)
402
+ elif args.task == 'diarization_list':
403
+ utts, segment2labels = model.diarize_list(args.wav_scp)
404
+ assert args.output_file is not None
405
+ model.make_rttm(np.vstack(segment2labels), args.output_file)
406
+ else:
407
+ print('Unsupported task {}'.format(args.task))
408
+ sys.exit(-1)
409
+
410
+
411
+ if __name__ == '__main__':
412
+ main()
speaker_client.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # new_client.py
2
+ import argparse
3
+ import asyncio
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+ import torchaudio.compliance.kaldi as kaldi
8
+ import tritonclient.grpc.aio as grpcclient
9
+ import sys
10
+ import time
11
+ import math
12
+
13
+
14
+ class TritonSpeakerClient:
15
+ def __init__(self, url, model_name="speaker_model", verbose=False):
16
+ try:
17
+ self.triton_client = grpcclient.InferenceServerClient(url=url, verbose=verbose)
18
+ except Exception as e:
19
+ print(f"Channel creation failed: {e}", file=sys.stderr)
20
+ sys.exit(1)
21
+
22
+ self.model_name = model_name
23
+
24
+ # --- 从旧的 similarity_model 迁移过来的预处理参数 ---
25
+ self.sample_rate = 16000
26
+ self.feature_dim = 80
27
+ self.min_duration = 0.1
28
+ # ----------------------------------------------------
29
+
30
+ def _preprocess_audio(self, audio_path: str):
31
+ """
32
+ 从音频文件路径加载并预处理音频,生成Fbank特征。
33
+ 这段逻辑完全复制自旧的 similarity_model.py 中的 preprocess 方法。
34
+ """
35
+ try:
36
+ waveform, sample_rate = torchaudio.load(audio_path)
37
+
38
+ # 如果采样率不匹配,则重采样
39
+ if sample_rate != self.sample_rate:
40
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
41
+ waveform = resampler(waveform)
42
+
43
+ # 如果音频太短,则重复填充以满足最小长度
44
+ duration = waveform.shape[1] / self.sample_rate
45
+ if duration < self.min_duration:
46
+ repeat_times = math.ceil(self.min_duration / duration)
47
+ waveform = waveform.repeat(1, repeat_times)
48
+
49
+ # 计算80维Fbank特征
50
+ # waveform 需要是 [batch, time] 格式,所以我们移除通道维度
51
+ if waveform.shape[0] > 1:
52
+ waveform = torch.mean(waveform, dim=0, keepdim=True) # 转为单声道
53
+
54
+ features = kaldi.fbank(
55
+ waveform,
56
+ num_mel_bins=self.feature_dim,
57
+ sample_frequency=self.sample_rate,
58
+ frame_length=25,
59
+ frame_shift=10
60
+ )
61
+ # fbank 输出 shape [1, num_frames, num_bins], 我们需要 [num_frames, 80]
62
+ return features.squeeze(0).numpy().astype(np.float32)
63
+
64
+ except Exception as e:
65
+ raise RuntimeError(f"Failed during audio preprocessing for {audio_path}: {e}")
66
+
67
+ def _calculate_cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray):
68
+ """在客户端计算余弦相似度。"""
69
+ e1 = torch.from_numpy(emb1).flatten()
70
+ e2 = torch.from_numpy(emb2).flatten()
71
+
72
+ similarity = torch.nn.functional.cosine_similarity(e1, e2, dim=0)
73
+
74
+ # 将相似度从 [-1, 1] 范围归一化到 [0, 1]
75
+ return (similarity.item() + 1) / 2
76
+
77
+ async def compute_similarity(self, audio1_path: str, audio2_path: str):
78
+ """
79
+ 计算两个音频文件的相似度。
80
+ 此函数现在包含完整的处理流程:预处理 -> 批处理 -> 推理 -> 后处理。
81
+ """
82
+ # 1. 在客户端对两个音频文件进行预处理
83
+ feats1 = self._preprocess_audio(audio1_path)
84
+ feats2 = self._preprocess_audio(audio2_path)
85
+
86
+ # 2. 批处理:为了使用Triton的动态批处理,我们将两个特征打包成一个请求。
87
+ # 由于它们的长度(帧数)可能不同,我们需要将它们填充到相同的长度。
88
+ max_len = max(feats1.shape[0], feats2.shape[0])
89
+
90
+ # 使用np.pad进行填充
91
+ padded_feats1 = np.pad(feats1, ((0, max_len - feats1.shape[0]), (0, 0)), 'constant', constant_values=0)
92
+ padded_feats2 = np.pad(feats2, ((0, max_len - feats2.shape[0]), (0, 0)), 'constant', constant_values=0)
93
+
94
+ # 将填充后的特征堆叠成一个批次
95
+ input_batch = np.stack([padded_feats1, padded_feats2]) # Shape: [2, max_len, 80]
96
+
97
+ # 3. 创建Triton输入张量
98
+ # 输入名称 "feats" 必须与 speaker_model 的 config.pbtxt 中的输入名匹配
99
+ inputs = [
100
+ grpcclient.InferInput("feats", input_batch.shape, "FP32")
101
+ ]
102
+ inputs[0].set_data_from_numpy(input_batch)
103
+
104
+ # 4. 设置请求的输出
105
+ # 输出名称 "embs" 必须与 speaker_model 的 config.pbtxt 中的输出名匹配
106
+ outputs = [grpcclient.InferRequestedOutput("embs")]
107
+
108
+ # 5. 发送推理请求
109
+ response = await self.triton_client.infer(
110
+ model_name=self.model_name,
111
+ inputs=inputs,
112
+ outputs=outputs
113
+ )
114
+
115
+ # 6. 解析结果
116
+ embeddings_batch = response.as_numpy("embs") # Shape: [2, embedding_dim]
117
+ emb1 = embeddings_batch[0]
118
+ emb2 = embeddings_batch[1]
119
+
120
+ # 7. 在客户端计��相似度
121
+ similarity = self._calculate_cosine_similarity(emb1, emb2)
122
+ return similarity
123
+
124
+ async def main():
125
+ parser = argparse.ArgumentParser(description="Triton client for speaker model (direct call).")
126
+ parser.add_argument('-v', '--verbose', action="store_true", default=False, help='Enable verbose output')
127
+ parser.add_argument('-u', '--url', type=str, default='localhost:8001', help='Inference server URL.')
128
+ # 注意:这里的 model_name 应该是 speaker_model
129
+ parser.add_argument('--model_name', default='speaker_model', help='The name of the speaker embedding model on Triton.')
130
+ parser.add_argument('--audio_file1', type=str, required=True, help='Path to first audio file')
131
+ parser.add_argument('--audio_file2', type=str, required=True, help='Path to second audio file')
132
+
133
+ FLAGS = parser.parse_args()
134
+
135
+ client = TritonSpeakerClient(FLAGS.url, FLAGS.model_name, verbose=FLAGS.verbose)
136
+
137
+ start_time = time.time()
138
+ try:
139
+ similarity = await client.compute_similarity(FLAGS.audio_file1, FLAGS.audio_file2)
140
+ elapsed = time.time() - start_time
141
+ print(f"Similarity: {similarity:.4f}, Time: {elapsed:.3f}s")
142
+ except Exception as e:
143
+ print(f"Error computing similarity: {e}", file=sys.stderr)
144
+ sys.exit(1)
145
+
146
+ # 使用示例:
147
+ # python speaker_client.py --audio_file1=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi1.wav --audio_file2=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi2.wav
148
+ if __name__ == '__main__':
149
+ asyncio.run(main())
test.py ADDED
@@ -0,0 +1,1643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import os
4
+ from typing import List, Dict, Tuple, Any
5
+ import numpy as np
6
+ from pathlib import Path
7
+ import torch
8
+ import torchaudio
9
+ import torchaudio.functional as F
10
+ import logging
11
+ import wespeaker
12
+ import shutil
13
+ from datetime import datetime
14
+ import multiprocessing as mp
15
+ from functools import partial
16
+ import math
17
+ import threading
18
+ import time
19
+ from concurrent.futures import ThreadPoolExecutor, as_completed
20
+ import random # 添加random模块用于shuffle
21
+
22
+ # 设置multiprocessing启动方式为spawn(CUDA兼容)
23
+ mp.set_start_method('spawn', force=True)
24
+
25
+ # 引用词对齐模块
26
+ from alignment import AlignmentModel, batch_get_alignment_result
27
+
28
+ class SpeakerSimilarityEvaluator:
29
+ """音色相似度评估器"""
30
+
31
+ def __init__(self, device="cuda",
32
+ alignment_model_dir='/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/models/mms_fa',
33
+ wespeaker_model_dir='/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft',
34
+ output_dir="./evaluation_results",
35
+ language="ZH",
36
+ similarity_max_workers=8):
37
+ """初始化评估器"""
38
+ self.device = device
39
+ self.alignment_model_dir = alignment_model_dir
40
+ self.wespeaker_model_dir = wespeaker_model_dir
41
+ self.language = language.upper() # 添加语言参数
42
+ self.similarity_max_workers = similarity_max_workers # 相似度计算线程数
43
+
44
+ # 先设置日志系统
45
+ logging.basicConfig(level=logging.INFO)
46
+ self.logger = logging.getLogger(__name__)
47
+
48
+ # 设置输出目录结构
49
+ self.output_dir = Path(output_dir)
50
+ self.segments_dir = self.output_dir / "segments" # 分割后的音频片段
51
+ self.prompts_dir = self.output_dir / "prompts" # prompt音频的S1和S2片段
52
+ self.temp_dir = self.output_dir / "temp" # 临时文件
53
+ self.results_dir = self.output_dir / "results" # 评估结果
54
+ self.temp_results_dir = self.output_dir / "temp_results" # 临时结果文件
55
+ self.alignment_dir = self.output_dir / "alignments" # 对齐信息保存目录
56
+
57
+ # 创建所有必要的目录
58
+ self._create_output_directories()
59
+
60
+ # 在多进程环境中延迟模型初始化
61
+ self.alignment_model = None
62
+ self.similarity_model = None
63
+
64
+ # 线程局部存储,用于线程安全的模型访问
65
+ self._thread_local = threading.local()
66
+
67
+ # 记录运行信息
68
+ self.logger.info(f"评估结果将保存到: {self.output_dir}")
69
+ self.logger.info(f"对齐信息将保存到: {self.alignment_dir}")
70
+ self.logger.info(f"使用语言: {self.language}")
71
+
72
+ def _create_output_directories(self):
73
+ """创建输出目录结构"""
74
+ for dir_path in [self.segments_dir, self.prompts_dir, self.temp_dir,
75
+ self.results_dir, self.temp_results_dir, self.alignment_dir]:
76
+ dir_path.mkdir(parents=True, exist_ok=True)
77
+
78
+ def _get_safe_filename(self, text: str, max_length: int = 50) -> str:
79
+ """生成安全的文件名"""
80
+ # 移除特殊字符,只保留中文、英文、数字和基本符号
81
+ safe_text = re.sub(r'[^\u4e00-\u9fff\w\s]', '', text)
82
+ # 限制长度
83
+ if len(safe_text) > max_length:
84
+ safe_text = safe_text[:max_length]
85
+ # 替换空格为下划线
86
+ safe_text = safe_text.replace(' ', '_')
87
+ return safe_text if safe_text else "unnamed"
88
+
89
+ def _clean_temp_files(self):
90
+ """清理临时文件,但保留临时目录"""
91
+ if self.temp_dir.exists():
92
+ # 只删除临时目录中的文件,不删除目录本身
93
+ for file_path in self.temp_dir.iterdir():
94
+ if file_path.is_file():
95
+ try:
96
+ file_path.unlink()
97
+ except Exception as e:
98
+ self.logger.warning(f"删除临时文件失败: {file_path}, 错误: {e}")
99
+ else:
100
+ # 如果临时目录不存在,重新创建
101
+ self.temp_dir.mkdir(parents=True, exist_ok=True)
102
+
103
+ def _init_models_if_needed(self):
104
+ """延迟初始化模型(用于多进程环境)"""
105
+ # 初始化对齐模型 - 修正参数顺序
106
+ if self.alignment_model is None:
107
+ # 根据AlignmentModel的构造函数,应该是(device, model_dir)而不是(model_dir, device)
108
+ self.alignment_model = AlignmentModel(self.device, self.alignment_model_dir)
109
+
110
+ # 初始化相似度模型
111
+ if self.similarity_model is None:
112
+ self._load_wespeaker_model(self.wespeaker_model_dir)
113
+
114
+ def _is_english_text(self, text: str) -> bool:
115
+ """简单判断文本是否主要是英文"""
116
+ # 计算英文字符的比例
117
+ english_chars = sum(1 for c in text if c.isascii() and c.isalpha())
118
+ total_chars = sum(1 for c in text if c.isalpha())
119
+
120
+ if total_chars == 0:
121
+ return False
122
+
123
+ return english_chars / total_chars > 0.8 # 如果80%以上是英文字符,认为是英文
124
+
125
+ def _detect_language_from_text(self, text: str) -> str:
126
+ """从文本内容检测语言"""
127
+ clean_text = self.remove_speaker_tags(text)
128
+ if self._is_english_text(clean_text):
129
+ return "EN"
130
+ else:
131
+ return "ZH"
132
+
133
+ def save_alignment_info(self, alignment_data: Dict[str, Any], input_id: str, file_type: str = "output"):
134
+ """
135
+ 保存对齐信息到单独的JSON文件
136
+
137
+ Args:
138
+ alignment_data: 对齐信息数据
139
+ input_id: 输入ID
140
+ file_type: 文件类型 ("output", "prompt", "segment")
141
+ """
142
+ try:
143
+ safe_input_id = self._get_safe_filename(input_id)
144
+ alignment_filename = f"{safe_input_id}_{file_type}_alignment.json"
145
+ alignment_path = self.alignment_dir / alignment_filename
146
+
147
+ # 添加元数据
148
+ alignment_info = {
149
+ 'input_id': input_id,
150
+ 'file_type': file_type,
151
+ 'language': self.language,
152
+ 'timestamp': datetime.now().isoformat(),
153
+ 'alignment_data': alignment_data
154
+ }
155
+
156
+ with open(alignment_path, 'w', encoding='utf-8') as f:
157
+ json.dump(alignment_info, f, ensure_ascii=False, indent=2)
158
+
159
+ self.logger.info(f"对齐信息已保存: {alignment_path}")
160
+ return str(alignment_path)
161
+
162
+ except Exception as e:
163
+ self.logger.error(f"保存对齐信息失败: {e}")
164
+ return None
165
+
166
+ def save_detailed_alignment_info(self, alignments: List[Dict[str, Any]],
167
+ text_segments: List[Dict[str, Any]],
168
+ input_id: str, audio_path: str,
169
+ original_text: str, processed_text: str):
170
+ """
171
+ 保存详细的对齐信息,包括分段信息
172
+
173
+ Args:
174
+ alignments: 对齐结果列表
175
+ text_segments: 文本分段信息
176
+ input_id: 输入ID
177
+ audio_path: 音频文件路径
178
+ original_text: 原始文本
179
+ processed_text: 处理后的文本
180
+ """
181
+ alignment_data = {
182
+ 'original_text': original_text,
183
+ 'processed_text': processed_text,
184
+ 'audio_path': audio_path,
185
+ 'language': self.language,
186
+ 'total_alignments': len(alignments),
187
+ 'total_segments': len(text_segments),
188
+ 'alignments': alignments,
189
+ 'text_segments': text_segments,
190
+ 'segment_alignment_mapping': []
191
+ }
192
+
193
+ # 建立文本段和对齐结果的映射关系
194
+ for segment in text_segments:
195
+ segment_mapping = {
196
+ 'segment_id': segment.get('segment_id', 0),
197
+ 'segment_text': segment.get('text', ''),
198
+ 'speaker_label': segment.get('speaker_label', ''),
199
+ 'start_time': segment.get('start_time', 0.0),
200
+ 'end_time': segment.get('end_time', 0.0),
201
+ 'corresponding_alignments': []
202
+ }
203
+
204
+ # 找到对应的对齐项
205
+ segment_start = segment.get('start_time', 0.0)
206
+ segment_end = segment.get('end_time', 0.0)
207
+
208
+ for i, align_item in enumerate(alignments):
209
+ align_start = align_item.get('start', 0.0)
210
+ align_end = align_item.get('end', 0.0)
211
+
212
+ # 检查对齐项是否在当前段的时间范围内
213
+ if (align_start >= segment_start and align_end <= segment_end) or \
214
+ (align_start < segment_end and align_end > segment_start):
215
+ segment_mapping['corresponding_alignments'].append({
216
+ 'alignment_index': i,
217
+ 'transcript': align_item.get('transcript', ''),
218
+ 'start': align_start,
219
+ 'end': align_end,
220
+ 'score': align_item.get('score', 0.0) if 'score' in align_item else None
221
+ })
222
+
223
+ alignment_data['segment_alignment_mapping'].append(segment_mapping)
224
+
225
+ return self.save_alignment_info(alignment_data, input_id, "detailed")
226
+
227
+ def remove_speaker_tags(self, text: str) -> str:
228
+ """删除文本中的说话人标签[S1][S2]"""
229
+ return re.sub(r'\[S[12]\]', '', text).strip()
230
+
231
+ def extract_speaker_segments(self, text: str) -> List[Dict[str, Any]]:
232
+ """提取文本中的说话人片段信息"""
233
+ segments = []
234
+ pattern = r'\[S([12])\]([^[]*)'
235
+ matches = re.findall(pattern, text)
236
+
237
+ for speaker_id, content in matches:
238
+ segments.append({
239
+ 'speaker': f'S{speaker_id}',
240
+ 'content': content.strip()
241
+ })
242
+ return segments
243
+
244
+ def replace_punctuation_with_comma(self, text: str, language: str = None) -> str:
245
+ """将所有标点符号替换为逗号,连续逗号只保留一个,根据语言选择正确的逗号类型"""
246
+ # 如果未指定语言,使用类的默认语言设置或自动检测
247
+ if language is None:
248
+ if hasattr(self, 'language'):
249
+ language = self.language
250
+ else:
251
+ language = self._detect_language_from_text(text)
252
+
253
+ language = language.upper()
254
+
255
+ # 根据语言选择逗号类型和处理策略
256
+ if language == "EN" or (language == "AUTO" and self._is_english_text(text)):
257
+ # 英文处理:先删除撇号,再替换其他标点符号
258
+ text = re.sub(r"'", '', text) # 删除撇号(don't -> dont)
259
+ target_comma = ',' # 英文逗号
260
+ comma_pattern = r',+' # 匹配连续英文逗号
261
+ # 更新正则表达式,不包含撇号
262
+ text = re.sub(r'[.,!?;:()\[\]<>\"…·,。;:!?()【】《》""\\、]', target_comma, text)
263
+ else:
264
+ # 中文处理:包含撇号在替换范围内
265
+ target_comma = ',' # 中文逗号
266
+ comma_pattern = r',+' # 匹配连续中文逗号
267
+ # 更新正则表达式以匹配更多的标点符号
268
+ text = re.sub(r'[.,!?;:()\[\]<>\'\"…·,。;:!?()【】《》''""\\、]', target_comma, text)
269
+
270
+ text = re.sub(comma_pattern, target_comma, text)
271
+ return text.strip(target_comma)
272
+
273
+ def align_text_with_audio(self, text: str, audio_path: str, language=None) -> List[Dict[str, Any]]:
274
+ """
275
+ 文本和音频的词对齐
276
+ 返回每个词对应的音频时间段
277
+ """
278
+ # 确保模型已初始化
279
+ self._init_models_if_needed()
280
+
281
+ # 如果未指定语言,使用类的默认语言设置或自动检测
282
+ if language is None:
283
+ if hasattr(self, 'language'):
284
+ language = self.language
285
+ else:
286
+ language = self._detect_language_from_text(text)
287
+ else:
288
+ language = language.upper()
289
+
290
+ # 加载音频
291
+ waveform, sample_rate = torchaudio.load(audio_path)
292
+
293
+ # 重采样到模型要求的采样率
294
+ if sample_rate != self.alignment_model.bundle.sample_rate:
295
+ waveform = F.resample(waveform, sample_rate, self.alignment_model.bundle.sample_rate)
296
+
297
+ # 转换为单声道
298
+ if waveform.shape[0] > 1:
299
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
300
+
301
+ waveform = waveform.squeeze(0) # 移除批次维度
302
+
303
+ # 将音频移动到正确的设备
304
+ waveform = waveform.to(self.device)
305
+
306
+ # 执行对齐
307
+ try:
308
+ alignment_results = batch_get_alignment_result(
309
+ self.alignment_model,
310
+ [waveform],
311
+ [text],
312
+ [language]
313
+ )
314
+ if not alignment_results or not alignment_results[0]:
315
+ raise RuntimeError(f"对齐结果为空: {audio_path}")
316
+ return alignment_results[0]
317
+ except Exception as e:
318
+ self.logger.error(f"音频对齐失败: {audio_path}")
319
+ self.logger.error(f"错误详情: {e}")
320
+ raise RuntimeError(f"音频对齐失败,程序终止。文件: {audio_path},错误: {e}")
321
+
322
+ def split_audio_segment(self, audio_path: str, start_time: float, end_time: float, output_path: str):
323
+ """分割音频片段"""
324
+ waveform, sample_rate = torchaudio.load(audio_path)
325
+
326
+ start_frame = int(start_time * sample_rate)
327
+ end_frame = int(end_time * sample_rate)
328
+
329
+ segment = waveform[:, start_frame:end_frame]
330
+
331
+ # 确保输出目录存在
332
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
333
+
334
+ torchaudio.save(output_path, segment, sample_rate)
335
+ return output_path
336
+
337
+ def concatenate_audio_files(self, audio_files: List[str], output_path: str):
338
+ """拼接多个音频文件"""
339
+ if not audio_files:
340
+ return
341
+
342
+ waveforms = []
343
+ sample_rate = None
344
+
345
+ for audio_file in audio_files:
346
+ if os.path.exists(audio_file):
347
+ waveform, sr = torchaudio.load(audio_file)
348
+ if sample_rate is None:
349
+ sample_rate = sr
350
+ elif sr != sample_rate:
351
+ waveform = F.resample(waveform, sr, sample_rate)
352
+ waveforms.append(waveform)
353
+
354
+ if waveforms:
355
+ concatenated = torch.cat(waveforms, dim=1)
356
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
357
+ torchaudio.save(output_path, concatenated, sample_rate)
358
+
359
+ def split_audio_by_speaker(self, prompt_text: str, prompt_audio: str, audio_id: str) -> Tuple[str, str]:
360
+ """
361
+ 根据说话人标签分割prompt音频
362
+ 返回S1和S2的音频片段路径
363
+ """
364
+ # 1. 提取说话人片段
365
+ speaker_segments = self.extract_speaker_segments(prompt_text)
366
+
367
+ # 2. 删除标签后进行词对齐 - 如果失败则直接抛出异常
368
+ clean_text = self.remove_speaker_tags(prompt_text)
369
+
370
+ # 检测语言或使用设置的语言
371
+ alignment_language = self.language
372
+ if alignment_language == "AUTO":
373
+ alignment_language = self._detect_language_from_text(clean_text)
374
+
375
+ alignments = self.align_text_with_audio(clean_text, prompt_audio, alignment_language)
376
+
377
+ # 保存prompt对齐信息
378
+ prompt_alignment_data = {
379
+ 'original_text': prompt_text,
380
+ 'clean_text': clean_text,
381
+ 'audio_path': prompt_audio,
382
+ 'language': alignment_language,
383
+ 'speaker_segments': speaker_segments,
384
+ 'alignments': alignments
385
+ }
386
+ self.save_alignment_info(prompt_alignment_data, audio_id, "prompt")
387
+
388
+ # 3. 根据对齐结果分割音频
389
+ s1_segments = []
390
+ s2_segments = []
391
+
392
+ # 为每个说话人片段找到对应的时间段
393
+ text_pos = 0
394
+ for seg in speaker_segments:
395
+ seg_text = seg['content'].strip()
396
+ seg_length = len(seg_text)
397
+
398
+ # 找到这个片段在对齐结果中的起始和结束
399
+ start_time = None
400
+ end_time = None
401
+
402
+ current_pos = 0
403
+ for align_item in alignments:
404
+ item_text = align_item['transcript']
405
+ item_length = len(item_text)
406
+
407
+ if current_pos >= text_pos and current_pos < text_pos + seg_length:
408
+ if start_time is None:
409
+ start_time = align_item['start']
410
+ end_time = align_item['end']
411
+
412
+ current_pos += item_length
413
+
414
+ if start_time is not None and end_time is not None:
415
+ if seg['speaker'] == 'S1':
416
+ s1_segments.append((start_time, end_time))
417
+ else:
418
+ s2_segments.append((start_time, end_time))
419
+
420
+ text_pos += seg_length
421
+
422
+ # 4. 分割并拼接音频片段
423
+ safe_audio_id = self._get_safe_filename(audio_id)
424
+ prompts1_path = str(self.prompts_dir / f"{safe_audio_id}_s1.wav")
425
+ prompts2_path = str(self.prompts_dir / f"{safe_audio_id}_s2.wav")
426
+
427
+ # 分割S1的所有片段
428
+ if s1_segments:
429
+ s1_temp_segments = []
430
+ for i, (start, end) in enumerate(s1_segments):
431
+ temp_path = str(self.temp_dir / f"{safe_audio_id}_s1_temp_{i}.wav")
432
+ self.split_audio_segment(prompt_audio, start, end, temp_path)
433
+ s1_temp_segments.append(temp_path)
434
+
435
+ # 拼接S1片段
436
+ self.concatenate_audio_files(s1_temp_segments, prompts1_path)
437
+
438
+ # 分割S2的所有片段
439
+ if s2_segments:
440
+ s2_temp_segments = []
441
+ for i, (start, end) in enumerate(s2_segments):
442
+ temp_path = str(self.temp_dir / f"{safe_audio_id}_s2_temp_{i}.wav")
443
+ self.split_audio_segment(prompt_audio, start, end, temp_path)
444
+ s2_temp_segments.append(temp_path)
445
+
446
+ # 拼接S2片段
447
+ self.concatenate_audio_files(s2_temp_segments, prompts2_path)
448
+
449
+ return prompts1_path, prompts2_path
450
+
451
+ def map_text_segments_to_speakers(self, original_text: str) -> List[Dict[str, Any]]:
452
+ """
453
+ 将原始文本按说话人和标点符号同时分割,保持映射关系
454
+ 支持英文单词级别的处理
455
+ """
456
+ segments = []
457
+ pattern = r'\[S([12])\]([^[]*)'
458
+ matches = re.findall(pattern, original_text)
459
+
460
+ # 检测语言或使用设置的语言
461
+ alignment_language = self.language
462
+ if alignment_language == "AUTO":
463
+ alignment_language = self._detect_language_from_text(original_text)
464
+
465
+ segment_id = 0
466
+ for speaker_id, content in matches:
467
+ speaker = f'S{speaker_id}'
468
+ clean_content = content.strip()
469
+ comma_content = self.replace_punctuation_with_comma(clean_content, alignment_language)
470
+
471
+ # 根据语言选择正确的逗号分割
472
+ if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_content)):
473
+ # 英文:按��文逗号分割,保持单词完整性
474
+ parts = [part.strip() for part in comma_content.split(',') if part.strip()]
475
+ else:
476
+ # 中文:按中文逗号分割
477
+ parts = [part.strip() for part in comma_content.split(',') if part.strip()]
478
+
479
+ for part in parts:
480
+ if part.strip():
481
+ segments.append({
482
+ 'segment_id': segment_id,
483
+ 'text': part.strip(),
484
+ 'speaker_label': speaker,
485
+ 'original_speaker_content': clean_content
486
+ })
487
+ segment_id += 1
488
+
489
+ return segments
490
+
491
+ def split_output_audio_by_comma(self, text: str, output_audio: str, audio_id: str) -> List[Dict[str, Any]]:
492
+ """
493
+ 根据逗号分割输出音频,返回每小段的信息 - 基于词对齐结果中的标点符号划分句子
494
+ """
495
+ # 1. 获取文本片段和对应的说话人(用于获取speaker标签)
496
+ text_segments = self.map_text_segments_to_speakers(text)
497
+
498
+ # 2. 删除标签并替换标点符号
499
+ clean_text = self.remove_speaker_tags(text)
500
+
501
+ # 3. 检测语言或使用设置的语言
502
+ alignment_language = self.language
503
+ if alignment_language == "AUTO":
504
+ alignment_language = self._detect_language_from_text(clean_text)
505
+
506
+ # 使用检测到的语言替换标点符号
507
+ comma_text = self.replace_punctuation_with_comma(clean_text, alignment_language)
508
+
509
+ # 4. 词对齐 - 如果失败则直接抛出异常
510
+ alignments = self.align_text_with_audio(comma_text, output_audio, alignment_language)
511
+
512
+ # 5. 根据标点符号划分句子
513
+ segments = []
514
+ safe_audio_id = self._get_safe_filename(audio_id)
515
+
516
+ # 确定标点符号(根据语言选择,英文不包含撇号)
517
+ if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
518
+ punctuation_chars = set([',', '.', '!', '?', ';', ':']) # 不包含撇号
519
+ else:
520
+ punctuation_chars = set([',', '。', '!', '?', ';', ':'])
521
+
522
+ # 顺序扫描对齐结果,根据标点符号划分句子
523
+ sentence_start_idx = 0
524
+ sentence_alignments = []
525
+ segment_id = 0
526
+
527
+ for i, align_item in enumerate(alignments):
528
+ transcript = align_item['transcript']
529
+ sentence_alignments.append(align_item)
530
+
531
+ # 检查是否包含标点符号(句子结束标志)
532
+ has_punctuation = any(punct in transcript for punct in punctuation_chars)
533
+
534
+ if has_punctuation or i == len(alignments) - 1: # 遇到标点符号或最后一个词
535
+ # 创建句子片段
536
+ if sentence_alignments:
537
+ # 获取句子的开始和结束时间
538
+ start_time = sentence_alignments[0]['start']
539
+ end_time = sentence_alignments[-1]['end']
540
+
541
+ # 构建句子文本(去除标点符号)
542
+ sentence_text_parts = []
543
+ for align in sentence_alignments:
544
+ # 根据语言选择不同的清理策略
545
+ if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
546
+ # 英文:去除标点符号,但保留撇号已被删除的单词
547
+ clean_transcript = align['transcript'].rstrip(',.!?;:')
548
+ else:
549
+ # 中文:去除中文标点符号
550
+ clean_transcript = align['transcript'].rstrip(',。!?;:')
551
+
552
+ if clean_transcript.strip():
553
+ sentence_text_parts.append(clean_transcript)
554
+
555
+ # 根据语言选择连接方式
556
+ if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
557
+ sentence_text = ' '.join(sentence_text_parts).strip() # 英文用空格连接
558
+ else:
559
+ sentence_text = ''.join(sentence_text_parts).strip() # 中文直接连接
560
+
561
+ if sentence_text: # 只有非空句子才处理
562
+ # 确定说话人标签(从原始text_segments中获取,如果可能的话)
563
+ speaker_label = "S1" # 默认
564
+ if segment_id < len(text_segments):
565
+ speaker_label = text_segments[segment_id]['speaker_label']
566
+ elif text_segments:
567
+ # 如果超出范围,使用最后一个片段的speaker
568
+ speaker_label = text_segments[-1]['speaker_label']
569
+
570
+ # 生成音频文件路径
571
+ safe_text = self._get_safe_filename(sentence_text, 30)
572
+ audio_path = str(self.segments_dir / f"{safe_audio_id}_segment_{segment_id:03d}_{safe_text}.wav")
573
+
574
+ # 分割音频
575
+ try:
576
+ self.split_audio_segment(output_audio, start_time, end_time, audio_path)
577
+ except Exception as e:
578
+ self.logger.error(f"分割音频失败: {e}")
579
+ # 使用默认时间间隔
580
+ start_time = segment_id * 1.0
581
+ end_time = (segment_id + 1) * 1.0
582
+ self.split_audio_segment(output_audio, start_time, end_time, audio_path)
583
+
584
+ # 创建segment
585
+ segment = {
586
+ 'segment_id': segment_id,
587
+ 'text': sentence_text,
588
+ 'speaker_label': speaker_label,
589
+ 'original_speaker_content': sentence_text, # 这里简化处理
590
+ 'audio_path': audio_path,
591
+ 'start_time': start_time,
592
+ 'end_time': end_time
593
+ }
594
+
595
+ segments.append(segment)
596
+
597
+ self.logger.info(f"句子 {segment_id}: '{sentence_text}' ({speaker_label}) -> {start_time:.3f}-{end_time:.3f}s")
598
+ segment_id += 1
599
+
600
+ # 重置为下一个句子
601
+ sentence_alignments = []
602
+ sentence_start_idx = i + 1
603
+
604
+ # 保存详细的对齐信息
605
+ self.save_detailed_alignment_info(
606
+ alignments, segments, audio_id, output_audio, text, comma_text
607
+ )
608
+
609
+ self.logger.info(f"总共分割出 {len(segments)} 个句子片段")
610
+ return segments
611
+
612
+ def _get_thread_local_similarity_model(self):
613
+ """获取线程局部的相似度模型实例(线程安全)"""
614
+ if not hasattr(self._thread_local, 'similarity_model'):
615
+ # 为当前线程创建独立的模型实例
616
+ self._thread_local.similarity_model = self._create_similarity_model()
617
+ return self._thread_local.similarity_model
618
+
619
+ def _create_similarity_model(self):
620
+ """创建新的相似度模型实例"""
621
+ try:
622
+ import wespeaker
623
+
624
+ # 使用与主模型相同的加载逻辑
625
+ local_model_path = '/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft'
626
+
627
+ try:
628
+ model = wespeaker.load_model_local(local_model_path)
629
+ return model
630
+ except Exception as e:
631
+ self.logger.warning(f"加载指定本地模型失败: {e}")
632
+
633
+ # 回退方案
634
+ if os.path.exists(self.wespeaker_model_dir):
635
+ try:
636
+ model = wespeaker.load_model_local(self.wespeaker_model_dir)
637
+ return model
638
+ except Exception as e:
639
+ self.logger.warning(f"加载传入本地模型失败: {e}")
640
+
641
+ # 最终回退到预训练模型
642
+ try:
643
+ model = wespeaker.load_model('chinese')
644
+ return model
645
+ except Exception as e:
646
+ model = wespeaker.load_model('english')
647
+ return model
648
+
649
+ except Exception as e:
650
+ self.logger.error(f"创建相似度模型失败: {e}")
651
+ raise
652
+
653
+ def calculate_voice_similarity_thread_safe(self, audio1_path: str, audio2_path: str) -> float:
654
+ """
655
+ 线程安全的音色相似度计算
656
+ 对于过短的音频片段,通过复制来达到最小长度要求
657
+ """
658
+ try:
659
+ if not os.path.exists(audio1_path) or not os.path.exists(audio2_path):
660
+ self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}")
661
+ return None
662
+
663
+ # 获取线程局部的模型实例
664
+ similarity_model = self._get_thread_local_similarity_model()
665
+
666
+ # 检查并处理音频文件长度
667
+ def process_audio_for_similarity(audio_path, min_duration=0.1):
668
+ """
669
+ 处理音频文件,如果过短则复制到满足最小长度要求
670
+ 返回处理后的音频路径和是否为临时文件的标志
671
+ """
672
+ try:
673
+ waveform, sample_rate = torchaudio.load(audio_path)
674
+ duration = waveform.shape[1] / sample_rate
675
+
676
+ if duration >= min_duration:
677
+ # 音频长度足够,直接返回原路径
678
+ return audio_path, False
679
+
680
+ # 音频过短,需要复制
681
+ repeat_times = math.ceil(min_duration / duration)
682
+ thread_id = threading.get_ident()
683
+
684
+ # 复制音频
685
+ repeated_waveform = waveform.repeat(1, repeat_times)
686
+
687
+ # 生成临时文件路径(包含线程ID避免冲突)
688
+ temp_filename = f"temp_{thread_id}_{os.path.basename(audio_path)}"
689
+ temp_path = str(self.temp_dir / temp_filename)
690
+
691
+ # 保存复制后的音频
692
+ torchaudio.save(temp_path, repeated_waveform, sample_rate)
693
+
694
+ return temp_path, True
695
+
696
+ except Exception as e:
697
+ self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}")
698
+ return audio_path, False
699
+
700
+ # 处理两个音频文件
701
+ processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path)
702
+ processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path)
703
+
704
+ # 计算相似度
705
+ similarity = similarity_model.compute_similarity(processed_audio1, processed_audio2)
706
+
707
+ # 清理临时文件
708
+ if is_temp1 and os.path.exists(processed_audio1):
709
+ try:
710
+ os.remove(processed_audio1)
711
+ except Exception as e:
712
+ self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}")
713
+
714
+ if is_temp2 and os.path.exists(processed_audio2):
715
+ try:
716
+ os.remove(processed_audio2)
717
+ except Exception as e:
718
+ self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}")
719
+
720
+ return float(similarity)
721
+
722
+ except Exception as e:
723
+ # 检查是否是窗口大小错误或其他计算错误
724
+ if "choose a window size" in str(e) or "window size" in str(e):
725
+ self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}")
726
+ return None
727
+ else:
728
+ self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}")
729
+ return None
730
+
731
+ def calculate_segment_similarities_parallel(self, output_segments: List[Dict[str, Any]],
732
+ prompts1_path: str, prompts2_path: str) -> List[Dict[str, Any]]:
733
+ """
734
+ 并行计算所有segments的相似度
735
+ Args:
736
+ output_segments: 音频segments列表
737
+ prompts1_path: S1 prompt音频路径
738
+ prompts2_path: S2 prompt音频路径
739
+ Returns:
740
+ 包含相似度信息的segment列表
741
+ """
742
+
743
+ def calculate_single_segment_similarity(segment):
744
+ """计算单个segment与两个prompts的相似度"""
745
+ try:
746
+ # 使用线程安全的相似度计算方法
747
+ sim1 = self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts1_path)
748
+ sim2 = self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts2_path)
749
+
750
+ return {
751
+ 'segment': segment,
752
+ 'sim1': sim1,
753
+ 'sim2': sim2,
754
+ 'success': True
755
+ }
756
+ except Exception as e:
757
+ self.logger.error(f"计算segment {segment['segment_id']} 相似度失败: {e}")
758
+ return {
759
+ 'segment': segment,
760
+ 'sim1': None,
761
+ 'sim2': None,
762
+ 'success': False
763
+ }
764
+
765
+ # 使用线程池并行处理所有segments
766
+ self.logger.info(f"开始并行计算 {len(output_segments)} 个segments的相似度,使用 {self.similarity_max_workers} 个线程")
767
+
768
+ results = []
769
+ with ThreadPoolExecutor(max_workers=self.similarity_max_workers) as executor:
770
+ # 提交所有segment任务
771
+ future_to_segment = {
772
+ executor.submit(calculate_single_segment_similarity, segment): segment
773
+ for segment in output_segments
774
+ }
775
+
776
+ # 收集结果(保持原有顺序)
777
+ segment_to_result = {}
778
+ completed_count = 0
779
+ for future in as_completed(future_to_segment):
780
+ result = future.result()
781
+ segment_id = result['segment']['segment_id']
782
+ segment_to_result[segment_id] = result
783
+ completed_count += 1
784
+
785
+ # 每完成10个segment报告一次进度
786
+ if completed_count % 10 == 0 or completed_count == len(output_segments):
787
+ self.logger.info(f"相似度计算进度: {completed_count}/{len(output_segments)}")
788
+
789
+ # 按segment_id顺序返回结果
790
+ for segment in output_segments:
791
+ segment_id = segment['segment_id']
792
+ if segment_id in segment_to_result:
793
+ results.append(segment_to_result[segment_id])
794
+
795
+ return results
796
+
797
+ def evaluate_single_input(self, data: Dict[str, Any], input_id: str = None) -> Dict[str, Any]:
798
+ """评估单个输入的音色相似度"""
799
+
800
+ # 生成输入ID
801
+ if input_id is None:
802
+ input_id = f"input_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
803
+
804
+ self.logger.info(f"开始评估输入: {input_id},使用语言: {self.language}")
805
+
806
+ # 1. 获取或分割prompt音频
807
+ prompts1_path, prompts2_path = self.get_or_split_prompt_audio(data, f"{input_id}_prompt")
808
+
809
+ # 2. 分割output音频(这里会保存详细对齐信息)
810
+ output_segments = self.split_output_audio_by_comma(data['text'], data['output_audio'], f"{input_id}_output")
811
+
812
+ # 3. 并行计算每小段的相似度
813
+ similarity_results = self.calculate_segment_similarities_parallel(
814
+ output_segments, prompts1_path, prompts2_path
815
+ )
816
+
817
+ # 4. 处理相似度结果
818
+ segment_results = []
819
+ correct_predictions = 0
820
+ total_segments = 0 # 只计算有效段数
821
+ label_similarities = [] # 每小段与其标签的相似度
822
+ skipped_segments = 0 # 跳过的段数
823
+
824
+ for sim_result in similarity_results:
825
+ segment = sim_result['segment']
826
+ sim1 = sim_result['sim1']
827
+ sim2 = sim_result['sim2']
828
+
829
+ # 如果任一相似度为None(音频过短或计算失败),跳过该段
830
+ if sim1 is None or sim2 is None:
831
+ skipped_segments += 1
832
+ self.logger.info(f"跳过段 {segment['segment_id']}: 相似度计算失败")
833
+ continue
834
+
835
+ # 只有有效段才参与计算
836
+ total_segments += 1
837
+
838
+ # 判断实际音色
839
+ predicted_speaker = 'S1' if sim1 > sim2 else 'S2'
840
+ actual_speaker = segment['speaker_label']
841
+ is_correct = predicted_speaker == actual_speaker
842
+
843
+ if is_correct:
844
+ correct_predictions += 1
845
+
846
+ # 计算与标签的相似度
847
+ if actual_speaker == 'S1':
848
+ label_similarity = sim1
849
+ else:
850
+ label_similarity = sim2
851
+ label_similarities.append(label_similarity)
852
+
853
+ segment_result = {
854
+ 'segment_id': segment['segment_id'],
855
+ 'text': segment['text'],
856
+ 'speaker_label': actual_speaker,
857
+ 'predicted_speaker': predicted_speaker,
858
+ 'sim1': sim1,
859
+ 'sim2': sim2,
860
+ 'label_similarity': label_similarity,
861
+ 'is_correct': is_correct,
862
+ 'audio_path': segment['audio_path'],
863
+ 'start_time': segment.get('start_time', 0.0),
864
+ 'end_time': segment.get('end_time', 1.0)
865
+ }
866
+ segment_results.append(segment_result)
867
+
868
+ # 4. 计算整体指标(只基于有效段)
869
+ accuracy = correct_predictions / total_segments if total_segments > 0 else 0.0
870
+ average_similarity = np.mean(label_similarities) if label_similarities else 0.0
871
+
872
+ # 5. 保存评估结果的对齐信息摘要
873
+ evaluation_alignment_summary = {
874
+ 'input_id': input_id,
875
+ 'language': self.language,
876
+ 'prompt_alignment_files': [
877
+ f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json"
878
+ ],
879
+ 'output_alignment_file': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json",
880
+ 'total_segments': total_segments,
881
+ 'total_alignments_processed': len(output_segments),
882
+ 'alignment_success_rate': total_segments / len(output_segments) if output_segments else 0.0
883
+ }
884
+ self.save_alignment_info(evaluation_alignment_summary, input_id, "summary")
885
+
886
+ result = {
887
+ 'input_id': input_id,
888
+ 'language': self.language,
889
+ 'input_data': data, # 保存原始输入数据
890
+ 'prompts1_path': prompts1_path,
891
+ 'prompts2_path': prompts2_path,
892
+ 'segments': segment_results,
893
+ 'accuracy': accuracy,
894
+ 'average_similarity': average_similarity,
895
+ 'total_segments': total_segments, # 有效段数
896
+ 'correct_predictions': correct_predictions,
897
+ 'skipped_segments': skipped_segments, # 跳过的段数
898
+ 'original_total_segments': len(output_segments), # 原始总段数
899
+ 'alignment_files': {
900
+ 'summary': f"{self._get_safe_filename(input_id)}_summary_alignment.json",
901
+ 'output_detailed': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json",
902
+ 'prompt': f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json"
903
+ },
904
+ 'timestamp': datetime.now().isoformat()
905
+ }
906
+
907
+ self.logger.info(f"完成评估输入: {input_id}, 语言: {self.language}, 有效段: {total_segments}/{len(output_segments)}, 跳过: {skipped_segments}, 准确率: {accuracy:.3f}, 平均相似度: {average_similarity:.3f}")
908
+
909
+ return result
910
+
911
+ def save_results_to_jsonl(self, results: List[Dict[str, Any]], filename: str = None):
912
+ """保存结果到JSONL文件"""
913
+ if filename is None:
914
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
915
+ filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
916
+
917
+ output_path = self.results_dir / filename
918
+
919
+ with open(output_path, 'w', encoding='utf-8') as f:
920
+ for result in results:
921
+ f.write(json.dumps(result, ensure_ascii=False) + '\n')
922
+
923
+ return str(output_path)
924
+
925
+ def save_summary_report(self, results: List[Dict[str, Any]], filename: str = None):
926
+ """保存汇总报告"""
927
+ if filename is None:
928
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
929
+ filename = f"evaluation_summary_{self.language.lower()}_{timestamp}.json"
930
+
931
+ summary_path = self.results_dir / filename
932
+
933
+ # 计算总体统计
934
+ total_accuracy = np.mean([r['accuracy'] for r in results])
935
+ total_avg_similarity = np.mean([r['average_similarity'] for r in results])
936
+ total_segments = sum([r['total_segments'] for r in results])
937
+ total_correct = sum([r['correct_predictions'] for r in results])
938
+
939
+ summary = {
940
+ 'evaluation_summary': {
941
+ 'language': self.language,
942
+ 'total_inputs': len(results),
943
+ 'total_segments': total_segments,
944
+ 'total_correct_predictions': total_correct,
945
+ 'overall_accuracy': total_accuracy,
946
+ 'overall_average_similarity': total_avg_similarity,
947
+ 'evaluation_timestamp': datetime.now().isoformat(),
948
+ 'output_directory': str(self.output_dir),
949
+ 'alignment_directory': str(self.alignment_dir)
950
+ },
951
+ 'per_input_results': [
952
+ {
953
+ 'input_id': r['input_id'],
954
+ 'language': r.get('language', self.language),
955
+ 'accuracy': r['accuracy'],
956
+ 'average_similarity': r['average_similarity'],
957
+ 'total_segments': r['total_segments'],
958
+ 'correct_predictions': r['correct_predictions'],
959
+ 'output_audio_path': r['input_data']['output_audio'],
960
+ 'alignment_files': r.get('alignment_files', {})
961
+ }
962
+ for r in results
963
+ ]
964
+ }
965
+
966
+ with open(summary_path, 'w', encoding='utf-8') as f:
967
+ json.dump(summary, f, ensure_ascii=False, indent=2)
968
+
969
+ return str(summary_path)
970
+
971
+ def process_batch_from_jsonl_parallel(self, jsonl_path: str,
972
+ processes_per_gpu: int = 16,
973
+ results_filename: str = None,
974
+ shuffle_data: bool = True):
975
+ """从JSONL文件并行批量处理输入数据"""
976
+ # 加载数据
977
+ input_data = self.load_data_from_jsonl(jsonl_path)
978
+
979
+ if not input_data:
980
+ self.logger.error("没有有效的输入数据")
981
+ return []
982
+
983
+ # 对数据进行shuffle,使分配更均匀
984
+ if shuffle_data:
985
+ random.shuffle(input_data)
986
+ self.logger.info(f"已对 {len(input_data)} 条数据进行随机shuffle")
987
+
988
+ return self.process_batch_parallel(input_data, processes_per_gpu, results_filename)
989
+
990
+ def process_batch_from_jsonl(self, jsonl_path: str, results_filename: str = None):
991
+ """从JSONL文件批量处理输入数据(单进程版本)"""
992
+ # 加载数据
993
+ input_data = self.load_data_from_jsonl(jsonl_path)
994
+
995
+ if not input_data:
996
+ self.logger.error("没有有效的输入数据")
997
+ return []
998
+
999
+ return self.process_batch_from_data(input_data, results_filename)
1000
+
1001
+ def process_batch_from_data(self, input_data: List[Dict[str, Any]], results_filename: str = None):
1002
+ """处理数据列表(单进程版本,用于兼容),支持增量写入"""
1003
+ # 准备结果文件
1004
+ if results_filename is None:
1005
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
1006
+ results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
1007
+
1008
+ results_path = self.results_dir / results_filename
1009
+
1010
+ # 如果文件已存在,删除它(重新开始)
1011
+ if results_path.exists():
1012
+ results_path.unlink()
1013
+
1014
+ results = []
1015
+
1016
+ self.logger.info(f"开始处理 {len(input_data)} 个输入,使用语言: {self.language}...")
1017
+
1018
+ for i, data in enumerate(input_data):
1019
+ input_id = f"input_{i+1:03d}"
1020
+ print(f"处理第{i+1}/{len(input_data)}个输入: {input_id},语言: {self.language}")
1021
+
1022
+ try:
1023
+ result = self.evaluate_single_input(data, input_id=input_id)
1024
+ results.append(result)
1025
+
1026
+ # 增量写入结果
1027
+ self.append_result_to_jsonl(result, str(results_path))
1028
+
1029
+ except Exception as e:
1030
+ self.logger.error(f"处理输入{input_id}时出错: {e}")
1031
+ continue
1032
+
1033
+ if not results:
1034
+ self.logger.error("没有成功处理的输入")
1035
+ return []
1036
+
1037
+ # 保存汇总报告
1038
+ summary_path = self.save_summary_report(results)
1039
+
1040
+ # 清理临时文件
1041
+ self._clean_temp_files()
1042
+
1043
+ # 打印总体统计
1044
+ total_accuracy = np.mean([r['accuracy'] for r in results])
1045
+ total_avg_similarity = np.mean([r['average_similarity'] for r in results])
1046
+
1047
+ print(f"\n=== 评估完成 ===")
1048
+ print(f"使用语言: {self.language}")
1049
+ print(f"总体准确率: {total_accuracy:.3f}")
1050
+ print(f"总体平均相似度: {total_avg_similarity:.3f}")
1051
+ print(f"详细结果已保存到: {results_path}")
1052
+ print(f"汇总报告已保存到: {summary_path}")
1053
+ print(f"对齐信息已保存到: {self.alignment_dir}")
1054
+ print(f"所有中间文件保存在: {self.output_dir}")
1055
+
1056
+ return results
1057
+
1058
+ def _load_wespeaker_model(self, wespeaker_model_dir):
1059
+ """加载wespeaker模型"""
1060
+ try:
1061
+ import wespeaker
1062
+
1063
+ # 使用load_model_local方法加载本地模型
1064
+ # 根据你提供的参考,使用你指定的模型路径
1065
+ local_model_path = '/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft'
1066
+
1067
+ try:
1068
+ self.similarity_model = wespeaker.load_model_local(local_model_path)
1069
+ self.logger.info(f"成功加载本地wespeaker模型: {local_model_path}")
1070
+ return
1071
+ except Exception as e:
1072
+ self.logger.warning(f"加载指定本地模型失败: {e}")
1073
+
1074
+ # 回退方案1: 尝试使用传入的模型目录
1075
+ if os.path.exists(wespeaker_model_dir):
1076
+ try:
1077
+ self.similarity_model = wespeaker.load_model_local(wespeaker_model_dir)
1078
+ self.logger.info(f"成功加载传入的本地wespeaker模型: {wespeaker_model_dir}")
1079
+ return
1080
+ except Exception as e:
1081
+ self.logger.warning(f"加载传入本地模型失败: {e}")
1082
+
1083
+ # 回退方案2: 使用预训练的中文模型
1084
+ try:
1085
+ self.similarity_model = wespeaker.load_model('chinese')
1086
+ self.logger.info("回退到wespeaker预训练中文模型")
1087
+ return
1088
+ except Exception as e:
1089
+ self.logger.warning(f"加载预训练中文模型失败: {e}")
1090
+
1091
+ # 回退方案3: 使用预训练的英文模型
1092
+ try:
1093
+ self.similarity_model = wespeaker.load_model('english')
1094
+ self.logger.info("回退到wespeaker预训练英文模型")
1095
+ return
1096
+ except Exception as e:
1097
+ self.logger.error(f"加载英文模型也失败: {e}")
1098
+
1099
+ # 如果所有方法都失败,抛出异常
1100
+ raise Exception("无法加载任何wespeaker模型")
1101
+
1102
+ except ImportError:
1103
+ raise ImportError("请安装wespeaker: pip install git+https://github.com/wenet-e2e/wespeaker.git")
1104
+ except Exception as e:
1105
+ self.logger.error(f"加载wespeaker模型失败: {e}")
1106
+ raise
1107
+
1108
+ def load_data_from_jsonl(self, jsonl_path: str) -> List[Dict[str, Any]]:
1109
+ """从JSONL文件加载数据"""
1110
+ data = []
1111
+ try:
1112
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
1113
+ for line_num, line in enumerate(f, 1):
1114
+ line = line.strip()
1115
+ if line:
1116
+ try:
1117
+ item = json.loads(line)
1118
+ # 验证必要字段
1119
+ required_fields = ['text', 'output_audio']
1120
+ for field in required_fields:
1121
+ if field not in item:
1122
+ self.logger.error(f"第{line_num}行缺少必要字段: {field}")
1123
+ continue
1124
+
1125
+ # 验证音频路径模式:要么有prompt_audio和prompt_text,要么有分别的speaker音频文件
1126
+ has_combined_prompt = 'prompt_audio' in item and 'prompt_text' in item
1127
+ has_separate_prompts = ('prompt_audio_speaker1' in item and
1128
+ 'prompt_text_speaker1' in item and
1129
+ 'prompt_audio_speaker2' in item and
1130
+ 'prompt_text_speaker2' in item)
1131
+
1132
+ if not (has_combined_prompt or has_separate_prompts):
1133
+ self.logger.error(f"第{line_num}行:需要提供prompt_audio+prompt_text或者分别的speaker音频文件")
1134
+ continue
1135
+
1136
+ data.append(item)
1137
+
1138
+ except json.JSONDecodeError as e:
1139
+ self.logger.error(f"第{line_num}行JSON解析错误: {e}")
1140
+ continue
1141
+
1142
+ self.logger.info(f"从{jsonl_path}成功加载{len(data)}条数据")
1143
+ return data
1144
+
1145
+ except FileNotFoundError:
1146
+ self.logger.error(f"JSONL文件不存在: {jsonl_path}")
1147
+ return []
1148
+ except Exception as e:
1149
+ self.logger.error(f"读取JSONL文件失败: {e}")
1150
+ return []
1151
+
1152
+ @staticmethod
1153
+ def get_gpu_count():
1154
+ """获取可用GPU数量"""
1155
+ if torch.cuda.is_available():
1156
+ return torch.cuda.device_count()
1157
+ return 0
1158
+
1159
+ @staticmethod
1160
+ def split_data_by_gpu(data: List[Dict[str, Any]], num_gpus: int) -> List[List[Dict[str, Any]]]:
1161
+ """根据GPU数量分割数据"""
1162
+ if num_gpus == 0:
1163
+ return [data]
1164
+
1165
+ chunk_size = math.ceil(len(data) / num_gpus)
1166
+ gpu_chunks = []
1167
+
1168
+ for i in range(num_gpus):
1169
+ start_idx = i * chunk_size
1170
+ end_idx = min((i + 1) * chunk_size, len(data))
1171
+ if start_idx < len(data):
1172
+ gpu_chunks.append(data[start_idx:end_idx])
1173
+
1174
+ return gpu_chunks
1175
+
1176
+ @staticmethod
1177
+ def split_data_by_processes(data: List[Dict[str, Any]], num_processes: int) -> List[List[Dict[str, Any]]]:
1178
+ """根据进程数量分割数据"""
1179
+ if num_processes <= 1:
1180
+ return [data]
1181
+
1182
+ chunk_size = math.ceil(len(data) / num_processes)
1183
+ process_chunks = []
1184
+
1185
+ for i in range(num_processes):
1186
+ start_idx = i * chunk_size
1187
+ end_idx = min((i + 1) * chunk_size, len(data))
1188
+ if start_idx < len(data):
1189
+ process_chunks.append(data[start_idx:end_idx])
1190
+
1191
+ return process_chunks
1192
+
1193
+ def append_result_to_jsonl(self, result: Dict[str, Any], filepath: str):
1194
+ """增量写入结果到JSONL文件"""
1195
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
1196
+ with open(filepath, 'a', encoding='utf-8') as f:
1197
+ f.write(json.dumps(result, ensure_ascii=False) + '\n')
1198
+ f.flush() # 强制刷新缓冲区
1199
+
1200
+ def merge_temp_results(self, temp_files: List[str], final_path: str):
1201
+ """合并临时结果文件"""
1202
+ all_results = []
1203
+
1204
+ for temp_file in temp_files:
1205
+ if os.path.exists(temp_file):
1206
+ try:
1207
+ with open(temp_file, 'r', encoding='utf-8') as f:
1208
+ for line in f:
1209
+ line = line.strip()
1210
+ if line:
1211
+ result = json.loads(line)
1212
+ all_results.append(result)
1213
+ except Exception as e:
1214
+ self.logger.error(f"读取临时文件失败: {temp_file}, 错误: {e}")
1215
+
1216
+ # 写入最终文件
1217
+ with open(final_path, 'w', encoding='utf-8') as f:
1218
+ for result in all_results:
1219
+ f.write(json.dumps(result, ensure_ascii=False) + '\n')
1220
+
1221
+ return all_results
1222
+
1223
+ def process_batch_parallel(self, input_data: List[Dict[str, Any]],
1224
+ processes_per_gpu: int = 8, # 降低进程数
1225
+ results_filename: str = None,
1226
+ shuffle_data: bool = True):
1227
+ """并行批量处理输入数据"""
1228
+ # 1. ���查GPU数量
1229
+ num_gpus = self.get_gpu_count()
1230
+ if num_gpus == 0:
1231
+ self.logger.warning("未检测到GPU,将使用CPU单进程处理")
1232
+ return self.process_batch_from_data(input_data, results_filename)
1233
+
1234
+ # 限制每个GPU的进程数,避免CUDA内存冲突
1235
+ max_processes_per_gpu = min(processes_per_gpu, 16)
1236
+ self.logger.info(f"检测到 {num_gpus} 个GPU,每个GPU将使用 {max_processes_per_gpu} 个进程")
1237
+
1238
+ # 2. 对数据进行shuffle(如果还没有shuffle过)
1239
+ shuffled_data = input_data.copy()
1240
+ if shuffle_data:
1241
+ random.shuffle(shuffled_data)
1242
+ self.logger.info(f"已对 {len(shuffled_data)} 条数据进行随机shuffle以平衡GPU负载")
1243
+
1244
+ # 3. 按GPU分割数据
1245
+ gpu_chunks = self.split_data_by_gpu(shuffled_data, num_gpus)
1246
+
1247
+ # 打印每个GPU分配到的数据量
1248
+ for gpu_id, gpu_data in enumerate(gpu_chunks):
1249
+ if gpu_data:
1250
+ self.logger.info(f"GPU {gpu_id}: 分配到 {len(gpu_data)} 条数据")
1251
+
1252
+ # 4. 准备结果文件路径
1253
+ if results_filename is None:
1254
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
1255
+ results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
1256
+
1257
+ final_results_path = self.results_dir / results_filename
1258
+
1259
+ # 5. 为所有GPU准备进程参数
1260
+ all_temp_files = []
1261
+ all_gpu_tasks = []
1262
+
1263
+ for gpu_id, gpu_data in enumerate(gpu_chunks):
1264
+ if not gpu_data:
1265
+ continue
1266
+
1267
+ self.logger.info(f"GPU {gpu_id}: 准备处理 {len(gpu_data)} 条数据")
1268
+
1269
+ # 按进程数分割当前GPU的数据
1270
+ process_chunks = self.split_data_by_processes(gpu_data, max_processes_per_gpu)
1271
+
1272
+ # 为当前GPU准备所有进程参数
1273
+ gpu_process_args = []
1274
+ for proc_id, proc_data in enumerate(process_chunks):
1275
+ if proc_data:
1276
+ temp_result_file = str(self.temp_results_dir / f"gpu{gpu_id}_proc{proc_id}_results.jsonl")
1277
+ all_temp_files.append(temp_result_file)
1278
+
1279
+ # 子进程输出目录在主输出目录内部
1280
+ subprocess_output_dir = str(self.output_dir / f"gpu{gpu_id}_proc{proc_id}")
1281
+
1282
+ gpu_process_args.append((
1283
+ proc_data,
1284
+ gpu_id,
1285
+ proc_id,
1286
+ subprocess_output_dir,
1287
+ temp_result_file,
1288
+ self.alignment_model_dir,
1289
+ self.wespeaker_model_dir,
1290
+ self.language, # 语言参数
1291
+ self.similarity_max_workers # 添加相似度计算线程数参数
1292
+ ))
1293
+
1294
+ if gpu_process_args:
1295
+ all_gpu_tasks.append((gpu_id, gpu_process_args, max_processes_per_gpu))
1296
+
1297
+ # 6. 使用ThreadPoolExecutor并行处理所有GPU
1298
+ def process_gpu_tasks(gpu_task):
1299
+ gpu_id, process_args, actual_processes = gpu_task
1300
+ self.logger.info(f"GPU {gpu_id}: 开始并行处理 {len(process_args)} 个进程")
1301
+
1302
+ # 为每个GPU使用独立的进程池,避免进程间冲突
1303
+ with mp.Pool(processes=actual_processes) as pool:
1304
+ pool.map(process_data_chunk_incremental, process_args)
1305
+
1306
+ self.logger.info(f"GPU {gpu_id}: 所有进程处理完成")
1307
+ return gpu_id
1308
+
1309
+ # 使用线程池同时处理所有GPU
1310
+ with ThreadPoolExecutor(max_workers=num_gpus) as executor:
1311
+ # 提交所有GPU任务
1312
+ future_to_gpu = {executor.submit(process_gpu_tasks, gpu_task): gpu_task[0]
1313
+ for gpu_task in all_gpu_tasks}
1314
+
1315
+ # 等待所有GPU完成
1316
+ completed_gpus = []
1317
+ for future in as_completed(future_to_gpu):
1318
+ gpu_id = future_to_gpu[future]
1319
+ try:
1320
+ result_gpu_id = future.result()
1321
+ completed_gpus.append(result_gpu_id)
1322
+ self.logger.info(f"GPU {result_gpu_id} 完成处理")
1323
+ except Exception as exc:
1324
+ self.logger.error(f"GPU {gpu_id} 处理时发生异常: {exc}")
1325
+
1326
+ self.logger.info(f"所有GPU处理完成: {completed_gpus}")
1327
+
1328
+ # 7. 合并所有临时结果文件
1329
+ self.logger.info("合并所有临时结果文件...")
1330
+ all_results = self.merge_temp_results(all_temp_files, str(final_results_path))
1331
+
1332
+ if not all_results:
1333
+ self.logger.error("没有成功处理的数据")
1334
+ return []
1335
+
1336
+ # 8. 生成汇总报告
1337
+ summary_path = self.save_summary_report(all_results)
1338
+
1339
+ # 9. 清理临时文件
1340
+ for temp_file in all_temp_files:
1341
+ if os.path.exists(temp_file):
1342
+ os.remove(temp_file)
1343
+
1344
+ # 10. 打印总体统计
1345
+ total_accuracy = np.mean([r['accuracy'] for r in all_results])
1346
+ total_avg_similarity = np.mean([r['average_similarity'] for r in all_results])
1347
+
1348
+ print(f"\n=== 并行评估完成 ===")
1349
+ print(f"使用语言: {self.language}")
1350
+ print(f"使用 {num_gpus} 个GPU,每GPU {max_processes_per_gpu} 个进程")
1351
+ print(f"总处理数据: {len(input_data)} 条")
1352
+ print(f"成功处理: {len(all_results)} 条")
1353
+ print(f"总体准确率: {total_accuracy:.3f}")
1354
+ print(f"总体平均相似度: {total_avg_similarity:.3f}")
1355
+ print(f"详细结果已保存到: {final_results_path}")
1356
+ print(f"汇总报告已保存到: {summary_path}")
1357
+ print(f"对齐信息已保存到: {self.alignment_dir}")
1358
+
1359
+ return all_results
1360
+
1361
+ def get_or_split_prompt_audio(self, data: Dict[str, Any], audio_id: str) -> Tuple[str, str]:
1362
+ """
1363
+ 获取或分割prompt音频
1364
+ 如果提供了分别的speaker音频文件则直接使用,否则从combined prompt分割
1365
+ """
1366
+ # 检查是否有分别的speaker音频文件
1367
+ if ('prompt_audio_speaker1' in data and 'prompt_audio_speaker2' in data and
1368
+ 'prompt_text_speaker1' in data and 'prompt_text_speaker2' in data):
1369
+
1370
+ self.logger.info(f"使用预分割的speaker音频文件")
1371
+
1372
+ # 即使使用预分割的音频,也保存对齐信息
1373
+ try:
1374
+ # 检测语言或使用设置的语言
1375
+ alignment_language = self.language
1376
+ if alignment_language == "AUTO":
1377
+ alignment_language = self._detect_language_from_text(data['prompt_text_speaker1'])
1378
+
1379
+ # 对S1音频进行对齐
1380
+ s1_alignments = self.align_text_with_audio(
1381
+ data['prompt_text_speaker1'], data['prompt_audio_speaker1'], alignment_language
1382
+ )
1383
+ s1_alignment_data = {
1384
+ 'speaker': 'S1',
1385
+ 'text': data['prompt_text_speaker1'],
1386
+ 'audio_path': data['prompt_audio_speaker1'],
1387
+ 'language': alignment_language,
1388
+ 'alignments': s1_alignments
1389
+ }
1390
+ self.save_alignment_info(s1_alignment_data, audio_id, "prompt_s1")
1391
+
1392
+ # 对S2音频进行对齐
1393
+ s2_alignments = self.align_text_with_audio(
1394
+ data['prompt_text_speaker2'], data['prompt_audio_speaker2'], alignment_language
1395
+ )
1396
+ s2_alignment_data = {
1397
+ 'speaker': 'S2',
1398
+ 'text': data['prompt_text_speaker2'],
1399
+ 'audio_path': data['prompt_audio_speaker2'],
1400
+ 'language': alignment_language,
1401
+ 'alignments': s2_alignments
1402
+ }
1403
+ self.save_alignment_info(s2_alignment_data, audio_id, "prompt_s2")
1404
+
1405
+ except Exception as e:
1406
+ self.logger.warning(f"保存预分割音频对齐信息失败: {e}")
1407
+
1408
+ return data['prompt_audio_speaker1'], data['prompt_audio_speaker2']
1409
+
1410
+ # 否则从combined prompt分割
1411
+ elif 'prompt_audio' in data and 'prompt_text' in data:
1412
+ self.logger.info(f"从combined prompt音频分割speaker片段")
1413
+ return self.split_audio_by_speaker(data['prompt_text'], data['prompt_audio'], audio_id)
1414
+
1415
+ else:
1416
+ raise ValueError("必须提供prompt_audio+prompt_text或者分别的speaker音频文件")
1417
+
1418
+ def calculate_voice_similarity(self, audio1_path: str, audio2_path: str) -> float:
1419
+ """
1420
+ 计算两个音频的音色相似度(向后兼容版本)
1421
+ 对于过短的音频片段,通过复制来达到最小长度要求
1422
+ """
1423
+ # 如果在多线程环境中,使用线程安全版本
1424
+ if threading.current_thread() != threading.main_thread():
1425
+ return self.calculate_voice_similarity_thread_safe(audio1_path, audio2_path)
1426
+
1427
+ # 确保模型已初始化
1428
+ self._init_models_if_needed()
1429
+
1430
+ try:
1431
+ if not os.path.exists(audio1_path) or not os.path.exists(audio2_path):
1432
+ self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}")
1433
+ return None
1434
+
1435
+ # 检查并处理音频文件长度
1436
+ def process_audio_for_similarity(audio_path, min_duration=0.1):
1437
+ """
1438
+ 处理音频文件,如果过短则复制到满足最小长度要求
1439
+ 返回处理后的音频路径和是否为临时文件的标志
1440
+ """
1441
+ try:
1442
+ waveform, sample_rate = torchaudio.load(audio_path)
1443
+ duration = waveform.shape[1] / sample_rate
1444
+
1445
+ if duration >= min_duration:
1446
+ # 音频长度足够,直接返回原路径
1447
+ return audio_path, False
1448
+
1449
+ # 音频过短,需要复制
1450
+ repeat_times = math.ceil(min_duration / duration)
1451
+ self.logger.info(f"音频过短 ({duration:.3f}s),复制 {repeat_times} 次达到 {min_duration}s 要求: {audio_path}")
1452
+
1453
+ # 复制音频
1454
+ repeated_waveform = waveform.repeat(1, repeat_times)
1455
+
1456
+ # 生成临时文件路径
1457
+ temp_filename = f"temp_{os.path.basename(audio_path)}"
1458
+ temp_path = str(self.temp_dir / temp_filename)
1459
+
1460
+ # 保存复制后的音频
1461
+ torchaudio.save(temp_path, repeated_waveform, sample_rate)
1462
+
1463
+ return temp_path, True
1464
+
1465
+ except Exception as e:
1466
+ self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}")
1467
+ return audio_path, False
1468
+
1469
+ # 处理两个音频文件
1470
+ processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path)
1471
+ processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path)
1472
+
1473
+ # 计算相似度
1474
+ similarity = self.similarity_model.compute_similarity(processed_audio1, processed_audio2)
1475
+
1476
+ # 清理临时文件
1477
+ if is_temp1 and os.path.exists(processed_audio1):
1478
+ try:
1479
+ os.remove(processed_audio1)
1480
+ except Exception as e:
1481
+ self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}")
1482
+
1483
+ if is_temp2 and os.path.exists(processed_audio2):
1484
+ try:
1485
+ os.remove(processed_audio2)
1486
+ except Exception as e:
1487
+ self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}")
1488
+
1489
+ return float(similarity)
1490
+
1491
+ except Exception as e:
1492
+ # 检查是否是窗口大小错误或其他计算错误
1493
+ if "choose a window size" in str(e) or "window size" in str(e):
1494
+ self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}")
1495
+ return None
1496
+ else:
1497
+ self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}")
1498
+ return None
1499
+
1500
+ # 全局函数,用于多进程处理(支持增量写入)
1501
+ def process_data_chunk_incremental(args):
1502
+ """处理数据块的工作函数(增量写入版本)"""
1503
+ data_chunk, gpu_id, proc_id, output_dir, temp_result_file, alignment_model_dir, wespeaker_model_dir, language, similarity_max_workers = args
1504
+
1505
+ # 设置当前进程使用的GPU
1506
+ device = f"cuda:{gpu_id}" if torch.cuda.is_available() and gpu_id < torch.cuda.device_count() else "cpu"
1507
+
1508
+ try:
1509
+ # 清理CUDA状态,避免进程间冲突
1510
+ if torch.cuda.is_available():
1511
+ torch.cuda.empty_cache()
1512
+ # 设置当前进程的GPU设备
1513
+ torch.cuda.set_device(gpu_id)
1514
+ # 添加小延迟,避免同时初始化冲突
1515
+ time.sleep(proc_id * 0.5)
1516
+
1517
+ # 创建评估器实例,传入模型路径、语言参数和相似度计算线程数
1518
+ evaluator = SpeakerSimilarityEvaluator(
1519
+ device=device,
1520
+ alignment_model_dir=alignment_model_dir,
1521
+ wespeaker_model_dir=wespeaker_model_dir,
1522
+ output_dir=output_dir,
1523
+ language=language, # 传入语言参数
1524
+ similarity_max_workers=similarity_max_workers # 传入相似度计算线程数
1525
+ )
1526
+
1527
+ # 延迟初始化模型
1528
+ evaluator._init_models_if_needed()
1529
+
1530
+ # 清空临时结果文件(如果存在)
1531
+ if os.path.exists(temp_result_file):
1532
+ os.remove(temp_result_file)
1533
+
1534
+ # 处理数据块
1535
+ for i, data in enumerate(data_chunk):
1536
+ input_id = f"gpu{gpu_id}_proc{proc_id}_input_{i+1:03d}"
1537
+
1538
+ try:
1539
+ result = evaluator.evaluate_single_input(data, input_id=input_id)
1540
+
1541
+ # 立即写入结果到临时文件
1542
+ evaluator.append_result_to_jsonl(result, temp_result_file)
1543
+
1544
+ print(f"GPU{gpu_id}-进程{proc_id}: 完成 {input_id} (语言: {language}, 相似度线程: {similarity_max_workers})")
1545
+
1546
+ # 每处理完一个数据项,清理CUDA缓存
1547
+ if torch.cuda.is_available():
1548
+ torch.cuda.empty_cache()
1549
+
1550
+ except Exception as e:
1551
+ print(f"GPU{gpu_id}-进程{proc_id}: 处理 {input_id} 失败: {e}")
1552
+ # 出错时也清理CUDA缓存
1553
+ if torch.cuda.is_available():
1554
+ torch.cuda.empty_cache()
1555
+ continue
1556
+
1557
+ print(f"GPU{gpu_id}-进程{proc_id}: 所有数据处理完成,结果已写入 {temp_result_file}")
1558
+
1559
+ except Exception as e:
1560
+ print(f"GPU{gpu_id}-进程{proc_id}: 初始化失败: {e}")
1561
+ # 出错时清理CUDA缓存
1562
+ if torch.cuda.is_available():
1563
+ torch.cuda.empty_cache()
1564
+
1565
+ def main():
1566
+ """主函数示例"""
1567
+ import argparse
1568
+
1569
+ parser = argparse.ArgumentParser(description='Speaker Similarity Evaluator')
1570
+ parser.add_argument('--jsonl_path', type=str, help='JSONL文件路径')
1571
+ parser.add_argument('--output_dir', type=str,
1572
+ default=f"/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/eval_res/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
1573
+ help='结果保存目录')
1574
+ parser.add_argument('--language', type=str, choices=['zh', 'en', 'auto'], default='zh',
1575
+ help='指定语言: zh=中文, en=英文, auto=自动检测 (默认: zh)')
1576
+ parser.add_argument('--no_parallel', action='store_true', help='禁用并行处理(默认启用并行)')
1577
+ parser.add_argument('--processes_per_gpu', type=int, default=4, help='每个GPU的进程数(建议不超过4)')
1578
+ parser.add_argument('--similarity_workers', type=int, default=16, help='相似度计算的线程数(默认: 8)')
1579
+ parser.add_argument('--no_shuffle', action='store_true', help='禁用数据shuffle(默认启用shuffle)')
1580
+ parser.add_argument('--random_seed', type=int, default=None, help='随机种子(可选,用于结果复现)')
1581
+
1582
+ args = parser.parse_args()
1583
+
1584
+ # 设置随机种子(如果指定)
1585
+ if args.random_seed is not None:
1586
+ random.seed(args.random_seed)
1587
+ np.random.seed(args.random_seed)
1588
+ torch.manual_seed(args.random_seed)
1589
+ print(f"设置随机种子: {args.random_seed}")
1590
+
1591
+ # 语言参数处理
1592
+ language = args.language.upper()
1593
+ if language == 'AUTO':
1594
+ language = 'AUTO'
1595
+ elif language == 'EN':
1596
+ language = 'EN'
1597
+ else:
1598
+ language = 'ZH' # 默认中文
1599
+
1600
+ # 创建评估器,指定结果保存目录、语言和相似度计算线程数
1601
+ evaluator = SpeakerSimilarityEvaluator(
1602
+ output_dir=args.output_dir,
1603
+ language=language,
1604
+ similarity_max_workers=args.similarity_workers
1605
+ )
1606
+
1607
+ # 默认使用并行处理,除非明确禁用
1608
+ use_parallel = not args.no_parallel
1609
+ use_shuffle = not args.no_shuffle
1610
+
1611
+ print(f"使用语言设置: {language}")
1612
+ print(f"相似度计算线程数: {args.similarity_workers}")
1613
+
1614
+ if args.jsonl_path:
1615
+ # 从JSONL文件处理数据
1616
+ if use_parallel:
1617
+ evaluator.process_batch_from_jsonl_parallel(
1618
+ args.jsonl_path,
1619
+ processes_per_gpu=args.processes_per_gpu,
1620
+ shuffle_data=use_shuffle
1621
+ )
1622
+ else:
1623
+ evaluator.process_batch_from_jsonl(args.jsonl_path)
1624
+ else:
1625
+ # 使用示例数据(兼容性)
1626
+ input_data = [
1627
+ {
1628
+ 'prompt_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/zhouxingchi/zxc_enhanced.wav",
1629
+ 'prompt_text': "[S1]你再往前半步我就把你给杀了。[S2]你应该这么做,我也应该死。",
1630
+ 'text': "[S1]至尊宝,如果有一天我不再是紫霞仙子,只是一个普通的凡人,你还会像现在这样陪着我吗?[S2]这个嘛,那我得先问问月老,看看他给不给我打折!毕竟追仙子要花好多力气的![S1]哼!油嘴滑舌!我是认真的![S2]紫霞,不管你是仙子还是凡人,哪怕变成一根香蕉,我都认得出你。不过……你最好别真变成香蕉,我怕我会忍不住吃掉……[S1]讨厌!谁要变成香蕉啊!那……如果有一天,我们不得不分开呢?[S2]哇!你这话比牛魔王的斧头还狠!不行不行,你得赔我精神损失费![S1]怎么赔?[S2]很简单,让我亲一下,就当是定金![S1]想得美!那如果有一天,你真的忘了我呢?[S2]那我就算翻遍三界,打烂阎王殿,也要把记忆找回来。紫霞,我至尊宝这辈子,赖定你了![S1]傻瓜。",
1631
+ 'output_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step145000/test_set/output_7.wav"
1632
+ }
1633
+ ]
1634
+
1635
+ # 处理数据
1636
+ if use_parallel:
1637
+ evaluator.process_batch_parallel(input_data, processes_per_gpu=args.processes_per_gpu)
1638
+ else:
1639
+ evaluator.process_batch_from_data(input_data)
1640
+
1641
+
1642
+ if __name__ == "__main__":
1643
+ main()
test.sh ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ source /inspire/hdd/project/embodied-multimodality/public/yqzhang/miniconda3/bin/activate
4
+ conda activate /inspire/hdd/project/embodied-multimodality/public/cchang/env/mooncast/
5
+
6
+ # 设置CUDA环境变量
7
+ export CUDA_LAUNCH_BLOCKING=1
8
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
9
+
10
+ # 创建日志目录和文件名
11
+ LOG_DIR="/inspire/hdd/project/embodied-multimodality/public/cchang/projects/auto_evaluation_new/logs"
12
+ mkdir -p "$LOG_DIR"
13
+ LOG_FILE="$LOG_DIR/evaluation_$(date +%Y%m%d_%H%M%S).log"
14
+
15
+ # 记录开始时间
16
+ START_TIME=$(date +%s)
17
+ START_TIME_READABLE=$(date '+%Y-%m-%d %H:%M:%S')
18
+
19
+ echo "========================================="
20
+ echo "音色相似度评估开始"
21
+ echo "开始时间: $START_TIME_READABLE"
22
+ echo "日志文件: $LOG_FILE"
23
+ echo "========================================="
24
+ echo "可以使用以下命令实时查看日志:"
25
+ echo "tail -f $LOG_FILE"
26
+ echo ""
27
+
28
+ # 将开始时间信息也写入日志文件
29
+ {
30
+ echo "========================================="
31
+ echo "音色相似度评估开始"
32
+ echo "开始时间: $START_TIME_READABLE"
33
+ echo "进程配置: 每GPU 8个进程"
34
+ echo "语言设置: zh (中文)"
35
+ echo "========================================="
36
+ echo ""
37
+ } | tee "$LOG_FILE"
38
+
39
+ # 使用更保守的进程数
40
+ python -u /inspire/hdd/project/embodied-multimodality/public/cchang/projects/auto_evaluation_new/test.py \
41
+ --jsonl_path /inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step70000/eval_new/output.jsonl \
42
+ --output_dir /inspire/hdd/project/embodied-multimodality/public/cchang/projects/auto_evaluation_new/eval_res/new_test \
43
+ --processes_per_gpu 8 \
44
+ --language zh \
45
+ 2>&1 | tee -a "$LOG_FILE"
46
+
47
+
48
+ # 记录结束时间
49
+ END_TIME=$(date +%s)
50
+ END_TIME_READABLE=$(date '+%Y-%m-%d %H:%M:%S')
51
+
52
+ # 计算耗时
53
+ DURATION=$((END_TIME - START_TIME))
54
+ HOURS=$((DURATION / 3600))
55
+ MINUTES=$(((DURATION % 3600) / 60))
56
+ SECONDS=$((DURATION % 60))
57
+
58
+ # 输出结束信息
59
+ {
60
+ echo ""
61
+ echo "========================================="
62
+ echo "音色相似度评估完成!"
63
+ echo "结束时间: $END_TIME_READABLE"
64
+ echo "总耗时: ${HOURS}小时${MINUTES}分钟${SECONDS}秒 (共${DURATION}秒)"
65
+ echo "日志文件: $LOG_FILE"
66
+ echo "========================================="
67
+ } | tee -a "$LOG_FILE"
68
+
69
+ # 显示在终端
70
+ echo ""
71
+ echo "评估完成!"
72
+ echo "开始时间: $START_TIME_READABLE"
73
+ echo "结束时间: $END_TIME_READABLE"
74
+ echo "总耗时: ${HOURS}小时${MINUTES}分钟${SECONDS}秒"
75
+ echo "日志已保存到: $LOG_FILE"
76
+
77
+ # 如果耗时超过1小时,发送额外提醒
78
+ if [ $DURATION -gt 3600 ]; then
79
+ echo ""
80
+ echo "⏰ 注意:本次评估耗时较长,超过1小时"
81
+ echo " 建议检查性能优化效果"
82
+ fi
test_alignment.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torchaudio.functional as F
4
+ import torchaudio
5
+ import uroman as ur
6
+ import logging
7
+ from typing import List, Dict, Any, Optional
8
+
9
+ def split_and_merge_punctuation(text: str) -> List[str]:
10
+ """
11
+ 处理英文文本,按空格分词并将标点符号合并到前面的单词
12
+
13
+ Args:
14
+ text: 输入的英文文本
15
+
16
+ Returns:
17
+ 处理后的单词列表,标点符号已合并到对应单词
18
+ """
19
+ # 先按空格拆分文本
20
+ elements = text.split()
21
+
22
+ # 用于保存最终的结果
23
+ result = []
24
+
25
+ # 遍历每个拆分后的元素
26
+ for ele in elements:
27
+ # 使用正则表达式提取连续字母、数字和标点
28
+ parts = re.findall(r'[a-zA-Z0-9]+|[^\w\s]+', ele)
29
+
30
+ # 用于保存拆分后的部分
31
+ merged_parts = []
32
+
33
+ for i in range(len(parts)):
34
+ if i % 2 == 0: # 如果是字母或数字部分
35
+ # 将字母或数字部分添加到结果中
36
+ merged_parts.append(parts[i])
37
+ else: # 如果是标点或其他符号部分
38
+ # 将标点部分与前面的字母或数字部分合并
39
+ if merged_parts:
40
+ merged_parts[-1] += parts[i]
41
+ else:
42
+ merged_parts.append(parts[i])
43
+
44
+ # 将合并后的部分加入最终结果
45
+ result.extend(merged_parts)
46
+
47
+ return result
48
+
49
+ def restore_spaces_in_english_text(tokens: List[str]) -> str:
50
+ """
51
+ 在英文单词之间恢复空格
52
+
53
+ Args:
54
+ tokens: 单词列表
55
+
56
+ Returns:
57
+ 恢复空格后的文本
58
+ """
59
+ result = []
60
+ for i, token in enumerate(tokens):
61
+ # 检查是否需要在单词前添加空格
62
+ if i > 0 and token[0].isalnum() and not any(p in tokens[i-1] for p in ',.!?;:()[]<>\'\"…'):
63
+ result.append(" ")
64
+ result.append(token)
65
+
66
+ return "".join(result)
67
+
68
+ def get_aligned_result_with_punctuation(alignment_result: List[Dict], text: str) -> List[Dict]:
69
+ """
70
+ 将对齐结果转换为包含标点符号的格式
71
+
72
+ Args:
73
+ alignment_result: 原始对齐结果
74
+ text: 原始文本
75
+
76
+ Returns:
77
+ 处理后的对齐结果,标点符号已合并
78
+ """
79
+ text_tokens = split_and_merge_punctuation(text)
80
+
81
+ updated_alignment_result = []
82
+ token_idx = 0
83
+
84
+ for index, align_item in enumerate(alignment_result):
85
+ if token_idx >= len(text_tokens):
86
+ break
87
+
88
+ start = align_item["start"]
89
+ end = align_item["end"]
90
+ text_token = text_tokens[token_idx]
91
+
92
+ updated_item = {
93
+ "start": start,
94
+ "end": end,
95
+ "transcript": text_token
96
+ }
97
+
98
+ # 保留原始对齐结果中的其他字段
99
+ updated_item.update({key: align_item[key] for key in align_item
100
+ if key not in ["start", "end", "transcript"]})
101
+
102
+ updated_alignment_result.append(updated_item)
103
+ token_idx += 1
104
+
105
+ return updated_alignment_result
106
+
107
+ class EnglishAlignmentModel:
108
+ def __init__(self, device: str = "cuda", model_dir: Optional[str] = None):
109
+ """
110
+ 初始化英文对齐模型
111
+
112
+ Args:
113
+ device: 设备类型 ("cuda" 或 "cpu")
114
+ model_dir: 模型目录路径,如果为None则使用默认路径
115
+ """
116
+ self.device = torch.device(device)
117
+ self.bundle = torchaudio.pipelines.MMS_FA
118
+
119
+ # 设置模型下载参数
120
+ dl_kwargs = {}
121
+ if model_dir:
122
+ dl_kwargs['model_dir'] = model_dir
123
+
124
+ self.align_model = self.bundle.get_model(
125
+ with_star=False,
126
+ dl_kwargs=dl_kwargs
127
+ ).to(self.device)
128
+
129
+ self.uroman = ur.Uroman()
130
+ self.DICTIONARY = self.bundle.get_dict()
131
+
132
+ def align(self, emission: torch.Tensor, tokens: torch.Tensor):
133
+ """
134
+ 执行强对齐
135
+
136
+ Args:
137
+ emission: 模型的输出
138
+ tokens: 目标tokens
139
+
140
+ Returns:
141
+ 对齐的tokens和分数
142
+ """
143
+ alignments, scores = F.forced_align(
144
+ log_probs=emission,
145
+ targets=tokens,
146
+ blank=0
147
+ )
148
+ alignments, scores = alignments[0], scores[0]
149
+ scores = scores.exp()
150
+ return alignments, scores
151
+
152
+ def unflatten(self, list_: List, lengths: List[int]) -> List[List]:
153
+ """
154
+ 将一个长列表按照长度拆分成子列表
155
+
156
+ Args:
157
+ list_: 长列表
158
+ lengths: 各子列表的长度
159
+
160
+ Returns:
161
+ 拆分后的子列表
162
+ """
163
+ assert len(list_) == sum(lengths)
164
+ i = 0
165
+ ret = []
166
+ for l in lengths:
167
+ ret.append(list_[i:i + l])
168
+ i += l
169
+ return ret
170
+
171
+ def preview_word(self, waveform: torch.Tensor, spans: List, num_frames: int,
172
+ transcript: List[str], sample_rate: int) -> List[Dict]:
173
+ """
174
+ 生成每个单词的时间对齐信息
175
+
176
+ Args:
177
+ waveform: 音频波形
178
+ spans: 单词的跨度
179
+ num_frames: 帧数
180
+ transcript: 转录文本单词列表
181
+ sample_rate: 采样率
182
+
183
+ Returns:
184
+ 单词的对齐信息列表
185
+ """
186
+ end = 0
187
+ alignment_result = []
188
+
189
+ for span, trans in zip(spans, transcript):
190
+ ratio = waveform.size(1) / num_frames
191
+ x0 = int(ratio * span[0].start)
192
+ x1 = int(ratio * span[-1].end)
193
+
194
+ align_info = {
195
+ "transcript": trans,
196
+ "start": round(x0 / sample_rate, 3),
197
+ "end": round(x1 / sample_rate, 3)
198
+ }
199
+ align_info["pause"] = round(align_info["start"] - end, 3)
200
+ align_info["duration"] = round(align_info["end"] - align_info["start"], 3)
201
+ end = align_info["end"]
202
+ alignment_result.append(align_info)
203
+
204
+ return alignment_result
205
+
206
+ def make_wav_batch(self, wav_list: List[torch.Tensor]):
207
+ """
208
+ 将wav_list中的每个wav张量填充为相同的长度
209
+
210
+ Args:
211
+ wav_list: wav文件列表
212
+
213
+ Returns:
214
+ 填充后的音频张量和原始长度
215
+ """
216
+ wav_lengths = torch.tensor([wav.size(0) for wav in wav_list], dtype=torch.long)
217
+ max_length = max(wav_lengths)
218
+ wavs_tensors = torch.zeros(len(wav_list), max_length, device=wav_list[0].device)
219
+
220
+ for i, wav in enumerate(wav_list):
221
+ wavs_tensors[i, :wav_lengths[i]] = wav
222
+
223
+ return wavs_tensors, wav_lengths.to(wavs_tensors.device)
224
+
225
+ def get_target(self, transcript: str) -> torch.Tensor:
226
+ """
227
+ 获取给定英文转录文本的目标tokens
228
+
229
+ Args:
230
+ transcript: 英文转录文本
231
+
232
+ Returns:
233
+ 转录文本的目标tokens
234
+ """
235
+ # 移除标点符号并转换为小写
236
+ transcript = re.sub(r'[^\w\s]', r' ', transcript)
237
+ words = transcript.lower().split()
238
+
239
+ # 获取字典中的特殊符号token
240
+ star_token = self.DICTIONARY['*']
241
+
242
+ # 将每个字符转换为对应的token
243
+ tokenized_transcript = []
244
+ for word in words:
245
+ tokenized_transcript.extend([
246
+ self.DICTIONARY[c] if c in self.DICTIONARY and c != '-' else star_token
247
+ for c in word
248
+ ])
249
+
250
+ return torch.tensor([tokenized_transcript], dtype=torch.int32, device=self.device)
251
+
252
+ def get_alignment_result(self, emission_padded: torch.Tensor, emission_length: int,
253
+ aligned_tokens: torch.Tensor, alignment_scores: torch.Tensor,
254
+ transcript: str, waveform: torch.Tensor) -> List[Dict]:
255
+ """
256
+ 根据给定的emission和对齐信息生成对齐结果
257
+
258
+ Args:
259
+ emission_padded: 填充后的emission
260
+ emission_length: emission的有效长度
261
+ aligned_tokens: 对齐的tokens
262
+ alignment_scores: 对齐的分数
263
+ transcript: 转录文本
264
+ waveform: 音频波形
265
+
266
+ Returns:
267
+ 对齐结果
268
+ """
269
+ # 处理文本
270
+ processed_transcript = re.sub(r'[^\w\s]', r' ', transcript)
271
+ words = processed_transcript.lower().split()
272
+
273
+ emission = emission_padded[:emission_length, :].unsqueeze(0)
274
+ token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
275
+ word_spans = self.unflatten(token_spans, [len(word) for word in words])
276
+ num_frames = emission.size(1)
277
+
278
+ return self.preview_word(waveform.unsqueeze(0), word_spans, num_frames,
279
+ words, self.bundle.sample_rate)
280
+
281
+ def align_audio_text(self, waveform: torch.Tensor, transcript: str) -> List[Dict]:
282
+ """
283
+ 对单个音频和文本进行对齐
284
+
285
+ Args:
286
+ waveform: 音频波形张量 (1D tensor)
287
+ transcript: 英文转录文本
288
+
289
+ Returns:
290
+ 对齐结果列表,包含每个单词的时间信息
291
+ """
292
+ # 确保音频在正确的设备上
293
+ waveform = waveform.to(self.device)
294
+
295
+ # 如果需要重采样
296
+ if hasattr(self, 'original_sample_rate'):
297
+ if self.original_sample_rate != self.bundle.sample_rate:
298
+ waveform = F.resample(waveform, self.original_sample_rate, self.bundle.sample_rate)
299
+
300
+ # 批量处理(单个样本)
301
+ return self.batch_alignment([waveform], [transcript])[0]
302
+
303
+ def batch_alignment(self, wav_list: List[torch.Tensor], transcript_list: List[str]) -> List[List[Dict]]:
304
+ """
305
+ 批量对齐
306
+
307
+ Args:
308
+ wav_list: wav文件列表
309
+ transcript_list: 转录文本列表
310
+
311
+ Returns:
312
+ 对齐结果列表
313
+ """
314
+ wavs_tensors, wavs_lengths_tensor = self.make_wav_batch(wav_list)
315
+
316
+ # 前向传播
317
+ with torch.inference_mode():
318
+ emission, emission_lengths = self.align_model(
319
+ wavs_tensors.to(self.device),
320
+ wavs_lengths_tensor
321
+ )
322
+ # 添加star维度
323
+ star_dim = torch.zeros(
324
+ (emission.shape[0], emission.size(1), 1),
325
+ dtype=emission.dtype,
326
+ device=self.device
327
+ )
328
+ emission = torch.cat((emission, star_dim), dim=-1)
329
+
330
+ # 获取目标tokens
331
+ target_list = [self.get_target(transcript) for transcript in transcript_list]
332
+
333
+ # 执行对齐
334
+ align_results = [
335
+ self.align(emission_padded[:emission_length, :].unsqueeze(0), target)
336
+ for emission_padded, emission_length, target in zip(emission, emission_lengths, target_list)
337
+ ]
338
+
339
+ batch_aligned_tokens = [align_result[0] for align_result in align_results]
340
+ batch_alignment_scores = [align_result[1] for align_result in align_results]
341
+
342
+ # 生成对齐结果
343
+ alignment_result_list = [
344
+ self.get_alignment_result(emission_padded, emission_length, aligned_tokens,
345
+ alignment_scores, transcript, waveform)
346
+ for emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform
347
+ in zip(emission, emission_lengths, batch_aligned_tokens, batch_alignment_scores,
348
+ transcript_list, wav_list)
349
+ ]
350
+
351
+ # 处理标点符号
352
+ final_results = []
353
+ for alignment_result, transcript in zip(alignment_result_list, transcript_list):
354
+ processed_result = get_aligned_result_with_punctuation(alignment_result, transcript)
355
+ final_results.append(processed_result)
356
+
357
+ return final_results
358
+
359
+ def align_english_audio_text(audio_path: str, transcript: str, device: str = "cuda",
360
+ model_dir: Optional[str] = None) -> List[Dict]:
361
+ """
362
+ 便捷函数:对英文音频和文本进行对齐
363
+
364
+ Args:
365
+ audio_path: 音频文件路径
366
+ transcript: 英文转录文本
367
+ device: 设备类型 ("cuda" 或 "cpu")
368
+ model_dir: 模型目录路径
369
+
370
+ Returns:
371
+ 对齐结果列表,包含每个单词的时间信息
372
+
373
+ Example:
374
+ >>> result = align_english_audio_text("audio.wav", "Hello world!")
375
+ >>> print(result)
376
+ [
377
+ {"transcript": "Hello", "start": 0.0, "end": 0.5, "duration": 0.5, "pause": 0.0},
378
+ {"transcript": "world!", "start": 0.6, "end": 1.2, "duration": 0.6, "pause": 0.1}
379
+ ]
380
+ """
381
+ # 加载音频
382
+ waveform, sample_rate = torchaudio.load(audio_path)
383
+
384
+ # 转换为单声道
385
+ if waveform.size(0) > 1:
386
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
387
+ waveform = waveform.squeeze(0) # 移除批次维度
388
+
389
+ # 初始化模型
390
+ model = EnglishAlignmentModel(device=device, model_dir=model_dir)
391
+ model.original_sample_rate = sample_rate
392
+
393
+ # 执行对齐
394
+ return model.align_audio_text(waveform, transcript)
395
+
396
+ if __name__ == "__main__":
397
+ # 使用示例
398
+ audio_file = "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step40000/test_en/gpu4/output_0.wav"
399
+ text = "[S1]Hey, did you hear about that company called MoSi AI? [S2]MoSi AI? Yeah, I think I've heard of them. Aren't they the ones doing AI stuff? What new thing have they come up with now? [S1]Yeah, that's them! They recently launched this super hot new product called, um, Asteroid. [S2]Asteroid. That's a pretty cool name. Does it mean like the space rock? [S1]Yeah, I think that's what it means. Let me tell you, this thing is incredible. They say it's currently the most realistic, human-like conversational TTS model out there. [S2]Oh, TTS technology? You mean the text-to-speech thing? Aren't there already a lot of those on the market? What makes this one so special? [S1]Well, it's completely different. They say the voice produced by Asteroid sounds almost exactly like a real person talking. And it's super smooth and natural. Not at all like, you know, that stiff robotic tone. [S2]I see. Some voice assistants do still have that mechanical feel, especially during multi-turn conversations. So how amazing is this Asteroid exactly? [S1]I heard they internally call Asteroid China's own version of NotebookLM. [S2]NotebookLM? Oh, I know that one. Isn't that the personal AI that Google made? The one that helps organize notes and answers all kinds of questions? So Asteroid has similar functions? [S1]Right. That's probably what they mean. It's not just that the voice sounds incredibly human. The intelligence level is also really high. It can have these really logical, contextual, in-depth conversations with you. It's just like chatting with a real person. [S2]Wow, that sounds amazing. If they can really achieve that... [S1]Yeah, it's basically like having a personal assistant that's both articulate and really understands you. [S2]Hmm. That does sound appealing. [S1]And some people are saying it's like the, what's it called again in the voice technology circle? Oh right, DeepSeek. [S2]DeepSeek? Isn't that the company making large language models? Their models are pretty popular now. That's high praise. So they're saying Asteroid is top-tier technology? [S1]Yeah, I think that's what they mean. It's like they've reached a whole new level in voice synthesis. Similar to the impact DeepSeek has had in natural language processing. It could be that kind of groundbreaking technology. [S2]If Asteroid is really that impressive, where could it be used? I feel like there must be huge potential there. [S1]Absolutely. Just imagine future smart customer service, audiobook reading, and those virtual livestreamers that are so popular now. The quality would improve dramatically. We might even have personal assistants using Asteroid to talk to us directly. How natural would that be? [S2]Yeah. That does sound exciting. When can we actually try it out? Are there any demos available? [S1]I haven't looked into that carefully yet. But since they've already announced it, I'm guessing it won't be long. I'm really eager to try it and see just how human-like it is. [S2]Yeah, yeah. If it can really deliver what they're promising, getting information and interacting with machines will be so much more convenient. The experience will be much better too. [S1]Exactly, exactly. We're just waiting for MoSi AI to give us this big surprise."
400
+
401
+ # 对文本进行归一化,删除所有[S1][S2]标记
402
+ import re
403
+ normalized_text = re.sub(r'\[S[12]\]', '', text).strip()
404
+
405
+ # 设置本地模型目录
406
+ alignment_model_dir = '/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/models/mms_fa'
407
+
408
+ try:
409
+ alignment_result = align_english_audio_text(audio_file, normalized_text, model_dir=alignment_model_dir)
410
+
411
+ print("对齐结果:")
412
+ for item in alignment_result:
413
+ print(f"单词: '{item['transcript']}', 开始: {item['start']}s, 结束: {item['end']}s, 持续: {item['duration']}s")
414
+
415
+ except Exception as e:
416
+ print(f"对齐失败: {e}")
test_online.py ADDED
@@ -0,0 +1,1550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import re
4
+ import os
5
+ from typing import List, Dict, Tuple, Any
6
+ import numpy as np
7
+ from pathlib import Path
8
+ import torch
9
+ import torchaudio
10
+ import torchaudio.functional as F
11
+ import tritonclient.grpc as grpcclient
12
+ from tritonclient.utils import *
13
+ import logging
14
+ import wespeaker
15
+ import shutil
16
+ from datetime import datetime
17
+ import multiprocessing as mp
18
+ from functools import partial
19
+ import math
20
+ import threading
21
+ import time
22
+ from concurrent.futures import ThreadPoolExecutor, as_completed
23
+ import random # 添加random模块用于shuffle
24
+
25
+ # 设置multiprocessing启动方式为spawn(CUDA兼容)
26
+ mp.set_start_method('spawn', force=True)
27
+
28
+ # 引用词对齐模块
29
+ from alignment import AlignmentModel, batch_get_alignment_result
30
+ # from tensorrt_client import TritonSimilarityClient
31
+ from speaker_client import TritonSpeakerClient
32
+
33
+
34
+ class SpeakerSimilarityEvaluator:
35
+ """音色相似度评估器"""
36
+
37
+ def __init__(self, device="cuda",
38
+ alignment_model_dir='./models/mms_fa',
39
+ wespeaker_model_url='localhost:8001',
40
+ output_dir="./evaluation_results",
41
+ language="ZH",
42
+ similarity_max_workers=8):
43
+ """初始化评估器"""
44
+ self.device = device
45
+ self.alignment_model_dir = alignment_model_dir
46
+ self.wespeaker_model_url = wespeaker_model_url
47
+ self.language = language.upper() # 添加语言参数
48
+ self.similarity_max_workers = similarity_max_workers # 相似度计算线程数,已无效
49
+
50
+ # 先设置日志系统
51
+ logging.basicConfig(level=logging.INFO)
52
+ self.logger = logging.getLogger(__name__)
53
+
54
+ # 设置输出目录结构
55
+ self.output_dir = Path(output_dir)
56
+ self.segments_dir = self.output_dir / "segments" # 分割后的音频片段
57
+ self.prompts_dir = self.output_dir / "prompts" # prompt音频的S1和S2片段
58
+ self.temp_dir = self.output_dir / "temp" # 临时文件
59
+ self.results_dir = self.output_dir / "results" # 评估结果
60
+ self.temp_results_dir = self.output_dir / "temp_results" # 临时结果文件
61
+ self.alignment_dir = self.output_dir / "alignments" # 对齐信息保存目录
62
+
63
+ # 创建所有必要的目录
64
+ self._create_output_directories()
65
+
66
+ # 在多进程环境中延迟模型初始化
67
+ self.alignment_model = None
68
+ self.similarity_model = None
69
+
70
+ # 线程局部存储,用于线程安全的模型访问
71
+ self._thread_local = threading.local()
72
+
73
+ # 记录运行信息
74
+ self.logger.info(f"评估结果将保存到: {self.output_dir}")
75
+ self.logger.info(f"对齐信息将保存到: {self.alignment_dir}")
76
+ self.logger.info(f"使用语言: {self.language}")
77
+
78
+ def _create_output_directories(self):
79
+ """创建输出目录结构"""
80
+ for dir_path in [self.segments_dir, self.prompts_dir, self.temp_dir,
81
+ self.results_dir, self.temp_results_dir, self.alignment_dir]:
82
+ dir_path.mkdir(parents=True, exist_ok=True)
83
+
84
+ def _get_safe_filename(self, text: str, max_length: int = 50) -> str:
85
+ """生成安全的文件名"""
86
+ # 移除特殊字符,只保留中文、英文、数字和基本符号
87
+ safe_text = re.sub(r'[^\u4e00-\u9fff\w\s]', '', text)
88
+ # 限制长度
89
+ if len(safe_text) > max_length:
90
+ safe_text = safe_text[:max_length]
91
+ # 替换空格为下划线
92
+ safe_text = safe_text.replace(' ', '_')
93
+ return safe_text if safe_text else "unnamed"
94
+
95
+ def _clean_temp_files(self):
96
+ """清理临时文件,但保留临时目录"""
97
+ if self.temp_dir.exists():
98
+ # 只删除临时目录中的文件,不删除目录本身
99
+ for file_path in self.temp_dir.iterdir():
100
+ if file_path.is_file():
101
+ try:
102
+ file_path.unlink()
103
+ except Exception as e:
104
+ self.logger.warning(f"删除临时文件失败: {file_path}, 错误: {e}")
105
+ else:
106
+ # 如果临时目录不存在,重新创建
107
+ self.temp_dir.mkdir(parents=True, exist_ok=True)
108
+
109
+ def _init_models_if_needed(self):
110
+ """延迟初始化模型(用于多进程环境)"""
111
+ # 初始化对齐模型 - 修正参数顺序
112
+ if self.alignment_model is None:
113
+ # 根据AlignmentModel的构造函数,应该是(device, model_dir)而不是(model_dir, device)
114
+ self.alignment_model = AlignmentModel(self.device, self.alignment_model_dir)
115
+
116
+ # 初始化相似度模型
117
+ if self.similarity_model is None:
118
+ self._load_wespeaker_model(self.wespeaker_model_url)
119
+
120
+ def _is_english_text(self, text: str) -> bool:
121
+ """简单判断文本是否主要是英文"""
122
+ # 计算英文字符的比例
123
+ english_chars = sum(1 for c in text if c.isascii() and c.isalpha())
124
+ total_chars = sum(1 for c in text if c.isalpha())
125
+
126
+ if total_chars == 0:
127
+ return False
128
+
129
+ return english_chars / total_chars > 0.8 # 如果80%以上是英文字符,认为是英文
130
+
131
+ def _detect_language_from_text(self, text: str) -> str:
132
+ """从文本内容检测语言"""
133
+ clean_text = self.remove_speaker_tags(text)
134
+ if self._is_english_text(clean_text):
135
+ return "EN"
136
+ else:
137
+ return "ZH"
138
+
139
+ def save_alignment_info(self, alignment_data: Dict[str, Any], input_id: str, file_type: str = "output"):
140
+ """
141
+ 保存对齐信息到单独的JSON文件
142
+
143
+ Args:
144
+ alignment_data: 对齐信息数据
145
+ input_id: 输入ID
146
+ file_type: 文件类型 ("output", "prompt", "segment")
147
+ """
148
+ try:
149
+ safe_input_id = self._get_safe_filename(input_id)
150
+ alignment_filename = f"{safe_input_id}_{file_type}_alignment.json"
151
+ alignment_path = self.alignment_dir / alignment_filename
152
+
153
+ # 添加元数据
154
+ alignment_info = {
155
+ 'input_id': input_id,
156
+ 'file_type': file_type,
157
+ 'language': self.language,
158
+ 'timestamp': datetime.now().isoformat(),
159
+ 'alignment_data': alignment_data
160
+ }
161
+
162
+ with open(alignment_path, 'w', encoding='utf-8') as f:
163
+ json.dump(alignment_info, f, ensure_ascii=False, indent=2)
164
+
165
+ self.logger.info(f"对齐信息已保存: {alignment_path}")
166
+ return str(alignment_path)
167
+
168
+ except Exception as e:
169
+ self.logger.error(f"保存对齐信息失败: {e}")
170
+ return None
171
+
172
+ def save_detailed_alignment_info(self, alignments: List[Dict[str, Any]],
173
+ text_segments: List[Dict[str, Any]],
174
+ input_id: str, audio_path: str,
175
+ original_text: str, processed_text: str):
176
+ """
177
+ 保存详细的对齐信息,包括分段信息
178
+
179
+ Args:
180
+ alignments: 对齐结果列表
181
+ text_segments: 文本分段信息
182
+ input_id: 输入ID
183
+ audio_path: 音频文件路径
184
+ original_text: 原始文本
185
+ processed_text: 处理后的文本
186
+ """
187
+ alignment_data = {
188
+ 'original_text': original_text,
189
+ 'processed_text': processed_text,
190
+ 'audio_path': audio_path,
191
+ 'language': self.language,
192
+ 'total_alignments': len(alignments),
193
+ 'total_segments': len(text_segments),
194
+ 'alignments': alignments,
195
+ 'text_segments': text_segments,
196
+ 'segment_alignment_mapping': []
197
+ }
198
+
199
+ # 建立文本段和对齐结果的映射关系
200
+ for segment in text_segments:
201
+ segment_mapping = {
202
+ 'segment_id': segment.get('segment_id', 0),
203
+ 'segment_text': segment.get('text', ''),
204
+ 'speaker_label': segment.get('speaker_label', ''),
205
+ 'start_time': segment.get('start_time', 0.0),
206
+ 'end_time': segment.get('end_time', 0.0),
207
+ 'corresponding_alignments': []
208
+ }
209
+
210
+ # 找到对应的对齐项
211
+ segment_start = segment.get('start_time', 0.0)
212
+ segment_end = segment.get('end_time', 0.0)
213
+
214
+ for i, align_item in enumerate(alignments):
215
+ align_start = align_item.get('start', 0.0)
216
+ align_end = align_item.get('end', 0.0)
217
+
218
+ # 检查对齐项是否在当前段的时间范围内
219
+ if (align_start >= segment_start and align_end <= segment_end) or \
220
+ (align_start < segment_end and align_end > segment_start):
221
+ segment_mapping['corresponding_alignments'].append({
222
+ 'alignment_index': i,
223
+ 'transcript': align_item.get('transcript', ''),
224
+ 'start': align_start,
225
+ 'end': align_end,
226
+ 'score': align_item.get('score', 0.0) if 'score' in align_item else None
227
+ })
228
+
229
+ alignment_data['segment_alignment_mapping'].append(segment_mapping)
230
+
231
+ return self.save_alignment_info(alignment_data, input_id, "detailed")
232
+
233
+ def remove_speaker_tags(self, text: str) -> str:
234
+ """删除文本中的说话人标签[S1][S2]"""
235
+ return re.sub(r'\[S[12]\]', '', text).strip()
236
+
237
+ def extract_speaker_segments(self, text: str) -> List[Dict[str, Any]]:
238
+ """提取文本中的说话人片段信息"""
239
+ segments = []
240
+ pattern = r'\[S([12])\]([^[]*)'
241
+ matches = re.findall(pattern, text)
242
+
243
+ for speaker_id, content in matches:
244
+ segments.append({
245
+ 'speaker': f'S{speaker_id}',
246
+ 'content': content.strip()
247
+ })
248
+ return segments
249
+
250
+ def replace_punctuation_with_comma(self, text: str, language: str = None) -> str:
251
+ """将所有标点符号替换为逗号,连续逗号只保留一个,根据语言选择正确的逗号类型"""
252
+ # 如果未指定语言,使用类的默认语言设置或自动检测
253
+ if language is None:
254
+ if hasattr(self, 'language'):
255
+ language = self.language
256
+ else:
257
+ language = self._detect_language_from_text(text)
258
+
259
+ language = language.upper()
260
+
261
+ # 根据语言选择逗号类型和处理策略
262
+ if language == "EN" or (language == "AUTO" and self._is_english_text(text)):
263
+ # 英文处理:先删除撇号,再替换其他标点符号
264
+ text = re.sub(r"'", '', text) # 删除撇号(don't -> dont)
265
+ target_comma = ',' # 英文逗号
266
+ comma_pattern = r',+' # 匹配连续英文逗号
267
+ # 更新正则表达式,不包含撇号
268
+ text = re.sub(r'[.,!?;:()\[\]<>\"…·,。;:!?()【】《》""\\、]', target_comma, text)
269
+ else:
270
+ # 中文处理:包含撇号在替换范围内
271
+ target_comma = ',' # 中文逗号
272
+ comma_pattern = r',+' # 匹配连续中文逗号
273
+ # 更新正则表达式以匹配更多的标点符号
274
+ text = re.sub(r'[.,!?;:()\[\]<>\'\"…·,。;:!?()【】《》''""\\、]', target_comma, text)
275
+
276
+ text = re.sub(comma_pattern, target_comma, text)
277
+ return text.strip(target_comma)
278
+
279
+ def align_text_with_audio(self, text: str, audio_path: str, language=None) -> List[Dict[str, Any]]:
280
+ """
281
+ 文本和音频的词对齐
282
+ 返回每个词对应的音频时间段
283
+ """
284
+ # 确保模型已初始化
285
+ self._init_models_if_needed()
286
+
287
+ # 如果未指定语言,使用类的默认语言设置或自动检测
288
+ if language is None:
289
+ if hasattr(self, 'language'):
290
+ language = self.language
291
+ else:
292
+ language = self._detect_language_from_text(text)
293
+ else:
294
+ language = language.upper()
295
+
296
+ # 加载音频
297
+ waveform, sample_rate = torchaudio.load(audio_path)
298
+
299
+ # 重采样到模型要求的采样率
300
+ if sample_rate != self.alignment_model.bundle.sample_rate:
301
+ waveform = F.resample(waveform, sample_rate, self.alignment_model.bundle.sample_rate)
302
+
303
+ # 转换为单声道
304
+ if waveform.shape[0] > 1:
305
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
306
+
307
+ waveform = waveform.squeeze(0) # 移除批次维度
308
+
309
+ # 将音频移动到正确的设备
310
+ waveform = waveform.to(self.device)
311
+
312
+ # 执行对齐
313
+ try:
314
+ alignment_results = batch_get_alignment_result(
315
+ self.alignment_model,
316
+ [waveform],
317
+ [text],
318
+ [language]
319
+ )
320
+ if not alignment_results or not alignment_results[0]:
321
+ raise RuntimeError(f"对齐结果为空: {audio_path}")
322
+ return alignment_results[0]
323
+ except Exception as e:
324
+ self.logger.error(f"音频对齐失败: {audio_path}")
325
+ self.logger.error(f"错误详情: {e}")
326
+ raise RuntimeError(f"音频对齐失败,程序终止。文件: {audio_path},错误: {e}")
327
+
328
+ def split_audio_segment(self, audio_path: str, start_time: float, end_time: float, output_path: str):
329
+ """分割音频片段"""
330
+ waveform, sample_rate = torchaudio.load(audio_path)
331
+
332
+ start_frame = int(start_time * sample_rate)
333
+ end_frame = int(end_time * sample_rate)
334
+
335
+ segment = waveform[:, start_frame:end_frame]
336
+
337
+ # 确保输出目录存在
338
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
339
+
340
+ torchaudio.save(output_path, segment, sample_rate)
341
+ return output_path
342
+
343
+ def concatenate_audio_files(self, audio_files: List[str], output_path: str):
344
+ """拼接多个音频文件"""
345
+ if not audio_files:
346
+ return
347
+
348
+ waveforms = []
349
+ sample_rate = None
350
+
351
+ for audio_file in audio_files:
352
+ if os.path.exists(audio_file):
353
+ waveform, sr = torchaudio.load(audio_file)
354
+ if sample_rate is None:
355
+ sample_rate = sr
356
+ elif sr != sample_rate:
357
+ waveform = F.resample(waveform, sr, sample_rate)
358
+ waveforms.append(waveform)
359
+
360
+ if waveforms:
361
+ concatenated = torch.cat(waveforms, dim=1)
362
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
363
+ torchaudio.save(output_path, concatenated, sample_rate)
364
+
365
+ def split_audio_by_speaker(self, prompt_text: str, prompt_audio: str, audio_id: str) -> Tuple[str, str]:
366
+ """
367
+ 根据说话人标签分割prompt音频
368
+ 返回S1和S2的音频片段路径
369
+ """
370
+ # 1. 提取说话人片段
371
+ speaker_segments = self.extract_speaker_segments(prompt_text)
372
+
373
+ # 2. 删除标签后进行词对齐 - 如果失败则直接抛出异常
374
+ clean_text = self.remove_speaker_tags(prompt_text)
375
+
376
+ # 检测语言或使用设置的语言
377
+ alignment_language = self.language
378
+ if alignment_language == "AUTO":
379
+ alignment_language = self._detect_language_from_text(clean_text)
380
+
381
+ alignments = self.align_text_with_audio(clean_text, prompt_audio, alignment_language)
382
+
383
+ # 保存prompt对齐信息
384
+ prompt_alignment_data = {
385
+ 'original_text': prompt_text,
386
+ 'clean_text': clean_text,
387
+ 'audio_path': prompt_audio,
388
+ 'language': alignment_language,
389
+ 'speaker_segments': speaker_segments,
390
+ 'alignments': alignments
391
+ }
392
+ self.save_alignment_info(prompt_alignment_data, audio_id, "prompt")
393
+
394
+ # 3. 根据对齐结果分割音频
395
+ s1_segments = []
396
+ s2_segments = []
397
+
398
+ # 为每个说话人片段找到对应的时间段
399
+ text_pos = 0
400
+ for seg in speaker_segments:
401
+ seg_text = seg['content'].strip()
402
+ seg_length = len(seg_text)
403
+
404
+ # 找到这个片段在对齐结果中的起始和结束
405
+ start_time = None
406
+ end_time = None
407
+
408
+ current_pos = 0
409
+ for align_item in alignments:
410
+ item_text = align_item['transcript']
411
+ item_length = len(item_text)
412
+
413
+ if current_pos >= text_pos and current_pos < text_pos + seg_length:
414
+ if start_time is None:
415
+ start_time = align_item['start']
416
+ end_time = align_item['end']
417
+
418
+ current_pos += item_length
419
+
420
+ if start_time is not None and end_time is not None:
421
+ if seg['speaker'] == 'S1':
422
+ s1_segments.append((start_time, end_time))
423
+ else:
424
+ s2_segments.append((start_time, end_time))
425
+
426
+ text_pos += seg_length
427
+
428
+ # 4. 分割并拼接音频片段
429
+ safe_audio_id = self._get_safe_filename(audio_id)
430
+ prompts1_path = str(self.prompts_dir / f"{safe_audio_id}_s1.wav")
431
+ prompts2_path = str(self.prompts_dir / f"{safe_audio_id}_s2.wav")
432
+
433
+ # 分割S1的所有片段
434
+ if s1_segments:
435
+ s1_temp_segments = []
436
+ for i, (start, end) in enumerate(s1_segments):
437
+ temp_path = str(self.temp_dir / f"{safe_audio_id}_s1_temp_{i}.wav")
438
+ self.split_audio_segment(prompt_audio, start, end, temp_path)
439
+ s1_temp_segments.append(temp_path)
440
+
441
+ # 拼接S1片段
442
+ self.concatenate_audio_files(s1_temp_segments, prompts1_path)
443
+
444
+ # 分割S2的所有片段
445
+ if s2_segments:
446
+ s2_temp_segments = []
447
+ for i, (start, end) in enumerate(s2_segments):
448
+ temp_path = str(self.temp_dir / f"{safe_audio_id}_s2_temp_{i}.wav")
449
+ self.split_audio_segment(prompt_audio, start, end, temp_path)
450
+ s2_temp_segments.append(temp_path)
451
+
452
+ # 拼接S2片段
453
+ self.concatenate_audio_files(s2_temp_segments, prompts2_path)
454
+
455
+ return prompts1_path, prompts2_path
456
+
457
+ def map_text_segments_to_speakers(self, original_text: str) -> List[Dict[str, Any]]:
458
+ """
459
+ 将原始文本按说话人和标点符号同时分割,保持映射关系
460
+ 支持英文单词级别的处理
461
+ """
462
+ segments = []
463
+ pattern = r'\[S([12])\]([^[]*)'
464
+ matches = re.findall(pattern, original_text)
465
+
466
+ # 检测语言或使用设置的语言
467
+ alignment_language = self.language
468
+ if alignment_language == "AUTO":
469
+ alignment_language = self._detect_language_from_text(original_text)
470
+
471
+ segment_id = 0
472
+ for speaker_id, content in matches:
473
+ speaker = f'S{speaker_id}'
474
+ clean_content = content.strip()
475
+ comma_content = self.replace_punctuation_with_comma(clean_content, alignment_language)
476
+
477
+ # 根据语言选择正确的逗号分割
478
+ if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_content)):
479
+ # 英文:按英文逗号分割,保持单词完整性
480
+ parts = [part.strip() for part in comma_content.split(',') if part.strip()]
481
+ else:
482
+ # 中文:按中文逗号分割
483
+ parts = [part.strip() for part in comma_content.split(',') if part.strip()]
484
+
485
+ for part in parts:
486
+ if part.strip():
487
+ segments.append({
488
+ 'segment_id': segment_id,
489
+ 'text': part.strip(),
490
+ 'speaker_label': speaker,
491
+ 'original_speaker_content': clean_content
492
+ })
493
+ segment_id += 1
494
+
495
+ return segments
496
+
497
+ def split_output_audio_by_comma(self, text: str, output_audio: str, audio_id: str) -> List[Dict[str, Any]]:
498
+ """
499
+ 根据逗号分割输出音频,返回每小段的信息 - 基于词对齐结果中的标点符号划分句子
500
+ """
501
+ # 1. 获取文本片段和对应的说话人(用于获取speaker标签)
502
+ text_segments = self.map_text_segments_to_speakers(text)
503
+
504
+ # 2. 删除标签并替换标点符号
505
+ clean_text = self.remove_speaker_tags(text)
506
+
507
+ # 3. 检测语言或使用设置的语言
508
+ alignment_language = self.language
509
+ if alignment_language == "AUTO":
510
+ alignment_language = self._detect_language_from_text(clean_text)
511
+
512
+ # 使用检测到的语言替换标点符号
513
+ comma_text = self.replace_punctuation_with_comma(clean_text, alignment_language)
514
+
515
+ # 4. 词对齐 - 如果失败则直接抛出异常
516
+ alignments = self.align_text_with_audio(comma_text, output_audio, alignment_language)
517
+
518
+ # 5. 根据标点符号划分句子
519
+ segments = []
520
+ safe_audio_id = self._get_safe_filename(audio_id)
521
+
522
+ # 确定标点符号(根据语言选择,英文不包含撇号)
523
+ if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
524
+ punctuation_chars = set([',', '.', '!', '?', ';', ':']) # 不包含撇号
525
+ else:
526
+ punctuation_chars = set([',', '。', '!', '?', ';', ':'])
527
+
528
+ # 顺序扫描对齐结果,根据标点符号划分句子
529
+ sentence_start_idx = 0
530
+ sentence_alignments = []
531
+ segment_id = 0
532
+
533
+ for i, align_item in enumerate(alignments):
534
+ transcript = align_item['transcript']
535
+ sentence_alignments.append(align_item)
536
+
537
+ # 检查是否包含标点符号(句子结束标志)
538
+ has_punctuation = any(punct in transcript for punct in punctuation_chars)
539
+
540
+ if has_punctuation or i == len(alignments) - 1: # 遇到标点符号或最后一个词
541
+ # 创建句子片段
542
+ if sentence_alignments:
543
+ # 获取句子的开始和结束时间
544
+ start_time = sentence_alignments[0]['start']
545
+ end_time = sentence_alignments[-1]['end']
546
+
547
+ # 构建句子文本(去除标点符号)
548
+ sentence_text_parts = []
549
+ for align in sentence_alignments:
550
+ # 根据语言选择不同的清理策略
551
+ if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
552
+ # 英文:去除标点符号,但保留撇号已被删除的单词
553
+ clean_transcript = align['transcript'].rstrip(',.!?;:')
554
+ else:
555
+ # 中文:去除中文标点符号
556
+ clean_transcript = align['transcript'].rstrip(',。!?;:')
557
+
558
+ if clean_transcript.strip():
559
+ sentence_text_parts.append(clean_transcript)
560
+
561
+ # 根据语言选择连接方式
562
+ if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
563
+ sentence_text = ' '.join(sentence_text_parts).strip() # 英文用空格连接
564
+ else:
565
+ sentence_text = ''.join(sentence_text_parts).strip() # 中文直接连接
566
+
567
+ if sentence_text: # 只有非空句子才处理
568
+ # 确定说话人标签(从原始text_segments中获取,如果可能的话)
569
+ speaker_label = "S1" # 默认
570
+ if segment_id < len(text_segments):
571
+ speaker_label = text_segments[segment_id]['speaker_label']
572
+ elif text_segments:
573
+ # 如果超出范围,使用最后一个片段的speaker
574
+ speaker_label = text_segments[-1]['speaker_label']
575
+
576
+ # 生成音频文件路径
577
+ safe_text = self._get_safe_filename(sentence_text, 30)
578
+ audio_path = str(self.segments_dir / f"{safe_audio_id}_segment_{segment_id:03d}_{safe_text}.wav")
579
+
580
+ # 分割音频
581
+ try:
582
+ self.split_audio_segment(output_audio, start_time, end_time, audio_path)
583
+ except Exception as e:
584
+ self.logger.error(f"分割音频失败: {e}")
585
+ # 使用默认时间间隔
586
+ start_time = segment_id * 1.0
587
+ end_time = (segment_id + 1) * 1.0
588
+ self.split_audio_segment(output_audio, start_time, end_time, audio_path)
589
+
590
+ # 创建segment
591
+ segment = {
592
+ 'segment_id': segment_id,
593
+ 'text': sentence_text,
594
+ 'speaker_label': speaker_label,
595
+ 'original_speaker_content': sentence_text, # 这里简化处理
596
+ 'audio_path': audio_path,
597
+ 'start_time': start_time,
598
+ 'end_time': end_time
599
+ }
600
+
601
+ segments.append(segment)
602
+
603
+ self.logger.info(f"句子 {segment_id}: '{sentence_text}' ({speaker_label}) -> {start_time:.3f}-{end_time:.3f}s")
604
+ segment_id += 1
605
+
606
+ # 重置为下一个句子
607
+ sentence_alignments = []
608
+ sentence_start_idx = i + 1
609
+
610
+ # 保存详细的对齐信息
611
+ self.save_detailed_alignment_info(
612
+ alignments, segments, audio_id, output_audio, text, comma_text
613
+ )
614
+
615
+ self.logger.info(f"总共分割出 {len(segments)} 个句子片段")
616
+ return segments
617
+
618
+ def _get_similarity_model_server(self):
619
+ """获取线程局部的相似度模型实例(线程安全)"""
620
+ if not hasattr(self, 'similarity_model'):
621
+ # 为当前线程创建独立的模型实例
622
+ self.similarity_model = self._create_similarity_model()
623
+ return self.similarity_model
624
+
625
+ def _create_similarity_model(self):
626
+ """创建新的相似度模型实例"""
627
+ try:
628
+ return TritonSpeakerClient(self.wespeaker_model_url)
629
+ except Exception as e:
630
+ self.logger.error(f"创建相似度模型失败: {e}")
631
+ raise
632
+
633
+ async def compute_similarity(self, processed_audio1, processed_audio2):
634
+ return await self.similarity_model.compute_similarity(processed_audio1, processed_audio2)
635
+
636
+ async def calculate_voice_similarity_thread_safe(self, audio1_path: str, audio2_path: str) -> float:
637
+ """
638
+ 线程安全的音色相似度计算
639
+ 对于过短的音频片段,通过复制来达到最小长度要求
640
+ """
641
+ try:
642
+ if not os.path.exists(audio1_path) or not os.path.exists(audio2_path):
643
+ self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}")
644
+ return None
645
+
646
+ # 获取线程局部的模型实例
647
+ _ = self._get_similarity_model_server()
648
+
649
+ # 计算相似度
650
+ similarity = await self.compute_similarity(audio1_path, audio2_path)
651
+
652
+ return float(similarity)
653
+
654
+ except Exception as e:
655
+ # 检查是否是窗口大小错误或其他计算错误
656
+ if "choose a window size" in str(e) or "window size" in str(e):
657
+ self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}")
658
+ return None
659
+ else:
660
+ self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}")
661
+ return None
662
+
663
+ async def calculate_segment_similarities_parallel(
664
+ self, output_segments: List[Dict[str, Any]], prompts1_path: str, prompts2_path: str
665
+ ) -> List[Dict[str, Any]]:
666
+ """
667
+ 并行计算所有segments的相似度
668
+ Args:
669
+ output_segments: 音频segments列表
670
+ prompts1_path: S1 prompt音频路径
671
+ prompts2_path: S2 prompt音频路径
672
+ Returns:
673
+ 包含相似度信息的segment列表
674
+ """
675
+
676
+ async def calculate_single_segment_similarity(segment):
677
+ """计算单个segment与两个prompts的相似度"""
678
+ try:
679
+ # 使用线程安全的相似度计算方法
680
+ sim1 = await self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts1_path)
681
+ sim2 = await self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts2_path)
682
+
683
+ return {
684
+ 'segment': segment,
685
+ 'sim1': sim1,
686
+ 'sim2': sim2,
687
+ 'success': True
688
+ }
689
+ except Exception as e:
690
+ self.logger.error(f"计算segment {segment['segment_id']} 相似度失败: {e}")
691
+ return {
692
+ 'segment': segment,
693
+ 'sim1': None,
694
+ 'sim2': None,
695
+ 'success': False
696
+ }
697
+
698
+ # 使用线程池并行处理所有segments
699
+ self.logger.info(f"开始异步计算 {len(output_segments)} 个segments的相似度")
700
+
701
+ # 创建任务并保留原始segment的顺序(gather会保持顺序)
702
+ tasks = [
703
+ asyncio.create_task(calculate_single_segment_similarity(segment))
704
+ for segment in output_segments
705
+ ]
706
+
707
+ # 正确版本:使用asyncio.as_completed实时报告进度
708
+ return await self._run_tasks_with_progress(tasks)
709
+
710
+ # 新增辅助方法:带进度报告的任务执行
711
+ async def _run_tasks_with_progress(self, tasks):
712
+ """执行任务集合并实时报告进度"""
713
+ completed_count = 0
714
+ total = len(tasks)
715
+ results = []
716
+
717
+ # 按完成顺序处理结果
718
+ for future in asyncio.as_completed(tasks):
719
+ result = await future
720
+ completed_count += 1
721
+
722
+ # 每完成10个segment报告一次进度
723
+ if completed_count % 10 == 0 or completed_count == total:
724
+ seg_id = result['segment']['segment_id']
725
+ self.logger.info(f"相似度计算进度: {completed_count}/{total} (最近完成: {seg_id})")
726
+
727
+ results.append(result)
728
+
729
+ # gather返回的就是按顺序的结果,无需额外排序
730
+ return results
731
+
732
+ async def evaluate_single_input(self, data: Dict[str, Any], input_id: str = None) -> Dict[str, Any]:
733
+ """评估单个输入的音色相似度"""
734
+
735
+ # 生成输入ID
736
+ if input_id is None:
737
+ input_id = f"input_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
738
+
739
+ self.logger.info(f"开始评估输入: {input_id},使用语言: {self.language}")
740
+
741
+ # 1. 获取或分割prompt音频
742
+ prompts1_path, prompts2_path = self.get_or_split_prompt_audio(data, f"{input_id}_prompt")
743
+
744
+ # 2. 分割output音频(这里会保存详细对齐信息)
745
+ output_segments = self.split_output_audio_by_comma(data['text'], data['output_audio'], f"{input_id}_output")
746
+
747
+ # 3. 并行计算每小段的相似度
748
+ similarity_results = await self.calculate_segment_similarities_parallel(
749
+ output_segments, prompts1_path, prompts2_path
750
+ )
751
+
752
+ # 4. 处理相似度结果
753
+ segment_results = []
754
+ correct_predictions = 0
755
+ total_segments = 0 # 只计算有效段数
756
+ label_similarities = [] # 每小段与其标签的相似度
757
+ skipped_segments = 0 # 跳过的段数
758
+
759
+ for sim_result in similarity_results:
760
+ segment = sim_result['segment']
761
+ sim1 = sim_result['sim1']
762
+ sim2 = sim_result['sim2']
763
+
764
+ # 如果任一相似度为None(音频过短或计算失败),跳过该段
765
+ if sim1 is None or sim2 is None:
766
+ skipped_segments += 1
767
+ self.logger.info(f"跳过段 {segment['segment_id']}: 相似度计算失败")
768
+ continue
769
+
770
+ # 只有有效段才参与计算
771
+ total_segments += 1
772
+
773
+ # 判断实际音色
774
+ predicted_speaker = 'S1' if sim1 > sim2 else 'S2'
775
+ actual_speaker = segment['speaker_label']
776
+ is_correct = predicted_speaker == actual_speaker
777
+
778
+ if is_correct:
779
+ correct_predictions += 1
780
+
781
+ # 计算与标签的相似度
782
+ if actual_speaker == 'S1':
783
+ label_similarity = sim1
784
+ else:
785
+ label_similarity = sim2
786
+ label_similarities.append(label_similarity)
787
+
788
+ segment_result = {
789
+ 'segment_id': segment['segment_id'],
790
+ 'text': segment['text'],
791
+ 'speaker_label': actual_speaker,
792
+ 'predicted_speaker': predicted_speaker,
793
+ 'sim1': sim1,
794
+ 'sim2': sim2,
795
+ 'label_similarity': label_similarity,
796
+ 'is_correct': is_correct,
797
+ 'audio_path': segment['audio_path'],
798
+ 'start_time': segment.get('start_time', 0.0),
799
+ 'end_time': segment.get('end_time', 1.0)
800
+ }
801
+ segment_results.append(segment_result)
802
+
803
+ # 4. 计算整体指标(只基于有效段)
804
+ accuracy = correct_predictions / total_segments if total_segments > 0 else 0.0
805
+ average_similarity = np.mean(label_similarities) if label_similarities else 0.0
806
+
807
+ # 5. 保存评估结果的对齐信息摘要
808
+ evaluation_alignment_summary = {
809
+ 'input_id': input_id,
810
+ 'language': self.language,
811
+ 'prompt_alignment_files': [
812
+ f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json"
813
+ ],
814
+ 'output_alignment_file': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json",
815
+ 'total_segments': total_segments,
816
+ 'total_alignments_processed': len(output_segments),
817
+ 'alignment_success_rate': total_segments / len(output_segments) if output_segments else 0.0
818
+ }
819
+ self.save_alignment_info(evaluation_alignment_summary, input_id, "summary")
820
+
821
+ result = {
822
+ 'input_id': input_id,
823
+ 'language': self.language,
824
+ 'input_data': data, # 保存原始输入数据
825
+ 'prompts1_path': prompts1_path,
826
+ 'prompts2_path': prompts2_path,
827
+ 'segments': segment_results,
828
+ 'accuracy': accuracy,
829
+ 'average_similarity': average_similarity,
830
+ 'total_segments': total_segments, # 有效段数
831
+ 'correct_predictions': correct_predictions,
832
+ 'skipped_segments': skipped_segments, # 跳过的段数
833
+ 'original_total_segments': len(output_segments), # 原始总段数
834
+ 'alignment_files': {
835
+ 'summary': f"{self._get_safe_filename(input_id)}_summary_alignment.json",
836
+ 'output_detailed': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json",
837
+ 'prompt': f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json"
838
+ },
839
+ 'timestamp': datetime.now().isoformat()
840
+ }
841
+
842
+ self.logger.info(f"完成评估输入: {input_id}, 语言: {self.language}, 有效段: {total_segments}/{len(output_segments)}, 跳过: {skipped_segments}, 准确率: {accuracy:.3f}, 平均相似度: {average_similarity:.3f}")
843
+
844
+ return result
845
+
846
+ def save_results_to_jsonl(self, results: List[Dict[str, Any]], filename: str = None):
847
+ """保存结果到JSONL文件"""
848
+ if filename is None:
849
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
850
+ filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
851
+
852
+ output_path = self.results_dir / filename
853
+
854
+ with open(output_path, 'w', encoding='utf-8') as f:
855
+ for result in results:
856
+ f.write(json.dumps(result, ensure_ascii=False) + '\n')
857
+
858
+ return str(output_path)
859
+
860
+ def save_summary_report(self, results: List[Dict[str, Any]], filename: str = None):
861
+ """保存汇总报告"""
862
+ if filename is None:
863
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
864
+ filename = f"evaluation_summary_{self.language.lower()}_{timestamp}.json"
865
+
866
+ summary_path = self.results_dir / filename
867
+
868
+ # 计算总体统计
869
+ total_accuracy = np.mean([r['accuracy'] for r in results])
870
+ total_avg_similarity = np.mean([r['average_similarity'] for r in results])
871
+ total_segments = sum([r['total_segments'] for r in results])
872
+ total_correct = sum([r['correct_predictions'] for r in results])
873
+
874
+ summary = {
875
+ 'evaluation_summary': {
876
+ 'language': self.language,
877
+ 'total_inputs': len(results),
878
+ 'total_segments': total_segments,
879
+ 'total_correct_predictions': total_correct,
880
+ 'overall_accuracy': total_accuracy,
881
+ 'overall_average_similarity': total_avg_similarity,
882
+ 'evaluation_timestamp': datetime.now().isoformat(),
883
+ 'output_directory': str(self.output_dir),
884
+ 'alignment_directory': str(self.alignment_dir)
885
+ },
886
+ 'per_input_results': [
887
+ {
888
+ 'input_id': r['input_id'],
889
+ 'language': r.get('language', self.language),
890
+ 'accuracy': r['accuracy'],
891
+ 'average_similarity': r['average_similarity'],
892
+ 'total_segments': r['total_segments'],
893
+ 'correct_predictions': r['correct_predictions'],
894
+ 'output_audio_path': r['input_data']['output_audio'],
895
+ 'alignment_files': r.get('alignment_files', {})
896
+ }
897
+ for r in results
898
+ ]
899
+ }
900
+
901
+ with open(summary_path, 'w', encoding='utf-8') as f:
902
+ json.dump(summary, f, ensure_ascii=False, indent=2)
903
+
904
+ return str(summary_path)
905
+
906
+ def process_batch_from_jsonl_parallel(self, jsonl_path: str,
907
+ processes_per_gpu: int = 16,
908
+ results_filename: str = None,
909
+ shuffle_data: bool = True):
910
+ """从JSONL文件并行批量处理输入数据"""
911
+ # 加载数据
912
+ input_data = self.load_data_from_jsonl(jsonl_path)
913
+
914
+ if not input_data:
915
+ self.logger.error("没有有效的输入数据")
916
+ return []
917
+
918
+ # 对数据进行shuffle,使分配更均匀
919
+ if shuffle_data:
920
+ random.shuffle(input_data)
921
+ self.logger.info(f"已对 {len(input_data)} 条数据进行随机shuffle")
922
+
923
+ return self.process_batch_parallel(input_data, processes_per_gpu, results_filename)
924
+
925
+ def process_batch_from_jsonl(self, jsonl_path: str, results_filename: str = None):
926
+ """从JSONL文件批量处理输入数据(单进程版本)"""
927
+ # 加载数据
928
+ input_data = self.load_data_from_jsonl(jsonl_path)
929
+
930
+ if not input_data:
931
+ self.logger.error("没有有效的输入数据")
932
+ return []
933
+
934
+ return asyncio.run(self.process_batch_from_data(input_data, results_filename))
935
+
936
+ async def process_batch_from_data(self, input_data: List[Dict[str, Any]], results_filename: str = None):
937
+ """处理数据列表(单进程版本,用于兼容),支持增量写入"""
938
+ # 准备结果文件
939
+ if results_filename is None:
940
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
941
+ results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
942
+
943
+ results_path = self.results_dir / results_filename
944
+
945
+ # 如果文件已存在,删除它(重新开始)
946
+ if results_path.exists():
947
+ results_path.unlink()
948
+
949
+ results = []
950
+
951
+ self.logger.info(f"开始处理 {len(input_data)} 个输入,使用语言: {self.language}...")
952
+
953
+ for i, data in enumerate(input_data):
954
+ input_id = f"input_{i+1:03d}"
955
+ print(f"处理第{i+1}/{len(input_data)}个输入: {input_id},语言: {self.language}")
956
+
957
+ try:
958
+ result = await self.evaluate_single_input(data, input_id=input_id)
959
+ results.append(result)
960
+
961
+ # 增量写入结果
962
+ self.append_result_to_jsonl(result, str(results_path))
963
+
964
+ except Exception as e:
965
+ self.logger.error(f"处理输入{input_id}时出错: {e}")
966
+ continue
967
+
968
+ if not results:
969
+ self.logger.error("没有成功处理的输入")
970
+ return []
971
+
972
+ # 保存汇总报告
973
+ summary_path = self.save_summary_report(results)
974
+
975
+ # 清理临时文件
976
+ self._clean_temp_files()
977
+
978
+ # 打印总体统计
979
+ total_accuracy = np.mean([r['accuracy'] for r in results])
980
+ total_avg_similarity = np.mean([r['average_similarity'] for r in results])
981
+
982
+ print(f"\n=== 评估完成 ===")
983
+ print(f"使用语言: {self.language}")
984
+ print(f"总体准确率: {total_accuracy:.3f}")
985
+ print(f"总体平均相似度: {total_avg_similarity:.3f}")
986
+ print(f"详细结果已保存到: {results_path}")
987
+ print(f"汇总报告已保存到: {summary_path}")
988
+ print(f"对齐信息已保存到: {self.alignment_dir}")
989
+ print(f"所有中间文件保存在: {self.output_dir}")
990
+
991
+ return results
992
+
993
+ def _load_wespeaker_model(self, wespeaker_model_url):
994
+ """加载wespeaker模型"""
995
+ try:
996
+ self.similarity_model = TritonSpeakerClient(wespeaker_model_url)
997
+ except ImportError:
998
+ raise ImportError("请安装wespeaker: pip install git+https://github.com/wenet-e2e/wespeaker.git")
999
+ except Exception as e:
1000
+ self.logger.error(f"加载wespeaker模型失败: {e}")
1001
+ raise
1002
+
1003
+ def load_data_from_jsonl(self, jsonl_path: str) -> List[Dict[str, Any]]:
1004
+ """从JSONL文件加载数据"""
1005
+ data = []
1006
+ try:
1007
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
1008
+ for line_num, line in enumerate(f, 1):
1009
+ line = line.strip()
1010
+ if line:
1011
+ try:
1012
+ item = json.loads(line)
1013
+ # 验证必要字段
1014
+ required_fields = ['text', 'output_audio']
1015
+ for field in required_fields:
1016
+ if field not in item:
1017
+ self.logger.error(f"第{line_num}行缺少必要字段: {field}")
1018
+ continue
1019
+
1020
+ # 验证音频路径模式:要么有prompt_audio和prompt_text,要么有分别的speaker音频文件
1021
+ has_combined_prompt = 'prompt_audio' in item and 'prompt_text' in item
1022
+ has_separate_prompts = ('prompt_audio_speaker1' in item and
1023
+ 'prompt_text_speaker1' in item and
1024
+ 'prompt_audio_speaker2' in item and
1025
+ 'prompt_text_speaker2' in item)
1026
+
1027
+ if not (has_combined_prompt or has_separate_prompts):
1028
+ self.logger.error(f"第{line_num}行:需要提供prompt_audio+prompt_text或者分别的speaker音频文件")
1029
+ continue
1030
+
1031
+ data.append(item)
1032
+
1033
+ except json.JSONDecodeError as e:
1034
+ self.logger.error(f"第{line_num}行JSON解析错误: {e}")
1035
+ continue
1036
+
1037
+ self.logger.info(f"从{jsonl_path}成功加载{len(data)}条数据")
1038
+ return data
1039
+
1040
+ except FileNotFoundError:
1041
+ self.logger.error(f"JSONL文件不存在: {jsonl_path}")
1042
+ return []
1043
+ except Exception as e:
1044
+ self.logger.error(f"读取JSONL文件失败: {e}")
1045
+ return []
1046
+
1047
+ @staticmethod
1048
+ def get_gpu_count():
1049
+ """获取可用GPU数量"""
1050
+ if torch.cuda.is_available():
1051
+ return torch.cuda.device_count()
1052
+ return 0
1053
+
1054
+ @staticmethod
1055
+ def split_data_by_gpu(data: List[Dict[str, Any]], num_gpus: int) -> List[List[Dict[str, Any]]]:
1056
+ """根据GPU数量分割数据"""
1057
+ if num_gpus == 0:
1058
+ return [data]
1059
+
1060
+ chunk_size = math.ceil(len(data) / num_gpus)
1061
+ gpu_chunks = []
1062
+
1063
+ for i in range(num_gpus):
1064
+ start_idx = i * chunk_size
1065
+ end_idx = min((i + 1) * chunk_size, len(data))
1066
+ if start_idx < len(data):
1067
+ gpu_chunks.append(data[start_idx:end_idx])
1068
+
1069
+ return gpu_chunks
1070
+
1071
+ @staticmethod
1072
+ def split_data_by_processes(data: List[Dict[str, Any]], num_processes: int) -> List[List[Dict[str, Any]]]:
1073
+ """根据进程数量分割数据"""
1074
+ if num_processes <= 1:
1075
+ return [data]
1076
+
1077
+ chunk_size = math.ceil(len(data) / num_processes)
1078
+ process_chunks = []
1079
+
1080
+ for i in range(num_processes):
1081
+ start_idx = i * chunk_size
1082
+ end_idx = min((i + 1) * chunk_size, len(data))
1083
+ if start_idx < len(data):
1084
+ process_chunks.append(data[start_idx:end_idx])
1085
+
1086
+ return process_chunks
1087
+
1088
+ def append_result_to_jsonl(self, result: Dict[str, Any], filepath: str):
1089
+ """增量写入结果到JSONL文件"""
1090
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
1091
+ with open(filepath, 'a', encoding='utf-8') as f:
1092
+ f.write(json.dumps(result, ensure_ascii=False) + '\n')
1093
+ f.flush() # 强制刷新缓冲区
1094
+
1095
+ def merge_temp_results(self, temp_files: List[str], final_path: str):
1096
+ """合并临时结果文件"""
1097
+ all_results = []
1098
+
1099
+ for temp_file in temp_files:
1100
+ if os.path.exists(temp_file):
1101
+ try:
1102
+ with open(temp_file, 'r', encoding='utf-8') as f:
1103
+ for line in f:
1104
+ line = line.strip()
1105
+ if line:
1106
+ result = json.loads(line)
1107
+ all_results.append(result)
1108
+ except Exception as e:
1109
+ self.logger.error(f"读取临时文件失败: {temp_file}, 错误: {e}")
1110
+
1111
+ # 写入最终文件
1112
+ with open(final_path, 'w', encoding='utf-8') as f:
1113
+ for result in all_results:
1114
+ f.write(json.dumps(result, ensure_ascii=False) + '\n')
1115
+
1116
+ return all_results
1117
+
1118
+ def process_batch_parallel(self, input_data: List[Dict[str, Any]],
1119
+ processes_per_gpu: int = 8, # 降低进程数
1120
+ results_filename: str = None,
1121
+ shuffle_data: bool = True):
1122
+ """并行批量处理输入数据"""
1123
+ # 1. 检查GPU数量
1124
+ num_gpus = self.get_gpu_count()
1125
+ if num_gpus == 0:
1126
+ self.logger.warning("未检测到GPU,将使用CPU单进程处理")
1127
+ return asyncio.run(self.process_batch_from_data(input_data, results_filename))
1128
+
1129
+ # 限制每个GPU的进程数,避免CUDA内存冲突
1130
+ max_processes_per_gpu = min(processes_per_gpu, 16)
1131
+ self.logger.info(f"检测到 {num_gpus} 个GPU,每个GPU将使用 {max_processes_per_gpu} 个进程")
1132
+
1133
+ # 2. 对数据进行shuffle(如果还没有shuffle过)
1134
+ shuffled_data = input_data.copy()
1135
+ if shuffle_data:
1136
+ random.shuffle(shuffled_data)
1137
+ self.logger.info(f"已对 {len(shuffled_data)} 条数据进行随机shuffle以平衡GPU负载")
1138
+
1139
+ # 3. 按GPU分割数据
1140
+ gpu_chunks = self.split_data_by_gpu(shuffled_data, num_gpus)
1141
+
1142
+ # 打印每个GPU分配到的数据量
1143
+ for gpu_id, gpu_data in enumerate(gpu_chunks):
1144
+ if gpu_data:
1145
+ self.logger.info(f"GPU {gpu_id}: 分配到 {len(gpu_data)} 条数据")
1146
+
1147
+ # 4. 准备结果文件路径
1148
+ if results_filename is None:
1149
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
1150
+ results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
1151
+
1152
+ final_results_path = self.results_dir / results_filename
1153
+
1154
+ # 5. 为所有GPU准备进程参数
1155
+ all_temp_files = []
1156
+ all_gpu_tasks = []
1157
+
1158
+ for gpu_id, gpu_data in enumerate(gpu_chunks):
1159
+ if not gpu_data:
1160
+ continue
1161
+
1162
+ self.logger.info(f"GPU {gpu_id}: 准备处理 {len(gpu_data)} 条数据")
1163
+
1164
+ # 按进程数分割当前GPU的数据
1165
+ process_chunks = self.split_data_by_processes(gpu_data, max_processes_per_gpu)
1166
+
1167
+ # 为当前GPU准备所有进程参数
1168
+ gpu_process_args = []
1169
+ for proc_id, proc_data in enumerate(process_chunks):
1170
+ if proc_data:
1171
+ temp_result_file = str(self.temp_results_dir / f"gpu{gpu_id}_proc{proc_id}_results.jsonl")
1172
+ all_temp_files.append(temp_result_file)
1173
+
1174
+ # 子进程输出目录在主输出目录内部
1175
+ subprocess_output_dir = str(self.output_dir / f"gpu{gpu_id}_proc{proc_id}")
1176
+
1177
+ gpu_process_args.append((
1178
+ proc_data,
1179
+ gpu_id,
1180
+ proc_id,
1181
+ subprocess_output_dir,
1182
+ temp_result_file,
1183
+ self.alignment_model_dir,
1184
+ self.wespeaker_model_url,
1185
+ self.language, # 语言参数
1186
+ self.similarity_max_workers # 添加相似度计算线程数参数
1187
+ ))
1188
+
1189
+ if gpu_process_args:
1190
+ all_gpu_tasks.append((gpu_id, gpu_process_args, max_processes_per_gpu))
1191
+
1192
+ # 6. 使用ThreadPoolExecutor并行处理所有GPU
1193
+ def process_gpu_tasks(gpu_task):
1194
+ gpu_id, process_args, actual_processes = gpu_task
1195
+ self.logger.info(f"GPU {gpu_id}: 开始并行处理 {len(process_args)} 个进程")
1196
+
1197
+ # 为每个GPU使用独立的进程池,避免进程间冲突
1198
+ with mp.Pool(processes=actual_processes) as pool:
1199
+ # 调用同步包装器 run_async_worker,在每个子进程内部运行异步函数。
1200
+ pool.map(run_async_worker, process_args)
1201
+
1202
+ self.logger.info(f"GPU {gpu_id}: 所有进程处理完成")
1203
+ return gpu_id
1204
+
1205
+ # 使用线程池同时处理所有GPU
1206
+ with ThreadPoolExecutor(max_workers=num_gpus) as executor:
1207
+ # 提交所有GPU任务
1208
+ future_to_gpu = {executor.submit(process_gpu_tasks, gpu_task): gpu_task[0]
1209
+ for gpu_task in all_gpu_tasks}
1210
+
1211
+ # 等待所有GPU完成
1212
+ completed_gpus = []
1213
+ for future in as_completed(future_to_gpu):
1214
+ gpu_id = future_to_gpu[future]
1215
+ try:
1216
+ result_gpu_id = future.result()
1217
+ completed_gpus.append(result_gpu_id)
1218
+ self.logger.info(f"GPU {result_gpu_id} 完成处理")
1219
+ except Exception as exc:
1220
+ self.logger.error(f"GPU {gpu_id} 处理时发生异常: {exc}")
1221
+
1222
+ self.logger.info(f"所有GPU处理完成: {completed_gpus}")
1223
+
1224
+ # 7. 合并所有临时结果文件
1225
+ self.logger.info("合并所有临时结果文件...")
1226
+ all_results = self.merge_temp_results(all_temp_files, str(final_results_path))
1227
+
1228
+ if not all_results:
1229
+ self.logger.error("没有成功处理的数据")
1230
+ return []
1231
+
1232
+ # 8. 生成汇总报告
1233
+ summary_path = self.save_summary_report(all_results)
1234
+
1235
+ # 9. 清理临时文件
1236
+ for temp_file in all_temp_files:
1237
+ if os.path.exists(temp_file):
1238
+ os.remove(temp_file)
1239
+
1240
+ # 10. 打印总体统计
1241
+ total_accuracy = np.mean([r['accuracy'] for r in all_results])
1242
+ total_avg_similarity = np.mean([r['average_similarity'] for r in all_results])
1243
+
1244
+ print(f"\n=== 并行评估完成 ===")
1245
+ print(f"使用语言: {self.language}")
1246
+ print(f"使用 {num_gpus} 个GPU,每GPU {max_processes_per_gpu} 个进程")
1247
+ print(f"总处理数据: {len(input_data)} 条")
1248
+ print(f"成功处理: {len(all_results)} 条")
1249
+ print(f"总体准确率: {total_accuracy:.3f}")
1250
+ print(f"总体平均相似度: {total_avg_similarity:.3f}")
1251
+ print(f"详细结果已保存到: {final_results_path}")
1252
+ print(f"汇总报告已保存到: {summary_path}")
1253
+ print(f"对齐信息已保存到: {self.alignment_dir}")
1254
+
1255
+ return all_results
1256
+
1257
+ def get_or_split_prompt_audio(self, data: Dict[str, Any], audio_id: str) -> Tuple[str, str]:
1258
+ """
1259
+ 获取或分割prompt音频
1260
+ 如果提供了分别的speaker音频文件则直接使用,否则从combined prompt分割
1261
+ """
1262
+ # 检查是否有分别的speaker音频文件
1263
+ if ('prompt_audio_speaker1' in data and 'prompt_audio_speaker2' in data and
1264
+ 'prompt_text_speaker1' in data and 'prompt_text_speaker2' in data):
1265
+
1266
+ self.logger.info(f"使用预分割的speaker音频文件")
1267
+
1268
+ # 即使使用预分割的音频,也保存对齐信息
1269
+ try:
1270
+ # 检测语言或使用设置的语言
1271
+ alignment_language = self.language
1272
+ if alignment_language == "AUTO":
1273
+ alignment_language = self._detect_language_from_text(data['prompt_text_speaker1'])
1274
+
1275
+ # 对S1音频进行对齐
1276
+ s1_alignments = self.align_text_with_audio(
1277
+ data['prompt_text_speaker1'], data['prompt_audio_speaker1'], alignment_language
1278
+ )
1279
+ s1_alignment_data = {
1280
+ 'speaker': 'S1',
1281
+ 'text': data['prompt_text_speaker1'],
1282
+ 'audio_path': data['prompt_audio_speaker1'],
1283
+ 'language': alignment_language,
1284
+ 'alignments': s1_alignments
1285
+ }
1286
+ self.save_alignment_info(s1_alignment_data, audio_id, "prompt_s1")
1287
+
1288
+ # 对S2音频进行对齐
1289
+ s2_alignments = self.align_text_with_audio(
1290
+ data['prompt_text_speaker2'], data['prompt_audio_speaker2'], alignment_language
1291
+ )
1292
+ s2_alignment_data = {
1293
+ 'speaker': 'S2',
1294
+ 'text': data['prompt_text_speaker2'],
1295
+ 'audio_path': data['prompt_audio_speaker2'],
1296
+ 'language': alignment_language,
1297
+ 'alignments': s2_alignments
1298
+ }
1299
+ self.save_alignment_info(s2_alignment_data, audio_id, "prompt_s2")
1300
+
1301
+ except Exception as e:
1302
+ self.logger.warning(f"保存预分割音频对齐信息失败: {e}")
1303
+
1304
+ return data['prompt_audio_speaker1'], data['prompt_audio_speaker2']
1305
+
1306
+ # 否则从combined prompt分割
1307
+ elif 'prompt_audio' in data and 'prompt_text' in data:
1308
+ self.logger.info(f"从combined prompt音频分割speaker片段")
1309
+ return self.split_audio_by_speaker(data['prompt_text'], data['prompt_audio'], audio_id)
1310
+
1311
+ else:
1312
+ raise ValueError("必须提供prompt_audio+prompt_text或者分别的speaker音频文件")
1313
+
1314
+ def calculate_voice_similarity(self, audio1_path: str, audio2_path: str) -> float:
1315
+ """
1316
+ 计算两个音频的音色相似度(向后兼容版本)
1317
+ 对于过短的音频片段,通过复制来达到最小长度要求
1318
+ """
1319
+ # 如果在多线程环境中,使用线程安全版本
1320
+ if threading.current_thread() != threading.main_thread():
1321
+ return self.calculate_voice_similarity_thread_safe(audio1_path, audio2_path)
1322
+
1323
+ # 确保模型已初始化
1324
+ self._init_models_if_needed()
1325
+
1326
+ try:
1327
+ if not os.path.exists(audio1_path) or not os.path.exists(audio2_path):
1328
+ self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}")
1329
+ return None
1330
+
1331
+ # 检查并处理音频文件长度
1332
+ def process_audio_for_similarity(audio_path, min_duration=0.1):
1333
+ """
1334
+ 处理音频文件,如果过短则复制到满足最小长度要求
1335
+ 返回处理后的音频路径和是否为临时文件的标志
1336
+ """
1337
+ try:
1338
+ waveform, sample_rate = torchaudio.load(audio_path)
1339
+ duration = waveform.shape[1] / sample_rate
1340
+
1341
+ if duration >= min_duration:
1342
+ # 音频长度足够,直接返回原路径
1343
+ return audio_path, False
1344
+
1345
+ # 音频过短,需要复制
1346
+ repeat_times = math.ceil(min_duration / duration)
1347
+ self.logger.info(f"音频过短 ({duration:.3f}s),复制 {repeat_times} 次达到 {min_duration}s 要求: {audio_path}")
1348
+
1349
+ # 复制音频
1350
+ repeated_waveform = waveform.repeat(1, repeat_times)
1351
+
1352
+ # 生成临时文件路径
1353
+ temp_filename = f"temp_{os.path.basename(audio_path)}"
1354
+ temp_path = str(self.temp_dir / temp_filename)
1355
+
1356
+ # 保存复制后的音频
1357
+ torchaudio.save(temp_path, repeated_waveform, sample_rate)
1358
+
1359
+ return temp_path, True
1360
+
1361
+ except Exception as e:
1362
+ self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}")
1363
+ return audio_path, False
1364
+
1365
+ # 处理两个音频文件
1366
+ processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path)
1367
+ processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path)
1368
+
1369
+ # 计算相似度
1370
+ similarity = self.similarity_model.compute_similarity(processed_audio1, processed_audio2)
1371
+
1372
+ # 清理临时文件
1373
+ if is_temp1 and os.path.exists(processed_audio1):
1374
+ try:
1375
+ os.remove(processed_audio1)
1376
+ except Exception as e:
1377
+ self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}")
1378
+
1379
+ if is_temp2 and os.path.exists(processed_audio2):
1380
+ try:
1381
+ os.remove(processed_audio2)
1382
+ except Exception as e:
1383
+ self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}")
1384
+
1385
+ return float(similarity)
1386
+
1387
+ except Exception as e:
1388
+ # 检查是否是窗口大小错误或其他计算错误
1389
+ if "choose a window size" in str(e) or "window size" in str(e):
1390
+ self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}")
1391
+ return None
1392
+ else:
1393
+ self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}")
1394
+ return None
1395
+
1396
+ # 全局函数,用于多进程处理(支持增量写入)
1397
+ async def process_data_chunk_incremental(args):
1398
+ """处理数据块的工作函数(增量写入版本)"""
1399
+ data_chunk, gpu_id, proc_id, output_dir, temp_result_file, alignment_model_dir, wespeaker_model_url, language, similarity_max_workers = args
1400
+
1401
+ # 设置当前进程使用的GPU
1402
+ device = f"cuda:{gpu_id}" if torch.cuda.is_available() and gpu_id < torch.cuda.device_count() else "cpu"
1403
+
1404
+ try:
1405
+ # 清理CUDA状态,避免进程间冲突
1406
+ if torch.cuda.is_available():
1407
+ torch.cuda.empty_cache()
1408
+ # 设置当前进程的GPU设备
1409
+ torch.cuda.set_device(gpu_id)
1410
+ # 添加小延迟,避免同时初始化冲突
1411
+ time.sleep(proc_id * 0.5)
1412
+
1413
+ # 创建评估器实例,传入模型路径、语言参数和相似度计算线程数
1414
+ evaluator = SpeakerSimilarityEvaluator(
1415
+ device=device,
1416
+ alignment_model_dir=alignment_model_dir,
1417
+ wespeaker_model_url=wespeaker_model_url,
1418
+ output_dir=output_dir,
1419
+ language=language, # 传入语言参数
1420
+ similarity_max_workers=similarity_max_workers # 传入相似度计算线程数
1421
+ )
1422
+
1423
+ # 延迟初始化模型
1424
+ evaluator._init_models_if_needed()
1425
+
1426
+ # 清空临时结果文件(如果存在)
1427
+ if os.path.exists(temp_result_file):
1428
+ os.remove(temp_result_file)
1429
+
1430
+ # 处理数据块
1431
+ for i, data in enumerate(data_chunk):
1432
+ input_id = f"gpu{gpu_id}_proc{proc_id}_input_{i+1:03d}"
1433
+
1434
+ try:
1435
+ result = await evaluator.evaluate_single_input(data, input_id=input_id)
1436
+
1437
+ # 立即写入结果到临时文件
1438
+ evaluator.append_result_to_jsonl(result, temp_result_file)
1439
+
1440
+ print(f"GPU{gpu_id}-进程{proc_id}: 完成 {input_id} (语言: {language}, 相似度线程: {similarity_max_workers})")
1441
+
1442
+ # 每处理完一个数据项,清理CUDA缓存
1443
+ if torch.cuda.is_available():
1444
+ torch.cuda.empty_cache()
1445
+
1446
+ except Exception as e:
1447
+ print(f"GPU{gpu_id}-进程{proc_id}: 处理 {input_id} 失败: {e}")
1448
+ # 出错时也清理CUDA缓存
1449
+ if torch.cuda.is_available():
1450
+ torch.cuda.empty_cache()
1451
+ continue
1452
+
1453
+ print(f"GPU{gpu_id}-进程{proc_id}: 所有数据处理完成,结果已写入 {temp_result_file}")
1454
+
1455
+ except Exception as e:
1456
+ print(f"GPU{gpu_id}-进程{proc_id}: 初始化失败: {e}")
1457
+ # 出错时清理CUDA缓存
1458
+ if torch.cuda.is_available():
1459
+ torch.cuda.empty_cache()
1460
+
1461
+
1462
+ def run_async_worker(args):
1463
+ """
1464
+ 一个同步包装器,为我们的异步工作函数设置并运行asyncio事件循环。
1465
+ 这是必需的,因为 multiprocessing.Pool 不能直接调用异步函数。
1466
+ """
1467
+ # asyncio.run() 是在每个子进程中启动和运行协程最简单、最安全的方式。
1468
+ # 它会创建一个新的事件循环,运行协程直到完成,然后关闭事件循环。
1469
+ return asyncio.run(process_data_chunk_incremental(args))
1470
+
1471
+
1472
+ def main():
1473
+ """主函数示例"""
1474
+ import argparse
1475
+
1476
+ parser = argparse.ArgumentParser(description='Speaker Similarity Evaluator')
1477
+ parser.add_argument('--jsonl_path', type=str, help='JSONL文件路径')
1478
+ parser.add_argument('--output_dir', type=str,
1479
+ default=f"/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/eval_res/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
1480
+ help='结果保存目录')
1481
+ parser.add_argument('--language', type=str, choices=['zh', 'en', 'auto'], default='zh',
1482
+ help='指定语言: zh=中文, en=英文, auto=自动检测 (默认: zh)')
1483
+ parser.add_argument('--no_parallel', action='store_true', help='禁用并行处理(默认启用并行)')
1484
+ parser.add_argument('--processes_per_gpu', type=int, default=4, help='每个GPU的进程数(建议不超过4)')
1485
+ parser.add_argument('--similarity_workers', type=int, default=16, help='相似度计算的线程数(默认: 8)')
1486
+ parser.add_argument('--no_shuffle', action='store_true', help='禁用数据shuffle(默认启用shuffle)')
1487
+ parser.add_argument('--random_seed', type=int, default=None, help='随机种子(可选,用于结果复现)')
1488
+
1489
+ args = parser.parse_args()
1490
+
1491
+ # 设置随机种子(如果指定)
1492
+ if args.random_seed is not None:
1493
+ random.seed(args.random_seed)
1494
+ np.random.seed(args.random_seed)
1495
+ torch.manual_seed(args.random_seed)
1496
+ print(f"设置随机种子: {args.random_seed}")
1497
+
1498
+ # 语言参数处理
1499
+ language = args.language.upper()
1500
+ if language == 'AUTO':
1501
+ language = 'AUTO'
1502
+ elif language == 'EN':
1503
+ language = 'EN'
1504
+ else:
1505
+ language = 'ZH' # 默认中文
1506
+
1507
+ # 创建评估器,指定结果保存目录、语言和相似度计算线程数
1508
+ evaluator = SpeakerSimilarityEvaluator(
1509
+ output_dir=args.output_dir,
1510
+ language=language,
1511
+ similarity_max_workers=args.similarity_workers
1512
+ )
1513
+
1514
+ # 默认使用并行处理,除非明确禁用
1515
+ use_parallel = not args.no_parallel
1516
+ use_shuffle = not args.no_shuffle
1517
+
1518
+ print(f"使用语言设置: {language}")
1519
+ print(f"相似度计算线程数: {args.similarity_workers}")
1520
+
1521
+ if args.jsonl_path:
1522
+ # 从JSONL文件处理数据
1523
+ if use_parallel:
1524
+ evaluator.process_batch_from_jsonl_parallel(
1525
+ args.jsonl_path,
1526
+ processes_per_gpu=args.processes_per_gpu,
1527
+ shuffle_data=use_shuffle
1528
+ )
1529
+ else:
1530
+ asyncio.run(evaluator.process_batch_from_jsonl(args.jsonl_path))
1531
+ else:
1532
+ # 使用示例数据(兼容性)
1533
+ input_data = [
1534
+ {
1535
+ 'prompt_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/zhouxingchi/zxc_enhanced.wav",
1536
+ 'prompt_text': "[S1]你再往前半步我就把你给杀了。[S2]你应该这么做,我也应该死。",
1537
+ 'text': "[S1]至尊宝,如果有一天我不再是紫霞仙子,只是一个普通的凡人,你还会像现在这样陪着我吗?[S2]这个嘛,那我得先问问月老,看看他给不给我打折!毕竟追仙子要花好多力气的![S1]哼!油嘴滑舌!我是认真的![S2]紫霞,不管你是仙子还是凡人,哪怕变成一根香蕉,我都认得出你。不过……你最好别真变成香蕉,我怕我会忍不住吃掉……[S1]讨厌!谁要变成香蕉啊!那……如果有一天,我们不得不分开呢?[S2]哇!你这话比牛魔王的斧头还狠!不行不行,你得赔我精神损失费![S1]怎么赔?[S2]很简单,让我亲一下,就当是定金![S1]想得美!那如果有一天,你真的忘了我呢?[S2]那我就算翻遍三界,打烂阎王殿,也要把记忆找回来。紫霞,我至尊宝这辈子,赖定你了![S1]傻瓜。",
1538
+ 'output_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step145000/test_set/output_7.wav"
1539
+ }
1540
+ ]
1541
+
1542
+ # 处理数据
1543
+ if use_parallel:
1544
+ evaluator.process_batch_parallel(input_data, processes_per_gpu=args.processes_per_gpu)
1545
+ else:
1546
+ asyncio.run(evaluator.process_batch_from_data(input_data))
1547
+
1548
+
1549
+ if __name__ == "__main__":
1550
+ main()
test_online.sh ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # 设置CUDA环境变量
4
+ export CUDA_LAUNCH_BLOCKING=1
5
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
6
+
7
+ # Color variables
8
+ RESET='\033[0m'
9
+ RED='\033[0;31m'
10
+ GREEN='\033[0;32m'
11
+ YELLOW='\033[0;33m'
12
+ BLUE='\033[0;34m'
13
+ MAGENTA='\033[0;35m'
14
+ CYAN='\033[0;36m'
15
+ WHITE='\033[0;37m'
16
+
17
+ # 创建日志目录和文件名
18
+ LOG_DIR="./logs"
19
+ mkdir -p "$LOG_DIR"
20
+ # 记录开始时间
21
+ START_TIME=$(date +%s)
22
+ START_TIME_READABLE=$(date '+%Y-%m-%d %H:%M:%S')
23
+ LOG_TIME=$(date +%Y%m%d_%H%M%S)
24
+ LOG_FILE="$LOG_DIR/evaluation_$LOG_TIME.log"
25
+ SERVER_FILE="$LOG_DIR/server_$LOG_TIME.log"
26
+
27
+ # A function to ensure the server is killed, which we'll call on exit.
28
+ cleanup() {
29
+ echo "--- Cleanup ---"
30
+ # Check if the server process is still running
31
+ if kill -0 $SERVER_PID 2>/dev/null; then
32
+ echo "Client has finished. Sending SIGTERM to shut down the server (PID: $SERVER_PID)..."
33
+ # Send the SIGTERM signal, allowing the server to shut down gracefully if it handles the signal.
34
+ kill $SERVER_PID
35
+ # Wait a moment for it to terminate
36
+ wait $SERVER_PID 2>/dev/null
37
+ echo "Server has been shut down."
38
+ else
39
+ echo "Server (PID: $SERVER_PID) was already stopped."
40
+ fi
41
+ }
42
+
43
+ # Use 'trap' to register the 'cleanup' function to be called when the script exits.
44
+ # This works for normal exit, Ctrl+C (SIGINT), or termination (SIGTERM).
45
+ trap cleanup EXIT
46
+
47
+ # 1. Start the server in the background
48
+ echo "Starting alignment models' remote_server.py in the background..."
49
+ /opt/tritonserver/bin/tritonserver --model-repository=./model_repo 2>&1 > $SERVER_FILE &
50
+
51
+ # 2. Capture the Process ID (PID) of the server
52
+ SERVER_PID=$!
53
+ echo "Server started with PID: $SERVER_PID"
54
+
55
+ # Give the server a moment to initialize and start listening on its port.
56
+ # This is crucial, otherwise the client might try to connect before the server is ready.
57
+ echo "Waited 3 seconds for server to initialize."
58
+ echo "------------------------------------------"
59
+ sleep 3
60
+
61
+ echo "${GREEN}========================================="
62
+ echo "音色相似度评估开始"
63
+ echo "开始时间: $START_TIME_READABLE"
64
+ echo "日志文件: $LOG_FILE"
65
+ echo "========================================="
66
+ echo "可以使用以下命令实时查看日志:"
67
+ echo "tail -f $LOG_FILE${RESET}"
68
+ echo ""
69
+
70
+ # 将开始时间信息也写入日志文件
71
+ {
72
+ echo "${GREEN}========================================="
73
+ echo "音色相似度评估开始"
74
+ echo "开始时间: $START_TIME_READABLE"
75
+ echo "进程配置: 每GPU 8个进程"
76
+ echo "语言设置: zh (中文)"
77
+ echo "=========================================${RESET}"
78
+ echo ""
79
+ } | tee "$LOG_FILE"
80
+
81
+ # 3. Run the client in the foreground
82
+ echo "Starting similarity test client test.py in the foreground..."
83
+ # The script will pause here and wait for client.py to complete.
84
+ # We wrap this in a block to capture the exit code.
85
+ {
86
+ # 使用更保守的进程数
87
+ python -u ./test_online.py \
88
+ --jsonl_path /data-mnt/data/yqzhang/testset_ttsd/test_set_zh_304/output_new.jsonl \
89
+ --output_dir ./eval_res/new_test_online \
90
+ --processes_per_gpu 8 \
91
+ --language zh \
92
+ 2>&1 | tee -a "$LOG_FILE"
93
+ CLIENT_EXIT_CODE=$?
94
+ }
95
+ echo "------------------------------------------"
96
+ echo "${YELLOW}Client.py has finished with exit code: $CLIENT_EXIT_CODE${RESET}"
97
+ # 记录结束时间
98
+ END_TIME=$(date +%s)
99
+ END_TIME_READABLE=$(date '+%Y-%m-%d %H:%M:%S')
100
+
101
+ # 计算耗时
102
+ DURATION=$((END_TIME - START_TIME))
103
+ HOURS=$((DURATION / 3600))
104
+ MINUTES=$(((DURATION % 3600) / 60))
105
+ SECONDS=$((DURATION % 60))
106
+
107
+ # 输出结束信息
108
+ {
109
+ echo "${GREEN}"
110
+ echo "========================================="
111
+ echo "音色相似度评估完成!"
112
+ echo "结束时间: $END_TIME_READABLE"
113
+ echo "总耗时: ${HOURS}小时${MINUTES}分钟${SECONDS}秒 (共${DURATION}秒)"
114
+ echo "日志文件: $LOG_FILE"
115
+ echo "========================================="
116
+ echo "${RESET}"
117
+ } | tee -a "$LOG_FILE"
118
+
119
+ # 显示在终端
120
+ echo "${GREEN}"
121
+ echo "评估完成!"
122
+ echo "开始时间: $START_TIME_READABLE"
123
+ echo "结束时间: $END_TIME_READABLE"
124
+ echo "总耗时: ${HOURS}小时${MINUTES}分钟${SECONDS}秒"
125
+ echo "日志已保存到: $LOG_FILE"
126
+ echo "${RESET}"
127
+
128
+ # 如果耗时超过1小时,发送额外提醒
129
+ if [ $DURATION -gt 3600 ]; then
130
+ echo "${RED}"
131
+ echo "⏰ 注意:本次评估耗时较长,超过1小时"
132
+ echo " 建议检查性能优化效果"
133
+ echo "${RESET}"
134
+ fi
135
+
136
+ # The 'trap' will automatically call the 'cleanup' function now that the script is exiting.
137
+ # The exit is triggered because the client process (the last foreground command) has finished.
138
+
139
+ # You can add logic based on the client's exit code if needed.
140
+ if [ $CLIENT_EXIT_CODE -ne 0 ]; then
141
+ echo "Warning: Client exited with an error."
142
+ exit 1 # Exit the main script with an error code as well
143
+ fi
144
+
145
+ echo "Script finished successfully."
146
+ exit 0
147
+
148
+
149
+
150
+