speech similarity model
Browse files- .gitattributes +2 -0
- .gitignore +7 -0
- README.md +134 -0
- alignment.py +384 -0
- alignment_online.py +398 -0
- docker/Dockerfile +40 -0
- download_mms_model.py +35 -0
- example_input.jsonl +1 -0
- model_repo/speaker_model/1/model.trt +3 -0
- model_repo/speaker_model/config.pbtxt +44 -0
- models/mms_fa/model.pt +3 -0
- models/mms_fa/model.pt.2c7cc4fedf8e4a089a0095148cc9201b.partial +3 -0
- models/mms_fa/model.pt.5c5fe9893a2c462e9132dcd6a3fba337.partial +3 -0
- models/voxblink2_samresnet100_ft/avg_model.onnx +3 -0
- models/voxblink2_samresnet100_ft/avg_model.pt +3 -0
- models/voxblink2_samresnet100_ft/config.yaml +83 -0
- models/wespeaker/chinese/config.yaml +7 -0
- models/wespeaker/chinese/model.onnx +3 -0
- python_backend/similarity_model/1/model.py +149 -0
- python_backend/similarity_model/1/model_old.py +97 -0
- python_backend/similarity_model/1/model_runnable.py +149 -0
- python_backend/similarity_model/config.pbtxt.back +46 -0
- python_backend/similarity_model/config.pbtxt.disabled +26 -0
- similarity.py +412 -0
- speaker_client.py +149 -0
- test.py +1643 -0
- test.sh +82 -0
- test_alignment.py +416 -0
- test_online.py +1550 -0
- test_online.sh +150 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.trt filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.pt.*.partial filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.venv/
|
| 3 |
+
*.pyc
|
| 4 |
+
.DS_Store
|
| 5 |
+
outputs/
|
| 6 |
+
logs/
|
| 7 |
+
eval_res/
|
README.md
CHANGED
|
@@ -1,3 +1,137 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
---
|
| 4 |
+
|
| 5 |
+
# Speaker Similarity Evaluator - 新格式使用说明
|
| 6 |
+
|
| 7 |
+
## 概述
|
| 8 |
+
|
| 9 |
+
音色相似度评估器现在支持从JSONL文件读取输入数据,并且支持两种prompt音频输入模式:
|
| 10 |
+
1. **预分割模式**:直接提供S1和S2的分别音频文件
|
| 11 |
+
2. **自动分割模式**:提供combined prompt音频,程序自动按说话人标签分割
|
| 12 |
+
|
| 13 |
+
## 输入格式
|
| 14 |
+
|
| 15 |
+
### JSONL文件格式
|
| 16 |
+
|
| 17 |
+
每行是一个JSON对象,必须包含以下字段:
|
| 18 |
+
|
| 19 |
+
#### 必需字段
|
| 20 |
+
- `text`: 待评估的文本,包含说话人标签[S1][S2]
|
| 21 |
+
- `output_audio`: 待评估的音频文件路径
|
| 22 |
+
|
| 23 |
+
#### prompt音频字段(两种模式择一)
|
| 24 |
+
|
| 25 |
+
**模式1:预分割模式**
|
| 26 |
+
- `prompt_audio_speaker1`: S1说话人的音频文件
|
| 27 |
+
- `prompt_text_speaker1`: S1说话人的文本
|
| 28 |
+
- `prompt_audio_speaker2`: S2说话人的音频文件
|
| 29 |
+
- `prompt_text_speaker2`: S2说话人的文本
|
| 30 |
+
|
| 31 |
+
**模式2:自动分割模式**
|
| 32 |
+
- `prompt_audio`: 包含两个说话人的combined音频文件
|
| 33 |
+
- `prompt_text`: 包含说话人标签的文本,如"[S1]文本1[S2]文本2"
|
| 34 |
+
|
| 35 |
+
### 示例
|
| 36 |
+
|
| 37 |
+
#### 预分割模式示例
|
| 38 |
+
```json
|
| 39 |
+
{
|
| 40 |
+
"text": "[S1]是我对不住你。[S2]没有没有!燕子幸亏咱俩没领证!",
|
| 41 |
+
"prompt_audio_speaker1": "/path/to/speaker1.wav",
|
| 42 |
+
"prompt_text_speaker1": "一共二十万我都记着呢。我一赚到钱就马上还给你。",
|
| 43 |
+
"prompt_audio_speaker2": "/path/to/speaker2.wav",
|
| 44 |
+
"prompt_text_speaker2": "没关系,我不缺钱。",
|
| 45 |
+
"output_audio": "/path/to/output.wav"
|
| 46 |
+
}
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
#### 自动分割模式示例
|
| 50 |
+
```json
|
| 51 |
+
{
|
| 52 |
+
"text": "[S1]今天天气真好啊。[S2]是的,阳光明媚。",
|
| 53 |
+
"prompt_audio": "/path/to/combined_prompt.wav",
|
| 54 |
+
"prompt_text": "[S1]早上好,今天怎么样?[S2]很好,谢谢你的关心。",
|
| 55 |
+
"output_audio": "/path/to/output.wav"
|
| 56 |
+
}
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
#### 混合模式示例(同时提供两种模式,优先使用预分割)
|
| 60 |
+
```json
|
| 61 |
+
{
|
| 62 |
+
"text": "[S1]是我对不住你。[S2]没有没有!",
|
| 63 |
+
"prompt_audio": "/path/to/combined.wav",
|
| 64 |
+
"prompt_text": "[S1]一共二十万我都记着呢。[S2]没关系,我不缺钱。",
|
| 65 |
+
"prompt_audio_speaker1": "/path/to/speaker1.wav",
|
| 66 |
+
"prompt_text_speaker1": "一共二十万我都记着呢。我一赚到钱就马上还给你。",
|
| 67 |
+
"prompt_audio_speaker2": "/path/to/speaker2.wav",
|
| 68 |
+
"prompt_text_speaker2": "没关系,我不缺钱。",
|
| 69 |
+
"output_audio": "/path/to/output.wav"
|
| 70 |
+
}
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## 使用方法
|
| 74 |
+
|
| 75 |
+
### 命令行运行
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
# 使用JSONL文件输入
|
| 79 |
+
python test.py --jsonl_path /path/to/your/input.jsonl --output_dir /path/to/results
|
| 80 |
+
|
| 81 |
+
# 使用默认示例数据(向后兼容)
|
| 82 |
+
python test.py --output_dir /path/to/results
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
### 程序调用
|
| 86 |
+
|
| 87 |
+
```python
|
| 88 |
+
from test import SpeakerSimilarityEvaluator
|
| 89 |
+
|
| 90 |
+
# 创建评估器
|
| 91 |
+
evaluator = SpeakerSimilarityEvaluator(output_dir="/path/to/results")
|
| 92 |
+
|
| 93 |
+
# 从JSONL文件处理
|
| 94 |
+
evaluator.process_batch_from_jsonl("/path/to/input.jsonl")
|
| 95 |
+
|
| 96 |
+
# 或者直接传入数据列表(旧接口,向后兼容)
|
| 97 |
+
input_data = [
|
| 98 |
+
{
|
| 99 |
+
'prompt_audio': "/path/to/prompt.wav",
|
| 100 |
+
'prompt_text': "[S1]文本1[S2]文本2",
|
| 101 |
+
'text': "[S1]输出文本1[S2]输出文本2",
|
| 102 |
+
'output_audio': "/path/to/output.wav"
|
| 103 |
+
}
|
| 104 |
+
]
|
| 105 |
+
evaluator.process_batch(input_data)
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
## 优势
|
| 109 |
+
|
| 110 |
+
### 预分割模式的优势
|
| 111 |
+
1. **更高精度**:避免了自动分割可能带来的误差
|
| 112 |
+
2. **更快速度**:跳过音频分割步骤
|
| 113 |
+
3. **更稳定**:不依赖词对齐模型的准确性
|
| 114 |
+
|
| 115 |
+
### 自动分割模式的优势
|
| 116 |
+
1. **便利性**:只需要提供一个combined音频文件
|
| 117 |
+
2. **向后兼容**:与现有数据格式兼容
|
| 118 |
+
|
| 119 |
+
## 输出文件结构
|
| 120 |
+
|
| 121 |
+
```
|
| 122 |
+
results_YYYYMMDD_HHMMSS/
|
| 123 |
+
├── segments/ # 分割后的音频片段
|
| 124 |
+
├── prompts/ # prompt音频的S1和S2片段(仅自动分割模式)
|
| 125 |
+
├── temp/ # 临时文件(运行结束后清空)
|
| 126 |
+
└── results/ # 评估结果
|
| 127 |
+
├── speaker_similarity_results_YYYYMMDD_HHMMSS.jsonl
|
| 128 |
+
└── evaluation_summary_YYYYMMDD_HHMMSS.json
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
## 注意事项
|
| 132 |
+
|
| 133 |
+
1. 确保所有音频文件路径正确且文件存在
|
| 134 |
+
2. 文本中的说话人标签格式必须为`[S1]`和`[S2]`
|
| 135 |
+
3. 如果同时提供两种模式的数据,程序优先使用预分割模式
|
| 136 |
+
4. JSONL文件中的每行必须是有效的JSON格式
|
| 137 |
+
5. 程序会自动验证输入数据的完整性,跳过有问题的行并继续处理
|
alignment.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio.functional as F
|
| 4 |
+
import torchaudio
|
| 5 |
+
import uroman as ur
|
| 6 |
+
import logging
|
| 7 |
+
import traceback
|
| 8 |
+
|
| 9 |
+
def convert_to_list_with_punctuation_mixed(text):
|
| 10 |
+
"""处理中文文本(可能包含英文单词) - 中文按字符分割,英文单词保持完整"""
|
| 11 |
+
result = []
|
| 12 |
+
text = text.strip()
|
| 13 |
+
|
| 14 |
+
if not text:
|
| 15 |
+
return result
|
| 16 |
+
|
| 17 |
+
def is_chinese(char):
|
| 18 |
+
"""检查是否是汉字"""
|
| 19 |
+
return '\u4e00' <= char <= '\u9fff'
|
| 20 |
+
|
| 21 |
+
# 使用更精确的正则表达式来分割文本
|
| 22 |
+
# 匹配:英文单词(含数字)、单个汉字、标点符号
|
| 23 |
+
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
|
| 24 |
+
tokens = re.findall(pattern, text)
|
| 25 |
+
|
| 26 |
+
for token in tokens:
|
| 27 |
+
if not token.strip(): # 跳过空字符
|
| 28 |
+
continue
|
| 29 |
+
|
| 30 |
+
if re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词(可能包含数字)
|
| 31 |
+
result.append(token)
|
| 32 |
+
elif is_chinese(token): # 单个汉字
|
| 33 |
+
result.append(token)
|
| 34 |
+
else: # 标点符号等其他字符
|
| 35 |
+
# 标点符号加到前一个词后面
|
| 36 |
+
if result:
|
| 37 |
+
result[-1] += token
|
| 38 |
+
else:
|
| 39 |
+
# 如果是文本开头的标点,单独作为一项
|
| 40 |
+
result.append(token)
|
| 41 |
+
|
| 42 |
+
return result
|
| 43 |
+
|
| 44 |
+
def split_and_merge_punctuation(text):
|
| 45 |
+
"""处理英文 - 按单词分割,保持单词完整性"""
|
| 46 |
+
# 先按空格拆分文本
|
| 47 |
+
elements = text.split()
|
| 48 |
+
|
| 49 |
+
# 用于保存最终的结果
|
| 50 |
+
result = []
|
| 51 |
+
|
| 52 |
+
# 遍历每个拆分后的元素
|
| 53 |
+
for ele in elements:
|
| 54 |
+
# 使用正则表达式提取连续字母、数字和标点
|
| 55 |
+
parts = re.findall(r'[a-zA-Z0-9]+|[^\w\s]+', ele)
|
| 56 |
+
|
| 57 |
+
# 用于保存拆分后的部分
|
| 58 |
+
merged_parts = []
|
| 59 |
+
|
| 60 |
+
for i in range(len(parts)):
|
| 61 |
+
if i % 2 == 0: # 如果是字母或数字部分
|
| 62 |
+
# 将字母或数字部分添加到结果中
|
| 63 |
+
merged_parts.append(parts[i])
|
| 64 |
+
else: # 如果是标点或其他符号部分
|
| 65 |
+
# 将标点部分与前面的字母或数字部分合并
|
| 66 |
+
if merged_parts:
|
| 67 |
+
merged_parts[-1] += parts[i]
|
| 68 |
+
else:
|
| 69 |
+
merged_parts.append(parts[i])
|
| 70 |
+
|
| 71 |
+
# 将合并后的部分加入最终结果
|
| 72 |
+
result.extend(merged_parts)
|
| 73 |
+
|
| 74 |
+
return result
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_aligned_result_text_with_punctuation(alignment_result, text, language):
|
| 78 |
+
"""
|
| 79 |
+
将对齐结果转换为正确的文本tokens,英文保持单词级别,中文保持字符级别(但英文单词完整)
|
| 80 |
+
"""
|
| 81 |
+
logging.info("start change text to text_tokens")
|
| 82 |
+
|
| 83 |
+
if language == "EN":
|
| 84 |
+
text_tokens = split_and_merge_punctuation(text) # 英文按单词分词
|
| 85 |
+
elif language == "ZH":
|
| 86 |
+
text_tokens = convert_to_list_with_punctuation_mixed(text) # 中文按字符分割,但英文单词保持完整
|
| 87 |
+
else:
|
| 88 |
+
raise ValueError(f"Unsupported language: {language}")
|
| 89 |
+
|
| 90 |
+
logging.info(f"Text tokens count: {len(text_tokens)}, Alignment result count: {len(alignment_result)}")
|
| 91 |
+
|
| 92 |
+
punctuations = set(',.!?;:()[]<>\'\"…·,。;:!?()【】《》''""\、')
|
| 93 |
+
|
| 94 |
+
logging.info("start get align result text with punctuation")
|
| 95 |
+
updated_alignment_result = []
|
| 96 |
+
token_idx = 0
|
| 97 |
+
|
| 98 |
+
for index, align_item in enumerate(alignment_result):
|
| 99 |
+
if token_idx >= len(text_tokens):
|
| 100 |
+
# 如果text_tokens用完了但还有对齐结果,跳出循环
|
| 101 |
+
logging.warning(f"Text tokens exhausted at index {token_idx}, but alignment has more items")
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
start = align_item["start"]
|
| 105 |
+
end = align_item["end"]
|
| 106 |
+
text_token = text_tokens[token_idx]
|
| 107 |
+
|
| 108 |
+
# 检查该 token 后是否有连续标点(仅对中文)
|
| 109 |
+
if language == "ZH":
|
| 110 |
+
while token_idx + 1 < len(text_tokens) and text_tokens[token_idx + 1] in punctuations:
|
| 111 |
+
assert False, "???" # 这里理论上应该进不去??
|
| 112 |
+
text_token += text_tokens[token_idx + 1] # 将标点加入
|
| 113 |
+
token_idx += 1
|
| 114 |
+
else:
|
| 115 |
+
# 英文不需要特殊的标点处理,因为标点已经在split_and_merge_punctuation中处理了
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
# 更新对齐结果
|
| 119 |
+
updated_item = {
|
| 120 |
+
"start": start,
|
| 121 |
+
"end": end,
|
| 122 |
+
"transcript": text_token
|
| 123 |
+
}
|
| 124 |
+
updated_item.update({key: align_item[key] for key in align_item if key not in ["start", "end", "transcript"]})
|
| 125 |
+
|
| 126 |
+
updated_alignment_result.append(updated_item)
|
| 127 |
+
token_idx += 1
|
| 128 |
+
|
| 129 |
+
logging.info("end get align result text with punctuation")
|
| 130 |
+
return updated_alignment_result
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class AlignmentModel:
|
| 134 |
+
def __init__(self, device, model_dir='/data-mnt/data/wy/X-Codec-2.0/checkpoints'):
|
| 135 |
+
"""
|
| 136 |
+
初始化对齐模型并加载必要的资源
|
| 137 |
+
:param device: 设备类型 ("cuda" 或 "cpu")
|
| 138 |
+
:param model_dir: 模型目录路径
|
| 139 |
+
"""
|
| 140 |
+
self.device = torch.device(device)
|
| 141 |
+
self.bundle = torchaudio.pipelines.MMS_FA
|
| 142 |
+
self.align_model = self.bundle.get_model(with_star=False, dl_kwargs={'model_dir': model_dir}).to(self.device)
|
| 143 |
+
self.uroman = ur.Uroman()
|
| 144 |
+
self.DICTIONARY = self.bundle.get_dict()
|
| 145 |
+
|
| 146 |
+
def align(self, emission, tokens):
|
| 147 |
+
"""
|
| 148 |
+
执行强对齐
|
| 149 |
+
:param emission: 模型的输出
|
| 150 |
+
:param tokens: 目标 tokens
|
| 151 |
+
:return: 对齐的 tokens 和分数
|
| 152 |
+
"""
|
| 153 |
+
alignments, scores = F.forced_align(
|
| 154 |
+
log_probs=emission,
|
| 155 |
+
targets=tokens,
|
| 156 |
+
blank=0
|
| 157 |
+
)
|
| 158 |
+
alignments, scores = alignments[0], scores[0]
|
| 159 |
+
scores = scores.exp()
|
| 160 |
+
return alignments, scores
|
| 161 |
+
|
| 162 |
+
def unflatten(self, list_, lengths):
|
| 163 |
+
"""
|
| 164 |
+
将一个长列表按照长度拆分成子列表
|
| 165 |
+
:param list_: 长列表
|
| 166 |
+
:param lengths: 各子列表的长度
|
| 167 |
+
:return: 拆分后的子列表
|
| 168 |
+
"""
|
| 169 |
+
assert len(list_) == sum(lengths)
|
| 170 |
+
i = 0
|
| 171 |
+
ret = []
|
| 172 |
+
for l in lengths:
|
| 173 |
+
ret.append(list_[i:i + l])
|
| 174 |
+
i += l
|
| 175 |
+
return ret
|
| 176 |
+
|
| 177 |
+
def preview_word(self, waveform, spans, num_frames, transcript, sample_rate):
|
| 178 |
+
"""
|
| 179 |
+
预览每个单词的开始时间和结束时间
|
| 180 |
+
:param waveform: 音频波形
|
| 181 |
+
:param spans: 单词的跨度
|
| 182 |
+
:param num_frames: 帧数
|
| 183 |
+
:param transcript: 转录文本
|
| 184 |
+
:param sample_rate: 采样率
|
| 185 |
+
:return: 单词的对齐信息
|
| 186 |
+
"""
|
| 187 |
+
end = 0
|
| 188 |
+
alignment_result = []
|
| 189 |
+
for span, trans in zip(spans, transcript):
|
| 190 |
+
ratio = waveform.size(1) / num_frames
|
| 191 |
+
x0 = int(ratio * span[0].start)
|
| 192 |
+
x1 = int(ratio * span[-1].end)
|
| 193 |
+
align_info = {
|
| 194 |
+
"transcript": trans,
|
| 195 |
+
"start": round(x0 / sample_rate, 3),
|
| 196 |
+
"end": round(x1 / sample_rate, 3)
|
| 197 |
+
}
|
| 198 |
+
align_info["pause"] = round(align_info["start"] - end, 3)
|
| 199 |
+
align_info["duration"] = round(align_info["end"] - align_info["start"], 3)
|
| 200 |
+
end = align_info["end"]
|
| 201 |
+
alignment_result.append(align_info)
|
| 202 |
+
return alignment_result
|
| 203 |
+
|
| 204 |
+
def make_wav_batch(self, wav_list):
|
| 205 |
+
"""
|
| 206 |
+
将 wav_list 中的每个 wav 张量填充为相同的长度,返回填充后的张量和每个张量的原始长度。
|
| 207 |
+
:param wav_list: wav 文件列表
|
| 208 |
+
:return: 填充后的音频张量和原始长度
|
| 209 |
+
"""
|
| 210 |
+
wav_lengths = torch.tensor([wav.size(0) for wav in wav_list], dtype=torch.long)
|
| 211 |
+
max_length = max(wav_lengths)
|
| 212 |
+
# 确保张量在正确的设备上
|
| 213 |
+
wavs_tensors = torch.zeros(len(wav_list), max_length, device=self.device)
|
| 214 |
+
for i, wav in enumerate(wav_list):
|
| 215 |
+
wav = wav.to(self.device) # 确保wav在正确的设备上
|
| 216 |
+
wavs_tensors[i, :wav_lengths[i]] = wav
|
| 217 |
+
return wavs_tensors, wav_lengths.to(self.device)
|
| 218 |
+
|
| 219 |
+
def get_target(self, transcript, language):
|
| 220 |
+
"""
|
| 221 |
+
获取给定转录文本的目标 tokens - 修正版本,保持英文单词完整性
|
| 222 |
+
"""
|
| 223 |
+
original_transcript = transcript # 保存原始文本用于调试
|
| 224 |
+
|
| 225 |
+
if language == "ZH":
|
| 226 |
+
# 中文处理:保持英文单词完整,只对中文字符进行romanization
|
| 227 |
+
# 使用相同的分词逻辑
|
| 228 |
+
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
|
| 229 |
+
tokens = re.findall(pattern, transcript)
|
| 230 |
+
|
| 231 |
+
# 分别处理中文字符和英文单词
|
| 232 |
+
processed_parts = []
|
| 233 |
+
for token in tokens:
|
| 234 |
+
if not token.strip():
|
| 235 |
+
continue
|
| 236 |
+
elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词
|
| 237 |
+
# 英文单词保持原样,不进行romanization
|
| 238 |
+
processed_parts.append(token.lower())
|
| 239 |
+
elif '\u4e00' <= token <= '\u9fff': # 中文字符
|
| 240 |
+
# 只对中文字符进行romanization
|
| 241 |
+
romanized = self.uroman.romanize_string(token)
|
| 242 |
+
processed_parts.append(romanized)
|
| 243 |
+
else: # 标点符号等
|
| 244 |
+
# 标点符号直接添加,但会在后续步骤中被过滤掉
|
| 245 |
+
processed_parts.append(token)
|
| 246 |
+
|
| 247 |
+
# 用空格连接所有部分
|
| 248 |
+
transcript = ' '.join(processed_parts)
|
| 249 |
+
|
| 250 |
+
elif language == "EN":
|
| 251 |
+
# 英文处理:保持单词结构,只是清理标点
|
| 252 |
+
pass
|
| 253 |
+
else:
|
| 254 |
+
assert False, f"Unsupported language: {language}"
|
| 255 |
+
|
| 256 |
+
# 清理标点符号
|
| 257 |
+
transcript = re.sub(r'[^\w\s]', r' ', transcript)
|
| 258 |
+
TRANSCRIPT = transcript.lower().split()
|
| 259 |
+
|
| 260 |
+
# 提前获取字典中的特殊符号 token
|
| 261 |
+
star_token = self.DICTIONARY['*']
|
| 262 |
+
tokenized_transcript = []
|
| 263 |
+
|
| 264 |
+
# 统一的tokenization逻辑
|
| 265 |
+
for word in TRANSCRIPT:
|
| 266 |
+
# 对每个word中的字符进行token化
|
| 267 |
+
word_tokens = []
|
| 268 |
+
for c in word:
|
| 269 |
+
if c in self.DICTIONARY and c != '-':
|
| 270 |
+
word_tokens.append(self.DICTIONARY[c])
|
| 271 |
+
else:
|
| 272 |
+
word_tokens.append(star_token)
|
| 273 |
+
tokenized_transcript.extend(word_tokens)
|
| 274 |
+
|
| 275 |
+
logging.info(f"Original transcript: {original_transcript}")
|
| 276 |
+
logging.info(f"Processed transcript: {transcript}")
|
| 277 |
+
logging.info(f"Final TRANSCRIPT: {TRANSCRIPT}")
|
| 278 |
+
|
| 279 |
+
return torch.tensor([tokenized_transcript], dtype=torch.int32, device=self.device)
|
| 280 |
+
|
| 281 |
+
def get_alignment_result(self, emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language):
|
| 282 |
+
"""
|
| 283 |
+
根据给定的 emission 和对齐信息生成对齐结果 - 修正版本
|
| 284 |
+
"""
|
| 285 |
+
original_transcript = transcript # 保存原始文本
|
| 286 |
+
|
| 287 |
+
if language == "ZH":
|
| 288 |
+
# 使用与get_target相同的处理逻辑
|
| 289 |
+
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
|
| 290 |
+
tokens = re.findall(pattern, transcript)
|
| 291 |
+
|
| 292 |
+
processed_parts = []
|
| 293 |
+
for token in tokens:
|
| 294 |
+
if not token.strip():
|
| 295 |
+
continue
|
| 296 |
+
elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词
|
| 297 |
+
processed_parts.append(token.lower())
|
| 298 |
+
elif '\u4e00' <= token <= '\u9fff': # 中文字符
|
| 299 |
+
romanized = self.uroman.romanize_string(token)
|
| 300 |
+
processed_parts.append(romanized)
|
| 301 |
+
else: # 标点符号等
|
| 302 |
+
processed_parts.append(token)
|
| 303 |
+
|
| 304 |
+
transcript = ' '.join(processed_parts)
|
| 305 |
+
elif language == "EN":
|
| 306 |
+
pass
|
| 307 |
+
else:
|
| 308 |
+
assert False, f"Unsupported language: {language}"
|
| 309 |
+
|
| 310 |
+
transcript = re.sub(r'[^\w\s]', r' ', transcript)
|
| 311 |
+
emission = emission_padded[:emission_length, :].unsqueeze(0)
|
| 312 |
+
TRANSCRIPT = transcript.lower().split()
|
| 313 |
+
|
| 314 |
+
token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
|
| 315 |
+
|
| 316 |
+
# 统一的分组逻辑
|
| 317 |
+
word_spans = self.unflatten(token_spans, [len(word) for word in TRANSCRIPT])
|
| 318 |
+
|
| 319 |
+
num_frames = emission.size(1)
|
| 320 |
+
|
| 321 |
+
logging.info(f"Original transcript for alignment: {original_transcript}")
|
| 322 |
+
logging.info(f"Processed TRANSCRIPT: {TRANSCRIPT}")
|
| 323 |
+
|
| 324 |
+
return self.preview_word(waveform.unsqueeze(0), word_spans, num_frames, TRANSCRIPT, self.bundle.sample_rate)
|
| 325 |
+
|
| 326 |
+
def batch_alignment(self, wav_list, transcript_list, language_list):
|
| 327 |
+
"""
|
| 328 |
+
批量对齐
|
| 329 |
+
:param wav_list: wav 文件列表
|
| 330 |
+
:param transcript_list: 转录文本列表
|
| 331 |
+
:param language_list: 语言类型列表
|
| 332 |
+
:return: 对齐结果列表
|
| 333 |
+
"""
|
| 334 |
+
wavs_tensors, wavs_lengths_tensor = self.make_wav_batch(wav_list)
|
| 335 |
+
logging.info("start alignment model forward")
|
| 336 |
+
with torch.inference_mode():
|
| 337 |
+
emission, emission_lengths = self.align_model(wavs_tensors.to(self.device), wavs_lengths_tensor)
|
| 338 |
+
star_dim = torch.zeros((emission.shape[0], emission.size(1), 1), dtype=emission.dtype, device=self.device)
|
| 339 |
+
emission = torch.cat((emission, star_dim), dim=-1)
|
| 340 |
+
|
| 341 |
+
logging.info("end alignment model forward")
|
| 342 |
+
|
| 343 |
+
target_list = [self.get_target(transcript, language) for transcript, language in zip(transcript_list, language_list)]
|
| 344 |
+
|
| 345 |
+
logging.info("align success")
|
| 346 |
+
align_results = [
|
| 347 |
+
self.align(emission_padded[:emission_length, :].unsqueeze(0), target)
|
| 348 |
+
for emission_padded, emission_length, target in zip(emission, emission_lengths, target_list)
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
logging.info("get align result")
|
| 352 |
+
batch_aligned_tokens = [align_result[0] for align_result in align_results]
|
| 353 |
+
batch_alignment_scores = [align_result[1] for align_result in align_results]
|
| 354 |
+
|
| 355 |
+
alignment_result_list = [
|
| 356 |
+
self.get_alignment_result(emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language)
|
| 357 |
+
for emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language
|
| 358 |
+
in zip(emission, emission_lengths, batch_aligned_tokens, batch_alignment_scores, transcript_list, wav_list, language_list)
|
| 359 |
+
]
|
| 360 |
+
logging.info("get align result success")
|
| 361 |
+
return alignment_result_list
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def batch_get_alignment_result(alignment_model, wav_list, transcript_list, language_list):
|
| 365 |
+
"""
|
| 366 |
+
批量获取对齐结果的便捷函数
|
| 367 |
+
"""
|
| 368 |
+
alignment_results = alignment_model.batch_alignment(
|
| 369 |
+
wav_list=wav_list,
|
| 370 |
+
transcript_list=transcript_list,
|
| 371 |
+
language_list=language_list
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
alignments_results_with_text_and_punctuation = []
|
| 375 |
+
for alignment_result, transcript, language in zip(alignment_results, transcript_list, language_list):
|
| 376 |
+
try:
|
| 377 |
+
result = get_aligned_result_text_with_punctuation(alignment_result, transcript, language)
|
| 378 |
+
alignments_results_with_text_and_punctuation.append(result)
|
| 379 |
+
except:
|
| 380 |
+
logger = logging.getLogger("tokenize")
|
| 381 |
+
logger.error(f"Error in processing {alignment_result}")
|
| 382 |
+
traceback.print_exc()
|
| 383 |
+
alignments_results_with_text_and_punctuation.append(alignment_result)
|
| 384 |
+
return alignments_results_with_text_and_punctuation
|
alignment_online.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import httpx
|
| 3 |
+
import re
|
| 4 |
+
import requests
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio.functional as F
|
| 7 |
+
import torchaudio
|
| 8 |
+
import uroman as ur
|
| 9 |
+
import logging
|
| 10 |
+
import traceback
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def convert_to_list_with_punctuation_mixed(text):
|
| 14 |
+
"""处理中文文本(可能包含英文单词) - 中文按字符分割,英文单词保持完整"""
|
| 15 |
+
result = []
|
| 16 |
+
text = text.strip()
|
| 17 |
+
|
| 18 |
+
if not text:
|
| 19 |
+
return result
|
| 20 |
+
|
| 21 |
+
def is_chinese(char):
|
| 22 |
+
"""检查是否是汉字"""
|
| 23 |
+
return '\u4e00' <= char <= '\u9fff'
|
| 24 |
+
|
| 25 |
+
# 使用更精确的正则表达式来分割文本
|
| 26 |
+
# 匹配:英文单词(含数字)、单个汉字、标点符号
|
| 27 |
+
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
|
| 28 |
+
tokens = re.findall(pattern, text)
|
| 29 |
+
|
| 30 |
+
for token in tokens:
|
| 31 |
+
if not token.strip(): # 跳过空字符
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
if re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词(可能包含数字)
|
| 35 |
+
result.append(token)
|
| 36 |
+
elif is_chinese(token): # 单个汉字
|
| 37 |
+
result.append(token)
|
| 38 |
+
else: # 标点符号等其他字符
|
| 39 |
+
# 标点符号加到前一个词后面
|
| 40 |
+
if result:
|
| 41 |
+
result[-1] += token
|
| 42 |
+
else:
|
| 43 |
+
# 如果是文本开头的标点,单独作为一项
|
| 44 |
+
result.append(token)
|
| 45 |
+
|
| 46 |
+
return result
|
| 47 |
+
|
| 48 |
+
def split_and_merge_punctuation(text):
|
| 49 |
+
"""处理英文 - 按单词分割,保持单词完整性"""
|
| 50 |
+
# 先按空格拆分文本
|
| 51 |
+
elements = text.split()
|
| 52 |
+
|
| 53 |
+
# 用于保存最终的结果
|
| 54 |
+
result = []
|
| 55 |
+
|
| 56 |
+
# 遍历每个拆分后的元素
|
| 57 |
+
for ele in elements:
|
| 58 |
+
# 使用正则表达式提取连续字母、数字和标点
|
| 59 |
+
parts = re.findall(r'[a-zA-Z0-9]+|[^\w\s]+', ele)
|
| 60 |
+
|
| 61 |
+
# 用于保存拆分后的部分
|
| 62 |
+
merged_parts = []
|
| 63 |
+
|
| 64 |
+
for i in range(len(parts)):
|
| 65 |
+
if i % 2 == 0: # 如果是字母或数字部分
|
| 66 |
+
# 将字母或数字部分添加到结果中
|
| 67 |
+
merged_parts.append(parts[i])
|
| 68 |
+
else: # 如果是标点或其他符号部分
|
| 69 |
+
# 将标点部分与前面的字母或数字部分合并
|
| 70 |
+
if merged_parts:
|
| 71 |
+
merged_parts[-1] += parts[i]
|
| 72 |
+
else:
|
| 73 |
+
merged_parts.append(parts[i])
|
| 74 |
+
|
| 75 |
+
# 将合并后的部分加入最终结果
|
| 76 |
+
result.extend(merged_parts)
|
| 77 |
+
|
| 78 |
+
return result
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_aligned_result_text_with_punctuation(alignment_result, text, language):
|
| 82 |
+
"""
|
| 83 |
+
将对齐结果转换为正确的文本tokens,英文保持单词级别,中文保持字符级别(但英文单词完整)
|
| 84 |
+
"""
|
| 85 |
+
logging.info("start change text to text_tokens")
|
| 86 |
+
|
| 87 |
+
if language == "EN":
|
| 88 |
+
text_tokens = split_and_merge_punctuation(text) # 英文按单词分词
|
| 89 |
+
elif language == "ZH":
|
| 90 |
+
text_tokens = convert_to_list_with_punctuation_mixed(text) # 中文按字符分割,但英文单词保持完整
|
| 91 |
+
else:
|
| 92 |
+
raise ValueError(f"Unsupported language: {language}")
|
| 93 |
+
|
| 94 |
+
logging.info(f"Text tokens count: {len(text_tokens)}, Alignment result count: {len(alignment_result)}")
|
| 95 |
+
|
| 96 |
+
punctuations = set(',.!?;:()[]<>\'\"…·,。;:!?()【】《》''""\、')
|
| 97 |
+
|
| 98 |
+
logging.info("start get align result text with punctuation")
|
| 99 |
+
updated_alignment_result = []
|
| 100 |
+
token_idx = 0
|
| 101 |
+
|
| 102 |
+
for index, align_item in enumerate(alignment_result):
|
| 103 |
+
if token_idx >= len(text_tokens):
|
| 104 |
+
# 如果text_tokens用完了但还有对齐结果,跳出循环
|
| 105 |
+
logging.warning(f"Text tokens exhausted at index {token_idx}, but alignment has more items")
|
| 106 |
+
break
|
| 107 |
+
|
| 108 |
+
start = align_item["start"]
|
| 109 |
+
end = align_item["end"]
|
| 110 |
+
text_token = text_tokens[token_idx]
|
| 111 |
+
|
| 112 |
+
# 检查该 token 后是否有连续标点(仅对中文)
|
| 113 |
+
if language == "ZH":
|
| 114 |
+
while token_idx + 1 < len(text_tokens) and text_tokens[token_idx + 1] in punctuations:
|
| 115 |
+
assert False, "???" # 这里理论上应该进不去??
|
| 116 |
+
text_token += text_tokens[token_idx + 1] # 将标点加入
|
| 117 |
+
token_idx += 1
|
| 118 |
+
else:
|
| 119 |
+
# 英文不需要特殊的标点处理,因为标点已经在split_and_merge_punctuation中处理了
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
# 更新对齐结果
|
| 123 |
+
updated_item = {
|
| 124 |
+
"start": start,
|
| 125 |
+
"end": end,
|
| 126 |
+
"transcript": text_token
|
| 127 |
+
}
|
| 128 |
+
updated_item.update({key: align_item[key] for key in align_item if key not in ["start", "end", "transcript"]})
|
| 129 |
+
|
| 130 |
+
updated_alignment_result.append(updated_item)
|
| 131 |
+
token_idx += 1
|
| 132 |
+
|
| 133 |
+
logging.info("end get align result text with punctuation")
|
| 134 |
+
return updated_alignment_result
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class AlignmentModel:
|
| 138 |
+
def __init__(self, device, model_dir='/data-mnt/data/wy/X-Codec-2.0/checkpoints'):
|
| 139 |
+
"""
|
| 140 |
+
初始化对齐模型并加载必要的资源
|
| 141 |
+
"""
|
| 142 |
+
self.device = torch.device(device)
|
| 143 |
+
self.bundle = torchaudio.pipelines.MMS_FA
|
| 144 |
+
model = self.bundle.get_model(with_star=False, dl_kwargs={'model_dir': model_dir}).to(self.device)
|
| 145 |
+
|
| 146 |
+
# --- 核心优化 ---
|
| 147 |
+
# 使用 torch.compile 对模型进行 JIT 编译
|
| 148 |
+
# mode="max-autotune" 会花费更长时间编译,但能达到最佳性能
|
| 149 |
+
print("Compiling the model... This may take a moment.")
|
| 150 |
+
self.align_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
|
| 151 |
+
print("Model compiled successfully.")
|
| 152 |
+
|
| 153 |
+
self.uroman = ur.Uroman()
|
| 154 |
+
self.DICTIONARY = self.bundle.get_dict()
|
| 155 |
+
|
| 156 |
+
def align(self, emission, tokens):
|
| 157 |
+
"""
|
| 158 |
+
执行强对齐
|
| 159 |
+
:param emission: 模型的输出
|
| 160 |
+
:param tokens: 目标 tokens
|
| 161 |
+
:return: 对齐的 tokens 和分数
|
| 162 |
+
"""
|
| 163 |
+
alignments, scores = F.forced_align(
|
| 164 |
+
log_probs=emission,
|
| 165 |
+
targets=tokens,
|
| 166 |
+
blank=0
|
| 167 |
+
)
|
| 168 |
+
alignments, scores = alignments[0], scores[0]
|
| 169 |
+
scores = scores.exp()
|
| 170 |
+
return alignments, scores
|
| 171 |
+
|
| 172 |
+
def unflatten(self, list_, lengths):
|
| 173 |
+
"""
|
| 174 |
+
将一个长列表按照长度拆分成子列表
|
| 175 |
+
:param list_: 长列表
|
| 176 |
+
:param lengths: 各子列表的长度
|
| 177 |
+
:return: 拆分后的子列表
|
| 178 |
+
"""
|
| 179 |
+
assert len(list_) == sum(lengths)
|
| 180 |
+
i = 0
|
| 181 |
+
ret = []
|
| 182 |
+
for l in lengths:
|
| 183 |
+
ret.append(list_[i:i + l])
|
| 184 |
+
i += l
|
| 185 |
+
return ret
|
| 186 |
+
|
| 187 |
+
def preview_word(self, waveform, spans, num_frames, transcript, sample_rate):
|
| 188 |
+
"""
|
| 189 |
+
预览每个单词的开始时间和结束时间
|
| 190 |
+
:param waveform: 音频波形
|
| 191 |
+
:param spans: 单词的跨度
|
| 192 |
+
:param num_frames: 帧数
|
| 193 |
+
:param transcript: 转录文本
|
| 194 |
+
:param sample_rate: 采样率
|
| 195 |
+
:return: 单词的对齐信息
|
| 196 |
+
"""
|
| 197 |
+
end = 0
|
| 198 |
+
alignment_result = []
|
| 199 |
+
for span, trans in zip(spans, transcript):
|
| 200 |
+
ratio = waveform.size(1) / num_frames
|
| 201 |
+
x0 = int(ratio * span[0].start)
|
| 202 |
+
x1 = int(ratio * span[-1].end)
|
| 203 |
+
align_info = {
|
| 204 |
+
"transcript": trans,
|
| 205 |
+
"start": round(x0 / sample_rate, 3),
|
| 206 |
+
"end": round(x1 / sample_rate, 3)
|
| 207 |
+
}
|
| 208 |
+
align_info["pause"] = round(align_info["start"] - end, 3)
|
| 209 |
+
align_info["duration"] = round(align_info["end"] - align_info["start"], 3)
|
| 210 |
+
end = align_info["end"]
|
| 211 |
+
alignment_result.append(align_info)
|
| 212 |
+
return alignment_result
|
| 213 |
+
|
| 214 |
+
def make_wav_batch(self, wav_list):
|
| 215 |
+
"""
|
| 216 |
+
将 wav_list 中的每个 wav 张量填充为相同的长度,返回填充后的张量和每个张量的原始长度。
|
| 217 |
+
:param wav_list: wav 文件列表
|
| 218 |
+
:return: 填充后的音频张量和原始长度
|
| 219 |
+
"""
|
| 220 |
+
wav_lengths = torch.tensor([wav.size(0) for wav in wav_list], dtype=torch.long)
|
| 221 |
+
max_length = max(wav_lengths)
|
| 222 |
+
# 确保张量在正确的设备上
|
| 223 |
+
wavs_tensors = torch.zeros(len(wav_list), max_length, device=self.device)
|
| 224 |
+
for i, wav in enumerate(wav_list):
|
| 225 |
+
wav = wav.to(self.device) # 确保wav在正确的设备上
|
| 226 |
+
wavs_tensors[i, :wav_lengths[i]] = wav
|
| 227 |
+
return wavs_tensors, wav_lengths.to(self.device)
|
| 228 |
+
|
| 229 |
+
def get_target(self, transcript, language):
|
| 230 |
+
"""
|
| 231 |
+
获取给定转录文本的目标 tokens - 修正版本,保持英文单词完整性
|
| 232 |
+
"""
|
| 233 |
+
original_transcript = transcript # 保存原始文本用于调试
|
| 234 |
+
|
| 235 |
+
if language == "ZH":
|
| 236 |
+
# 中文处理:保持英文单词完整,只对中文字符进行romanization
|
| 237 |
+
# 使用相同的分词逻辑
|
| 238 |
+
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
|
| 239 |
+
tokens = re.findall(pattern, transcript)
|
| 240 |
+
|
| 241 |
+
# 分别处理中文字符和英文单词
|
| 242 |
+
processed_parts = []
|
| 243 |
+
for token in tokens:
|
| 244 |
+
if not token.strip():
|
| 245 |
+
continue
|
| 246 |
+
elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词
|
| 247 |
+
# 英文单词保持原样,不进行romanization
|
| 248 |
+
processed_parts.append(token.lower())
|
| 249 |
+
elif '\u4e00' <= token <= '\u9fff': # 中文字符
|
| 250 |
+
# 只对中文字符进行romanization
|
| 251 |
+
romanized = self.uroman.romanize_string(token)
|
| 252 |
+
processed_parts.append(romanized)
|
| 253 |
+
else: # 标点符号等
|
| 254 |
+
# 标点符号直接添加,但会在后续步骤中被过滤掉
|
| 255 |
+
processed_parts.append(token)
|
| 256 |
+
|
| 257 |
+
# 用空格连接所有部分
|
| 258 |
+
transcript = ' '.join(processed_parts)
|
| 259 |
+
|
| 260 |
+
elif language == "EN":
|
| 261 |
+
# 英文处理:保持单词结构,只是清理标点
|
| 262 |
+
pass
|
| 263 |
+
else:
|
| 264 |
+
assert False, f"Unsupported language: {language}"
|
| 265 |
+
|
| 266 |
+
# 清理标点符号
|
| 267 |
+
transcript = re.sub(r'[^\w\s]', r' ', transcript)
|
| 268 |
+
TRANSCRIPT = transcript.lower().split()
|
| 269 |
+
|
| 270 |
+
# 提前获取字典中的特殊符号 token
|
| 271 |
+
star_token = self.DICTIONARY['*']
|
| 272 |
+
tokenized_transcript = []
|
| 273 |
+
|
| 274 |
+
# 统一的tokenization逻辑
|
| 275 |
+
for word in TRANSCRIPT:
|
| 276 |
+
# 对每个word中的字符进行token化
|
| 277 |
+
word_tokens = []
|
| 278 |
+
for c in word:
|
| 279 |
+
if c in self.DICTIONARY and c != '-':
|
| 280 |
+
word_tokens.append(self.DICTIONARY[c])
|
| 281 |
+
else:
|
| 282 |
+
word_tokens.append(star_token)
|
| 283 |
+
tokenized_transcript.extend(word_tokens)
|
| 284 |
+
|
| 285 |
+
logging.info(f"Original transcript: {original_transcript}")
|
| 286 |
+
logging.info(f"Processed transcript: {transcript}")
|
| 287 |
+
logging.info(f"Final TRANSCRIPT: {TRANSCRIPT}")
|
| 288 |
+
|
| 289 |
+
return torch.tensor([tokenized_transcript], dtype=torch.int32, device=self.device)
|
| 290 |
+
|
| 291 |
+
def get_alignment_result(self, emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language):
|
| 292 |
+
"""
|
| 293 |
+
根据给定的 emission 和对齐信息生成对齐结果 - 修正版本
|
| 294 |
+
"""
|
| 295 |
+
original_transcript = transcript # 保存原始文本
|
| 296 |
+
|
| 297 |
+
if language == "ZH":
|
| 298 |
+
# 使用与get_target相同的处理逻辑
|
| 299 |
+
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
|
| 300 |
+
tokens = re.findall(pattern, transcript)
|
| 301 |
+
|
| 302 |
+
processed_parts = []
|
| 303 |
+
for token in tokens:
|
| 304 |
+
if not token.strip():
|
| 305 |
+
continue
|
| 306 |
+
elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词
|
| 307 |
+
processed_parts.append(token.lower())
|
| 308 |
+
elif '\u4e00' <= token <= '\u9fff': # 中文字符
|
| 309 |
+
romanized = self.uroman.romanize_string(token)
|
| 310 |
+
processed_parts.append(romanized)
|
| 311 |
+
else: # 标点符号等
|
| 312 |
+
processed_parts.append(token)
|
| 313 |
+
|
| 314 |
+
transcript = ' '.join(processed_parts)
|
| 315 |
+
elif language == "EN":
|
| 316 |
+
pass
|
| 317 |
+
else:
|
| 318 |
+
assert False, f"Unsupported language: {language}"
|
| 319 |
+
|
| 320 |
+
transcript = re.sub(r'[^\w\s]', r' ', transcript)
|
| 321 |
+
emission = emission_padded[:emission_length, :].unsqueeze(0)
|
| 322 |
+
TRANSCRIPT = transcript.lower().split()
|
| 323 |
+
|
| 324 |
+
token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
|
| 325 |
+
|
| 326 |
+
# 统一的分组逻辑
|
| 327 |
+
word_spans = self.unflatten(token_spans, [len(word) for word in TRANSCRIPT])
|
| 328 |
+
|
| 329 |
+
num_frames = emission.size(1)
|
| 330 |
+
|
| 331 |
+
logging.info(f"Original transcript for alignment: {original_transcript}")
|
| 332 |
+
logging.info(f"Processed TRANSCRIPT: {TRANSCRIPT}")
|
| 333 |
+
|
| 334 |
+
return self.preview_word(waveform.unsqueeze(0), word_spans, num_frames, TRANSCRIPT, self.bundle.sample_rate)
|
| 335 |
+
|
| 336 |
+
def batch_alignment(self, wav_list, transcript_list, language_list):
|
| 337 |
+
"""
|
| 338 |
+
批量对齐
|
| 339 |
+
:param wav_list: wav 文件列表
|
| 340 |
+
:param transcript_list: 转录文本列表
|
| 341 |
+
:param language_list: 语言类型列表
|
| 342 |
+
:return: 对齐结果列表
|
| 343 |
+
"""
|
| 344 |
+
wavs_tensors, wavs_lengths_tensor = self.make_wav_batch(wav_list)
|
| 345 |
+
logging.info("start alignment model forward")
|
| 346 |
+
with torch.inference_mode():
|
| 347 |
+
emission, emission_lengths = self.align_model(wavs_tensors.to(self.device), wavs_lengths_tensor)
|
| 348 |
+
star_dim = torch.zeros((emission.shape[0], emission.size(1), 1), dtype=emission.dtype, device=self.device)
|
| 349 |
+
emission = torch.cat((emission, star_dim), dim=-1)
|
| 350 |
+
|
| 351 |
+
logging.info("end alignment model forward")
|
| 352 |
+
|
| 353 |
+
target_list = [self.get_target(transcript, language) for transcript, language in zip(transcript_list, language_list)]
|
| 354 |
+
|
| 355 |
+
logging.info("align success")
|
| 356 |
+
align_results = [
|
| 357 |
+
self.align(emission_padded[:emission_length, :].unsqueeze(0), target)
|
| 358 |
+
for emission_padded, emission_length, target in zip(emission, emission_lengths, target_list)
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
logging.info("get align result")
|
| 362 |
+
batch_aligned_tokens = [align_result[0] for align_result in align_results]
|
| 363 |
+
batch_alignment_scores = [align_result[1] for align_result in align_results]
|
| 364 |
+
|
| 365 |
+
alignment_result_list = [
|
| 366 |
+
self.get_alignment_result(emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language)
|
| 367 |
+
for emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language
|
| 368 |
+
in zip(emission, emission_lengths, batch_aligned_tokens, batch_alignment_scores, transcript_list, wav_list, language_list)
|
| 369 |
+
]
|
| 370 |
+
logging.info("get align result success")
|
| 371 |
+
return alignment_result_list
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
async def batch_get_alignment_result_remote(alignment_url, audio_path, transcript, language):
|
| 375 |
+
"""
|
| 376 |
+
通过调用远程对齐服务来批量获取对齐结果。
|
| 377 |
+
"""
|
| 378 |
+
payload = {
|
| 379 |
+
"audio_path": audio_path,
|
| 380 |
+
"transcript": transcript,
|
| 381 |
+
"language": language,
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
try:
|
| 385 |
+
async with httpx.AsyncClient() as client:
|
| 386 |
+
response = await client.post(alignment_url, json=payload, timeout=300) # 设置较长的超时
|
| 387 |
+
response.raise_for_status() # 如果状态码不是 2xx,则抛出异常
|
| 388 |
+
data = response.json()
|
| 389 |
+
return data['results']
|
| 390 |
+
|
| 391 |
+
except requests.exceptions.RequestException as e:
|
| 392 |
+
logging.error(f"Failed to connect to alignment service: {e}")
|
| 393 |
+
traceback.print_exc()
|
| 394 |
+
# 根据需求可以返回空列表或抛出异常
|
| 395 |
+
except Exception as e:
|
| 396 |
+
logging.error(f"An error occurred in remote alignment: {e}")
|
| 397 |
+
traceback.print_exc()
|
| 398 |
+
|
docker/Dockerfile
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Redistribution and use in source and binary forms, with or without modification, are permitted
|
| 6 |
+
# provided that the following conditions are met:
|
| 7 |
+
# * Redistributions of source code must retain the above copyright notice, this list of
|
| 8 |
+
# conditions and the following disclaimer.
|
| 9 |
+
# * Redistributions in binary form must reproduce the above copyright notice, this list of
|
| 10 |
+
# conditions and the following disclaimer in the documentation and/or other materials
|
| 11 |
+
# provided with the distribution.
|
| 12 |
+
# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
| 13 |
+
# to endorse or promote products derived from this software without specific prior written
|
| 14 |
+
# permission.
|
| 15 |
+
#
|
| 16 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
| 17 |
+
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
| 18 |
+
# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
| 19 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
| 20 |
+
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
| 21 |
+
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
| 22 |
+
# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 23 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 24 |
+
#
|
| 25 |
+
###################################################################################################
|
| 26 |
+
FROM nvcr.io/nvidia/tritonserver:25.08-py3
|
| 27 |
+
LABEL maintainer="NVIDIA"
|
| 28 |
+
LABEL repository="tritonserver"
|
| 29 |
+
|
| 30 |
+
RUN apt-get update && apt-get -y install swig && apt-get -y install python3-dev && apt-get install -y cmake && apt-get install -y libsndfile1
|
| 31 |
+
RUN pip3 install kaldiio
|
| 32 |
+
RUN pip3 install torch torchvision torchaudio -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
| 33 |
+
RUN pip3 install -v kaldifeat
|
| 34 |
+
RUN python3 -m pip install cupy
|
| 35 |
+
RUN python3 -m pip install soundfile
|
| 36 |
+
RUN pip3 install --upgrade pip
|
| 37 |
+
RUN pip install --extra-index-url https://pypi.nvidia.com cudf_cu12
|
| 38 |
+
RUN pip install --extra-index-url https://pypi.nvidia.com cuml_cu12
|
| 39 |
+
RUN pip install --extra-index-url https://pypi.nvidia.com cugraph_cu12
|
| 40 |
+
WORKDIR /workspace
|
download_mms_model.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchaudio
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
def download_mms_model(download_dir="/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation/models/mms_fa"):
|
| 6 |
+
"""下载MMS-FA模型到指定目录"""
|
| 7 |
+
|
| 8 |
+
# 创建下载目录
|
| 9 |
+
download_path = Path(download_dir)
|
| 10 |
+
download_path.mkdir(parents=True, exist_ok=True)
|
| 11 |
+
|
| 12 |
+
print(f"开始下载MMS-FA模型到: {download_path}")
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
# 获取MMS-FA bundle
|
| 16 |
+
bundle = torchaudio.pipelines.MMS_FA
|
| 17 |
+
|
| 18 |
+
# 下载模型
|
| 19 |
+
model = bundle.get_model(with_star=False, dl_kwargs={'model_dir': str(download_path)})
|
| 20 |
+
|
| 21 |
+
print(f"✅ 模型下载成功!保存在: {download_path}")
|
| 22 |
+
print(f"模型文件: {list(download_path.glob('*'))}")
|
| 23 |
+
|
| 24 |
+
return str(download_path)
|
| 25 |
+
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f"❌ 下载失败: {e}")
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
# 下载模型
|
| 32 |
+
model_path = download_mms_model()
|
| 33 |
+
if model_path:
|
| 34 |
+
print(f"\n使用方法:")
|
| 35 |
+
print(f"evaluator = SpeakerSimilarityEvaluator(alignment_model_dir='{model_path}')")
|
example_input.jsonl
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"text": "[S1] Hey, do you know the AI world has been super lively lately?[S2] Oh, yeah, new news every day. It feels like, um, a lot of big companies are just pushing really hard to get ahead.[S1] Right, right, exactly. Like, big news popping up every other day. Recently, I saw something about Anthropic. Didn't they release Claude 4?[S2] Oh, Claude 4, yeah, I saw some reports. They said it's really powerful, their latest model.[S1] Mhm, they're calling it the world's best programming model,sounds super impressive.[S2] Mm.[S1] Hey, really? World's best? That title alone is pretty catchy.[S2] Yeah, that really makes you curious, actually.[S1] Right? And it claims that for long tasks requiring extreme focus and thousands of steps, it can maintain stable performance.[S2] Mm.[S1] Meaning, it doesn't crash easily.[S2] Wow, that's amazing. So, it doesn't crash easily, huh?[S1] Exactly. They said, like, the Japanese e-commerce giant Rakuten, you know them, right? They actually verified Claude Opus 4's capability. In a demanding open-source refactoring task, it ran independently for seven hours.[S2] Seven hours?[S1] And throughout that time, its performance remained completely stable.[S2] Wow, my goodness. It runs on its own for seven hours without a break? That's incredible.[S1] Yeah, for those tasks that need focused effort and thousands of steps, it can handle them steadily.[S2] Mm, that's really something.[S1] Uh, so it's especially suitable for complex coding and problem-solving scenarios.[S2] Oh, I see. So, how's its performance in programming, really? Is it actually much better than before?[S1] Yeah, they mentioned the SWE-bench evaluation, which is a benchmark test for software engineering tasks.[S2] Oh, I know that test, it's quite professional.[S1] Mm, their Claude Sonnet 4 achieved an accuracy of 72.7 percent.[S2] Mm, 72.7 percent, that's high.[S1] Right, and they also compared it to the previous Sonnet 3.7 version.[S2] Mm.[S1] The 3.7 version got 62.3 percent.[S2] Oh, that's about a ten-point difference, then.[S1] Exactly, so Sonnet 4 improved significantly.[S2] Hmm, so it seems like this upgrade is substantial, not just hype.[S1] Indeed. And they also released Claude Code, which is a dedicated programming tool.[S2] Hmm, like, for developers to use?[S1] Yes, they said Claude Code is officially launched and supported by both Claude 4 models.[S2] Oh, I see. So, not only are the models powerful, but they've also improved the tools, like a complete package.[S1] That's right. And they also said that Claude Code isn't just for programmers.[S2] Huh? If it's not for programmers, then who's it for?[S1] They said, even for people who aren't really good at programming,[S2] Mm.[S1] Like product managers, if they want to create a prototype for an idea, they can just ask Claude to do it.[S2] Wow, that's really interesting. So, you don't have to write the code yourself, you just let the AI help you realize your ideas, right?[S1] Yeah, they're saying that in the future, if you have an idea, you might not need to write a document; you can just have it help you create the prototype.[S2] Hmm, that sounds a bit like, uh, will programmers' jobs become less common in the future?[S1] Hmm, it might be more like, Scott White, who's their product lead, he said that Claude is transforming from a tool that provides answers into a truly capable collaborative partner.[S2] Oh, I understand. So, it helps you with, uh, more basic or repetitive tasks, allowing you to focus more on creative things.[S1] Yes, exactly. And the models they released this time are called Opus 4 and Sonnet 4.[S2] Mm.[S1] Opus 4, they say, is their most powerful model to date, and also the world's best programming model.[S2] Definitely the flagship model.[S1] And Sonnet 4 is a major upgrade to Sonnet 3.7.[S2] Oh, so what are the specific differences between the two?[S1] Hmm, Opus 4 is better at high-end tasks like coding, research, writing, and scientific discovery.[S2] Hmm, Opus sounds more all-around capable.[S1] Right, and Sonnet 4 is more suitable for everyday use cases; it offers cutting-edge performance for daily tasks.[S2] Oh, I see. So, one is super high-end, and the other is also super strong for everyday use.[S1] Yes, and both models use a hybrid mode design.[S2] Hybrid mode? What does that mean?[S1] It means it can provide almost instant responses, but also perform deeper reasoning and thought.[S2] Oh.[S1] Like, uh, expansive thinking.[S2] Oh, I see. So, sometimes it needs to be fast, and other times it needs to be slow and think deeply.[S1] Exactly.[S2] Hmm, so what about the pricing? Is it very expensive?[S1] The pricing is the same as the previous Opus and Sonnet models.[S2] Oh.[S1] For Opus 4, it's fifteen dollars per million input tokens and seventy-five dollars for output tokens.[S2] Wow, output is much more expensive![S1] Right. And for Sonnet 4, input is three dollars and output is fifteen dollars.[S2] Hmm, Sonnet is much more affordable then.[S1] Yes, and Sonnet 4 is also available for free users.[S2] Oh, that's good, everyone can try it out.[S1] Mm, exactly.[S2] Hey, so how does it compare to other AI giants? Where does it stand now?[S1] This release of theirs has intensified the competition with giants like OpenAI and Google in the top-tier model space.[S2] Yeah, it really feels like everyone's pushing hard lately.[S1] Right? Like, Microsoft also announced new coding agents, didn't they? And they partnered with Elon Musk's xAI.[S2] Mm.[S1] Google, meanwhile, is accelerating the integration of AI agents into their services.[S2] Right.[S1] And OpenAI is even more impressive; they just made a six-point-five-billion-dollar deal to acquire an AI hardware startup founded by the father of iPhone, former Apple design chief Jony Ive.[S2] Wow, six-point-five billion, that's a huge move. It feels like AI competition is really heating up.[S1] Exactly, so for investors, it means re-evaluating the competitive landscape in the AI sector.[S2] Hmm, makes sense. So, does this Claude 4 also bring a lot of opportunities for Anthropic?[S1] Yeah, its strong performance in coding, reasoning, and agent tasks will definitely help it capture more market share and enterprise clients.[S2] Hmm, sounds like it has huge potential indeed.[S1] Mm, it just feels like the AI competition now is all about who can push the technology to new heights.[S2] Exactly, and also who can really, uh, implement these technologies into practical applications.[S1] Right, right, exactly like that.[S2] Okay, well, this news about Claude 4 today really makes you feel like AI has taken a huge leap forward.[S1] Yeah, looking forward to it bringing more surprises in the future.", "prompt_audio_speaker1": "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/moon-en1/en_spk1_moon.wav", "prompt_text_speaker1": "OK. I'm starting to see how this multi-headed approach could lead to some pretty impressive results.", "prompt_audio_speaker2": "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/moon-en1/en_spk2-moon.wav", "prompt_text_speaker2": "It's not just crunching data. It's starting to develop a more sophisticated understanding of how language actually works.", "output_audio": "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step40000/test_en/gpu0/output_0.wav"}
|
model_repo/speaker_model/1/model.trt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:17a3a3b794fa886c9b8341a08ad5e22f3bb385dd994ff891ba62d591503673f5
|
| 3 |
+
size 104729100
|
model_repo/speaker_model/config.pbtxt
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
name: "speaker_model"
|
| 16 |
+
backend: "tensorrt"
|
| 17 |
+
default_model_filename: "model.trt"
|
| 18 |
+
|
| 19 |
+
max_batch_size: 16
|
| 20 |
+
input [
|
| 21 |
+
{
|
| 22 |
+
name: "feats"
|
| 23 |
+
data_type: TYPE_FP32
|
| 24 |
+
dims: [ -1, 80 ] # num_mel_bins
|
| 25 |
+
}
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
output [
|
| 29 |
+
{
|
| 30 |
+
name: "embs"
|
| 31 |
+
data_type: TYPE_FP32
|
| 32 |
+
dims: [ 256 ] # [embedding_size]
|
| 33 |
+
}
|
| 34 |
+
]
|
| 35 |
+
dynamic_batching {
|
| 36 |
+
preferred_batch_size: [ 4, 8 ]
|
| 37 |
+
max_queue_delay_microseconds: 1000
|
| 38 |
+
}
|
| 39 |
+
instance_group [
|
| 40 |
+
{
|
| 41 |
+
count: 1
|
| 42 |
+
kind: KIND_GPU
|
| 43 |
+
}
|
| 44 |
+
]
|
models/mms_fa/model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20ef12963ab4924bef49ac4fc7f58ad5da2ee43b2c11bc8c853c9b90ecdbc680
|
| 3 |
+
size 1262047414
|
models/mms_fa/model.pt.2c7cc4fedf8e4a089a0095148cc9201b.partial
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e0cf233f857de07296254c36332b4b984045cdc0964ec1fef6a0c6cc5aae00b7
|
| 3 |
+
size 1056964608
|
models/mms_fa/model.pt.5c5fe9893a2c462e9132dcd6a3fba337.partial
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:51258936b4a1a51762ef849ec0f404920f38d03c4e018550d75ea4e1e82a451a
|
| 3 |
+
size 486539264
|
models/voxblink2_samresnet100_ft/avg_model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d92ee34668d8eb24a02df4e7869fd4bde661220a137e045f29e4a0c85eb4004
|
| 3 |
+
size 201115747
|
models/voxblink2_samresnet100_ft/avg_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5aeee438ca23c0ca6e341bab6c6bf7f465497e1dc323bb1bc1074d6a0c778b11
|
| 3 |
+
size 201318407
|
models/voxblink2_samresnet100_ft/config.yaml
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_type: shard
|
| 2 |
+
dataloader_args:
|
| 3 |
+
batch_size: 128
|
| 4 |
+
drop_last: true
|
| 5 |
+
num_workers: 16
|
| 6 |
+
pin_memory: false
|
| 7 |
+
prefetch_factor: 8
|
| 8 |
+
dataset_args:
|
| 9 |
+
aug_prob: 0.6
|
| 10 |
+
fbank_args:
|
| 11 |
+
dither: 1.0
|
| 12 |
+
frame_length: 25
|
| 13 |
+
frame_shift: 10
|
| 14 |
+
num_mel_bins: 80
|
| 15 |
+
filter: true
|
| 16 |
+
filter_args:
|
| 17 |
+
max_num_frames: 800
|
| 18 |
+
min_num_frames: 100
|
| 19 |
+
num_frms: 200
|
| 20 |
+
resample_rate: 16000
|
| 21 |
+
sample_num_per_epoch: 0
|
| 22 |
+
shuffle: true
|
| 23 |
+
shuffle_args:
|
| 24 |
+
shuffle_size: 2500
|
| 25 |
+
spec_aug: false
|
| 26 |
+
spec_aug_args:
|
| 27 |
+
max_f: 8
|
| 28 |
+
max_t: 10
|
| 29 |
+
num_f_mask: 1
|
| 30 |
+
num_t_mask: 1
|
| 31 |
+
prob: 0.6
|
| 32 |
+
speed_perturb: true
|
| 33 |
+
enable_amp: false
|
| 34 |
+
exp_dir: exp/samresnet100/
|
| 35 |
+
gpus:
|
| 36 |
+
- 0
|
| 37 |
+
- 1
|
| 38 |
+
log_batch_interval: 100
|
| 39 |
+
loss: CrossEntropyLoss
|
| 40 |
+
loss_args: {}
|
| 41 |
+
margin_scheduler: MarginScheduler
|
| 42 |
+
margin_update:
|
| 43 |
+
epoch_iter: 4265
|
| 44 |
+
final_margin: 0.2
|
| 45 |
+
fix_start_epoch: 40
|
| 46 |
+
increase_start_epoch: 20
|
| 47 |
+
increase_type: exp
|
| 48 |
+
initial_margin: 0.0
|
| 49 |
+
update_margin: true
|
| 50 |
+
model: SimAM_ResNet100_ASP
|
| 51 |
+
model_args:
|
| 52 |
+
embed_dim: 256
|
| 53 |
+
model_init: null
|
| 54 |
+
noise_data: data/musan/lmdb
|
| 55 |
+
num_avg: 1
|
| 56 |
+
num_epochs: 150
|
| 57 |
+
optimizer: SGD
|
| 58 |
+
optimizer_args:
|
| 59 |
+
lr: 0.1
|
| 60 |
+
momentum: 0.9
|
| 61 |
+
nesterov: true
|
| 62 |
+
weight_decay: 0.0001
|
| 63 |
+
projection_args:
|
| 64 |
+
do_lm: false
|
| 65 |
+
easy_margin: false
|
| 66 |
+
embed_dim: 256
|
| 67 |
+
num_class: 17982
|
| 68 |
+
project_type: arc_margin
|
| 69 |
+
scale: 32.0
|
| 70 |
+
reverb_data: data/rirs/lmdb
|
| 71 |
+
save_epoch_interval: 5
|
| 72 |
+
scheduler: ExponentialDecrease
|
| 73 |
+
scheduler_args:
|
| 74 |
+
epoch_iter: 4265
|
| 75 |
+
final_lr: 5.0e-05
|
| 76 |
+
initial_lr: 0.1
|
| 77 |
+
num_epochs: 150
|
| 78 |
+
scale_ratio: 4.0
|
| 79 |
+
warm_from_zero: true
|
| 80 |
+
warm_up_epoch: 6
|
| 81 |
+
seed: 42
|
| 82 |
+
train_data: data/vox2_dev/shard.list
|
| 83 |
+
train_label: data/vox2_dev/utt2spk
|
models/wespeaker/chinese/config.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: cnceleb_resnet34_LM
|
| 2 |
+
task: speaker_verification
|
| 3 |
+
domain: speech
|
| 4 |
+
framework: onnxruntime
|
| 5 |
+
dataset: cnceleb
|
| 6 |
+
language: chinese
|
| 7 |
+
sample_rate: 16000
|
models/wespeaker/chinese/model.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e7584940aeac8d5512d875e58ce6c09ba4ddad65d8128e1dac0d93aadd087ebb
|
| 3 |
+
size 26530309
|
python_backend/similarity_model/1/model.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 7 |
+
import traceback
|
| 8 |
+
from torch.utils.dlpack import from_dlpack
|
| 9 |
+
import triton_python_backend_utils as pb_utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TritonPythonModel:
|
| 13 |
+
def initialize(self, args):
|
| 14 |
+
self.sample_rate = 16000
|
| 15 |
+
self.feature_dim = 80
|
| 16 |
+
self.vad_enabled = True # This variable is declared but not used.
|
| 17 |
+
self.min_duration = 0.1
|
| 18 |
+
|
| 19 |
+
# This seems correct for BLS (Business Logic Scripting)
|
| 20 |
+
self.speaker_model_name = "speaker_model"
|
| 21 |
+
|
| 22 |
+
def execute(self, requests):
|
| 23 |
+
responses = []
|
| 24 |
+
for request in requests:
|
| 25 |
+
try:
|
| 26 |
+
# 1. Get the input audio BYTES, not a file path string.
|
| 27 |
+
# The input tensor is of type TYPE_STRING, which holds bytes.
|
| 28 |
+
# .as_numpy()[0] gives you the raw bytes object.
|
| 29 |
+
audio1_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_1").as_numpy()[0][0]
|
| 30 |
+
audio2_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_2").as_numpy()[0][0]
|
| 31 |
+
|
| 32 |
+
# 2. Preprocess audio from bytes
|
| 33 |
+
feats1 = self.preprocess(audio1_bytes)
|
| 34 |
+
feats2 = self.preprocess(audio2_bytes)
|
| 35 |
+
|
| 36 |
+
# 3. Call the speaker_model to compute similarity
|
| 37 |
+
similarity = self.compute_similarity(feats1, feats2)
|
| 38 |
+
|
| 39 |
+
# Prepare output
|
| 40 |
+
output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity], dtype=np.float32))
|
| 41 |
+
response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
|
| 42 |
+
responses.append(response)
|
| 43 |
+
|
| 44 |
+
except pb_utils.TritonModelException as e:
|
| 45 |
+
# If a Triton-specific error occurs, create an error response
|
| 46 |
+
error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(str(e)))
|
| 47 |
+
pb_utils.Logger.log_error(error_response)
|
| 48 |
+
responses.append(error_response)
|
| 49 |
+
except Exception as e:
|
| 50 |
+
# For any other unexpected error, log it and return an error response
|
| 51 |
+
error_message = f"Unexpected error: {e}\n{traceback.format_exc()}"
|
| 52 |
+
pb_utils.Logger.log_error(error_message)
|
| 53 |
+
error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_message))
|
| 54 |
+
responses.append(error_response)
|
| 55 |
+
|
| 56 |
+
return responses
|
| 57 |
+
|
| 58 |
+
def preprocess(self, audio_bytes: bytes):
|
| 59 |
+
"""
|
| 60 |
+
Processes audio data from an in-memory byte buffer.
|
| 61 |
+
If the audio is too short, it's padded by repetition to meet the minimum length.
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
# Wrap the raw bytes in a file-like object for torchaudio
|
| 65 |
+
# buffer = io.BytesIO(audio_bytes)
|
| 66 |
+
buffer = audio_bytes.decode('utf-8')
|
| 67 |
+
waveform, sample_rate = torchaudio.load(buffer)
|
| 68 |
+
|
| 69 |
+
# You might want to resample if the client's sample rate differs
|
| 70 |
+
if sample_rate != self.sample_rate:
|
| 71 |
+
# Note: This requires the 'torchaudio.transforms' module.
|
| 72 |
+
# Make sure torchaudio is fully installed in your Triton environment.
|
| 73 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
|
| 74 |
+
waveform = resampler(waveform)
|
| 75 |
+
|
| 76 |
+
duration = waveform.shape[1] / self.sample_rate
|
| 77 |
+
|
| 78 |
+
if duration < self.min_duration:
|
| 79 |
+
# Audio is too short, repeat it to meet the minimum duration
|
| 80 |
+
repeat_times = math.ceil(self.min_duration / duration)
|
| 81 |
+
waveform = waveform.repeat(1, repeat_times)
|
| 82 |
+
|
| 83 |
+
# --- THIS IS THE NEW, CRITICAL PART ---
|
| 84 |
+
# Calculate 80-dimensional Fbank features, which is what the speaker_model expects.
|
| 85 |
+
# The waveform needs to be shape [batch, time], so we squeeze it.
|
| 86 |
+
features = kaldi.fbank(
|
| 87 |
+
waveform.squeeze(0).unsqueeze(0), # Needs shape [1, T]
|
| 88 |
+
num_mel_bins=self.feature_dim, # This is 80
|
| 89 |
+
sample_frequency=self.sample_rate,
|
| 90 |
+
frame_length=25,
|
| 91 |
+
frame_shift=10
|
| 92 |
+
)
|
| 93 |
+
# The output of fbank is [1, num_frames, num_bins], e.g., [1, 150, 80]
|
| 94 |
+
# We need [num_frames, num_bins] for the speaker model
|
| 95 |
+
return features.squeeze(0) # Returns shape [num_frames, 80]
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
# Raise a specific exception that can be caught in execute()
|
| 99 |
+
raise pb_utils.TritonModelException(f"Failed during audio preprocessing: {e}")
|
| 100 |
+
|
| 101 |
+
def compute_similarity(self, waveform1, waveform2):
|
| 102 |
+
# Call speaker_model to get embeddings
|
| 103 |
+
# Assuming speaker_model takes a waveform and outputs an embedding
|
| 104 |
+
e1 = torch.from_numpy(self.call_speaker_model(waveform1)).to("cuda")
|
| 105 |
+
e2 = torch.from_numpy(self.call_speaker_model(waveform2)).to("cuda")
|
| 106 |
+
|
| 107 |
+
# Flatten the tensors
|
| 108 |
+
e1 = e1.flatten()
|
| 109 |
+
e2 = e2.flatten()
|
| 110 |
+
|
| 111 |
+
# Calculate cosine similarity
|
| 112 |
+
dot_product = torch.dot(e1, e2)
|
| 113 |
+
norm_e1 = torch.norm(e1)
|
| 114 |
+
norm_e2 = torch.norm(e2)
|
| 115 |
+
|
| 116 |
+
# Handle zero norms
|
| 117 |
+
if norm_e1 == 0 or norm_e2 == 0:
|
| 118 |
+
return 0.0
|
| 119 |
+
|
| 120 |
+
similarity = (dot_product / (norm_e1 * norm_e2)).item()
|
| 121 |
+
|
| 122 |
+
# Normalize from [-1, 1] to [0, 1]
|
| 123 |
+
return (similarity + 1) / 2
|
| 124 |
+
|
| 125 |
+
def call_speaker_model(self, waveform):
|
| 126 |
+
"""Calls the speaker_model to get an embedding vector."""
|
| 127 |
+
# Create the input tensor for the speaker_model.
|
| 128 |
+
# The name 'feats' here must match the input name in speaker_model's config.pbtxt
|
| 129 |
+
if waveform.dim() == 2:
|
| 130 |
+
waveform = waveform.unsqueeze(0)
|
| 131 |
+
input_tensor = pb_utils.Tensor("feats", waveform.cpu().numpy().astype(np.float32))
|
| 132 |
+
|
| 133 |
+
inference_request = pb_utils.InferenceRequest(
|
| 134 |
+
model_name=self.speaker_model_name,
|
| 135 |
+
requested_output_names=["embs"], # Must match output name in speaker_model's config
|
| 136 |
+
inputs=[input_tensor]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
inference_response = inference_request.exec()
|
| 140 |
+
|
| 141 |
+
if inference_response.has_error():
|
| 142 |
+
raise pb_utils.TritonModelException(f"Error from speaker_model: {inference_response.error().message()}")
|
| 143 |
+
|
| 144 |
+
output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs")
|
| 145 |
+
if output_tensor.is_cpu():
|
| 146 |
+
output_tensor = output_tensor.as_numpy()
|
| 147 |
+
else:
|
| 148 |
+
output_tensor = from_dlpack(output_tensor.to_dlpack()).detach().cpu().numpy()
|
| 149 |
+
return output_tensor
|
python_backend/similarity_model/1/model_old.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torchaudio
|
| 4 |
+
import traceback
|
| 5 |
+
import triton_python_backend_utils as pb_utils
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TritonPythonModel:
|
| 9 |
+
def initialize(self, args):
|
| 10 |
+
self.sample_rate = 16000
|
| 11 |
+
self.feature_dim = 80
|
| 12 |
+
self.vad_enabled = True
|
| 13 |
+
self.min_duration = 0.1
|
| 14 |
+
|
| 15 |
+
# 创建与speaker_model通信的客户端
|
| 16 |
+
self.speaker_model_name = "speaker_model"
|
| 17 |
+
|
| 18 |
+
def execute(self, requests):
|
| 19 |
+
responses = []
|
| 20 |
+
for request in requests:
|
| 21 |
+
# 获取输入音频
|
| 22 |
+
audio1 = pb_utils.get_input_tensor_by_name(request, "AUDIO1").as_numpy()[0].decode('utf-8')
|
| 23 |
+
audio2 = pb_utils.get_input_tensor_by_name(request, "AUDIO2").as_numpy()[0].decode('utf-8')
|
| 24 |
+
|
| 25 |
+
# 预处理音频
|
| 26 |
+
feats1 = self.preprocess(audio1)
|
| 27 |
+
feats2 = self.preprocess(audio2)
|
| 28 |
+
|
| 29 |
+
# 调用speaker_model计算相似度
|
| 30 |
+
similarity = self.compute_similarity(feats1, feats2)
|
| 31 |
+
|
| 32 |
+
# 准备输出
|
| 33 |
+
output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity]), dtype=np.float32)
|
| 34 |
+
response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
|
| 35 |
+
responses.append(response)
|
| 36 |
+
|
| 37 |
+
return responses
|
| 38 |
+
|
| 39 |
+
def preprocess(self, audio_path):
|
| 40 |
+
"""
|
| 41 |
+
处理音频文件,如果过短则复制到满足最小长度要求
|
| 42 |
+
返回处理后的音频路径和是否为临时文件的标志
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 46 |
+
duration = waveform.shape[1] / sample_rate
|
| 47 |
+
|
| 48 |
+
if duration >= self.min_duration:
|
| 49 |
+
# 音频长度足够,直接返回原路径
|
| 50 |
+
return waveform
|
| 51 |
+
|
| 52 |
+
# 音频过短,需要复制
|
| 53 |
+
repeat_times = math.ceil(self.min_duration / duration)
|
| 54 |
+
|
| 55 |
+
# 复制音频
|
| 56 |
+
return waveform.repeat(1, repeat_times)
|
| 57 |
+
|
| 58 |
+
except Exception:
|
| 59 |
+
traceback.format_exc()
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
def compute_similarity(self, feats1, feats2):
|
| 63 |
+
# 调用speaker_model获取嵌入向量
|
| 64 |
+
e1 = self.call_speaker_model(feats1)
|
| 65 |
+
e2 = self.call_speaker_model(feats2)
|
| 66 |
+
|
| 67 |
+
# 计算余弦相似度
|
| 68 |
+
dot_product = np.dot(e1, e2)
|
| 69 |
+
norm_e1 = np.linalg.norm(e1)
|
| 70 |
+
norm_e2 = np.linalg.norm(e2)
|
| 71 |
+
similarity = dot_product / (norm_e1 * norm_e2)
|
| 72 |
+
|
| 73 |
+
# 归一化到[0, 1]
|
| 74 |
+
return (similarity + 1) / 2
|
| 75 |
+
|
| 76 |
+
def call_speaker_model(self, features):
|
| 77 |
+
"""调用speaker_model获取嵌入向量"""
|
| 78 |
+
# 创建输入张量
|
| 79 |
+
input_tensor = pb_utils.Tensor("feats", features.astype(np.float32))
|
| 80 |
+
|
| 81 |
+
# 创建推理请求
|
| 82 |
+
inference_request = pb_utils.InferenceRequest(
|
| 83 |
+
model_name=self.speaker_model_name,
|
| 84 |
+
requested_output_names=["embs"],
|
| 85 |
+
inputs=[input_tensor]
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# 发送请求
|
| 89 |
+
inference_response = inference_request.exec()
|
| 90 |
+
|
| 91 |
+
# 处理响应
|
| 92 |
+
if inference_response.has_error():
|
| 93 |
+
raise pb_utils.TritonModelException(inference_response.error().message())
|
| 94 |
+
|
| 95 |
+
# 获取嵌入向量
|
| 96 |
+
output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs")
|
| 97 |
+
return output_tensor.as_numpy()
|
python_backend/similarity_model/1/model_runnable.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 7 |
+
import traceback
|
| 8 |
+
from torch.utils.dlpack import from_dlpack
|
| 9 |
+
import triton_python_backend_utils as pb_utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TritonPythonModel:
|
| 13 |
+
def initialize(self, args):
|
| 14 |
+
self.sample_rate = 16000
|
| 15 |
+
self.feature_dim = 80
|
| 16 |
+
self.vad_enabled = True # This variable is declared but not used.
|
| 17 |
+
self.min_duration = 0.1
|
| 18 |
+
|
| 19 |
+
# This seems correct for BLS (Business Logic Scripting)
|
| 20 |
+
self.speaker_model_name = "speaker_model"
|
| 21 |
+
|
| 22 |
+
def execute(self, requests):
|
| 23 |
+
responses = []
|
| 24 |
+
for request in requests:
|
| 25 |
+
try:
|
| 26 |
+
# 1. Get the input audio BYTES, not a file path string.
|
| 27 |
+
# The input tensor is of type TYPE_STRING, which holds bytes.
|
| 28 |
+
# .as_numpy()[0] gives you the raw bytes object.
|
| 29 |
+
audio1_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_1").as_numpy()[0][0]
|
| 30 |
+
audio2_bytes = pb_utils.get_input_tensor_by_name(request, "AUDIO_BYTES_2").as_numpy()[0][0]
|
| 31 |
+
|
| 32 |
+
# 2. Preprocess audio from bytes
|
| 33 |
+
feats1 = self.preprocess(audio1_bytes)
|
| 34 |
+
feats2 = self.preprocess(audio2_bytes)
|
| 35 |
+
|
| 36 |
+
# 3. Call the speaker_model to compute similarity
|
| 37 |
+
similarity = self.compute_similarity(feats1, feats2)
|
| 38 |
+
|
| 39 |
+
pb_utils.Logger.log_info(similarity)
|
| 40 |
+
# Prepare output
|
| 41 |
+
output_tensor = pb_utils.Tensor("SIMILARITY", np.array([similarity], dtype=np.float32))
|
| 42 |
+
response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
|
| 43 |
+
responses.append(response)
|
| 44 |
+
|
| 45 |
+
except pb_utils.TritonModelException as e:
|
| 46 |
+
# If a Triton-specific error occurs, create an error response
|
| 47 |
+
error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(str(e)))
|
| 48 |
+
responses.append(error_response)
|
| 49 |
+
except Exception as e:
|
| 50 |
+
# For any other unexpected error, log it and return an error response
|
| 51 |
+
error_message = f"Unexpected error: {e}\n{traceback.format_exc()}"
|
| 52 |
+
pb_utils.Logger.log_error(error_message)
|
| 53 |
+
error_response = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_message))
|
| 54 |
+
responses.append(error_response)
|
| 55 |
+
|
| 56 |
+
return responses
|
| 57 |
+
|
| 58 |
+
def preprocess(self, audio_bytes: bytes):
|
| 59 |
+
"""
|
| 60 |
+
Processes audio data from an in-memory byte buffer.
|
| 61 |
+
If the audio is too short, it's padded by repetition to meet the minimum length.
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
# Wrap the raw bytes in a file-like object for torchaudio
|
| 65 |
+
# buffer = io.BytesIO(audio_bytes)
|
| 66 |
+
buffer = audio_bytes.decode('utf-8')
|
| 67 |
+
waveform, sample_rate = torchaudio.load(buffer)
|
| 68 |
+
|
| 69 |
+
# You might want to resample if the client's sample rate differs
|
| 70 |
+
if sample_rate != self.sample_rate:
|
| 71 |
+
# Note: This requires the 'torchaudio.transforms' module.
|
| 72 |
+
# Make sure torchaudio is fully installed in your Triton environment.
|
| 73 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
|
| 74 |
+
waveform = resampler(waveform)
|
| 75 |
+
|
| 76 |
+
duration = waveform.shape[1] / self.sample_rate
|
| 77 |
+
|
| 78 |
+
if duration < self.min_duration:
|
| 79 |
+
# Audio is too short, repeat it to meet the minimum duration
|
| 80 |
+
repeat_times = math.ceil(self.min_duration / duration)
|
| 81 |
+
waveform = waveform.repeat(1, repeat_times)
|
| 82 |
+
|
| 83 |
+
# --- THIS IS THE NEW, CRITICAL PART ---
|
| 84 |
+
# Calculate 80-dimensional Fbank features, which is what the speaker_model expects.
|
| 85 |
+
# The waveform needs to be shape [batch, time], so we squeeze it.
|
| 86 |
+
features = kaldi.fbank(
|
| 87 |
+
waveform.squeeze(0).unsqueeze(0), # Needs shape [1, T]
|
| 88 |
+
num_mel_bins=self.feature_dim, # This is 80
|
| 89 |
+
sample_frequency=self.sample_rate,
|
| 90 |
+
frame_length=25,
|
| 91 |
+
frame_shift=10
|
| 92 |
+
)
|
| 93 |
+
# The output of fbank is [1, num_frames, num_bins], e.g., [1, 150, 80]
|
| 94 |
+
# We need [num_frames, num_bins] for the speaker model
|
| 95 |
+
return features.squeeze(0) # Returns shape [num_frames, 80]
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
# Raise a specific exception that can be caught in execute()
|
| 99 |
+
raise pb_utils.TritonModelException(f"Failed during audio preprocessing: {e}")
|
| 100 |
+
|
| 101 |
+
def compute_similarity(self, waveform1, waveform2):
|
| 102 |
+
# Call speaker_model to get embeddings
|
| 103 |
+
# Assuming speaker_model takes a waveform and outputs an embedding
|
| 104 |
+
e1 = torch.from_numpy(self.call_speaker_model(waveform1)).to("cuda")
|
| 105 |
+
e2 = torch.from_numpy(self.call_speaker_model(waveform2)).to("cuda")
|
| 106 |
+
|
| 107 |
+
# Flatten the tensors
|
| 108 |
+
e1 = e1.flatten()
|
| 109 |
+
e2 = e2.flatten()
|
| 110 |
+
|
| 111 |
+
# Calculate cosine similarity
|
| 112 |
+
dot_product = torch.dot(e1, e2)
|
| 113 |
+
norm_e1 = torch.norm(e1)
|
| 114 |
+
norm_e2 = torch.norm(e2)
|
| 115 |
+
|
| 116 |
+
# Handle zero norms
|
| 117 |
+
if norm_e1 == 0 or norm_e2 == 0:
|
| 118 |
+
return 0.0
|
| 119 |
+
|
| 120 |
+
similarity = (dot_product / (norm_e1 * norm_e2)).item()
|
| 121 |
+
|
| 122 |
+
# Normalize from [-1, 1] to [0, 1]
|
| 123 |
+
return (similarity + 1) / 2
|
| 124 |
+
|
| 125 |
+
def call_speaker_model(self, waveform):
|
| 126 |
+
"""Calls the speaker_model to get an embedding vector."""
|
| 127 |
+
# Create the input tensor for the speaker_model.
|
| 128 |
+
# The name 'feats' here must match the input name in speaker_model's config.pbtxt
|
| 129 |
+
if waveform.dim() == 2:
|
| 130 |
+
waveform = waveform.unsqueeze(0)
|
| 131 |
+
input_tensor = pb_utils.Tensor("feats", waveform.cpu().numpy().astype(np.float32))
|
| 132 |
+
|
| 133 |
+
inference_request = pb_utils.InferenceRequest(
|
| 134 |
+
model_name=self.speaker_model_name,
|
| 135 |
+
requested_output_names=["embs"], # Must match output name in speaker_model's config
|
| 136 |
+
inputs=[input_tensor]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
inference_response = inference_request.exec()
|
| 140 |
+
|
| 141 |
+
if inference_response.has_error():
|
| 142 |
+
raise pb_utils.TritonModelException(f"Error from speaker_model: {inference_response.error().message()}")
|
| 143 |
+
|
| 144 |
+
output_tensor = pb_utils.get_output_tensor_by_name(inference_response, "embs")
|
| 145 |
+
if output_tensor.is_cpu():
|
| 146 |
+
output_tensor = output_tensor.as_numpy()
|
| 147 |
+
else:
|
| 148 |
+
output_tensor = from_dlpack(output_tensor.to_dlpack()).detach().cpu().numpy()
|
| 149 |
+
return output_tensor
|
python_backend/similarity_model/config.pbtxt.back
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: "similarity_model"
|
| 2 |
+
backend: "python"
|
| 3 |
+
max_batch_size: 128
|
| 4 |
+
|
| 5 |
+
parameters: {
|
| 6 |
+
key: "EXECUTION_ENV_PATH",
|
| 7 |
+
value: {
|
| 8 |
+
# string_value: "/inspire/hdd/project/embodied-multimodality/public/cchang/env/audio.tar.gz"
|
| 9 |
+
# string_value: "/inspire/hdd/project/embodied-multimodality/public/cchang/env/audio_clean.tar.gz"
|
| 10 |
+
# string_value: "/inspire/hdd/project/embodied-multimodality/public/cchang/env_tar/audio_env.tar.gz"
|
| 11 |
+
string_value: "/inspire/hdd/project/embodied-multimodality/public/cchang/env/mooncast/bin/python"
|
| 12 |
+
}
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
input [
|
| 16 |
+
{
|
| 17 |
+
name: "AUDIO1"
|
| 18 |
+
data_type: TYPE_STRING
|
| 19 |
+
dims: [ 1 ] # 音频路径
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
name: "AUDIO2"
|
| 23 |
+
data_type: TYPE_STRING
|
| 24 |
+
dims: [ 1 ] # 音频路径
|
| 25 |
+
}
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
output [
|
| 29 |
+
{
|
| 30 |
+
name: "SIMILARITY"
|
| 31 |
+
data_type: TYPE_FP32
|
| 32 |
+
dims: [ 1 ] # 相似度分数
|
| 33 |
+
}
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
dynamic_batching {
|
| 37 |
+
preferred_batch_size: [ 16, 32 ]
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
instance_group [
|
| 41 |
+
{
|
| 42 |
+
count: 1
|
| 43 |
+
kind: KIND_GPU
|
| 44 |
+
}
|
| 45 |
+
]
|
| 46 |
+
|
python_backend/similarity_model/config.pbtxt.disabled
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: "similarity_model" # Or whatever you call this model
|
| 2 |
+
backend: "python"
|
| 3 |
+
max_batch_size: 8
|
| 4 |
+
|
| 5 |
+
# Input tensors are now raw audio bytes
|
| 6 |
+
input [
|
| 7 |
+
{
|
| 8 |
+
name: "AUDIO_BYTES_1"
|
| 9 |
+
data_type: TYPE_STRING # TYPE_STRING is used for variable-length binary data
|
| 10 |
+
dims: [ 1 ]
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
name: "AUDIO_BYTES_2"
|
| 14 |
+
data_type: TYPE_STRING
|
| 15 |
+
dims: [ 1 ]
|
| 16 |
+
}
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
# Output is a single similarity score
|
| 20 |
+
output [
|
| 21 |
+
{
|
| 22 |
+
name: "SIMILARITY"
|
| 23 |
+
data_type: TYPE_FP32
|
| 24 |
+
dims: [ 1 ]
|
| 25 |
+
}
|
| 26 |
+
]
|
similarity.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Binbin Zhang ([email protected])
|
| 2 |
+
# Shuai Wang ([email protected])
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
from typing import List, Tuple
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
|
| 22 |
+
import torch
|
| 23 |
+
import torchaudio
|
| 24 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 25 |
+
import yaml
|
| 26 |
+
import kaldiio
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
from wespeaker.cli.hub import Hub
|
| 30 |
+
from wespeaker.cli.utils import get_args
|
| 31 |
+
from wespeaker.models.speaker_model import get_speaker_model
|
| 32 |
+
from wespeaker.utils.checkpoint import load_checkpoint
|
| 33 |
+
from wespeaker.diar.umap_clusterer import cluster
|
| 34 |
+
from wespeaker.diar.extract_emb import subsegment
|
| 35 |
+
from wespeaker.diar.make_rttm import merge_segments
|
| 36 |
+
from wespeaker.utils.utils import set_seed
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Speaker:
|
| 40 |
+
|
| 41 |
+
def __init__(self, model_dir: str):
|
| 42 |
+
set_seed()
|
| 43 |
+
|
| 44 |
+
config_path = os.path.join(model_dir, 'config.yaml')
|
| 45 |
+
model_path = os.path.join(model_dir, 'avg_model.pt')
|
| 46 |
+
with open(config_path, 'r') as fin:
|
| 47 |
+
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
| 48 |
+
self.model = get_speaker_model(
|
| 49 |
+
configs['model'])(**configs['model_args'])
|
| 50 |
+
load_checkpoint(self.model, model_path)
|
| 51 |
+
self.model.eval()
|
| 52 |
+
self.vad = load_silero_vad()
|
| 53 |
+
self.table = {}
|
| 54 |
+
self.resample_rate = 16000
|
| 55 |
+
self.apply_vad = False
|
| 56 |
+
self.device = torch.device('cpu')
|
| 57 |
+
self.wavform_norm = False
|
| 58 |
+
|
| 59 |
+
# diarization parmas
|
| 60 |
+
self.diar_min_duration = 0.255
|
| 61 |
+
self.diar_window_secs = 1.5
|
| 62 |
+
self.diar_period_secs = 0.75
|
| 63 |
+
self.diar_frame_shift = 10
|
| 64 |
+
self.diar_batch_size = 32
|
| 65 |
+
self.diar_subseg_cmn = True
|
| 66 |
+
|
| 67 |
+
def set_wavform_norm(self, wavform_norm: bool):
|
| 68 |
+
self.wavform_norm = wavform_norm
|
| 69 |
+
|
| 70 |
+
def set_resample_rate(self, resample_rate: int):
|
| 71 |
+
self.resample_rate = resample_rate
|
| 72 |
+
|
| 73 |
+
def set_vad(self, apply_vad: bool):
|
| 74 |
+
self.apply_vad = apply_vad
|
| 75 |
+
|
| 76 |
+
def set_device(self, device: str):
|
| 77 |
+
self.device = torch.device(device)
|
| 78 |
+
self.model = self.model.to(self.device)
|
| 79 |
+
|
| 80 |
+
def set_diarization_params(self,
|
| 81 |
+
min_duration: float = 0.255,
|
| 82 |
+
window_secs: float = 1.5,
|
| 83 |
+
period_secs: float = 0.75,
|
| 84 |
+
frame_shift: int = 10,
|
| 85 |
+
batch_size: int = 32,
|
| 86 |
+
subseg_cmn: bool = True):
|
| 87 |
+
self.diar_min_duration = min_duration
|
| 88 |
+
self.diar_window_secs = window_secs
|
| 89 |
+
self.diar_period_secs = period_secs
|
| 90 |
+
self.diar_frame_shift = frame_shift
|
| 91 |
+
self.diar_batch_size = batch_size
|
| 92 |
+
self.diar_subseg_cmn = subseg_cmn
|
| 93 |
+
|
| 94 |
+
def compute_fbank(self,
|
| 95 |
+
wavform,
|
| 96 |
+
sample_rate=16000,
|
| 97 |
+
num_mel_bins=80,
|
| 98 |
+
frame_length=25,
|
| 99 |
+
frame_shift=10,
|
| 100 |
+
cmn=True):
|
| 101 |
+
feat = kaldi.fbank(wavform,
|
| 102 |
+
num_mel_bins=num_mel_bins,
|
| 103 |
+
frame_length=frame_length,
|
| 104 |
+
frame_shift=frame_shift,
|
| 105 |
+
sample_frequency=sample_rate,
|
| 106 |
+
window_type='hamming')
|
| 107 |
+
if cmn:
|
| 108 |
+
feat = feat - torch.mean(feat, 0)
|
| 109 |
+
return feat
|
| 110 |
+
|
| 111 |
+
def extract_embedding_feats(self, fbanks, batch_size, subseg_cmn):
|
| 112 |
+
fbanks_array = np.stack(fbanks)
|
| 113 |
+
if subseg_cmn:
|
| 114 |
+
fbanks_array = fbanks_array - np.mean(
|
| 115 |
+
fbanks_array, axis=1, keepdims=True)
|
| 116 |
+
embeddings = []
|
| 117 |
+
fbanks_array = torch.from_numpy(fbanks_array).to(self.device)
|
| 118 |
+
for i in tqdm(range(0, fbanks_array.shape[0], batch_size)):
|
| 119 |
+
batch_feats = fbanks_array[i:i + batch_size]
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
batch_embs = self.model(batch_feats)
|
| 122 |
+
batch_embs = batch_embs[-1] if isinstance(
|
| 123 |
+
batch_embs, tuple) else batch_embs
|
| 124 |
+
embeddings.append(batch_embs.detach().cpu().numpy())
|
| 125 |
+
embeddings = np.vstack(embeddings)
|
| 126 |
+
return embeddings
|
| 127 |
+
|
| 128 |
+
def extract_embedding(self, audio_path: str):
|
| 129 |
+
pcm, sample_rate = torchaudio.load(audio_path,
|
| 130 |
+
normalize=self.wavform_norm)
|
| 131 |
+
return self.extract_embedding_from_pcm(pcm, sample_rate)
|
| 132 |
+
|
| 133 |
+
def extract_embedding_from_pcm(self, pcm: torch.Tensor, sample_rate: int):
|
| 134 |
+
if self.apply_vad:
|
| 135 |
+
# TODO(Binbin Zhang): Refine the segments logic, here we just
|
| 136 |
+
# suppose there is only silence at the start/end of the speech
|
| 137 |
+
vad_sample_rate = 16000
|
| 138 |
+
wav = pcm
|
| 139 |
+
if wav.size(0) > 1:
|
| 140 |
+
wav = wav.mean(dim=0, keepdim=True)
|
| 141 |
+
|
| 142 |
+
if sample_rate != vad_sample_rate:
|
| 143 |
+
transform = torchaudio.transforms.Resample(
|
| 144 |
+
orig_freq=sample_rate, new_freq=vad_sample_rate)
|
| 145 |
+
wav = transform(wav)
|
| 146 |
+
segments = get_speech_timestamps(wav,
|
| 147 |
+
self.vad,
|
| 148 |
+
return_seconds=True)
|
| 149 |
+
pcmTotal = torch.Tensor()
|
| 150 |
+
if len(segments) > 0: # remove all the silence
|
| 151 |
+
for segment in segments:
|
| 152 |
+
start = int(segment['start'] * sample_rate)
|
| 153 |
+
end = int(segment['end'] * sample_rate)
|
| 154 |
+
pcmTemp = pcm[0, start:end]
|
| 155 |
+
pcmTotal = torch.cat([pcmTotal, pcmTemp], 0)
|
| 156 |
+
pcm = pcmTotal.unsqueeze(0)
|
| 157 |
+
else: # all silence, nospeech
|
| 158 |
+
return None
|
| 159 |
+
pcm = pcm.to(torch.float)
|
| 160 |
+
if sample_rate != self.resample_rate:
|
| 161 |
+
pcm = torchaudio.transforms.Resample(
|
| 162 |
+
orig_freq=sample_rate, new_freq=self.resample_rate)(pcm)
|
| 163 |
+
feats = self.compute_fbank(pcm,
|
| 164 |
+
sample_rate=self.resample_rate,
|
| 165 |
+
cmn=True)
|
| 166 |
+
feats = feats.unsqueeze(0)
|
| 167 |
+
feats = feats.to(self.device)
|
| 168 |
+
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
outputs = self.model(feats)
|
| 171 |
+
outputs = outputs[-1] if isinstance(outputs, tuple) else outputs
|
| 172 |
+
embedding = outputs[0].to(torch.device('cpu'))
|
| 173 |
+
return embedding
|
| 174 |
+
|
| 175 |
+
def extract_embedding_list(self, scp_path: str):
|
| 176 |
+
names = []
|
| 177 |
+
embeddings = []
|
| 178 |
+
with open(scp_path, 'r') as read_scp:
|
| 179 |
+
for line in tqdm(read_scp):
|
| 180 |
+
name, wav_path = line.strip().split()
|
| 181 |
+
names.append(name)
|
| 182 |
+
embedding = self.extract_embedding(wav_path)
|
| 183 |
+
embeddings.append(embedding.detach().numpy())
|
| 184 |
+
return names, embeddings
|
| 185 |
+
|
| 186 |
+
def compute_similarity(self, audio_path1: str, audio_path2: str) -> float:
|
| 187 |
+
e1 = self.extract_embedding(audio_path1)
|
| 188 |
+
e2 = self.extract_embedding(audio_path2)
|
| 189 |
+
if e1 is None or e2 is None:
|
| 190 |
+
return 0.0
|
| 191 |
+
else:
|
| 192 |
+
return self.cosine_similarity(e1, e2)
|
| 193 |
+
|
| 194 |
+
def compute_similarity_batch(
|
| 195 |
+
self, audio_pairs: List[Tuple[str, str]]) -> List[float]:
|
| 196 |
+
"""
|
| 197 |
+
Computes cosine similarity for a batch of audio file pairs.
|
| 198 |
+
This method is optimized to extract embedding for each unique audio file
|
| 199 |
+
only once.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
audio_pairs (List[Tuple[str, str]]): A list of tuples, where each
|
| 203 |
+
tuple contains two audio paths.
|
| 204 |
+
e.g., [('audio1.wav', 'audio2.wav'),
|
| 205 |
+
('audio1.wav', 'audio3.wav')]
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
List[float]: A list of similarity scores, corresponding to the
|
| 209 |
+
input pairs.
|
| 210 |
+
"""
|
| 211 |
+
# 1. Collect all unique audio paths to avoid redundant computations
|
| 212 |
+
unique_audio_paths = set()
|
| 213 |
+
for path1, path2 in audio_pairs:
|
| 214 |
+
unique_audio_paths.add(path1)
|
| 215 |
+
unique_audio_paths.add(path2)
|
| 216 |
+
|
| 217 |
+
# 2. Extract embeddings for all unique files and store them in a cache
|
| 218 |
+
embedding_cache = {}
|
| 219 |
+
print(f"Extracting embeddings for {len(unique_audio_paths)} "
|
| 220 |
+
"unique audio files...")
|
| 221 |
+
for path in tqdm(list(unique_audio_paths)):
|
| 222 |
+
embedding_cache[path] = self.extract_embedding(path)
|
| 223 |
+
|
| 224 |
+
# 3. Compute similarity for each pair using the cached embeddings
|
| 225 |
+
scores = []
|
| 226 |
+
for path1, path2 in audio_pairs:
|
| 227 |
+
e1 = embedding_cache.get(path1)
|
| 228 |
+
e2 = embedding_cache.get(path2)
|
| 229 |
+
|
| 230 |
+
if e1 is None or e2 is None:
|
| 231 |
+
# Handle cases where embedding extraction failed (e.g., all
|
| 232 |
+
# silence)
|
| 233 |
+
scores.append(0.0)
|
| 234 |
+
else:
|
| 235 |
+
score = self.cosine_similarity(e1, e2)
|
| 236 |
+
scores.append(score)
|
| 237 |
+
|
| 238 |
+
return scores
|
| 239 |
+
|
| 240 |
+
def cosine_similarity(self, e1, e2):
|
| 241 |
+
cosine_score = torch.dot(e1, e2) / (torch.norm(e1) * torch.norm(e2))
|
| 242 |
+
cosine_score = cosine_score.item()
|
| 243 |
+
return (cosine_score + 1.0) / 2 # normalize: [-1, 1] => [0, 1]
|
| 244 |
+
|
| 245 |
+
def register(self, name: str, audio_path: str):
|
| 246 |
+
if name in self.table:
|
| 247 |
+
print('Speaker {} already registered, ignore'.format(name))
|
| 248 |
+
else:
|
| 249 |
+
self.table[name] = self.extract_embedding(audio_path)
|
| 250 |
+
|
| 251 |
+
def recognize(self, audio_path: str):
|
| 252 |
+
q = self.extract_embedding(audio_path)
|
| 253 |
+
best_score = 0.0
|
| 254 |
+
best_name = ''
|
| 255 |
+
for name, e in self.table.items():
|
| 256 |
+
score = self.cosine_similarity(q, e)
|
| 257 |
+
if best_score < score:
|
| 258 |
+
best_score = score
|
| 259 |
+
best_name = name
|
| 260 |
+
result = {}
|
| 261 |
+
result['name'] = best_name
|
| 262 |
+
result['confidence'] = best_score
|
| 263 |
+
return result
|
| 264 |
+
|
| 265 |
+
def diarize(self, audio_path: str, utt: str = "unk"):
|
| 266 |
+
|
| 267 |
+
pcm, sample_rate = torchaudio.load(audio_path, normalize=False)
|
| 268 |
+
# 1. vad
|
| 269 |
+
wav = read_audio(audio_path)
|
| 270 |
+
vad_segments = get_speech_timestamps(wav,
|
| 271 |
+
self.vad,
|
| 272 |
+
return_seconds=True)
|
| 273 |
+
if not vad_segments:
|
| 274 |
+
return []
|
| 275 |
+
# 2. extact fbanks
|
| 276 |
+
subsegs, subseg_fbanks = [], []
|
| 277 |
+
window_fs = int(self.diar_window_secs * 1000) // self.diar_frame_shift
|
| 278 |
+
period_fs = int(self.diar_period_secs * 1000) // self.diar_frame_shift
|
| 279 |
+
for item in vad_segments:
|
| 280 |
+
begin, end = item['start'], item['end']
|
| 281 |
+
if end - begin >= self.diar_min_duration:
|
| 282 |
+
begin_idx = int(begin * sample_rate)
|
| 283 |
+
end_idx = int(end * sample_rate)
|
| 284 |
+
tmp_wavform = pcm[0, begin_idx:end_idx].unsqueeze(0).to(
|
| 285 |
+
torch.float)
|
| 286 |
+
fbank = self.compute_fbank(tmp_wavform,
|
| 287 |
+
sample_rate=sample_rate,
|
| 288 |
+
cmn=False)
|
| 289 |
+
tmp_subsegs, tmp_subseg_fbanks = subsegment(
|
| 290 |
+
fbank=fbank,
|
| 291 |
+
seg_id="{:08d}-{:08d}".format(int(begin * 1000),
|
| 292 |
+
int(end * 1000)),
|
| 293 |
+
window_fs=window_fs,
|
| 294 |
+
period_fs=period_fs,
|
| 295 |
+
frame_shift=self.diar_frame_shift)
|
| 296 |
+
subsegs.extend(tmp_subsegs)
|
| 297 |
+
subseg_fbanks.extend(tmp_subseg_fbanks)
|
| 298 |
+
|
| 299 |
+
# 3. extract embedding
|
| 300 |
+
embeddings = self.extract_embedding_feats(subseg_fbanks,
|
| 301 |
+
self.diar_batch_size,
|
| 302 |
+
self.diar_subseg_cmn)
|
| 303 |
+
|
| 304 |
+
# 4. cluster
|
| 305 |
+
subseg2label = []
|
| 306 |
+
labels = cluster(embeddings)
|
| 307 |
+
for (_subseg, _label) in zip(subsegs, labels):
|
| 308 |
+
# b, e = process_seg_id(_subseg, frame_shift=self.diar_frame_shift)
|
| 309 |
+
# subseg2label.append([b, e, _label])
|
| 310 |
+
begin_ms, end_ms, begin_frames, end_frames = _subseg.split('-')
|
| 311 |
+
begin = (int(begin_ms) +
|
| 312 |
+
int(begin_frames) * self.diar_frame_shift) / 1000.0
|
| 313 |
+
end = (int(begin_ms) +
|
| 314 |
+
int(end_frames) * self.diar_frame_shift) / 1000.0
|
| 315 |
+
subseg2label.append([begin, end, _label])
|
| 316 |
+
|
| 317 |
+
# 5. merged segments
|
| 318 |
+
# [[utt, ([begin, end, label], [])], [utt, ([], [])]]
|
| 319 |
+
merged_segment_to_labels = merge_segments({utt: subseg2label})
|
| 320 |
+
|
| 321 |
+
return merged_segment_to_labels
|
| 322 |
+
|
| 323 |
+
def diarize_list(self, scp_path: str):
|
| 324 |
+
utts = []
|
| 325 |
+
segment2labels = []
|
| 326 |
+
with open(scp_path, 'r', encoding='utf-8') as read_scp:
|
| 327 |
+
for line in tqdm(read_scp):
|
| 328 |
+
utt, wav_path = line.strip().split()
|
| 329 |
+
utts.append(utt)
|
| 330 |
+
segment2label = self.diarize(wav_path, utt)
|
| 331 |
+
segment2labels.append(segment2label)
|
| 332 |
+
return utts, segment2labels
|
| 333 |
+
|
| 334 |
+
def make_rttm(self, merged_segment_to_labels, outfile):
|
| 335 |
+
with open(outfile, 'w', encoding='utf-8') as fin:
|
| 336 |
+
for (utt, begin, end, label) in merged_segment_to_labels:
|
| 337 |
+
fin.write(
|
| 338 |
+
"SPEAKER {} {} {:.3f} {:.3f} <NA> <NA> {} <NA> <NA>\n".
|
| 339 |
+
format(utt, 1, float(begin),
|
| 340 |
+
float(end) - float(begin), label))
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def load_model(language: str) -> Speaker:
|
| 344 |
+
model_path = Hub.get_model(language)
|
| 345 |
+
return Speaker(model_path)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def load_model_local(model_dir: str) -> Speaker:
|
| 349 |
+
return Speaker(model_dir)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def main():
|
| 353 |
+
args = get_args()
|
| 354 |
+
if args.pretrain == "":
|
| 355 |
+
if args.campplus:
|
| 356 |
+
model = load_model("campplus")
|
| 357 |
+
model.set_wavform_norm(True)
|
| 358 |
+
elif args.eres2net:
|
| 359 |
+
model = load_model("eres2net")
|
| 360 |
+
model.set_wavform_norm(True)
|
| 361 |
+
elif args.vblinkp:
|
| 362 |
+
model = load_model("vblinkp")
|
| 363 |
+
elif args.vblinkf:
|
| 364 |
+
model = load_model("vblinkf")
|
| 365 |
+
else:
|
| 366 |
+
model = load_model(args.language)
|
| 367 |
+
else:
|
| 368 |
+
model = load_model_local(args.pretrain)
|
| 369 |
+
model.set_resample_rate(args.resample_rate)
|
| 370 |
+
model.set_vad(args.vad)
|
| 371 |
+
model.set_device(args.device)
|
| 372 |
+
model.set_diarization_params(min_duration=args.diar_min_duration,
|
| 373 |
+
window_secs=args.diar_window_secs,
|
| 374 |
+
period_secs=args.diar_period_secs,
|
| 375 |
+
frame_shift=args.diar_frame_shift,
|
| 376 |
+
batch_size=args.diar_emb_bs,
|
| 377 |
+
subseg_cmn=args.diar_subseg_cmn)
|
| 378 |
+
if args.task == 'embedding':
|
| 379 |
+
embedding = model.extract_embedding(args.audio_file)
|
| 380 |
+
if embedding is not None:
|
| 381 |
+
np.savetxt(args.output_file, embedding.detach().numpy())
|
| 382 |
+
print('Succeed, see {}'.format(args.output_file))
|
| 383 |
+
else:
|
| 384 |
+
print('Fails to extract embedding')
|
| 385 |
+
elif args.task == 'embedding_kaldi':
|
| 386 |
+
names, embeddings = model.extract_embedding_list(args.wav_scp)
|
| 387 |
+
embed_ark = args.output_file + ".ark"
|
| 388 |
+
embed_scp = args.output_file + ".scp"
|
| 389 |
+
with kaldiio.WriteHelper('ark,scp:' + embed_ark + "," +
|
| 390 |
+
embed_scp) as writer:
|
| 391 |
+
for name, embedding in zip(names, embeddings):
|
| 392 |
+
writer(name, embedding)
|
| 393 |
+
elif args.task == 'similarity':
|
| 394 |
+
print(model.compute_similarity(args.audio_file, args.audio_file2))
|
| 395 |
+
elif args.task == 'diarization':
|
| 396 |
+
diar_result = model.diarize(args.audio_file)
|
| 397 |
+
if args.output_file is None:
|
| 398 |
+
for (_, start, end, spkid) in diar_result:
|
| 399 |
+
print("{:.3f}\t{:.3f}\t{:d}".format(start, end, spkid))
|
| 400 |
+
else:
|
| 401 |
+
model.make_rttm(diar_result, args.output_file)
|
| 402 |
+
elif args.task == 'diarization_list':
|
| 403 |
+
utts, segment2labels = model.diarize_list(args.wav_scp)
|
| 404 |
+
assert args.output_file is not None
|
| 405 |
+
model.make_rttm(np.vstack(segment2labels), args.output_file)
|
| 406 |
+
else:
|
| 407 |
+
print('Unsupported task {}'.format(args.task))
|
| 408 |
+
sys.exit(-1)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
if __name__ == '__main__':
|
| 412 |
+
main()
|
speaker_client.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# new_client.py
|
| 2 |
+
import argparse
|
| 3 |
+
import asyncio
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torchaudio
|
| 7 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 8 |
+
import tritonclient.grpc.aio as grpcclient
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TritonSpeakerClient:
|
| 15 |
+
def __init__(self, url, model_name="speaker_model", verbose=False):
|
| 16 |
+
try:
|
| 17 |
+
self.triton_client = grpcclient.InferenceServerClient(url=url, verbose=verbose)
|
| 18 |
+
except Exception as e:
|
| 19 |
+
print(f"Channel creation failed: {e}", file=sys.stderr)
|
| 20 |
+
sys.exit(1)
|
| 21 |
+
|
| 22 |
+
self.model_name = model_name
|
| 23 |
+
|
| 24 |
+
# --- 从旧的 similarity_model 迁移过来的预处理参数 ---
|
| 25 |
+
self.sample_rate = 16000
|
| 26 |
+
self.feature_dim = 80
|
| 27 |
+
self.min_duration = 0.1
|
| 28 |
+
# ----------------------------------------------------
|
| 29 |
+
|
| 30 |
+
def _preprocess_audio(self, audio_path: str):
|
| 31 |
+
"""
|
| 32 |
+
从音频文件路径加载并预处理音频,生成Fbank特征。
|
| 33 |
+
这段逻辑完全复制自旧的 similarity_model.py 中的 preprocess 方法。
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 37 |
+
|
| 38 |
+
# 如果采样率不匹配,则重采样
|
| 39 |
+
if sample_rate != self.sample_rate:
|
| 40 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
|
| 41 |
+
waveform = resampler(waveform)
|
| 42 |
+
|
| 43 |
+
# 如果音频太短,则重复填充以满足最小长度
|
| 44 |
+
duration = waveform.shape[1] / self.sample_rate
|
| 45 |
+
if duration < self.min_duration:
|
| 46 |
+
repeat_times = math.ceil(self.min_duration / duration)
|
| 47 |
+
waveform = waveform.repeat(1, repeat_times)
|
| 48 |
+
|
| 49 |
+
# 计算80维Fbank特征
|
| 50 |
+
# waveform 需要是 [batch, time] 格式,所以我们移除通道维度
|
| 51 |
+
if waveform.shape[0] > 1:
|
| 52 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True) # 转为单声道
|
| 53 |
+
|
| 54 |
+
features = kaldi.fbank(
|
| 55 |
+
waveform,
|
| 56 |
+
num_mel_bins=self.feature_dim,
|
| 57 |
+
sample_frequency=self.sample_rate,
|
| 58 |
+
frame_length=25,
|
| 59 |
+
frame_shift=10
|
| 60 |
+
)
|
| 61 |
+
# fbank 输出 shape [1, num_frames, num_bins], 我们需要 [num_frames, 80]
|
| 62 |
+
return features.squeeze(0).numpy().astype(np.float32)
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
raise RuntimeError(f"Failed during audio preprocessing for {audio_path}: {e}")
|
| 66 |
+
|
| 67 |
+
def _calculate_cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray):
|
| 68 |
+
"""在客户端计算余弦相似度。"""
|
| 69 |
+
e1 = torch.from_numpy(emb1).flatten()
|
| 70 |
+
e2 = torch.from_numpy(emb2).flatten()
|
| 71 |
+
|
| 72 |
+
similarity = torch.nn.functional.cosine_similarity(e1, e2, dim=0)
|
| 73 |
+
|
| 74 |
+
# 将相似度从 [-1, 1] 范围归一化到 [0, 1]
|
| 75 |
+
return (similarity.item() + 1) / 2
|
| 76 |
+
|
| 77 |
+
async def compute_similarity(self, audio1_path: str, audio2_path: str):
|
| 78 |
+
"""
|
| 79 |
+
计算两个音频文件的相似度。
|
| 80 |
+
此函数现在包含完整的处理流程:预处理 -> 批处理 -> 推理 -> 后处理。
|
| 81 |
+
"""
|
| 82 |
+
# 1. 在客户端对两个音频文件进行预处理
|
| 83 |
+
feats1 = self._preprocess_audio(audio1_path)
|
| 84 |
+
feats2 = self._preprocess_audio(audio2_path)
|
| 85 |
+
|
| 86 |
+
# 2. 批处理:为了使用Triton的动态批处理,我们将两个特征打包成一个请求。
|
| 87 |
+
# 由于它们的长度(帧数)可能不同,我们需要将它们填充到相同的长度。
|
| 88 |
+
max_len = max(feats1.shape[0], feats2.shape[0])
|
| 89 |
+
|
| 90 |
+
# 使用np.pad进行填充
|
| 91 |
+
padded_feats1 = np.pad(feats1, ((0, max_len - feats1.shape[0]), (0, 0)), 'constant', constant_values=0)
|
| 92 |
+
padded_feats2 = np.pad(feats2, ((0, max_len - feats2.shape[0]), (0, 0)), 'constant', constant_values=0)
|
| 93 |
+
|
| 94 |
+
# 将填充后的特征堆叠成一个批次
|
| 95 |
+
input_batch = np.stack([padded_feats1, padded_feats2]) # Shape: [2, max_len, 80]
|
| 96 |
+
|
| 97 |
+
# 3. 创建Triton输入张量
|
| 98 |
+
# 输入名称 "feats" 必须与 speaker_model 的 config.pbtxt 中的输入名匹配
|
| 99 |
+
inputs = [
|
| 100 |
+
grpcclient.InferInput("feats", input_batch.shape, "FP32")
|
| 101 |
+
]
|
| 102 |
+
inputs[0].set_data_from_numpy(input_batch)
|
| 103 |
+
|
| 104 |
+
# 4. 设置请求的输出
|
| 105 |
+
# 输出名称 "embs" 必须与 speaker_model 的 config.pbtxt 中的输出名匹配
|
| 106 |
+
outputs = [grpcclient.InferRequestedOutput("embs")]
|
| 107 |
+
|
| 108 |
+
# 5. 发送推理请求
|
| 109 |
+
response = await self.triton_client.infer(
|
| 110 |
+
model_name=self.model_name,
|
| 111 |
+
inputs=inputs,
|
| 112 |
+
outputs=outputs
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# 6. 解析结果
|
| 116 |
+
embeddings_batch = response.as_numpy("embs") # Shape: [2, embedding_dim]
|
| 117 |
+
emb1 = embeddings_batch[0]
|
| 118 |
+
emb2 = embeddings_batch[1]
|
| 119 |
+
|
| 120 |
+
# 7. 在客户端计��相似度
|
| 121 |
+
similarity = self._calculate_cosine_similarity(emb1, emb2)
|
| 122 |
+
return similarity
|
| 123 |
+
|
| 124 |
+
async def main():
|
| 125 |
+
parser = argparse.ArgumentParser(description="Triton client for speaker model (direct call).")
|
| 126 |
+
parser.add_argument('-v', '--verbose', action="store_true", default=False, help='Enable verbose output')
|
| 127 |
+
parser.add_argument('-u', '--url', type=str, default='localhost:8001', help='Inference server URL.')
|
| 128 |
+
# 注意:这里的 model_name 应该是 speaker_model
|
| 129 |
+
parser.add_argument('--model_name', default='speaker_model', help='The name of the speaker embedding model on Triton.')
|
| 130 |
+
parser.add_argument('--audio_file1', type=str, required=True, help='Path to first audio file')
|
| 131 |
+
parser.add_argument('--audio_file2', type=str, required=True, help='Path to second audio file')
|
| 132 |
+
|
| 133 |
+
FLAGS = parser.parse_args()
|
| 134 |
+
|
| 135 |
+
client = TritonSpeakerClient(FLAGS.url, FLAGS.model_name, verbose=FLAGS.verbose)
|
| 136 |
+
|
| 137 |
+
start_time = time.time()
|
| 138 |
+
try:
|
| 139 |
+
similarity = await client.compute_similarity(FLAGS.audio_file1, FLAGS.audio_file2)
|
| 140 |
+
elapsed = time.time() - start_time
|
| 141 |
+
print(f"Similarity: {similarity:.4f}, Time: {elapsed:.3f}s")
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f"Error computing similarity: {e}", file=sys.stderr)
|
| 144 |
+
sys.exit(1)
|
| 145 |
+
|
| 146 |
+
# 使用示例:
|
| 147 |
+
# python speaker_client.py --audio_file1=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi1.wav --audio_file2=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi2.wav
|
| 148 |
+
if __name__ == '__main__':
|
| 149 |
+
asyncio.run(main())
|
test.py
ADDED
|
@@ -0,0 +1,1643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Dict, Tuple, Any
|
| 5 |
+
import numpy as np
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import torch
|
| 8 |
+
import torchaudio
|
| 9 |
+
import torchaudio.functional as F
|
| 10 |
+
import logging
|
| 11 |
+
import wespeaker
|
| 12 |
+
import shutil
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
import multiprocessing as mp
|
| 15 |
+
from functools import partial
|
| 16 |
+
import math
|
| 17 |
+
import threading
|
| 18 |
+
import time
|
| 19 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 20 |
+
import random # 添加random模块用于shuffle
|
| 21 |
+
|
| 22 |
+
# 设置multiprocessing启动方式为spawn(CUDA兼容)
|
| 23 |
+
mp.set_start_method('spawn', force=True)
|
| 24 |
+
|
| 25 |
+
# 引用词对齐模块
|
| 26 |
+
from alignment import AlignmentModel, batch_get_alignment_result
|
| 27 |
+
|
| 28 |
+
class SpeakerSimilarityEvaluator:
|
| 29 |
+
"""音色相似度评估器"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, device="cuda",
|
| 32 |
+
alignment_model_dir='/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/models/mms_fa',
|
| 33 |
+
wespeaker_model_dir='/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft',
|
| 34 |
+
output_dir="./evaluation_results",
|
| 35 |
+
language="ZH",
|
| 36 |
+
similarity_max_workers=8):
|
| 37 |
+
"""初始化评估器"""
|
| 38 |
+
self.device = device
|
| 39 |
+
self.alignment_model_dir = alignment_model_dir
|
| 40 |
+
self.wespeaker_model_dir = wespeaker_model_dir
|
| 41 |
+
self.language = language.upper() # 添加语言参数
|
| 42 |
+
self.similarity_max_workers = similarity_max_workers # 相似度计算线程数
|
| 43 |
+
|
| 44 |
+
# 先设置日志系统
|
| 45 |
+
logging.basicConfig(level=logging.INFO)
|
| 46 |
+
self.logger = logging.getLogger(__name__)
|
| 47 |
+
|
| 48 |
+
# 设置输出目录结构
|
| 49 |
+
self.output_dir = Path(output_dir)
|
| 50 |
+
self.segments_dir = self.output_dir / "segments" # 分割后的音频片段
|
| 51 |
+
self.prompts_dir = self.output_dir / "prompts" # prompt音频的S1和S2片段
|
| 52 |
+
self.temp_dir = self.output_dir / "temp" # 临时文件
|
| 53 |
+
self.results_dir = self.output_dir / "results" # 评估结果
|
| 54 |
+
self.temp_results_dir = self.output_dir / "temp_results" # 临时结果文件
|
| 55 |
+
self.alignment_dir = self.output_dir / "alignments" # 对齐信息保存目录
|
| 56 |
+
|
| 57 |
+
# 创建所有必要的目录
|
| 58 |
+
self._create_output_directories()
|
| 59 |
+
|
| 60 |
+
# 在多进程环境中延迟模型初始化
|
| 61 |
+
self.alignment_model = None
|
| 62 |
+
self.similarity_model = None
|
| 63 |
+
|
| 64 |
+
# 线程局部存储,用于线程安全的模型访问
|
| 65 |
+
self._thread_local = threading.local()
|
| 66 |
+
|
| 67 |
+
# 记录运行信息
|
| 68 |
+
self.logger.info(f"评估结果将保存到: {self.output_dir}")
|
| 69 |
+
self.logger.info(f"对齐信息将保存到: {self.alignment_dir}")
|
| 70 |
+
self.logger.info(f"使用语言: {self.language}")
|
| 71 |
+
|
| 72 |
+
def _create_output_directories(self):
|
| 73 |
+
"""创建输出目录结构"""
|
| 74 |
+
for dir_path in [self.segments_dir, self.prompts_dir, self.temp_dir,
|
| 75 |
+
self.results_dir, self.temp_results_dir, self.alignment_dir]:
|
| 76 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
def _get_safe_filename(self, text: str, max_length: int = 50) -> str:
|
| 79 |
+
"""生成安全的文件名"""
|
| 80 |
+
# 移除特殊字符,只保留中文、英文、数字和基本符号
|
| 81 |
+
safe_text = re.sub(r'[^\u4e00-\u9fff\w\s]', '', text)
|
| 82 |
+
# 限制长度
|
| 83 |
+
if len(safe_text) > max_length:
|
| 84 |
+
safe_text = safe_text[:max_length]
|
| 85 |
+
# 替换空格为下划线
|
| 86 |
+
safe_text = safe_text.replace(' ', '_')
|
| 87 |
+
return safe_text if safe_text else "unnamed"
|
| 88 |
+
|
| 89 |
+
def _clean_temp_files(self):
|
| 90 |
+
"""清理临时文件,但保留临时目录"""
|
| 91 |
+
if self.temp_dir.exists():
|
| 92 |
+
# 只删除临时目录中的文件,不删除目录本身
|
| 93 |
+
for file_path in self.temp_dir.iterdir():
|
| 94 |
+
if file_path.is_file():
|
| 95 |
+
try:
|
| 96 |
+
file_path.unlink()
|
| 97 |
+
except Exception as e:
|
| 98 |
+
self.logger.warning(f"删除临时文件失败: {file_path}, 错误: {e}")
|
| 99 |
+
else:
|
| 100 |
+
# 如果临时目录不存在,重新创建
|
| 101 |
+
self.temp_dir.mkdir(parents=True, exist_ok=True)
|
| 102 |
+
|
| 103 |
+
def _init_models_if_needed(self):
|
| 104 |
+
"""延迟初始化模型(用于多进程环境)"""
|
| 105 |
+
# 初始化对齐模型 - 修正参数顺序
|
| 106 |
+
if self.alignment_model is None:
|
| 107 |
+
# 根据AlignmentModel的构造函数,应该是(device, model_dir)而不是(model_dir, device)
|
| 108 |
+
self.alignment_model = AlignmentModel(self.device, self.alignment_model_dir)
|
| 109 |
+
|
| 110 |
+
# 初始化相似度模型
|
| 111 |
+
if self.similarity_model is None:
|
| 112 |
+
self._load_wespeaker_model(self.wespeaker_model_dir)
|
| 113 |
+
|
| 114 |
+
def _is_english_text(self, text: str) -> bool:
|
| 115 |
+
"""简单判断文本是否主要是英文"""
|
| 116 |
+
# 计算英文字符的比例
|
| 117 |
+
english_chars = sum(1 for c in text if c.isascii() and c.isalpha())
|
| 118 |
+
total_chars = sum(1 for c in text if c.isalpha())
|
| 119 |
+
|
| 120 |
+
if total_chars == 0:
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
return english_chars / total_chars > 0.8 # 如果80%以上是英文字符,认为是英文
|
| 124 |
+
|
| 125 |
+
def _detect_language_from_text(self, text: str) -> str:
|
| 126 |
+
"""从文本内容检测语言"""
|
| 127 |
+
clean_text = self.remove_speaker_tags(text)
|
| 128 |
+
if self._is_english_text(clean_text):
|
| 129 |
+
return "EN"
|
| 130 |
+
else:
|
| 131 |
+
return "ZH"
|
| 132 |
+
|
| 133 |
+
def save_alignment_info(self, alignment_data: Dict[str, Any], input_id: str, file_type: str = "output"):
|
| 134 |
+
"""
|
| 135 |
+
保存对齐信息到单独的JSON文件
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
alignment_data: 对齐信息数据
|
| 139 |
+
input_id: 输入ID
|
| 140 |
+
file_type: 文件类型 ("output", "prompt", "segment")
|
| 141 |
+
"""
|
| 142 |
+
try:
|
| 143 |
+
safe_input_id = self._get_safe_filename(input_id)
|
| 144 |
+
alignment_filename = f"{safe_input_id}_{file_type}_alignment.json"
|
| 145 |
+
alignment_path = self.alignment_dir / alignment_filename
|
| 146 |
+
|
| 147 |
+
# 添加元数据
|
| 148 |
+
alignment_info = {
|
| 149 |
+
'input_id': input_id,
|
| 150 |
+
'file_type': file_type,
|
| 151 |
+
'language': self.language,
|
| 152 |
+
'timestamp': datetime.now().isoformat(),
|
| 153 |
+
'alignment_data': alignment_data
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
with open(alignment_path, 'w', encoding='utf-8') as f:
|
| 157 |
+
json.dump(alignment_info, f, ensure_ascii=False, indent=2)
|
| 158 |
+
|
| 159 |
+
self.logger.info(f"对齐信息已保存: {alignment_path}")
|
| 160 |
+
return str(alignment_path)
|
| 161 |
+
|
| 162 |
+
except Exception as e:
|
| 163 |
+
self.logger.error(f"保存对齐信息失败: {e}")
|
| 164 |
+
return None
|
| 165 |
+
|
| 166 |
+
def save_detailed_alignment_info(self, alignments: List[Dict[str, Any]],
|
| 167 |
+
text_segments: List[Dict[str, Any]],
|
| 168 |
+
input_id: str, audio_path: str,
|
| 169 |
+
original_text: str, processed_text: str):
|
| 170 |
+
"""
|
| 171 |
+
保存详细的对齐信息,包括分段信息
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
alignments: 对齐结果列表
|
| 175 |
+
text_segments: 文本分段信息
|
| 176 |
+
input_id: 输入ID
|
| 177 |
+
audio_path: 音频文件路径
|
| 178 |
+
original_text: 原始文本
|
| 179 |
+
processed_text: 处理后的文本
|
| 180 |
+
"""
|
| 181 |
+
alignment_data = {
|
| 182 |
+
'original_text': original_text,
|
| 183 |
+
'processed_text': processed_text,
|
| 184 |
+
'audio_path': audio_path,
|
| 185 |
+
'language': self.language,
|
| 186 |
+
'total_alignments': len(alignments),
|
| 187 |
+
'total_segments': len(text_segments),
|
| 188 |
+
'alignments': alignments,
|
| 189 |
+
'text_segments': text_segments,
|
| 190 |
+
'segment_alignment_mapping': []
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
# 建立文本段和对齐结果的映射关系
|
| 194 |
+
for segment in text_segments:
|
| 195 |
+
segment_mapping = {
|
| 196 |
+
'segment_id': segment.get('segment_id', 0),
|
| 197 |
+
'segment_text': segment.get('text', ''),
|
| 198 |
+
'speaker_label': segment.get('speaker_label', ''),
|
| 199 |
+
'start_time': segment.get('start_time', 0.0),
|
| 200 |
+
'end_time': segment.get('end_time', 0.0),
|
| 201 |
+
'corresponding_alignments': []
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
# 找到对应的对齐项
|
| 205 |
+
segment_start = segment.get('start_time', 0.0)
|
| 206 |
+
segment_end = segment.get('end_time', 0.0)
|
| 207 |
+
|
| 208 |
+
for i, align_item in enumerate(alignments):
|
| 209 |
+
align_start = align_item.get('start', 0.0)
|
| 210 |
+
align_end = align_item.get('end', 0.0)
|
| 211 |
+
|
| 212 |
+
# 检查对齐项是否在当前段的时间范围内
|
| 213 |
+
if (align_start >= segment_start and align_end <= segment_end) or \
|
| 214 |
+
(align_start < segment_end and align_end > segment_start):
|
| 215 |
+
segment_mapping['corresponding_alignments'].append({
|
| 216 |
+
'alignment_index': i,
|
| 217 |
+
'transcript': align_item.get('transcript', ''),
|
| 218 |
+
'start': align_start,
|
| 219 |
+
'end': align_end,
|
| 220 |
+
'score': align_item.get('score', 0.0) if 'score' in align_item else None
|
| 221 |
+
})
|
| 222 |
+
|
| 223 |
+
alignment_data['segment_alignment_mapping'].append(segment_mapping)
|
| 224 |
+
|
| 225 |
+
return self.save_alignment_info(alignment_data, input_id, "detailed")
|
| 226 |
+
|
| 227 |
+
def remove_speaker_tags(self, text: str) -> str:
|
| 228 |
+
"""删除文本中的说话人标签[S1][S2]"""
|
| 229 |
+
return re.sub(r'\[S[12]\]', '', text).strip()
|
| 230 |
+
|
| 231 |
+
def extract_speaker_segments(self, text: str) -> List[Dict[str, Any]]:
|
| 232 |
+
"""提取文本中的说话人片段信息"""
|
| 233 |
+
segments = []
|
| 234 |
+
pattern = r'\[S([12])\]([^[]*)'
|
| 235 |
+
matches = re.findall(pattern, text)
|
| 236 |
+
|
| 237 |
+
for speaker_id, content in matches:
|
| 238 |
+
segments.append({
|
| 239 |
+
'speaker': f'S{speaker_id}',
|
| 240 |
+
'content': content.strip()
|
| 241 |
+
})
|
| 242 |
+
return segments
|
| 243 |
+
|
| 244 |
+
def replace_punctuation_with_comma(self, text: str, language: str = None) -> str:
|
| 245 |
+
"""将所有标点符号替换为逗号,连续逗号只保留一个,根据语言选择正确的逗号类型"""
|
| 246 |
+
# 如果未指定语言,使用类的默认语言设置或自动检测
|
| 247 |
+
if language is None:
|
| 248 |
+
if hasattr(self, 'language'):
|
| 249 |
+
language = self.language
|
| 250 |
+
else:
|
| 251 |
+
language = self._detect_language_from_text(text)
|
| 252 |
+
|
| 253 |
+
language = language.upper()
|
| 254 |
+
|
| 255 |
+
# 根据语言选择逗号类型和处理策略
|
| 256 |
+
if language == "EN" or (language == "AUTO" and self._is_english_text(text)):
|
| 257 |
+
# 英文处理:先删除撇号,再替换其他标点符号
|
| 258 |
+
text = re.sub(r"'", '', text) # 删除撇号(don't -> dont)
|
| 259 |
+
target_comma = ',' # 英文逗号
|
| 260 |
+
comma_pattern = r',+' # 匹配连续英文逗号
|
| 261 |
+
# 更新正则表达式,不包含撇号
|
| 262 |
+
text = re.sub(r'[.,!?;:()\[\]<>\"…·,。;:!?()【】《》""\\、]', target_comma, text)
|
| 263 |
+
else:
|
| 264 |
+
# 中文处理:包含撇号在替换范围内
|
| 265 |
+
target_comma = ',' # 中文逗号
|
| 266 |
+
comma_pattern = r',+' # 匹配连续中文逗号
|
| 267 |
+
# 更新正则表达式以匹配更多的标点符号
|
| 268 |
+
text = re.sub(r'[.,!?;:()\[\]<>\'\"…·,。;:!?()【】《》''""\\、]', target_comma, text)
|
| 269 |
+
|
| 270 |
+
text = re.sub(comma_pattern, target_comma, text)
|
| 271 |
+
return text.strip(target_comma)
|
| 272 |
+
|
| 273 |
+
def align_text_with_audio(self, text: str, audio_path: str, language=None) -> List[Dict[str, Any]]:
|
| 274 |
+
"""
|
| 275 |
+
文本和音频的词对齐
|
| 276 |
+
返回每个词对应的音频时间段
|
| 277 |
+
"""
|
| 278 |
+
# 确保模型已初始化
|
| 279 |
+
self._init_models_if_needed()
|
| 280 |
+
|
| 281 |
+
# 如果未指定语言,使用类的默认语言设置或自动检测
|
| 282 |
+
if language is None:
|
| 283 |
+
if hasattr(self, 'language'):
|
| 284 |
+
language = self.language
|
| 285 |
+
else:
|
| 286 |
+
language = self._detect_language_from_text(text)
|
| 287 |
+
else:
|
| 288 |
+
language = language.upper()
|
| 289 |
+
|
| 290 |
+
# 加载音频
|
| 291 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 292 |
+
|
| 293 |
+
# 重采样到模型要求的采样率
|
| 294 |
+
if sample_rate != self.alignment_model.bundle.sample_rate:
|
| 295 |
+
waveform = F.resample(waveform, sample_rate, self.alignment_model.bundle.sample_rate)
|
| 296 |
+
|
| 297 |
+
# 转换为单声道
|
| 298 |
+
if waveform.shape[0] > 1:
|
| 299 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 300 |
+
|
| 301 |
+
waveform = waveform.squeeze(0) # 移除批次维度
|
| 302 |
+
|
| 303 |
+
# 将音频移动到正确的设备
|
| 304 |
+
waveform = waveform.to(self.device)
|
| 305 |
+
|
| 306 |
+
# 执行对齐
|
| 307 |
+
try:
|
| 308 |
+
alignment_results = batch_get_alignment_result(
|
| 309 |
+
self.alignment_model,
|
| 310 |
+
[waveform],
|
| 311 |
+
[text],
|
| 312 |
+
[language]
|
| 313 |
+
)
|
| 314 |
+
if not alignment_results or not alignment_results[0]:
|
| 315 |
+
raise RuntimeError(f"对齐结果为空: {audio_path}")
|
| 316 |
+
return alignment_results[0]
|
| 317 |
+
except Exception as e:
|
| 318 |
+
self.logger.error(f"音频对齐失败: {audio_path}")
|
| 319 |
+
self.logger.error(f"错误详情: {e}")
|
| 320 |
+
raise RuntimeError(f"音频对齐失败,程序终止。文件: {audio_path},错误: {e}")
|
| 321 |
+
|
| 322 |
+
def split_audio_segment(self, audio_path: str, start_time: float, end_time: float, output_path: str):
|
| 323 |
+
"""分割音频片段"""
|
| 324 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 325 |
+
|
| 326 |
+
start_frame = int(start_time * sample_rate)
|
| 327 |
+
end_frame = int(end_time * sample_rate)
|
| 328 |
+
|
| 329 |
+
segment = waveform[:, start_frame:end_frame]
|
| 330 |
+
|
| 331 |
+
# 确保输出目录存在
|
| 332 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 333 |
+
|
| 334 |
+
torchaudio.save(output_path, segment, sample_rate)
|
| 335 |
+
return output_path
|
| 336 |
+
|
| 337 |
+
def concatenate_audio_files(self, audio_files: List[str], output_path: str):
|
| 338 |
+
"""拼接多个音频文件"""
|
| 339 |
+
if not audio_files:
|
| 340 |
+
return
|
| 341 |
+
|
| 342 |
+
waveforms = []
|
| 343 |
+
sample_rate = None
|
| 344 |
+
|
| 345 |
+
for audio_file in audio_files:
|
| 346 |
+
if os.path.exists(audio_file):
|
| 347 |
+
waveform, sr = torchaudio.load(audio_file)
|
| 348 |
+
if sample_rate is None:
|
| 349 |
+
sample_rate = sr
|
| 350 |
+
elif sr != sample_rate:
|
| 351 |
+
waveform = F.resample(waveform, sr, sample_rate)
|
| 352 |
+
waveforms.append(waveform)
|
| 353 |
+
|
| 354 |
+
if waveforms:
|
| 355 |
+
concatenated = torch.cat(waveforms, dim=1)
|
| 356 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 357 |
+
torchaudio.save(output_path, concatenated, sample_rate)
|
| 358 |
+
|
| 359 |
+
def split_audio_by_speaker(self, prompt_text: str, prompt_audio: str, audio_id: str) -> Tuple[str, str]:
|
| 360 |
+
"""
|
| 361 |
+
根据说话人标签分割prompt音频
|
| 362 |
+
返回S1和S2的音频片段路径
|
| 363 |
+
"""
|
| 364 |
+
# 1. 提取说话人片段
|
| 365 |
+
speaker_segments = self.extract_speaker_segments(prompt_text)
|
| 366 |
+
|
| 367 |
+
# 2. 删除标签后进行词对齐 - 如果失败则直接抛出异常
|
| 368 |
+
clean_text = self.remove_speaker_tags(prompt_text)
|
| 369 |
+
|
| 370 |
+
# 检测语言或使用设置的语言
|
| 371 |
+
alignment_language = self.language
|
| 372 |
+
if alignment_language == "AUTO":
|
| 373 |
+
alignment_language = self._detect_language_from_text(clean_text)
|
| 374 |
+
|
| 375 |
+
alignments = self.align_text_with_audio(clean_text, prompt_audio, alignment_language)
|
| 376 |
+
|
| 377 |
+
# 保存prompt对齐信息
|
| 378 |
+
prompt_alignment_data = {
|
| 379 |
+
'original_text': prompt_text,
|
| 380 |
+
'clean_text': clean_text,
|
| 381 |
+
'audio_path': prompt_audio,
|
| 382 |
+
'language': alignment_language,
|
| 383 |
+
'speaker_segments': speaker_segments,
|
| 384 |
+
'alignments': alignments
|
| 385 |
+
}
|
| 386 |
+
self.save_alignment_info(prompt_alignment_data, audio_id, "prompt")
|
| 387 |
+
|
| 388 |
+
# 3. 根据对齐结果分割音频
|
| 389 |
+
s1_segments = []
|
| 390 |
+
s2_segments = []
|
| 391 |
+
|
| 392 |
+
# 为每个说话人片段找到对应的时间段
|
| 393 |
+
text_pos = 0
|
| 394 |
+
for seg in speaker_segments:
|
| 395 |
+
seg_text = seg['content'].strip()
|
| 396 |
+
seg_length = len(seg_text)
|
| 397 |
+
|
| 398 |
+
# 找到这个片段在对齐结果中的起始和结束
|
| 399 |
+
start_time = None
|
| 400 |
+
end_time = None
|
| 401 |
+
|
| 402 |
+
current_pos = 0
|
| 403 |
+
for align_item in alignments:
|
| 404 |
+
item_text = align_item['transcript']
|
| 405 |
+
item_length = len(item_text)
|
| 406 |
+
|
| 407 |
+
if current_pos >= text_pos and current_pos < text_pos + seg_length:
|
| 408 |
+
if start_time is None:
|
| 409 |
+
start_time = align_item['start']
|
| 410 |
+
end_time = align_item['end']
|
| 411 |
+
|
| 412 |
+
current_pos += item_length
|
| 413 |
+
|
| 414 |
+
if start_time is not None and end_time is not None:
|
| 415 |
+
if seg['speaker'] == 'S1':
|
| 416 |
+
s1_segments.append((start_time, end_time))
|
| 417 |
+
else:
|
| 418 |
+
s2_segments.append((start_time, end_time))
|
| 419 |
+
|
| 420 |
+
text_pos += seg_length
|
| 421 |
+
|
| 422 |
+
# 4. 分割并拼接音频片段
|
| 423 |
+
safe_audio_id = self._get_safe_filename(audio_id)
|
| 424 |
+
prompts1_path = str(self.prompts_dir / f"{safe_audio_id}_s1.wav")
|
| 425 |
+
prompts2_path = str(self.prompts_dir / f"{safe_audio_id}_s2.wav")
|
| 426 |
+
|
| 427 |
+
# 分割S1的所有片段
|
| 428 |
+
if s1_segments:
|
| 429 |
+
s1_temp_segments = []
|
| 430 |
+
for i, (start, end) in enumerate(s1_segments):
|
| 431 |
+
temp_path = str(self.temp_dir / f"{safe_audio_id}_s1_temp_{i}.wav")
|
| 432 |
+
self.split_audio_segment(prompt_audio, start, end, temp_path)
|
| 433 |
+
s1_temp_segments.append(temp_path)
|
| 434 |
+
|
| 435 |
+
# 拼接S1片段
|
| 436 |
+
self.concatenate_audio_files(s1_temp_segments, prompts1_path)
|
| 437 |
+
|
| 438 |
+
# 分割S2的所有片段
|
| 439 |
+
if s2_segments:
|
| 440 |
+
s2_temp_segments = []
|
| 441 |
+
for i, (start, end) in enumerate(s2_segments):
|
| 442 |
+
temp_path = str(self.temp_dir / f"{safe_audio_id}_s2_temp_{i}.wav")
|
| 443 |
+
self.split_audio_segment(prompt_audio, start, end, temp_path)
|
| 444 |
+
s2_temp_segments.append(temp_path)
|
| 445 |
+
|
| 446 |
+
# 拼接S2片段
|
| 447 |
+
self.concatenate_audio_files(s2_temp_segments, prompts2_path)
|
| 448 |
+
|
| 449 |
+
return prompts1_path, prompts2_path
|
| 450 |
+
|
| 451 |
+
def map_text_segments_to_speakers(self, original_text: str) -> List[Dict[str, Any]]:
|
| 452 |
+
"""
|
| 453 |
+
将原始文本按说话人和标点符号同时分割,保持映射关系
|
| 454 |
+
支持英文单词级别的处理
|
| 455 |
+
"""
|
| 456 |
+
segments = []
|
| 457 |
+
pattern = r'\[S([12])\]([^[]*)'
|
| 458 |
+
matches = re.findall(pattern, original_text)
|
| 459 |
+
|
| 460 |
+
# 检测语言或使用设置的语言
|
| 461 |
+
alignment_language = self.language
|
| 462 |
+
if alignment_language == "AUTO":
|
| 463 |
+
alignment_language = self._detect_language_from_text(original_text)
|
| 464 |
+
|
| 465 |
+
segment_id = 0
|
| 466 |
+
for speaker_id, content in matches:
|
| 467 |
+
speaker = f'S{speaker_id}'
|
| 468 |
+
clean_content = content.strip()
|
| 469 |
+
comma_content = self.replace_punctuation_with_comma(clean_content, alignment_language)
|
| 470 |
+
|
| 471 |
+
# 根据语言选择正确的逗号分割
|
| 472 |
+
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_content)):
|
| 473 |
+
# 英文:按��文逗号分割,保持单词完整性
|
| 474 |
+
parts = [part.strip() for part in comma_content.split(',') if part.strip()]
|
| 475 |
+
else:
|
| 476 |
+
# 中文:按中文逗号分割
|
| 477 |
+
parts = [part.strip() for part in comma_content.split(',') if part.strip()]
|
| 478 |
+
|
| 479 |
+
for part in parts:
|
| 480 |
+
if part.strip():
|
| 481 |
+
segments.append({
|
| 482 |
+
'segment_id': segment_id,
|
| 483 |
+
'text': part.strip(),
|
| 484 |
+
'speaker_label': speaker,
|
| 485 |
+
'original_speaker_content': clean_content
|
| 486 |
+
})
|
| 487 |
+
segment_id += 1
|
| 488 |
+
|
| 489 |
+
return segments
|
| 490 |
+
|
| 491 |
+
def split_output_audio_by_comma(self, text: str, output_audio: str, audio_id: str) -> List[Dict[str, Any]]:
|
| 492 |
+
"""
|
| 493 |
+
根据逗号分割输出音频,返回每小段的信息 - 基于词对齐结果中的标点符号划分句子
|
| 494 |
+
"""
|
| 495 |
+
# 1. 获取文本片段和对应的说话人(用于获取speaker标签)
|
| 496 |
+
text_segments = self.map_text_segments_to_speakers(text)
|
| 497 |
+
|
| 498 |
+
# 2. 删除标签并替换标点符号
|
| 499 |
+
clean_text = self.remove_speaker_tags(text)
|
| 500 |
+
|
| 501 |
+
# 3. 检测语言或使用设置的语言
|
| 502 |
+
alignment_language = self.language
|
| 503 |
+
if alignment_language == "AUTO":
|
| 504 |
+
alignment_language = self._detect_language_from_text(clean_text)
|
| 505 |
+
|
| 506 |
+
# 使用检测到的语言替换标点符号
|
| 507 |
+
comma_text = self.replace_punctuation_with_comma(clean_text, alignment_language)
|
| 508 |
+
|
| 509 |
+
# 4. 词对齐 - 如果失败则直接抛出异常
|
| 510 |
+
alignments = self.align_text_with_audio(comma_text, output_audio, alignment_language)
|
| 511 |
+
|
| 512 |
+
# 5. 根据标点符号划分句子
|
| 513 |
+
segments = []
|
| 514 |
+
safe_audio_id = self._get_safe_filename(audio_id)
|
| 515 |
+
|
| 516 |
+
# 确定标点符号(根据语言选择,英文不包含撇号)
|
| 517 |
+
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
|
| 518 |
+
punctuation_chars = set([',', '.', '!', '?', ';', ':']) # 不包含撇号
|
| 519 |
+
else:
|
| 520 |
+
punctuation_chars = set([',', '。', '!', '?', ';', ':'])
|
| 521 |
+
|
| 522 |
+
# 顺序扫描对齐结果,根据标点符号划分句子
|
| 523 |
+
sentence_start_idx = 0
|
| 524 |
+
sentence_alignments = []
|
| 525 |
+
segment_id = 0
|
| 526 |
+
|
| 527 |
+
for i, align_item in enumerate(alignments):
|
| 528 |
+
transcript = align_item['transcript']
|
| 529 |
+
sentence_alignments.append(align_item)
|
| 530 |
+
|
| 531 |
+
# 检查是否包含标点符号(句子结束标志)
|
| 532 |
+
has_punctuation = any(punct in transcript for punct in punctuation_chars)
|
| 533 |
+
|
| 534 |
+
if has_punctuation or i == len(alignments) - 1: # 遇到标点符号或最后一个词
|
| 535 |
+
# 创建句子片段
|
| 536 |
+
if sentence_alignments:
|
| 537 |
+
# 获取句子的开始和结束时间
|
| 538 |
+
start_time = sentence_alignments[0]['start']
|
| 539 |
+
end_time = sentence_alignments[-1]['end']
|
| 540 |
+
|
| 541 |
+
# 构建句子文本(去除标点符号)
|
| 542 |
+
sentence_text_parts = []
|
| 543 |
+
for align in sentence_alignments:
|
| 544 |
+
# 根据语言选择不同的清理策略
|
| 545 |
+
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
|
| 546 |
+
# 英文:去除标点符号,但保留撇号已被删除的单词
|
| 547 |
+
clean_transcript = align['transcript'].rstrip(',.!?;:')
|
| 548 |
+
else:
|
| 549 |
+
# 中文:去除中文标点符号
|
| 550 |
+
clean_transcript = align['transcript'].rstrip(',。!?;:')
|
| 551 |
+
|
| 552 |
+
if clean_transcript.strip():
|
| 553 |
+
sentence_text_parts.append(clean_transcript)
|
| 554 |
+
|
| 555 |
+
# 根据语言选择连接方式
|
| 556 |
+
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
|
| 557 |
+
sentence_text = ' '.join(sentence_text_parts).strip() # 英文用空格连接
|
| 558 |
+
else:
|
| 559 |
+
sentence_text = ''.join(sentence_text_parts).strip() # 中文直接连接
|
| 560 |
+
|
| 561 |
+
if sentence_text: # 只有非空句子才处理
|
| 562 |
+
# 确定说话人标签(从原始text_segments中获取,如果可能的话)
|
| 563 |
+
speaker_label = "S1" # 默认
|
| 564 |
+
if segment_id < len(text_segments):
|
| 565 |
+
speaker_label = text_segments[segment_id]['speaker_label']
|
| 566 |
+
elif text_segments:
|
| 567 |
+
# 如果超出范围,使用最后一个片段的speaker
|
| 568 |
+
speaker_label = text_segments[-1]['speaker_label']
|
| 569 |
+
|
| 570 |
+
# 生成音频文件路径
|
| 571 |
+
safe_text = self._get_safe_filename(sentence_text, 30)
|
| 572 |
+
audio_path = str(self.segments_dir / f"{safe_audio_id}_segment_{segment_id:03d}_{safe_text}.wav")
|
| 573 |
+
|
| 574 |
+
# 分割音频
|
| 575 |
+
try:
|
| 576 |
+
self.split_audio_segment(output_audio, start_time, end_time, audio_path)
|
| 577 |
+
except Exception as e:
|
| 578 |
+
self.logger.error(f"分割音频失败: {e}")
|
| 579 |
+
# 使用默认时间间隔
|
| 580 |
+
start_time = segment_id * 1.0
|
| 581 |
+
end_time = (segment_id + 1) * 1.0
|
| 582 |
+
self.split_audio_segment(output_audio, start_time, end_time, audio_path)
|
| 583 |
+
|
| 584 |
+
# 创建segment
|
| 585 |
+
segment = {
|
| 586 |
+
'segment_id': segment_id,
|
| 587 |
+
'text': sentence_text,
|
| 588 |
+
'speaker_label': speaker_label,
|
| 589 |
+
'original_speaker_content': sentence_text, # 这里简化处理
|
| 590 |
+
'audio_path': audio_path,
|
| 591 |
+
'start_time': start_time,
|
| 592 |
+
'end_time': end_time
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
segments.append(segment)
|
| 596 |
+
|
| 597 |
+
self.logger.info(f"句子 {segment_id}: '{sentence_text}' ({speaker_label}) -> {start_time:.3f}-{end_time:.3f}s")
|
| 598 |
+
segment_id += 1
|
| 599 |
+
|
| 600 |
+
# 重置为下一个句子
|
| 601 |
+
sentence_alignments = []
|
| 602 |
+
sentence_start_idx = i + 1
|
| 603 |
+
|
| 604 |
+
# 保存详细的对齐信息
|
| 605 |
+
self.save_detailed_alignment_info(
|
| 606 |
+
alignments, segments, audio_id, output_audio, text, comma_text
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
self.logger.info(f"总共分割出 {len(segments)} 个句子片段")
|
| 610 |
+
return segments
|
| 611 |
+
|
| 612 |
+
def _get_thread_local_similarity_model(self):
|
| 613 |
+
"""获取线程局部的相似度模型实例(线程安全)"""
|
| 614 |
+
if not hasattr(self._thread_local, 'similarity_model'):
|
| 615 |
+
# 为当前线程创建独立的模型实例
|
| 616 |
+
self._thread_local.similarity_model = self._create_similarity_model()
|
| 617 |
+
return self._thread_local.similarity_model
|
| 618 |
+
|
| 619 |
+
def _create_similarity_model(self):
|
| 620 |
+
"""创建新的相似度模型实例"""
|
| 621 |
+
try:
|
| 622 |
+
import wespeaker
|
| 623 |
+
|
| 624 |
+
# 使用与主模型相同的加载逻辑
|
| 625 |
+
local_model_path = '/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft'
|
| 626 |
+
|
| 627 |
+
try:
|
| 628 |
+
model = wespeaker.load_model_local(local_model_path)
|
| 629 |
+
return model
|
| 630 |
+
except Exception as e:
|
| 631 |
+
self.logger.warning(f"加载指定本地模型失败: {e}")
|
| 632 |
+
|
| 633 |
+
# 回退方案
|
| 634 |
+
if os.path.exists(self.wespeaker_model_dir):
|
| 635 |
+
try:
|
| 636 |
+
model = wespeaker.load_model_local(self.wespeaker_model_dir)
|
| 637 |
+
return model
|
| 638 |
+
except Exception as e:
|
| 639 |
+
self.logger.warning(f"加载传入本地模型失败: {e}")
|
| 640 |
+
|
| 641 |
+
# 最终回退到预训练模型
|
| 642 |
+
try:
|
| 643 |
+
model = wespeaker.load_model('chinese')
|
| 644 |
+
return model
|
| 645 |
+
except Exception as e:
|
| 646 |
+
model = wespeaker.load_model('english')
|
| 647 |
+
return model
|
| 648 |
+
|
| 649 |
+
except Exception as e:
|
| 650 |
+
self.logger.error(f"创建相似度模型失败: {e}")
|
| 651 |
+
raise
|
| 652 |
+
|
| 653 |
+
def calculate_voice_similarity_thread_safe(self, audio1_path: str, audio2_path: str) -> float:
|
| 654 |
+
"""
|
| 655 |
+
线程安全的音色相似度计算
|
| 656 |
+
对于过短的音频片段,通过复制来达到最小长度要求
|
| 657 |
+
"""
|
| 658 |
+
try:
|
| 659 |
+
if not os.path.exists(audio1_path) or not os.path.exists(audio2_path):
|
| 660 |
+
self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}")
|
| 661 |
+
return None
|
| 662 |
+
|
| 663 |
+
# 获取线程局部的模型实例
|
| 664 |
+
similarity_model = self._get_thread_local_similarity_model()
|
| 665 |
+
|
| 666 |
+
# 检查并处理音频文件长度
|
| 667 |
+
def process_audio_for_similarity(audio_path, min_duration=0.1):
|
| 668 |
+
"""
|
| 669 |
+
处理音频文件,如果过短则复制到满足最小长度要求
|
| 670 |
+
返回处理后的音频路径和是否为临时文件的标志
|
| 671 |
+
"""
|
| 672 |
+
try:
|
| 673 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 674 |
+
duration = waveform.shape[1] / sample_rate
|
| 675 |
+
|
| 676 |
+
if duration >= min_duration:
|
| 677 |
+
# 音频长度足够,直接返回原路径
|
| 678 |
+
return audio_path, False
|
| 679 |
+
|
| 680 |
+
# 音频过短,需要复制
|
| 681 |
+
repeat_times = math.ceil(min_duration / duration)
|
| 682 |
+
thread_id = threading.get_ident()
|
| 683 |
+
|
| 684 |
+
# 复制音频
|
| 685 |
+
repeated_waveform = waveform.repeat(1, repeat_times)
|
| 686 |
+
|
| 687 |
+
# 生成临时文件路径(包含线程ID避免冲突)
|
| 688 |
+
temp_filename = f"temp_{thread_id}_{os.path.basename(audio_path)}"
|
| 689 |
+
temp_path = str(self.temp_dir / temp_filename)
|
| 690 |
+
|
| 691 |
+
# 保存复制后的音频
|
| 692 |
+
torchaudio.save(temp_path, repeated_waveform, sample_rate)
|
| 693 |
+
|
| 694 |
+
return temp_path, True
|
| 695 |
+
|
| 696 |
+
except Exception as e:
|
| 697 |
+
self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}")
|
| 698 |
+
return audio_path, False
|
| 699 |
+
|
| 700 |
+
# 处理两个音频文件
|
| 701 |
+
processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path)
|
| 702 |
+
processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path)
|
| 703 |
+
|
| 704 |
+
# 计算相似度
|
| 705 |
+
similarity = similarity_model.compute_similarity(processed_audio1, processed_audio2)
|
| 706 |
+
|
| 707 |
+
# 清理临时文件
|
| 708 |
+
if is_temp1 and os.path.exists(processed_audio1):
|
| 709 |
+
try:
|
| 710 |
+
os.remove(processed_audio1)
|
| 711 |
+
except Exception as e:
|
| 712 |
+
self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}")
|
| 713 |
+
|
| 714 |
+
if is_temp2 and os.path.exists(processed_audio2):
|
| 715 |
+
try:
|
| 716 |
+
os.remove(processed_audio2)
|
| 717 |
+
except Exception as e:
|
| 718 |
+
self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}")
|
| 719 |
+
|
| 720 |
+
return float(similarity)
|
| 721 |
+
|
| 722 |
+
except Exception as e:
|
| 723 |
+
# 检查是否是窗口大小错误或其他计算错误
|
| 724 |
+
if "choose a window size" in str(e) or "window size" in str(e):
|
| 725 |
+
self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}")
|
| 726 |
+
return None
|
| 727 |
+
else:
|
| 728 |
+
self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}")
|
| 729 |
+
return None
|
| 730 |
+
|
| 731 |
+
def calculate_segment_similarities_parallel(self, output_segments: List[Dict[str, Any]],
|
| 732 |
+
prompts1_path: str, prompts2_path: str) -> List[Dict[str, Any]]:
|
| 733 |
+
"""
|
| 734 |
+
并行计算所有segments的相似度
|
| 735 |
+
Args:
|
| 736 |
+
output_segments: 音频segments列表
|
| 737 |
+
prompts1_path: S1 prompt音频路径
|
| 738 |
+
prompts2_path: S2 prompt音频路径
|
| 739 |
+
Returns:
|
| 740 |
+
包含相似度信息的segment列表
|
| 741 |
+
"""
|
| 742 |
+
|
| 743 |
+
def calculate_single_segment_similarity(segment):
|
| 744 |
+
"""计算单个segment与两个prompts的相似度"""
|
| 745 |
+
try:
|
| 746 |
+
# 使用线程安全的相似度计算方法
|
| 747 |
+
sim1 = self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts1_path)
|
| 748 |
+
sim2 = self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts2_path)
|
| 749 |
+
|
| 750 |
+
return {
|
| 751 |
+
'segment': segment,
|
| 752 |
+
'sim1': sim1,
|
| 753 |
+
'sim2': sim2,
|
| 754 |
+
'success': True
|
| 755 |
+
}
|
| 756 |
+
except Exception as e:
|
| 757 |
+
self.logger.error(f"计算segment {segment['segment_id']} 相似度失败: {e}")
|
| 758 |
+
return {
|
| 759 |
+
'segment': segment,
|
| 760 |
+
'sim1': None,
|
| 761 |
+
'sim2': None,
|
| 762 |
+
'success': False
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
# 使用线程池并行处理所有segments
|
| 766 |
+
self.logger.info(f"开始并行计算 {len(output_segments)} 个segments的相似度,使用 {self.similarity_max_workers} 个线程")
|
| 767 |
+
|
| 768 |
+
results = []
|
| 769 |
+
with ThreadPoolExecutor(max_workers=self.similarity_max_workers) as executor:
|
| 770 |
+
# 提交所有segment任务
|
| 771 |
+
future_to_segment = {
|
| 772 |
+
executor.submit(calculate_single_segment_similarity, segment): segment
|
| 773 |
+
for segment in output_segments
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
# 收集结果(保持原有顺序)
|
| 777 |
+
segment_to_result = {}
|
| 778 |
+
completed_count = 0
|
| 779 |
+
for future in as_completed(future_to_segment):
|
| 780 |
+
result = future.result()
|
| 781 |
+
segment_id = result['segment']['segment_id']
|
| 782 |
+
segment_to_result[segment_id] = result
|
| 783 |
+
completed_count += 1
|
| 784 |
+
|
| 785 |
+
# 每完成10个segment报告一次进度
|
| 786 |
+
if completed_count % 10 == 0 or completed_count == len(output_segments):
|
| 787 |
+
self.logger.info(f"相似度计算进度: {completed_count}/{len(output_segments)}")
|
| 788 |
+
|
| 789 |
+
# 按segment_id顺序返回结果
|
| 790 |
+
for segment in output_segments:
|
| 791 |
+
segment_id = segment['segment_id']
|
| 792 |
+
if segment_id in segment_to_result:
|
| 793 |
+
results.append(segment_to_result[segment_id])
|
| 794 |
+
|
| 795 |
+
return results
|
| 796 |
+
|
| 797 |
+
def evaluate_single_input(self, data: Dict[str, Any], input_id: str = None) -> Dict[str, Any]:
|
| 798 |
+
"""评估单个输入的音色相似度"""
|
| 799 |
+
|
| 800 |
+
# 生成输入ID
|
| 801 |
+
if input_id is None:
|
| 802 |
+
input_id = f"input_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 803 |
+
|
| 804 |
+
self.logger.info(f"开始评估输入: {input_id},使用语言: {self.language}")
|
| 805 |
+
|
| 806 |
+
# 1. 获取或分割prompt音频
|
| 807 |
+
prompts1_path, prompts2_path = self.get_or_split_prompt_audio(data, f"{input_id}_prompt")
|
| 808 |
+
|
| 809 |
+
# 2. 分割output音频(这里会保存详细对齐信息)
|
| 810 |
+
output_segments = self.split_output_audio_by_comma(data['text'], data['output_audio'], f"{input_id}_output")
|
| 811 |
+
|
| 812 |
+
# 3. 并行计算每小段的相似度
|
| 813 |
+
similarity_results = self.calculate_segment_similarities_parallel(
|
| 814 |
+
output_segments, prompts1_path, prompts2_path
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
# 4. 处理相似度结果
|
| 818 |
+
segment_results = []
|
| 819 |
+
correct_predictions = 0
|
| 820 |
+
total_segments = 0 # 只计算有效段数
|
| 821 |
+
label_similarities = [] # 每小段与其标签的相似度
|
| 822 |
+
skipped_segments = 0 # 跳过的段数
|
| 823 |
+
|
| 824 |
+
for sim_result in similarity_results:
|
| 825 |
+
segment = sim_result['segment']
|
| 826 |
+
sim1 = sim_result['sim1']
|
| 827 |
+
sim2 = sim_result['sim2']
|
| 828 |
+
|
| 829 |
+
# 如果任一相似度为None(音频过短或计算失败),跳过该段
|
| 830 |
+
if sim1 is None or sim2 is None:
|
| 831 |
+
skipped_segments += 1
|
| 832 |
+
self.logger.info(f"跳过段 {segment['segment_id']}: 相似度计算失败")
|
| 833 |
+
continue
|
| 834 |
+
|
| 835 |
+
# 只有有效段才参与计算
|
| 836 |
+
total_segments += 1
|
| 837 |
+
|
| 838 |
+
# 判断实际音色
|
| 839 |
+
predicted_speaker = 'S1' if sim1 > sim2 else 'S2'
|
| 840 |
+
actual_speaker = segment['speaker_label']
|
| 841 |
+
is_correct = predicted_speaker == actual_speaker
|
| 842 |
+
|
| 843 |
+
if is_correct:
|
| 844 |
+
correct_predictions += 1
|
| 845 |
+
|
| 846 |
+
# 计算与标签的相似度
|
| 847 |
+
if actual_speaker == 'S1':
|
| 848 |
+
label_similarity = sim1
|
| 849 |
+
else:
|
| 850 |
+
label_similarity = sim2
|
| 851 |
+
label_similarities.append(label_similarity)
|
| 852 |
+
|
| 853 |
+
segment_result = {
|
| 854 |
+
'segment_id': segment['segment_id'],
|
| 855 |
+
'text': segment['text'],
|
| 856 |
+
'speaker_label': actual_speaker,
|
| 857 |
+
'predicted_speaker': predicted_speaker,
|
| 858 |
+
'sim1': sim1,
|
| 859 |
+
'sim2': sim2,
|
| 860 |
+
'label_similarity': label_similarity,
|
| 861 |
+
'is_correct': is_correct,
|
| 862 |
+
'audio_path': segment['audio_path'],
|
| 863 |
+
'start_time': segment.get('start_time', 0.0),
|
| 864 |
+
'end_time': segment.get('end_time', 1.0)
|
| 865 |
+
}
|
| 866 |
+
segment_results.append(segment_result)
|
| 867 |
+
|
| 868 |
+
# 4. 计算整体指标(只基于有效段)
|
| 869 |
+
accuracy = correct_predictions / total_segments if total_segments > 0 else 0.0
|
| 870 |
+
average_similarity = np.mean(label_similarities) if label_similarities else 0.0
|
| 871 |
+
|
| 872 |
+
# 5. 保存评估结果的对齐信息摘要
|
| 873 |
+
evaluation_alignment_summary = {
|
| 874 |
+
'input_id': input_id,
|
| 875 |
+
'language': self.language,
|
| 876 |
+
'prompt_alignment_files': [
|
| 877 |
+
f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json"
|
| 878 |
+
],
|
| 879 |
+
'output_alignment_file': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json",
|
| 880 |
+
'total_segments': total_segments,
|
| 881 |
+
'total_alignments_processed': len(output_segments),
|
| 882 |
+
'alignment_success_rate': total_segments / len(output_segments) if output_segments else 0.0
|
| 883 |
+
}
|
| 884 |
+
self.save_alignment_info(evaluation_alignment_summary, input_id, "summary")
|
| 885 |
+
|
| 886 |
+
result = {
|
| 887 |
+
'input_id': input_id,
|
| 888 |
+
'language': self.language,
|
| 889 |
+
'input_data': data, # 保存原始输入数据
|
| 890 |
+
'prompts1_path': prompts1_path,
|
| 891 |
+
'prompts2_path': prompts2_path,
|
| 892 |
+
'segments': segment_results,
|
| 893 |
+
'accuracy': accuracy,
|
| 894 |
+
'average_similarity': average_similarity,
|
| 895 |
+
'total_segments': total_segments, # 有效段数
|
| 896 |
+
'correct_predictions': correct_predictions,
|
| 897 |
+
'skipped_segments': skipped_segments, # 跳过的段数
|
| 898 |
+
'original_total_segments': len(output_segments), # 原始总段数
|
| 899 |
+
'alignment_files': {
|
| 900 |
+
'summary': f"{self._get_safe_filename(input_id)}_summary_alignment.json",
|
| 901 |
+
'output_detailed': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json",
|
| 902 |
+
'prompt': f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json"
|
| 903 |
+
},
|
| 904 |
+
'timestamp': datetime.now().isoformat()
|
| 905 |
+
}
|
| 906 |
+
|
| 907 |
+
self.logger.info(f"完成评估输入: {input_id}, 语言: {self.language}, 有效段: {total_segments}/{len(output_segments)}, 跳过: {skipped_segments}, 准确率: {accuracy:.3f}, 平均相似度: {average_similarity:.3f}")
|
| 908 |
+
|
| 909 |
+
return result
|
| 910 |
+
|
| 911 |
+
def save_results_to_jsonl(self, results: List[Dict[str, Any]], filename: str = None):
|
| 912 |
+
"""保存结果到JSONL文件"""
|
| 913 |
+
if filename is None:
|
| 914 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 915 |
+
filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
|
| 916 |
+
|
| 917 |
+
output_path = self.results_dir / filename
|
| 918 |
+
|
| 919 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 920 |
+
for result in results:
|
| 921 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
| 922 |
+
|
| 923 |
+
return str(output_path)
|
| 924 |
+
|
| 925 |
+
def save_summary_report(self, results: List[Dict[str, Any]], filename: str = None):
|
| 926 |
+
"""保存汇总报告"""
|
| 927 |
+
if filename is None:
|
| 928 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 929 |
+
filename = f"evaluation_summary_{self.language.lower()}_{timestamp}.json"
|
| 930 |
+
|
| 931 |
+
summary_path = self.results_dir / filename
|
| 932 |
+
|
| 933 |
+
# 计算总体统计
|
| 934 |
+
total_accuracy = np.mean([r['accuracy'] for r in results])
|
| 935 |
+
total_avg_similarity = np.mean([r['average_similarity'] for r in results])
|
| 936 |
+
total_segments = sum([r['total_segments'] for r in results])
|
| 937 |
+
total_correct = sum([r['correct_predictions'] for r in results])
|
| 938 |
+
|
| 939 |
+
summary = {
|
| 940 |
+
'evaluation_summary': {
|
| 941 |
+
'language': self.language,
|
| 942 |
+
'total_inputs': len(results),
|
| 943 |
+
'total_segments': total_segments,
|
| 944 |
+
'total_correct_predictions': total_correct,
|
| 945 |
+
'overall_accuracy': total_accuracy,
|
| 946 |
+
'overall_average_similarity': total_avg_similarity,
|
| 947 |
+
'evaluation_timestamp': datetime.now().isoformat(),
|
| 948 |
+
'output_directory': str(self.output_dir),
|
| 949 |
+
'alignment_directory': str(self.alignment_dir)
|
| 950 |
+
},
|
| 951 |
+
'per_input_results': [
|
| 952 |
+
{
|
| 953 |
+
'input_id': r['input_id'],
|
| 954 |
+
'language': r.get('language', self.language),
|
| 955 |
+
'accuracy': r['accuracy'],
|
| 956 |
+
'average_similarity': r['average_similarity'],
|
| 957 |
+
'total_segments': r['total_segments'],
|
| 958 |
+
'correct_predictions': r['correct_predictions'],
|
| 959 |
+
'output_audio_path': r['input_data']['output_audio'],
|
| 960 |
+
'alignment_files': r.get('alignment_files', {})
|
| 961 |
+
}
|
| 962 |
+
for r in results
|
| 963 |
+
]
|
| 964 |
+
}
|
| 965 |
+
|
| 966 |
+
with open(summary_path, 'w', encoding='utf-8') as f:
|
| 967 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 968 |
+
|
| 969 |
+
return str(summary_path)
|
| 970 |
+
|
| 971 |
+
def process_batch_from_jsonl_parallel(self, jsonl_path: str,
|
| 972 |
+
processes_per_gpu: int = 16,
|
| 973 |
+
results_filename: str = None,
|
| 974 |
+
shuffle_data: bool = True):
|
| 975 |
+
"""从JSONL文件并行批量处理输入数据"""
|
| 976 |
+
# 加载数据
|
| 977 |
+
input_data = self.load_data_from_jsonl(jsonl_path)
|
| 978 |
+
|
| 979 |
+
if not input_data:
|
| 980 |
+
self.logger.error("没有有效的输入数据")
|
| 981 |
+
return []
|
| 982 |
+
|
| 983 |
+
# 对数据进行shuffle,使分配更均匀
|
| 984 |
+
if shuffle_data:
|
| 985 |
+
random.shuffle(input_data)
|
| 986 |
+
self.logger.info(f"已对 {len(input_data)} 条数据进行随机shuffle")
|
| 987 |
+
|
| 988 |
+
return self.process_batch_parallel(input_data, processes_per_gpu, results_filename)
|
| 989 |
+
|
| 990 |
+
def process_batch_from_jsonl(self, jsonl_path: str, results_filename: str = None):
|
| 991 |
+
"""从JSONL文件批量处理输入数据(单进程版本)"""
|
| 992 |
+
# 加载数据
|
| 993 |
+
input_data = self.load_data_from_jsonl(jsonl_path)
|
| 994 |
+
|
| 995 |
+
if not input_data:
|
| 996 |
+
self.logger.error("没有有效的输入数据")
|
| 997 |
+
return []
|
| 998 |
+
|
| 999 |
+
return self.process_batch_from_data(input_data, results_filename)
|
| 1000 |
+
|
| 1001 |
+
def process_batch_from_data(self, input_data: List[Dict[str, Any]], results_filename: str = None):
|
| 1002 |
+
"""处理数据列表(单进程版本,用于兼容),支持增量写入"""
|
| 1003 |
+
# 准备结果文件
|
| 1004 |
+
if results_filename is None:
|
| 1005 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 1006 |
+
results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
|
| 1007 |
+
|
| 1008 |
+
results_path = self.results_dir / results_filename
|
| 1009 |
+
|
| 1010 |
+
# 如果文件已存在,删除它(重新开始)
|
| 1011 |
+
if results_path.exists():
|
| 1012 |
+
results_path.unlink()
|
| 1013 |
+
|
| 1014 |
+
results = []
|
| 1015 |
+
|
| 1016 |
+
self.logger.info(f"开始处理 {len(input_data)} 个输入,使用语言: {self.language}...")
|
| 1017 |
+
|
| 1018 |
+
for i, data in enumerate(input_data):
|
| 1019 |
+
input_id = f"input_{i+1:03d}"
|
| 1020 |
+
print(f"处理第{i+1}/{len(input_data)}个输入: {input_id},语言: {self.language}")
|
| 1021 |
+
|
| 1022 |
+
try:
|
| 1023 |
+
result = self.evaluate_single_input(data, input_id=input_id)
|
| 1024 |
+
results.append(result)
|
| 1025 |
+
|
| 1026 |
+
# 增量写入结果
|
| 1027 |
+
self.append_result_to_jsonl(result, str(results_path))
|
| 1028 |
+
|
| 1029 |
+
except Exception as e:
|
| 1030 |
+
self.logger.error(f"处理输入{input_id}时出错: {e}")
|
| 1031 |
+
continue
|
| 1032 |
+
|
| 1033 |
+
if not results:
|
| 1034 |
+
self.logger.error("没有成功处理的输入")
|
| 1035 |
+
return []
|
| 1036 |
+
|
| 1037 |
+
# 保存汇总报告
|
| 1038 |
+
summary_path = self.save_summary_report(results)
|
| 1039 |
+
|
| 1040 |
+
# 清理临时文件
|
| 1041 |
+
self._clean_temp_files()
|
| 1042 |
+
|
| 1043 |
+
# 打印总体统计
|
| 1044 |
+
total_accuracy = np.mean([r['accuracy'] for r in results])
|
| 1045 |
+
total_avg_similarity = np.mean([r['average_similarity'] for r in results])
|
| 1046 |
+
|
| 1047 |
+
print(f"\n=== 评估完成 ===")
|
| 1048 |
+
print(f"使用语言: {self.language}")
|
| 1049 |
+
print(f"总体准确率: {total_accuracy:.3f}")
|
| 1050 |
+
print(f"总体平均相似度: {total_avg_similarity:.3f}")
|
| 1051 |
+
print(f"详细结果已保存到: {results_path}")
|
| 1052 |
+
print(f"汇总报告已保存到: {summary_path}")
|
| 1053 |
+
print(f"对齐信息已保存到: {self.alignment_dir}")
|
| 1054 |
+
print(f"所有中间文件保存在: {self.output_dir}")
|
| 1055 |
+
|
| 1056 |
+
return results
|
| 1057 |
+
|
| 1058 |
+
def _load_wespeaker_model(self, wespeaker_model_dir):
|
| 1059 |
+
"""加载wespeaker模型"""
|
| 1060 |
+
try:
|
| 1061 |
+
import wespeaker
|
| 1062 |
+
|
| 1063 |
+
# 使用load_model_local方法加载本地模型
|
| 1064 |
+
# 根据你提供的参考,使用你指定的模型路径
|
| 1065 |
+
local_model_path = '/inspire/ssd/project/embodied-multimodality/public/zylin/speaker_embedding/wespeaker_pretrain/voxblink2_samresnet100_ft'
|
| 1066 |
+
|
| 1067 |
+
try:
|
| 1068 |
+
self.similarity_model = wespeaker.load_model_local(local_model_path)
|
| 1069 |
+
self.logger.info(f"成功加载本地wespeaker模型: {local_model_path}")
|
| 1070 |
+
return
|
| 1071 |
+
except Exception as e:
|
| 1072 |
+
self.logger.warning(f"加载指定本地模型失败: {e}")
|
| 1073 |
+
|
| 1074 |
+
# 回退方案1: 尝试使用传入的模型目录
|
| 1075 |
+
if os.path.exists(wespeaker_model_dir):
|
| 1076 |
+
try:
|
| 1077 |
+
self.similarity_model = wespeaker.load_model_local(wespeaker_model_dir)
|
| 1078 |
+
self.logger.info(f"成功加载传入的本地wespeaker模型: {wespeaker_model_dir}")
|
| 1079 |
+
return
|
| 1080 |
+
except Exception as e:
|
| 1081 |
+
self.logger.warning(f"加载传入本地模型失败: {e}")
|
| 1082 |
+
|
| 1083 |
+
# 回退方案2: 使用预训练的中文模型
|
| 1084 |
+
try:
|
| 1085 |
+
self.similarity_model = wespeaker.load_model('chinese')
|
| 1086 |
+
self.logger.info("回退到wespeaker预训练中文模型")
|
| 1087 |
+
return
|
| 1088 |
+
except Exception as e:
|
| 1089 |
+
self.logger.warning(f"加载预训练中文模型失败: {e}")
|
| 1090 |
+
|
| 1091 |
+
# 回退方案3: 使用预训练的英文模型
|
| 1092 |
+
try:
|
| 1093 |
+
self.similarity_model = wespeaker.load_model('english')
|
| 1094 |
+
self.logger.info("回退到wespeaker预训练英文模型")
|
| 1095 |
+
return
|
| 1096 |
+
except Exception as e:
|
| 1097 |
+
self.logger.error(f"加载英文模型也失败: {e}")
|
| 1098 |
+
|
| 1099 |
+
# 如果所有方法都失败,抛出异常
|
| 1100 |
+
raise Exception("无法加载任何wespeaker模型")
|
| 1101 |
+
|
| 1102 |
+
except ImportError:
|
| 1103 |
+
raise ImportError("请安装wespeaker: pip install git+https://github.com/wenet-e2e/wespeaker.git")
|
| 1104 |
+
except Exception as e:
|
| 1105 |
+
self.logger.error(f"加载wespeaker模型失败: {e}")
|
| 1106 |
+
raise
|
| 1107 |
+
|
| 1108 |
+
def load_data_from_jsonl(self, jsonl_path: str) -> List[Dict[str, Any]]:
|
| 1109 |
+
"""从JSONL文件加载数据"""
|
| 1110 |
+
data = []
|
| 1111 |
+
try:
|
| 1112 |
+
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
| 1113 |
+
for line_num, line in enumerate(f, 1):
|
| 1114 |
+
line = line.strip()
|
| 1115 |
+
if line:
|
| 1116 |
+
try:
|
| 1117 |
+
item = json.loads(line)
|
| 1118 |
+
# 验证必要字段
|
| 1119 |
+
required_fields = ['text', 'output_audio']
|
| 1120 |
+
for field in required_fields:
|
| 1121 |
+
if field not in item:
|
| 1122 |
+
self.logger.error(f"第{line_num}行缺少必要字段: {field}")
|
| 1123 |
+
continue
|
| 1124 |
+
|
| 1125 |
+
# 验证音频路径模式:要么有prompt_audio和prompt_text,要么有分别的speaker音频文件
|
| 1126 |
+
has_combined_prompt = 'prompt_audio' in item and 'prompt_text' in item
|
| 1127 |
+
has_separate_prompts = ('prompt_audio_speaker1' in item and
|
| 1128 |
+
'prompt_text_speaker1' in item and
|
| 1129 |
+
'prompt_audio_speaker2' in item and
|
| 1130 |
+
'prompt_text_speaker2' in item)
|
| 1131 |
+
|
| 1132 |
+
if not (has_combined_prompt or has_separate_prompts):
|
| 1133 |
+
self.logger.error(f"第{line_num}行:需要提供prompt_audio+prompt_text或者分别的speaker音频文件")
|
| 1134 |
+
continue
|
| 1135 |
+
|
| 1136 |
+
data.append(item)
|
| 1137 |
+
|
| 1138 |
+
except json.JSONDecodeError as e:
|
| 1139 |
+
self.logger.error(f"第{line_num}行JSON解析错误: {e}")
|
| 1140 |
+
continue
|
| 1141 |
+
|
| 1142 |
+
self.logger.info(f"从{jsonl_path}成功加载{len(data)}条数据")
|
| 1143 |
+
return data
|
| 1144 |
+
|
| 1145 |
+
except FileNotFoundError:
|
| 1146 |
+
self.logger.error(f"JSONL文件不存在: {jsonl_path}")
|
| 1147 |
+
return []
|
| 1148 |
+
except Exception as e:
|
| 1149 |
+
self.logger.error(f"读取JSONL文件失败: {e}")
|
| 1150 |
+
return []
|
| 1151 |
+
|
| 1152 |
+
@staticmethod
|
| 1153 |
+
def get_gpu_count():
|
| 1154 |
+
"""获取可用GPU数量"""
|
| 1155 |
+
if torch.cuda.is_available():
|
| 1156 |
+
return torch.cuda.device_count()
|
| 1157 |
+
return 0
|
| 1158 |
+
|
| 1159 |
+
@staticmethod
|
| 1160 |
+
def split_data_by_gpu(data: List[Dict[str, Any]], num_gpus: int) -> List[List[Dict[str, Any]]]:
|
| 1161 |
+
"""根据GPU数量分割数据"""
|
| 1162 |
+
if num_gpus == 0:
|
| 1163 |
+
return [data]
|
| 1164 |
+
|
| 1165 |
+
chunk_size = math.ceil(len(data) / num_gpus)
|
| 1166 |
+
gpu_chunks = []
|
| 1167 |
+
|
| 1168 |
+
for i in range(num_gpus):
|
| 1169 |
+
start_idx = i * chunk_size
|
| 1170 |
+
end_idx = min((i + 1) * chunk_size, len(data))
|
| 1171 |
+
if start_idx < len(data):
|
| 1172 |
+
gpu_chunks.append(data[start_idx:end_idx])
|
| 1173 |
+
|
| 1174 |
+
return gpu_chunks
|
| 1175 |
+
|
| 1176 |
+
@staticmethod
|
| 1177 |
+
def split_data_by_processes(data: List[Dict[str, Any]], num_processes: int) -> List[List[Dict[str, Any]]]:
|
| 1178 |
+
"""根据进程数量分割数据"""
|
| 1179 |
+
if num_processes <= 1:
|
| 1180 |
+
return [data]
|
| 1181 |
+
|
| 1182 |
+
chunk_size = math.ceil(len(data) / num_processes)
|
| 1183 |
+
process_chunks = []
|
| 1184 |
+
|
| 1185 |
+
for i in range(num_processes):
|
| 1186 |
+
start_idx = i * chunk_size
|
| 1187 |
+
end_idx = min((i + 1) * chunk_size, len(data))
|
| 1188 |
+
if start_idx < len(data):
|
| 1189 |
+
process_chunks.append(data[start_idx:end_idx])
|
| 1190 |
+
|
| 1191 |
+
return process_chunks
|
| 1192 |
+
|
| 1193 |
+
def append_result_to_jsonl(self, result: Dict[str, Any], filepath: str):
|
| 1194 |
+
"""增量写入结果到JSONL文件"""
|
| 1195 |
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
| 1196 |
+
with open(filepath, 'a', encoding='utf-8') as f:
|
| 1197 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
| 1198 |
+
f.flush() # 强制刷新缓冲区
|
| 1199 |
+
|
| 1200 |
+
def merge_temp_results(self, temp_files: List[str], final_path: str):
|
| 1201 |
+
"""合并临时结果文件"""
|
| 1202 |
+
all_results = []
|
| 1203 |
+
|
| 1204 |
+
for temp_file in temp_files:
|
| 1205 |
+
if os.path.exists(temp_file):
|
| 1206 |
+
try:
|
| 1207 |
+
with open(temp_file, 'r', encoding='utf-8') as f:
|
| 1208 |
+
for line in f:
|
| 1209 |
+
line = line.strip()
|
| 1210 |
+
if line:
|
| 1211 |
+
result = json.loads(line)
|
| 1212 |
+
all_results.append(result)
|
| 1213 |
+
except Exception as e:
|
| 1214 |
+
self.logger.error(f"读取临时文件失败: {temp_file}, 错误: {e}")
|
| 1215 |
+
|
| 1216 |
+
# 写入最终文件
|
| 1217 |
+
with open(final_path, 'w', encoding='utf-8') as f:
|
| 1218 |
+
for result in all_results:
|
| 1219 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
| 1220 |
+
|
| 1221 |
+
return all_results
|
| 1222 |
+
|
| 1223 |
+
def process_batch_parallel(self, input_data: List[Dict[str, Any]],
|
| 1224 |
+
processes_per_gpu: int = 8, # 降低进程数
|
| 1225 |
+
results_filename: str = None,
|
| 1226 |
+
shuffle_data: bool = True):
|
| 1227 |
+
"""并行批量处理输入数据"""
|
| 1228 |
+
# 1. ���查GPU数量
|
| 1229 |
+
num_gpus = self.get_gpu_count()
|
| 1230 |
+
if num_gpus == 0:
|
| 1231 |
+
self.logger.warning("未检测到GPU,将使用CPU单进程处理")
|
| 1232 |
+
return self.process_batch_from_data(input_data, results_filename)
|
| 1233 |
+
|
| 1234 |
+
# 限制每个GPU的进程数,避免CUDA内存冲突
|
| 1235 |
+
max_processes_per_gpu = min(processes_per_gpu, 16)
|
| 1236 |
+
self.logger.info(f"检测到 {num_gpus} 个GPU,每个GPU将使用 {max_processes_per_gpu} 个进程")
|
| 1237 |
+
|
| 1238 |
+
# 2. 对数据进行shuffle(如果还没有shuffle过)
|
| 1239 |
+
shuffled_data = input_data.copy()
|
| 1240 |
+
if shuffle_data:
|
| 1241 |
+
random.shuffle(shuffled_data)
|
| 1242 |
+
self.logger.info(f"已对 {len(shuffled_data)} 条数据进行随机shuffle以平衡GPU负载")
|
| 1243 |
+
|
| 1244 |
+
# 3. 按GPU分割数据
|
| 1245 |
+
gpu_chunks = self.split_data_by_gpu(shuffled_data, num_gpus)
|
| 1246 |
+
|
| 1247 |
+
# 打印每个GPU分配到的数据量
|
| 1248 |
+
for gpu_id, gpu_data in enumerate(gpu_chunks):
|
| 1249 |
+
if gpu_data:
|
| 1250 |
+
self.logger.info(f"GPU {gpu_id}: 分配到 {len(gpu_data)} 条数据")
|
| 1251 |
+
|
| 1252 |
+
# 4. 准备结果文件路径
|
| 1253 |
+
if results_filename is None:
|
| 1254 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 1255 |
+
results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
|
| 1256 |
+
|
| 1257 |
+
final_results_path = self.results_dir / results_filename
|
| 1258 |
+
|
| 1259 |
+
# 5. 为所有GPU准备进程参数
|
| 1260 |
+
all_temp_files = []
|
| 1261 |
+
all_gpu_tasks = []
|
| 1262 |
+
|
| 1263 |
+
for gpu_id, gpu_data in enumerate(gpu_chunks):
|
| 1264 |
+
if not gpu_data:
|
| 1265 |
+
continue
|
| 1266 |
+
|
| 1267 |
+
self.logger.info(f"GPU {gpu_id}: 准备处理 {len(gpu_data)} 条数据")
|
| 1268 |
+
|
| 1269 |
+
# 按进程数分割当前GPU的数据
|
| 1270 |
+
process_chunks = self.split_data_by_processes(gpu_data, max_processes_per_gpu)
|
| 1271 |
+
|
| 1272 |
+
# 为当前GPU准备所有进程参数
|
| 1273 |
+
gpu_process_args = []
|
| 1274 |
+
for proc_id, proc_data in enumerate(process_chunks):
|
| 1275 |
+
if proc_data:
|
| 1276 |
+
temp_result_file = str(self.temp_results_dir / f"gpu{gpu_id}_proc{proc_id}_results.jsonl")
|
| 1277 |
+
all_temp_files.append(temp_result_file)
|
| 1278 |
+
|
| 1279 |
+
# 子进程输出目录在主输出目录内部
|
| 1280 |
+
subprocess_output_dir = str(self.output_dir / f"gpu{gpu_id}_proc{proc_id}")
|
| 1281 |
+
|
| 1282 |
+
gpu_process_args.append((
|
| 1283 |
+
proc_data,
|
| 1284 |
+
gpu_id,
|
| 1285 |
+
proc_id,
|
| 1286 |
+
subprocess_output_dir,
|
| 1287 |
+
temp_result_file,
|
| 1288 |
+
self.alignment_model_dir,
|
| 1289 |
+
self.wespeaker_model_dir,
|
| 1290 |
+
self.language, # 语言参数
|
| 1291 |
+
self.similarity_max_workers # 添加相似度计算线程数参数
|
| 1292 |
+
))
|
| 1293 |
+
|
| 1294 |
+
if gpu_process_args:
|
| 1295 |
+
all_gpu_tasks.append((gpu_id, gpu_process_args, max_processes_per_gpu))
|
| 1296 |
+
|
| 1297 |
+
# 6. 使用ThreadPoolExecutor并行处理所有GPU
|
| 1298 |
+
def process_gpu_tasks(gpu_task):
|
| 1299 |
+
gpu_id, process_args, actual_processes = gpu_task
|
| 1300 |
+
self.logger.info(f"GPU {gpu_id}: 开始并行处理 {len(process_args)} 个进程")
|
| 1301 |
+
|
| 1302 |
+
# 为每个GPU使用独立的进程池,避免进程间冲突
|
| 1303 |
+
with mp.Pool(processes=actual_processes) as pool:
|
| 1304 |
+
pool.map(process_data_chunk_incremental, process_args)
|
| 1305 |
+
|
| 1306 |
+
self.logger.info(f"GPU {gpu_id}: 所有进程处理完成")
|
| 1307 |
+
return gpu_id
|
| 1308 |
+
|
| 1309 |
+
# 使用线程池同时处理所有GPU
|
| 1310 |
+
with ThreadPoolExecutor(max_workers=num_gpus) as executor:
|
| 1311 |
+
# 提交所有GPU任务
|
| 1312 |
+
future_to_gpu = {executor.submit(process_gpu_tasks, gpu_task): gpu_task[0]
|
| 1313 |
+
for gpu_task in all_gpu_tasks}
|
| 1314 |
+
|
| 1315 |
+
# 等待所有GPU完成
|
| 1316 |
+
completed_gpus = []
|
| 1317 |
+
for future in as_completed(future_to_gpu):
|
| 1318 |
+
gpu_id = future_to_gpu[future]
|
| 1319 |
+
try:
|
| 1320 |
+
result_gpu_id = future.result()
|
| 1321 |
+
completed_gpus.append(result_gpu_id)
|
| 1322 |
+
self.logger.info(f"GPU {result_gpu_id} 完成处理")
|
| 1323 |
+
except Exception as exc:
|
| 1324 |
+
self.logger.error(f"GPU {gpu_id} 处理时发生异常: {exc}")
|
| 1325 |
+
|
| 1326 |
+
self.logger.info(f"所有GPU处理完成: {completed_gpus}")
|
| 1327 |
+
|
| 1328 |
+
# 7. 合并所有临时结果文件
|
| 1329 |
+
self.logger.info("合并所有临时结果文件...")
|
| 1330 |
+
all_results = self.merge_temp_results(all_temp_files, str(final_results_path))
|
| 1331 |
+
|
| 1332 |
+
if not all_results:
|
| 1333 |
+
self.logger.error("没有成功处理的数据")
|
| 1334 |
+
return []
|
| 1335 |
+
|
| 1336 |
+
# 8. 生成汇总报告
|
| 1337 |
+
summary_path = self.save_summary_report(all_results)
|
| 1338 |
+
|
| 1339 |
+
# 9. 清理临时文件
|
| 1340 |
+
for temp_file in all_temp_files:
|
| 1341 |
+
if os.path.exists(temp_file):
|
| 1342 |
+
os.remove(temp_file)
|
| 1343 |
+
|
| 1344 |
+
# 10. 打印总体统计
|
| 1345 |
+
total_accuracy = np.mean([r['accuracy'] for r in all_results])
|
| 1346 |
+
total_avg_similarity = np.mean([r['average_similarity'] for r in all_results])
|
| 1347 |
+
|
| 1348 |
+
print(f"\n=== 并行评估完成 ===")
|
| 1349 |
+
print(f"使用语言: {self.language}")
|
| 1350 |
+
print(f"使用 {num_gpus} 个GPU,每GPU {max_processes_per_gpu} 个进程")
|
| 1351 |
+
print(f"总处理数据: {len(input_data)} 条")
|
| 1352 |
+
print(f"成功处理: {len(all_results)} 条")
|
| 1353 |
+
print(f"总体准确率: {total_accuracy:.3f}")
|
| 1354 |
+
print(f"总体平均相似度: {total_avg_similarity:.3f}")
|
| 1355 |
+
print(f"详细结果已保存到: {final_results_path}")
|
| 1356 |
+
print(f"汇总报告已保存到: {summary_path}")
|
| 1357 |
+
print(f"对齐信息已保存到: {self.alignment_dir}")
|
| 1358 |
+
|
| 1359 |
+
return all_results
|
| 1360 |
+
|
| 1361 |
+
def get_or_split_prompt_audio(self, data: Dict[str, Any], audio_id: str) -> Tuple[str, str]:
|
| 1362 |
+
"""
|
| 1363 |
+
获取或分割prompt音频
|
| 1364 |
+
如果提供了分别的speaker音频文件则直接使用,否则从combined prompt分割
|
| 1365 |
+
"""
|
| 1366 |
+
# 检查是否有分别的speaker音频文件
|
| 1367 |
+
if ('prompt_audio_speaker1' in data and 'prompt_audio_speaker2' in data and
|
| 1368 |
+
'prompt_text_speaker1' in data and 'prompt_text_speaker2' in data):
|
| 1369 |
+
|
| 1370 |
+
self.logger.info(f"使用预分割的speaker音频文件")
|
| 1371 |
+
|
| 1372 |
+
# 即使使用预分割的音频,也保存对齐信息
|
| 1373 |
+
try:
|
| 1374 |
+
# 检测语言或使用设置的语言
|
| 1375 |
+
alignment_language = self.language
|
| 1376 |
+
if alignment_language == "AUTO":
|
| 1377 |
+
alignment_language = self._detect_language_from_text(data['prompt_text_speaker1'])
|
| 1378 |
+
|
| 1379 |
+
# 对S1音频进行对齐
|
| 1380 |
+
s1_alignments = self.align_text_with_audio(
|
| 1381 |
+
data['prompt_text_speaker1'], data['prompt_audio_speaker1'], alignment_language
|
| 1382 |
+
)
|
| 1383 |
+
s1_alignment_data = {
|
| 1384 |
+
'speaker': 'S1',
|
| 1385 |
+
'text': data['prompt_text_speaker1'],
|
| 1386 |
+
'audio_path': data['prompt_audio_speaker1'],
|
| 1387 |
+
'language': alignment_language,
|
| 1388 |
+
'alignments': s1_alignments
|
| 1389 |
+
}
|
| 1390 |
+
self.save_alignment_info(s1_alignment_data, audio_id, "prompt_s1")
|
| 1391 |
+
|
| 1392 |
+
# 对S2音频进行对齐
|
| 1393 |
+
s2_alignments = self.align_text_with_audio(
|
| 1394 |
+
data['prompt_text_speaker2'], data['prompt_audio_speaker2'], alignment_language
|
| 1395 |
+
)
|
| 1396 |
+
s2_alignment_data = {
|
| 1397 |
+
'speaker': 'S2',
|
| 1398 |
+
'text': data['prompt_text_speaker2'],
|
| 1399 |
+
'audio_path': data['prompt_audio_speaker2'],
|
| 1400 |
+
'language': alignment_language,
|
| 1401 |
+
'alignments': s2_alignments
|
| 1402 |
+
}
|
| 1403 |
+
self.save_alignment_info(s2_alignment_data, audio_id, "prompt_s2")
|
| 1404 |
+
|
| 1405 |
+
except Exception as e:
|
| 1406 |
+
self.logger.warning(f"保存预分割音频对齐信息失败: {e}")
|
| 1407 |
+
|
| 1408 |
+
return data['prompt_audio_speaker1'], data['prompt_audio_speaker2']
|
| 1409 |
+
|
| 1410 |
+
# 否则从combined prompt分割
|
| 1411 |
+
elif 'prompt_audio' in data and 'prompt_text' in data:
|
| 1412 |
+
self.logger.info(f"从combined prompt音频分割speaker片段")
|
| 1413 |
+
return self.split_audio_by_speaker(data['prompt_text'], data['prompt_audio'], audio_id)
|
| 1414 |
+
|
| 1415 |
+
else:
|
| 1416 |
+
raise ValueError("必须提供prompt_audio+prompt_text或者分别的speaker音频文件")
|
| 1417 |
+
|
| 1418 |
+
def calculate_voice_similarity(self, audio1_path: str, audio2_path: str) -> float:
|
| 1419 |
+
"""
|
| 1420 |
+
计算两个音频的音色相似度(向后兼容版本)
|
| 1421 |
+
对于过短的音频片段,通过复制来达到最小长度要求
|
| 1422 |
+
"""
|
| 1423 |
+
# 如果在多线程环境中,使用线程安全版本
|
| 1424 |
+
if threading.current_thread() != threading.main_thread():
|
| 1425 |
+
return self.calculate_voice_similarity_thread_safe(audio1_path, audio2_path)
|
| 1426 |
+
|
| 1427 |
+
# 确保模型已初始化
|
| 1428 |
+
self._init_models_if_needed()
|
| 1429 |
+
|
| 1430 |
+
try:
|
| 1431 |
+
if not os.path.exists(audio1_path) or not os.path.exists(audio2_path):
|
| 1432 |
+
self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}")
|
| 1433 |
+
return None
|
| 1434 |
+
|
| 1435 |
+
# 检查并处理音频文件长度
|
| 1436 |
+
def process_audio_for_similarity(audio_path, min_duration=0.1):
|
| 1437 |
+
"""
|
| 1438 |
+
处理音频文件,如果过短则复制到满足最小长度要求
|
| 1439 |
+
返回处理后的音频路径和是否为临时文件的标志
|
| 1440 |
+
"""
|
| 1441 |
+
try:
|
| 1442 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 1443 |
+
duration = waveform.shape[1] / sample_rate
|
| 1444 |
+
|
| 1445 |
+
if duration >= min_duration:
|
| 1446 |
+
# 音频长度足够,直接返回原路径
|
| 1447 |
+
return audio_path, False
|
| 1448 |
+
|
| 1449 |
+
# 音频过短,需要复制
|
| 1450 |
+
repeat_times = math.ceil(min_duration / duration)
|
| 1451 |
+
self.logger.info(f"音频过短 ({duration:.3f}s),复制 {repeat_times} 次达到 {min_duration}s 要求: {audio_path}")
|
| 1452 |
+
|
| 1453 |
+
# 复制音频
|
| 1454 |
+
repeated_waveform = waveform.repeat(1, repeat_times)
|
| 1455 |
+
|
| 1456 |
+
# 生成临时文件路径
|
| 1457 |
+
temp_filename = f"temp_{os.path.basename(audio_path)}"
|
| 1458 |
+
temp_path = str(self.temp_dir / temp_filename)
|
| 1459 |
+
|
| 1460 |
+
# 保存复制后的音频
|
| 1461 |
+
torchaudio.save(temp_path, repeated_waveform, sample_rate)
|
| 1462 |
+
|
| 1463 |
+
return temp_path, True
|
| 1464 |
+
|
| 1465 |
+
except Exception as e:
|
| 1466 |
+
self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}")
|
| 1467 |
+
return audio_path, False
|
| 1468 |
+
|
| 1469 |
+
# 处理两个音频文件
|
| 1470 |
+
processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path)
|
| 1471 |
+
processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path)
|
| 1472 |
+
|
| 1473 |
+
# 计算相似度
|
| 1474 |
+
similarity = self.similarity_model.compute_similarity(processed_audio1, processed_audio2)
|
| 1475 |
+
|
| 1476 |
+
# 清理临时文件
|
| 1477 |
+
if is_temp1 and os.path.exists(processed_audio1):
|
| 1478 |
+
try:
|
| 1479 |
+
os.remove(processed_audio1)
|
| 1480 |
+
except Exception as e:
|
| 1481 |
+
self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}")
|
| 1482 |
+
|
| 1483 |
+
if is_temp2 and os.path.exists(processed_audio2):
|
| 1484 |
+
try:
|
| 1485 |
+
os.remove(processed_audio2)
|
| 1486 |
+
except Exception as e:
|
| 1487 |
+
self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}")
|
| 1488 |
+
|
| 1489 |
+
return float(similarity)
|
| 1490 |
+
|
| 1491 |
+
except Exception as e:
|
| 1492 |
+
# 检查是否是窗口大小错误或其他计算错误
|
| 1493 |
+
if "choose a window size" in str(e) or "window size" in str(e):
|
| 1494 |
+
self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}")
|
| 1495 |
+
return None
|
| 1496 |
+
else:
|
| 1497 |
+
self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}")
|
| 1498 |
+
return None
|
| 1499 |
+
|
| 1500 |
+
# 全局函数,用于多进程处理(支持增量写入)
|
| 1501 |
+
def process_data_chunk_incremental(args):
|
| 1502 |
+
"""处理数据块的工作函数(增量写入版本)"""
|
| 1503 |
+
data_chunk, gpu_id, proc_id, output_dir, temp_result_file, alignment_model_dir, wespeaker_model_dir, language, similarity_max_workers = args
|
| 1504 |
+
|
| 1505 |
+
# 设置当前进程使用的GPU
|
| 1506 |
+
device = f"cuda:{gpu_id}" if torch.cuda.is_available() and gpu_id < torch.cuda.device_count() else "cpu"
|
| 1507 |
+
|
| 1508 |
+
try:
|
| 1509 |
+
# 清理CUDA状态,避免进程间冲突
|
| 1510 |
+
if torch.cuda.is_available():
|
| 1511 |
+
torch.cuda.empty_cache()
|
| 1512 |
+
# 设置当前进程的GPU设备
|
| 1513 |
+
torch.cuda.set_device(gpu_id)
|
| 1514 |
+
# 添加小延迟,避免同时初始化冲突
|
| 1515 |
+
time.sleep(proc_id * 0.5)
|
| 1516 |
+
|
| 1517 |
+
# 创建评估器实例,传入模型路径、语言参数和相似度计算线程数
|
| 1518 |
+
evaluator = SpeakerSimilarityEvaluator(
|
| 1519 |
+
device=device,
|
| 1520 |
+
alignment_model_dir=alignment_model_dir,
|
| 1521 |
+
wespeaker_model_dir=wespeaker_model_dir,
|
| 1522 |
+
output_dir=output_dir,
|
| 1523 |
+
language=language, # 传入语言参数
|
| 1524 |
+
similarity_max_workers=similarity_max_workers # 传入相似度计算线程数
|
| 1525 |
+
)
|
| 1526 |
+
|
| 1527 |
+
# 延迟初始化模型
|
| 1528 |
+
evaluator._init_models_if_needed()
|
| 1529 |
+
|
| 1530 |
+
# 清空临时结果文件(如果存在)
|
| 1531 |
+
if os.path.exists(temp_result_file):
|
| 1532 |
+
os.remove(temp_result_file)
|
| 1533 |
+
|
| 1534 |
+
# 处理数据块
|
| 1535 |
+
for i, data in enumerate(data_chunk):
|
| 1536 |
+
input_id = f"gpu{gpu_id}_proc{proc_id}_input_{i+1:03d}"
|
| 1537 |
+
|
| 1538 |
+
try:
|
| 1539 |
+
result = evaluator.evaluate_single_input(data, input_id=input_id)
|
| 1540 |
+
|
| 1541 |
+
# 立即写入结果到临时文件
|
| 1542 |
+
evaluator.append_result_to_jsonl(result, temp_result_file)
|
| 1543 |
+
|
| 1544 |
+
print(f"GPU{gpu_id}-进程{proc_id}: 完成 {input_id} (语言: {language}, 相似度线程: {similarity_max_workers})")
|
| 1545 |
+
|
| 1546 |
+
# 每处理完一个数据项,清理CUDA缓存
|
| 1547 |
+
if torch.cuda.is_available():
|
| 1548 |
+
torch.cuda.empty_cache()
|
| 1549 |
+
|
| 1550 |
+
except Exception as e:
|
| 1551 |
+
print(f"GPU{gpu_id}-进程{proc_id}: 处理 {input_id} 失败: {e}")
|
| 1552 |
+
# 出错时也清理CUDA缓存
|
| 1553 |
+
if torch.cuda.is_available():
|
| 1554 |
+
torch.cuda.empty_cache()
|
| 1555 |
+
continue
|
| 1556 |
+
|
| 1557 |
+
print(f"GPU{gpu_id}-进程{proc_id}: 所有数据处理完成,结果已写入 {temp_result_file}")
|
| 1558 |
+
|
| 1559 |
+
except Exception as e:
|
| 1560 |
+
print(f"GPU{gpu_id}-进程{proc_id}: 初始化失败: {e}")
|
| 1561 |
+
# 出错时清理CUDA缓存
|
| 1562 |
+
if torch.cuda.is_available():
|
| 1563 |
+
torch.cuda.empty_cache()
|
| 1564 |
+
|
| 1565 |
+
def main():
|
| 1566 |
+
"""主函数示例"""
|
| 1567 |
+
import argparse
|
| 1568 |
+
|
| 1569 |
+
parser = argparse.ArgumentParser(description='Speaker Similarity Evaluator')
|
| 1570 |
+
parser.add_argument('--jsonl_path', type=str, help='JSONL文件路径')
|
| 1571 |
+
parser.add_argument('--output_dir', type=str,
|
| 1572 |
+
default=f"/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/eval_res/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
| 1573 |
+
help='结果保存目录')
|
| 1574 |
+
parser.add_argument('--language', type=str, choices=['zh', 'en', 'auto'], default='zh',
|
| 1575 |
+
help='指定语言: zh=中文, en=英文, auto=自动检测 (默认: zh)')
|
| 1576 |
+
parser.add_argument('--no_parallel', action='store_true', help='禁用并行处理(默认启用并行)')
|
| 1577 |
+
parser.add_argument('--processes_per_gpu', type=int, default=4, help='每个GPU的进程数(建议不超过4)')
|
| 1578 |
+
parser.add_argument('--similarity_workers', type=int, default=16, help='相似度计算的线程数(默认: 8)')
|
| 1579 |
+
parser.add_argument('--no_shuffle', action='store_true', help='禁用数据shuffle(默认启用shuffle)')
|
| 1580 |
+
parser.add_argument('--random_seed', type=int, default=None, help='随机种子(可选,用于结果复现)')
|
| 1581 |
+
|
| 1582 |
+
args = parser.parse_args()
|
| 1583 |
+
|
| 1584 |
+
# 设置随机种子(如果指定)
|
| 1585 |
+
if args.random_seed is not None:
|
| 1586 |
+
random.seed(args.random_seed)
|
| 1587 |
+
np.random.seed(args.random_seed)
|
| 1588 |
+
torch.manual_seed(args.random_seed)
|
| 1589 |
+
print(f"设置随机种子: {args.random_seed}")
|
| 1590 |
+
|
| 1591 |
+
# 语言参数处理
|
| 1592 |
+
language = args.language.upper()
|
| 1593 |
+
if language == 'AUTO':
|
| 1594 |
+
language = 'AUTO'
|
| 1595 |
+
elif language == 'EN':
|
| 1596 |
+
language = 'EN'
|
| 1597 |
+
else:
|
| 1598 |
+
language = 'ZH' # 默认中文
|
| 1599 |
+
|
| 1600 |
+
# 创建评估器,指定结果保存目录、语言和相似度计算线程数
|
| 1601 |
+
evaluator = SpeakerSimilarityEvaluator(
|
| 1602 |
+
output_dir=args.output_dir,
|
| 1603 |
+
language=language,
|
| 1604 |
+
similarity_max_workers=args.similarity_workers
|
| 1605 |
+
)
|
| 1606 |
+
|
| 1607 |
+
# 默认使用并行处理,除非明确禁用
|
| 1608 |
+
use_parallel = not args.no_parallel
|
| 1609 |
+
use_shuffle = not args.no_shuffle
|
| 1610 |
+
|
| 1611 |
+
print(f"使用语言设置: {language}")
|
| 1612 |
+
print(f"相似度计算线程数: {args.similarity_workers}")
|
| 1613 |
+
|
| 1614 |
+
if args.jsonl_path:
|
| 1615 |
+
# 从JSONL文件处理数据
|
| 1616 |
+
if use_parallel:
|
| 1617 |
+
evaluator.process_batch_from_jsonl_parallel(
|
| 1618 |
+
args.jsonl_path,
|
| 1619 |
+
processes_per_gpu=args.processes_per_gpu,
|
| 1620 |
+
shuffle_data=use_shuffle
|
| 1621 |
+
)
|
| 1622 |
+
else:
|
| 1623 |
+
evaluator.process_batch_from_jsonl(args.jsonl_path)
|
| 1624 |
+
else:
|
| 1625 |
+
# 使用示例数据(兼容性)
|
| 1626 |
+
input_data = [
|
| 1627 |
+
{
|
| 1628 |
+
'prompt_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/zhouxingchi/zxc_enhanced.wav",
|
| 1629 |
+
'prompt_text': "[S1]你再往前半步我就把你给杀了。[S2]你应该这么做,我也应该死。",
|
| 1630 |
+
'text': "[S1]至尊宝,如果有一天我不再是紫霞仙子,只是一个普通的凡人,你还会像现在这样陪着我吗?[S2]这个嘛,那我得先问问月老,看看他给不给我打折!毕竟追仙子要花好多力气的![S1]哼!油嘴滑舌!我是认真的![S2]紫霞,不管你是仙子还是凡人,哪怕变成一根香蕉,我都认得出你。不过……你最好别真变成香蕉,我怕我会忍不住吃掉……[S1]讨厌!谁要变成香蕉啊!那……如果有一天,我们不得不分开呢?[S2]哇!你这话比牛魔王的斧头还狠!不行不行,你得赔我精神损失费![S1]怎么赔?[S2]很简单,让我亲一下,就当是定金![S1]想得美!那如果有一天,你真的忘了我呢?[S2]那我就算翻遍三界,打烂阎王殿,也要把记忆找回来。紫霞,我至尊宝这辈子,赖定你了![S1]傻瓜。",
|
| 1631 |
+
'output_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step145000/test_set/output_7.wav"
|
| 1632 |
+
}
|
| 1633 |
+
]
|
| 1634 |
+
|
| 1635 |
+
# 处理数据
|
| 1636 |
+
if use_parallel:
|
| 1637 |
+
evaluator.process_batch_parallel(input_data, processes_per_gpu=args.processes_per_gpu)
|
| 1638 |
+
else:
|
| 1639 |
+
evaluator.process_batch_from_data(input_data)
|
| 1640 |
+
|
| 1641 |
+
|
| 1642 |
+
if __name__ == "__main__":
|
| 1643 |
+
main()
|
test.sh
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
source /inspire/hdd/project/embodied-multimodality/public/yqzhang/miniconda3/bin/activate
|
| 4 |
+
conda activate /inspire/hdd/project/embodied-multimodality/public/cchang/env/mooncast/
|
| 5 |
+
|
| 6 |
+
# 设置CUDA环境变量
|
| 7 |
+
export CUDA_LAUNCH_BLOCKING=1
|
| 8 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
| 9 |
+
|
| 10 |
+
# 创建日志目录和文件名
|
| 11 |
+
LOG_DIR="/inspire/hdd/project/embodied-multimodality/public/cchang/projects/auto_evaluation_new/logs"
|
| 12 |
+
mkdir -p "$LOG_DIR"
|
| 13 |
+
LOG_FILE="$LOG_DIR/evaluation_$(date +%Y%m%d_%H%M%S).log"
|
| 14 |
+
|
| 15 |
+
# 记录开始时间
|
| 16 |
+
START_TIME=$(date +%s)
|
| 17 |
+
START_TIME_READABLE=$(date '+%Y-%m-%d %H:%M:%S')
|
| 18 |
+
|
| 19 |
+
echo "========================================="
|
| 20 |
+
echo "音色相似度评估开始"
|
| 21 |
+
echo "开始时间: $START_TIME_READABLE"
|
| 22 |
+
echo "日志文件: $LOG_FILE"
|
| 23 |
+
echo "========================================="
|
| 24 |
+
echo "可以使用以下命令实时查看日志:"
|
| 25 |
+
echo "tail -f $LOG_FILE"
|
| 26 |
+
echo ""
|
| 27 |
+
|
| 28 |
+
# 将开始时间信息也写入日志文件
|
| 29 |
+
{
|
| 30 |
+
echo "========================================="
|
| 31 |
+
echo "音色相似度评估开始"
|
| 32 |
+
echo "开始时间: $START_TIME_READABLE"
|
| 33 |
+
echo "进程配置: 每GPU 8个进程"
|
| 34 |
+
echo "语言设置: zh (中文)"
|
| 35 |
+
echo "========================================="
|
| 36 |
+
echo ""
|
| 37 |
+
} | tee "$LOG_FILE"
|
| 38 |
+
|
| 39 |
+
# 使用更保守的进程数
|
| 40 |
+
python -u /inspire/hdd/project/embodied-multimodality/public/cchang/projects/auto_evaluation_new/test.py \
|
| 41 |
+
--jsonl_path /inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step70000/eval_new/output.jsonl \
|
| 42 |
+
--output_dir /inspire/hdd/project/embodied-multimodality/public/cchang/projects/auto_evaluation_new/eval_res/new_test \
|
| 43 |
+
--processes_per_gpu 8 \
|
| 44 |
+
--language zh \
|
| 45 |
+
2>&1 | tee -a "$LOG_FILE"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# 记录结束时间
|
| 49 |
+
END_TIME=$(date +%s)
|
| 50 |
+
END_TIME_READABLE=$(date '+%Y-%m-%d %H:%M:%S')
|
| 51 |
+
|
| 52 |
+
# 计算耗时
|
| 53 |
+
DURATION=$((END_TIME - START_TIME))
|
| 54 |
+
HOURS=$((DURATION / 3600))
|
| 55 |
+
MINUTES=$(((DURATION % 3600) / 60))
|
| 56 |
+
SECONDS=$((DURATION % 60))
|
| 57 |
+
|
| 58 |
+
# 输出结束信息
|
| 59 |
+
{
|
| 60 |
+
echo ""
|
| 61 |
+
echo "========================================="
|
| 62 |
+
echo "音色相似度评估完成!"
|
| 63 |
+
echo "结束时间: $END_TIME_READABLE"
|
| 64 |
+
echo "总耗时: ${HOURS}小时${MINUTES}分钟${SECONDS}秒 (共${DURATION}秒)"
|
| 65 |
+
echo "日志文件: $LOG_FILE"
|
| 66 |
+
echo "========================================="
|
| 67 |
+
} | tee -a "$LOG_FILE"
|
| 68 |
+
|
| 69 |
+
# 显示在终端
|
| 70 |
+
echo ""
|
| 71 |
+
echo "评估完成!"
|
| 72 |
+
echo "开始时间: $START_TIME_READABLE"
|
| 73 |
+
echo "结束时间: $END_TIME_READABLE"
|
| 74 |
+
echo "总耗时: ${HOURS}小时${MINUTES}分钟${SECONDS}秒"
|
| 75 |
+
echo "日志已保存到: $LOG_FILE"
|
| 76 |
+
|
| 77 |
+
# 如果耗时超过1小时,发送额外提醒
|
| 78 |
+
if [ $DURATION -gt 3600 ]; then
|
| 79 |
+
echo ""
|
| 80 |
+
echo "⏰ 注意:本次评估耗时较长,超过1小时"
|
| 81 |
+
echo " 建议检查性能优化效果"
|
| 82 |
+
fi
|
test_alignment.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio.functional as F
|
| 4 |
+
import torchaudio
|
| 5 |
+
import uroman as ur
|
| 6 |
+
import logging
|
| 7 |
+
from typing import List, Dict, Any, Optional
|
| 8 |
+
|
| 9 |
+
def split_and_merge_punctuation(text: str) -> List[str]:
|
| 10 |
+
"""
|
| 11 |
+
处理英文文本,按空格分词并将标点符号合并到前面的单词
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
text: 输入的英文文本
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
处理后的单词列表,标点符号已合并到对应单词
|
| 18 |
+
"""
|
| 19 |
+
# 先按空格拆分文本
|
| 20 |
+
elements = text.split()
|
| 21 |
+
|
| 22 |
+
# 用于保存最终的结果
|
| 23 |
+
result = []
|
| 24 |
+
|
| 25 |
+
# 遍历每个拆分后的元素
|
| 26 |
+
for ele in elements:
|
| 27 |
+
# 使用正则表达式提取连续字母、数字和标点
|
| 28 |
+
parts = re.findall(r'[a-zA-Z0-9]+|[^\w\s]+', ele)
|
| 29 |
+
|
| 30 |
+
# 用于保存拆分后的部分
|
| 31 |
+
merged_parts = []
|
| 32 |
+
|
| 33 |
+
for i in range(len(parts)):
|
| 34 |
+
if i % 2 == 0: # 如果是字母或数字部分
|
| 35 |
+
# 将字母或数字部分添加到结果中
|
| 36 |
+
merged_parts.append(parts[i])
|
| 37 |
+
else: # 如果是标点或其他符号部分
|
| 38 |
+
# 将标点部分与前面的字母或数字部分合并
|
| 39 |
+
if merged_parts:
|
| 40 |
+
merged_parts[-1] += parts[i]
|
| 41 |
+
else:
|
| 42 |
+
merged_parts.append(parts[i])
|
| 43 |
+
|
| 44 |
+
# 将合并后的部分加入最终结果
|
| 45 |
+
result.extend(merged_parts)
|
| 46 |
+
|
| 47 |
+
return result
|
| 48 |
+
|
| 49 |
+
def restore_spaces_in_english_text(tokens: List[str]) -> str:
|
| 50 |
+
"""
|
| 51 |
+
在英文单词之间恢复空格
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
tokens: 单词列表
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
恢复空格后的文本
|
| 58 |
+
"""
|
| 59 |
+
result = []
|
| 60 |
+
for i, token in enumerate(tokens):
|
| 61 |
+
# 检查是否需要在单词前添加空格
|
| 62 |
+
if i > 0 and token[0].isalnum() and not any(p in tokens[i-1] for p in ',.!?;:()[]<>\'\"…'):
|
| 63 |
+
result.append(" ")
|
| 64 |
+
result.append(token)
|
| 65 |
+
|
| 66 |
+
return "".join(result)
|
| 67 |
+
|
| 68 |
+
def get_aligned_result_with_punctuation(alignment_result: List[Dict], text: str) -> List[Dict]:
|
| 69 |
+
"""
|
| 70 |
+
将对齐结果转换为包含标点符号的格式
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
alignment_result: 原始对齐结果
|
| 74 |
+
text: 原始文本
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
处理后的对齐结果,标点符号已合并
|
| 78 |
+
"""
|
| 79 |
+
text_tokens = split_and_merge_punctuation(text)
|
| 80 |
+
|
| 81 |
+
updated_alignment_result = []
|
| 82 |
+
token_idx = 0
|
| 83 |
+
|
| 84 |
+
for index, align_item in enumerate(alignment_result):
|
| 85 |
+
if token_idx >= len(text_tokens):
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
start = align_item["start"]
|
| 89 |
+
end = align_item["end"]
|
| 90 |
+
text_token = text_tokens[token_idx]
|
| 91 |
+
|
| 92 |
+
updated_item = {
|
| 93 |
+
"start": start,
|
| 94 |
+
"end": end,
|
| 95 |
+
"transcript": text_token
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# 保留原始对齐结果中的其他字段
|
| 99 |
+
updated_item.update({key: align_item[key] for key in align_item
|
| 100 |
+
if key not in ["start", "end", "transcript"]})
|
| 101 |
+
|
| 102 |
+
updated_alignment_result.append(updated_item)
|
| 103 |
+
token_idx += 1
|
| 104 |
+
|
| 105 |
+
return updated_alignment_result
|
| 106 |
+
|
| 107 |
+
class EnglishAlignmentModel:
|
| 108 |
+
def __init__(self, device: str = "cuda", model_dir: Optional[str] = None):
|
| 109 |
+
"""
|
| 110 |
+
初始化英文对齐模型
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
device: 设备类型 ("cuda" 或 "cpu")
|
| 114 |
+
model_dir: 模型目录路径,如果为None则使用默认路径
|
| 115 |
+
"""
|
| 116 |
+
self.device = torch.device(device)
|
| 117 |
+
self.bundle = torchaudio.pipelines.MMS_FA
|
| 118 |
+
|
| 119 |
+
# 设置模型下载参数
|
| 120 |
+
dl_kwargs = {}
|
| 121 |
+
if model_dir:
|
| 122 |
+
dl_kwargs['model_dir'] = model_dir
|
| 123 |
+
|
| 124 |
+
self.align_model = self.bundle.get_model(
|
| 125 |
+
with_star=False,
|
| 126 |
+
dl_kwargs=dl_kwargs
|
| 127 |
+
).to(self.device)
|
| 128 |
+
|
| 129 |
+
self.uroman = ur.Uroman()
|
| 130 |
+
self.DICTIONARY = self.bundle.get_dict()
|
| 131 |
+
|
| 132 |
+
def align(self, emission: torch.Tensor, tokens: torch.Tensor):
|
| 133 |
+
"""
|
| 134 |
+
执行强对齐
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
emission: 模型的输出
|
| 138 |
+
tokens: 目标tokens
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
对齐的tokens和分数
|
| 142 |
+
"""
|
| 143 |
+
alignments, scores = F.forced_align(
|
| 144 |
+
log_probs=emission,
|
| 145 |
+
targets=tokens,
|
| 146 |
+
blank=0
|
| 147 |
+
)
|
| 148 |
+
alignments, scores = alignments[0], scores[0]
|
| 149 |
+
scores = scores.exp()
|
| 150 |
+
return alignments, scores
|
| 151 |
+
|
| 152 |
+
def unflatten(self, list_: List, lengths: List[int]) -> List[List]:
|
| 153 |
+
"""
|
| 154 |
+
将一个长列表按照长度拆分成子列表
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
list_: 长列表
|
| 158 |
+
lengths: 各子列表的长度
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
拆分后的子列表
|
| 162 |
+
"""
|
| 163 |
+
assert len(list_) == sum(lengths)
|
| 164 |
+
i = 0
|
| 165 |
+
ret = []
|
| 166 |
+
for l in lengths:
|
| 167 |
+
ret.append(list_[i:i + l])
|
| 168 |
+
i += l
|
| 169 |
+
return ret
|
| 170 |
+
|
| 171 |
+
def preview_word(self, waveform: torch.Tensor, spans: List, num_frames: int,
|
| 172 |
+
transcript: List[str], sample_rate: int) -> List[Dict]:
|
| 173 |
+
"""
|
| 174 |
+
生成每个单词的时间对齐信息
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
waveform: 音频波形
|
| 178 |
+
spans: 单词的跨度
|
| 179 |
+
num_frames: 帧数
|
| 180 |
+
transcript: 转录文本单词列表
|
| 181 |
+
sample_rate: 采样率
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
单词的对齐信息列表
|
| 185 |
+
"""
|
| 186 |
+
end = 0
|
| 187 |
+
alignment_result = []
|
| 188 |
+
|
| 189 |
+
for span, trans in zip(spans, transcript):
|
| 190 |
+
ratio = waveform.size(1) / num_frames
|
| 191 |
+
x0 = int(ratio * span[0].start)
|
| 192 |
+
x1 = int(ratio * span[-1].end)
|
| 193 |
+
|
| 194 |
+
align_info = {
|
| 195 |
+
"transcript": trans,
|
| 196 |
+
"start": round(x0 / sample_rate, 3),
|
| 197 |
+
"end": round(x1 / sample_rate, 3)
|
| 198 |
+
}
|
| 199 |
+
align_info["pause"] = round(align_info["start"] - end, 3)
|
| 200 |
+
align_info["duration"] = round(align_info["end"] - align_info["start"], 3)
|
| 201 |
+
end = align_info["end"]
|
| 202 |
+
alignment_result.append(align_info)
|
| 203 |
+
|
| 204 |
+
return alignment_result
|
| 205 |
+
|
| 206 |
+
def make_wav_batch(self, wav_list: List[torch.Tensor]):
|
| 207 |
+
"""
|
| 208 |
+
将wav_list中的每个wav张量填充为相同的长度
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
wav_list: wav文件列表
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
填充后的音频张量和原始长度
|
| 215 |
+
"""
|
| 216 |
+
wav_lengths = torch.tensor([wav.size(0) for wav in wav_list], dtype=torch.long)
|
| 217 |
+
max_length = max(wav_lengths)
|
| 218 |
+
wavs_tensors = torch.zeros(len(wav_list), max_length, device=wav_list[0].device)
|
| 219 |
+
|
| 220 |
+
for i, wav in enumerate(wav_list):
|
| 221 |
+
wavs_tensors[i, :wav_lengths[i]] = wav
|
| 222 |
+
|
| 223 |
+
return wavs_tensors, wav_lengths.to(wavs_tensors.device)
|
| 224 |
+
|
| 225 |
+
def get_target(self, transcript: str) -> torch.Tensor:
|
| 226 |
+
"""
|
| 227 |
+
获取给定英文转录文本的目标tokens
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
transcript: 英文转录文本
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
转录文本的目标tokens
|
| 234 |
+
"""
|
| 235 |
+
# 移除标点符号并转换为小写
|
| 236 |
+
transcript = re.sub(r'[^\w\s]', r' ', transcript)
|
| 237 |
+
words = transcript.lower().split()
|
| 238 |
+
|
| 239 |
+
# 获取字典中的特殊符号token
|
| 240 |
+
star_token = self.DICTIONARY['*']
|
| 241 |
+
|
| 242 |
+
# 将每个字符转换为对应的token
|
| 243 |
+
tokenized_transcript = []
|
| 244 |
+
for word in words:
|
| 245 |
+
tokenized_transcript.extend([
|
| 246 |
+
self.DICTIONARY[c] if c in self.DICTIONARY and c != '-' else star_token
|
| 247 |
+
for c in word
|
| 248 |
+
])
|
| 249 |
+
|
| 250 |
+
return torch.tensor([tokenized_transcript], dtype=torch.int32, device=self.device)
|
| 251 |
+
|
| 252 |
+
def get_alignment_result(self, emission_padded: torch.Tensor, emission_length: int,
|
| 253 |
+
aligned_tokens: torch.Tensor, alignment_scores: torch.Tensor,
|
| 254 |
+
transcript: str, waveform: torch.Tensor) -> List[Dict]:
|
| 255 |
+
"""
|
| 256 |
+
根据给定的emission和对齐信息生成对齐结果
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
emission_padded: 填充后的emission
|
| 260 |
+
emission_length: emission的有效长度
|
| 261 |
+
aligned_tokens: 对齐的tokens
|
| 262 |
+
alignment_scores: 对齐的分数
|
| 263 |
+
transcript: 转录文本
|
| 264 |
+
waveform: 音频波形
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
对齐结果
|
| 268 |
+
"""
|
| 269 |
+
# 处理文本
|
| 270 |
+
processed_transcript = re.sub(r'[^\w\s]', r' ', transcript)
|
| 271 |
+
words = processed_transcript.lower().split()
|
| 272 |
+
|
| 273 |
+
emission = emission_padded[:emission_length, :].unsqueeze(0)
|
| 274 |
+
token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
|
| 275 |
+
word_spans = self.unflatten(token_spans, [len(word) for word in words])
|
| 276 |
+
num_frames = emission.size(1)
|
| 277 |
+
|
| 278 |
+
return self.preview_word(waveform.unsqueeze(0), word_spans, num_frames,
|
| 279 |
+
words, self.bundle.sample_rate)
|
| 280 |
+
|
| 281 |
+
def align_audio_text(self, waveform: torch.Tensor, transcript: str) -> List[Dict]:
|
| 282 |
+
"""
|
| 283 |
+
对单个音频和文本进行对齐
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
waveform: 音频波形张量 (1D tensor)
|
| 287 |
+
transcript: 英文转录文本
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
对齐结果列表,包含每个单词的时间信息
|
| 291 |
+
"""
|
| 292 |
+
# 确保音频在正确的设备上
|
| 293 |
+
waveform = waveform.to(self.device)
|
| 294 |
+
|
| 295 |
+
# 如果需要重采样
|
| 296 |
+
if hasattr(self, 'original_sample_rate'):
|
| 297 |
+
if self.original_sample_rate != self.bundle.sample_rate:
|
| 298 |
+
waveform = F.resample(waveform, self.original_sample_rate, self.bundle.sample_rate)
|
| 299 |
+
|
| 300 |
+
# 批量处理(单个样本)
|
| 301 |
+
return self.batch_alignment([waveform], [transcript])[0]
|
| 302 |
+
|
| 303 |
+
def batch_alignment(self, wav_list: List[torch.Tensor], transcript_list: List[str]) -> List[List[Dict]]:
|
| 304 |
+
"""
|
| 305 |
+
批量对齐
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
wav_list: wav文件列表
|
| 309 |
+
transcript_list: 转录文本列表
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
对齐结果列表
|
| 313 |
+
"""
|
| 314 |
+
wavs_tensors, wavs_lengths_tensor = self.make_wav_batch(wav_list)
|
| 315 |
+
|
| 316 |
+
# 前向传播
|
| 317 |
+
with torch.inference_mode():
|
| 318 |
+
emission, emission_lengths = self.align_model(
|
| 319 |
+
wavs_tensors.to(self.device),
|
| 320 |
+
wavs_lengths_tensor
|
| 321 |
+
)
|
| 322 |
+
# 添加star维度
|
| 323 |
+
star_dim = torch.zeros(
|
| 324 |
+
(emission.shape[0], emission.size(1), 1),
|
| 325 |
+
dtype=emission.dtype,
|
| 326 |
+
device=self.device
|
| 327 |
+
)
|
| 328 |
+
emission = torch.cat((emission, star_dim), dim=-1)
|
| 329 |
+
|
| 330 |
+
# 获取目标tokens
|
| 331 |
+
target_list = [self.get_target(transcript) for transcript in transcript_list]
|
| 332 |
+
|
| 333 |
+
# 执行对齐
|
| 334 |
+
align_results = [
|
| 335 |
+
self.align(emission_padded[:emission_length, :].unsqueeze(0), target)
|
| 336 |
+
for emission_padded, emission_length, target in zip(emission, emission_lengths, target_list)
|
| 337 |
+
]
|
| 338 |
+
|
| 339 |
+
batch_aligned_tokens = [align_result[0] for align_result in align_results]
|
| 340 |
+
batch_alignment_scores = [align_result[1] for align_result in align_results]
|
| 341 |
+
|
| 342 |
+
# 生成对齐结果
|
| 343 |
+
alignment_result_list = [
|
| 344 |
+
self.get_alignment_result(emission_padded, emission_length, aligned_tokens,
|
| 345 |
+
alignment_scores, transcript, waveform)
|
| 346 |
+
for emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform
|
| 347 |
+
in zip(emission, emission_lengths, batch_aligned_tokens, batch_alignment_scores,
|
| 348 |
+
transcript_list, wav_list)
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
# 处理标点符号
|
| 352 |
+
final_results = []
|
| 353 |
+
for alignment_result, transcript in zip(alignment_result_list, transcript_list):
|
| 354 |
+
processed_result = get_aligned_result_with_punctuation(alignment_result, transcript)
|
| 355 |
+
final_results.append(processed_result)
|
| 356 |
+
|
| 357 |
+
return final_results
|
| 358 |
+
|
| 359 |
+
def align_english_audio_text(audio_path: str, transcript: str, device: str = "cuda",
|
| 360 |
+
model_dir: Optional[str] = None) -> List[Dict]:
|
| 361 |
+
"""
|
| 362 |
+
便捷函数:对英文音频和文本进行对齐
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
audio_path: 音频文件路径
|
| 366 |
+
transcript: 英文转录文本
|
| 367 |
+
device: 设备类型 ("cuda" 或 "cpu")
|
| 368 |
+
model_dir: 模型目录路径
|
| 369 |
+
|
| 370 |
+
Returns:
|
| 371 |
+
对齐结果列表,包含每个单词的时间信息
|
| 372 |
+
|
| 373 |
+
Example:
|
| 374 |
+
>>> result = align_english_audio_text("audio.wav", "Hello world!")
|
| 375 |
+
>>> print(result)
|
| 376 |
+
[
|
| 377 |
+
{"transcript": "Hello", "start": 0.0, "end": 0.5, "duration": 0.5, "pause": 0.0},
|
| 378 |
+
{"transcript": "world!", "start": 0.6, "end": 1.2, "duration": 0.6, "pause": 0.1}
|
| 379 |
+
]
|
| 380 |
+
"""
|
| 381 |
+
# 加载音频
|
| 382 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 383 |
+
|
| 384 |
+
# 转换为单声道
|
| 385 |
+
if waveform.size(0) > 1:
|
| 386 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 387 |
+
waveform = waveform.squeeze(0) # 移除批次维度
|
| 388 |
+
|
| 389 |
+
# 初始化模型
|
| 390 |
+
model = EnglishAlignmentModel(device=device, model_dir=model_dir)
|
| 391 |
+
model.original_sample_rate = sample_rate
|
| 392 |
+
|
| 393 |
+
# 执行对齐
|
| 394 |
+
return model.align_audio_text(waveform, transcript)
|
| 395 |
+
|
| 396 |
+
if __name__ == "__main__":
|
| 397 |
+
# 使用示例
|
| 398 |
+
audio_file = "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step40000/test_en/gpu4/output_0.wav"
|
| 399 |
+
text = "[S1]Hey, did you hear about that company called MoSi AI? [S2]MoSi AI? Yeah, I think I've heard of them. Aren't they the ones doing AI stuff? What new thing have they come up with now? [S1]Yeah, that's them! They recently launched this super hot new product called, um, Asteroid. [S2]Asteroid. That's a pretty cool name. Does it mean like the space rock? [S1]Yeah, I think that's what it means. Let me tell you, this thing is incredible. They say it's currently the most realistic, human-like conversational TTS model out there. [S2]Oh, TTS technology? You mean the text-to-speech thing? Aren't there already a lot of those on the market? What makes this one so special? [S1]Well, it's completely different. They say the voice produced by Asteroid sounds almost exactly like a real person talking. And it's super smooth and natural. Not at all like, you know, that stiff robotic tone. [S2]I see. Some voice assistants do still have that mechanical feel, especially during multi-turn conversations. So how amazing is this Asteroid exactly? [S1]I heard they internally call Asteroid China's own version of NotebookLM. [S2]NotebookLM? Oh, I know that one. Isn't that the personal AI that Google made? The one that helps organize notes and answers all kinds of questions? So Asteroid has similar functions? [S1]Right. That's probably what they mean. It's not just that the voice sounds incredibly human. The intelligence level is also really high. It can have these really logical, contextual, in-depth conversations with you. It's just like chatting with a real person. [S2]Wow, that sounds amazing. If they can really achieve that... [S1]Yeah, it's basically like having a personal assistant that's both articulate and really understands you. [S2]Hmm. That does sound appealing. [S1]And some people are saying it's like the, what's it called again in the voice technology circle? Oh right, DeepSeek. [S2]DeepSeek? Isn't that the company making large language models? Their models are pretty popular now. That's high praise. So they're saying Asteroid is top-tier technology? [S1]Yeah, I think that's what they mean. It's like they've reached a whole new level in voice synthesis. Similar to the impact DeepSeek has had in natural language processing. It could be that kind of groundbreaking technology. [S2]If Asteroid is really that impressive, where could it be used? I feel like there must be huge potential there. [S1]Absolutely. Just imagine future smart customer service, audiobook reading, and those virtual livestreamers that are so popular now. The quality would improve dramatically. We might even have personal assistants using Asteroid to talk to us directly. How natural would that be? [S2]Yeah. That does sound exciting. When can we actually try it out? Are there any demos available? [S1]I haven't looked into that carefully yet. But since they've already announced it, I'm guessing it won't be long. I'm really eager to try it and see just how human-like it is. [S2]Yeah, yeah. If it can really deliver what they're promising, getting information and interacting with machines will be so much more convenient. The experience will be much better too. [S1]Exactly, exactly. We're just waiting for MoSi AI to give us this big surprise."
|
| 400 |
+
|
| 401 |
+
# 对文本进行归一化,删除所有[S1][S2]标记
|
| 402 |
+
import re
|
| 403 |
+
normalized_text = re.sub(r'\[S[12]\]', '', text).strip()
|
| 404 |
+
|
| 405 |
+
# 设置本地模型目录
|
| 406 |
+
alignment_model_dir = '/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/models/mms_fa'
|
| 407 |
+
|
| 408 |
+
try:
|
| 409 |
+
alignment_result = align_english_audio_text(audio_file, normalized_text, model_dir=alignment_model_dir)
|
| 410 |
+
|
| 411 |
+
print("对齐结果:")
|
| 412 |
+
for item in alignment_result:
|
| 413 |
+
print(f"单词: '{item['transcript']}', 开始: {item['start']}s, 结束: {item['end']}s, 持续: {item['duration']}s")
|
| 414 |
+
|
| 415 |
+
except Exception as e:
|
| 416 |
+
print(f"对齐失败: {e}")
|
test_online.py
ADDED
|
@@ -0,0 +1,1550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
from typing import List, Dict, Tuple, Any
|
| 6 |
+
import numpy as np
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
import torchaudio.functional as F
|
| 11 |
+
import tritonclient.grpc as grpcclient
|
| 12 |
+
from tritonclient.utils import *
|
| 13 |
+
import logging
|
| 14 |
+
import wespeaker
|
| 15 |
+
import shutil
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
import multiprocessing as mp
|
| 18 |
+
from functools import partial
|
| 19 |
+
import math
|
| 20 |
+
import threading
|
| 21 |
+
import time
|
| 22 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 23 |
+
import random # 添加random模块用于shuffle
|
| 24 |
+
|
| 25 |
+
# 设置multiprocessing启动方式为spawn(CUDA兼容)
|
| 26 |
+
mp.set_start_method('spawn', force=True)
|
| 27 |
+
|
| 28 |
+
# 引用词对齐模块
|
| 29 |
+
from alignment import AlignmentModel, batch_get_alignment_result
|
| 30 |
+
# from tensorrt_client import TritonSimilarityClient
|
| 31 |
+
from speaker_client import TritonSpeakerClient
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class SpeakerSimilarityEvaluator:
|
| 35 |
+
"""音色相似度评估器"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, device="cuda",
|
| 38 |
+
alignment_model_dir='./models/mms_fa',
|
| 39 |
+
wespeaker_model_url='localhost:8001',
|
| 40 |
+
output_dir="./evaluation_results",
|
| 41 |
+
language="ZH",
|
| 42 |
+
similarity_max_workers=8):
|
| 43 |
+
"""初始化评估器"""
|
| 44 |
+
self.device = device
|
| 45 |
+
self.alignment_model_dir = alignment_model_dir
|
| 46 |
+
self.wespeaker_model_url = wespeaker_model_url
|
| 47 |
+
self.language = language.upper() # 添加语言参数
|
| 48 |
+
self.similarity_max_workers = similarity_max_workers # 相似度计算线程数,已无效
|
| 49 |
+
|
| 50 |
+
# 先设置日志系统
|
| 51 |
+
logging.basicConfig(level=logging.INFO)
|
| 52 |
+
self.logger = logging.getLogger(__name__)
|
| 53 |
+
|
| 54 |
+
# 设置输出目录结构
|
| 55 |
+
self.output_dir = Path(output_dir)
|
| 56 |
+
self.segments_dir = self.output_dir / "segments" # 分割后的音频片段
|
| 57 |
+
self.prompts_dir = self.output_dir / "prompts" # prompt音频的S1和S2片段
|
| 58 |
+
self.temp_dir = self.output_dir / "temp" # 临时文件
|
| 59 |
+
self.results_dir = self.output_dir / "results" # 评估结果
|
| 60 |
+
self.temp_results_dir = self.output_dir / "temp_results" # 临时结果文件
|
| 61 |
+
self.alignment_dir = self.output_dir / "alignments" # 对齐信息保存目录
|
| 62 |
+
|
| 63 |
+
# 创建所有必要的目录
|
| 64 |
+
self._create_output_directories()
|
| 65 |
+
|
| 66 |
+
# 在多进程环境中延迟模型初始化
|
| 67 |
+
self.alignment_model = None
|
| 68 |
+
self.similarity_model = None
|
| 69 |
+
|
| 70 |
+
# 线程局部存储,用于线程安全的模型访问
|
| 71 |
+
self._thread_local = threading.local()
|
| 72 |
+
|
| 73 |
+
# 记录运行信息
|
| 74 |
+
self.logger.info(f"评估结果将保存到: {self.output_dir}")
|
| 75 |
+
self.logger.info(f"对齐信息将保存到: {self.alignment_dir}")
|
| 76 |
+
self.logger.info(f"使用语言: {self.language}")
|
| 77 |
+
|
| 78 |
+
def _create_output_directories(self):
|
| 79 |
+
"""创建输出目录结构"""
|
| 80 |
+
for dir_path in [self.segments_dir, self.prompts_dir, self.temp_dir,
|
| 81 |
+
self.results_dir, self.temp_results_dir, self.alignment_dir]:
|
| 82 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
def _get_safe_filename(self, text: str, max_length: int = 50) -> str:
|
| 85 |
+
"""生成安全的文件名"""
|
| 86 |
+
# 移除特殊字符,只保留中文、英文、数字和基本符号
|
| 87 |
+
safe_text = re.sub(r'[^\u4e00-\u9fff\w\s]', '', text)
|
| 88 |
+
# 限制长度
|
| 89 |
+
if len(safe_text) > max_length:
|
| 90 |
+
safe_text = safe_text[:max_length]
|
| 91 |
+
# 替换空格为下划线
|
| 92 |
+
safe_text = safe_text.replace(' ', '_')
|
| 93 |
+
return safe_text if safe_text else "unnamed"
|
| 94 |
+
|
| 95 |
+
def _clean_temp_files(self):
|
| 96 |
+
"""清理临时文件,但保留临时目录"""
|
| 97 |
+
if self.temp_dir.exists():
|
| 98 |
+
# 只删除临时目录中的文件,不删除目录本身
|
| 99 |
+
for file_path in self.temp_dir.iterdir():
|
| 100 |
+
if file_path.is_file():
|
| 101 |
+
try:
|
| 102 |
+
file_path.unlink()
|
| 103 |
+
except Exception as e:
|
| 104 |
+
self.logger.warning(f"删除临时文件失败: {file_path}, 错误: {e}")
|
| 105 |
+
else:
|
| 106 |
+
# 如果临时目录不存在,重新创建
|
| 107 |
+
self.temp_dir.mkdir(parents=True, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
def _init_models_if_needed(self):
|
| 110 |
+
"""延迟初始化模型(用于多进程环境)"""
|
| 111 |
+
# 初始化对齐模型 - 修正参数顺序
|
| 112 |
+
if self.alignment_model is None:
|
| 113 |
+
# 根据AlignmentModel的构造函数,应该是(device, model_dir)而不是(model_dir, device)
|
| 114 |
+
self.alignment_model = AlignmentModel(self.device, self.alignment_model_dir)
|
| 115 |
+
|
| 116 |
+
# 初始化相似度模型
|
| 117 |
+
if self.similarity_model is None:
|
| 118 |
+
self._load_wespeaker_model(self.wespeaker_model_url)
|
| 119 |
+
|
| 120 |
+
def _is_english_text(self, text: str) -> bool:
|
| 121 |
+
"""简单判断文本是否主要是英文"""
|
| 122 |
+
# 计算英文字符的比例
|
| 123 |
+
english_chars = sum(1 for c in text if c.isascii() and c.isalpha())
|
| 124 |
+
total_chars = sum(1 for c in text if c.isalpha())
|
| 125 |
+
|
| 126 |
+
if total_chars == 0:
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
return english_chars / total_chars > 0.8 # 如果80%以上是英文字符,认为是英文
|
| 130 |
+
|
| 131 |
+
def _detect_language_from_text(self, text: str) -> str:
|
| 132 |
+
"""从文本内容检测语言"""
|
| 133 |
+
clean_text = self.remove_speaker_tags(text)
|
| 134 |
+
if self._is_english_text(clean_text):
|
| 135 |
+
return "EN"
|
| 136 |
+
else:
|
| 137 |
+
return "ZH"
|
| 138 |
+
|
| 139 |
+
def save_alignment_info(self, alignment_data: Dict[str, Any], input_id: str, file_type: str = "output"):
|
| 140 |
+
"""
|
| 141 |
+
保存对齐信息到单独的JSON文件
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
alignment_data: 对齐信息数据
|
| 145 |
+
input_id: 输入ID
|
| 146 |
+
file_type: 文件类型 ("output", "prompt", "segment")
|
| 147 |
+
"""
|
| 148 |
+
try:
|
| 149 |
+
safe_input_id = self._get_safe_filename(input_id)
|
| 150 |
+
alignment_filename = f"{safe_input_id}_{file_type}_alignment.json"
|
| 151 |
+
alignment_path = self.alignment_dir / alignment_filename
|
| 152 |
+
|
| 153 |
+
# 添加元数据
|
| 154 |
+
alignment_info = {
|
| 155 |
+
'input_id': input_id,
|
| 156 |
+
'file_type': file_type,
|
| 157 |
+
'language': self.language,
|
| 158 |
+
'timestamp': datetime.now().isoformat(),
|
| 159 |
+
'alignment_data': alignment_data
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
with open(alignment_path, 'w', encoding='utf-8') as f:
|
| 163 |
+
json.dump(alignment_info, f, ensure_ascii=False, indent=2)
|
| 164 |
+
|
| 165 |
+
self.logger.info(f"对齐信息已保存: {alignment_path}")
|
| 166 |
+
return str(alignment_path)
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
self.logger.error(f"保存对齐信息失败: {e}")
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
def save_detailed_alignment_info(self, alignments: List[Dict[str, Any]],
|
| 173 |
+
text_segments: List[Dict[str, Any]],
|
| 174 |
+
input_id: str, audio_path: str,
|
| 175 |
+
original_text: str, processed_text: str):
|
| 176 |
+
"""
|
| 177 |
+
保存详细的对齐信息,包括分段信息
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
alignments: 对齐结果列表
|
| 181 |
+
text_segments: 文本分段信息
|
| 182 |
+
input_id: 输入ID
|
| 183 |
+
audio_path: 音频文件路径
|
| 184 |
+
original_text: 原始文本
|
| 185 |
+
processed_text: 处理后的文本
|
| 186 |
+
"""
|
| 187 |
+
alignment_data = {
|
| 188 |
+
'original_text': original_text,
|
| 189 |
+
'processed_text': processed_text,
|
| 190 |
+
'audio_path': audio_path,
|
| 191 |
+
'language': self.language,
|
| 192 |
+
'total_alignments': len(alignments),
|
| 193 |
+
'total_segments': len(text_segments),
|
| 194 |
+
'alignments': alignments,
|
| 195 |
+
'text_segments': text_segments,
|
| 196 |
+
'segment_alignment_mapping': []
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
# 建立文本段和对齐结果的映射关系
|
| 200 |
+
for segment in text_segments:
|
| 201 |
+
segment_mapping = {
|
| 202 |
+
'segment_id': segment.get('segment_id', 0),
|
| 203 |
+
'segment_text': segment.get('text', ''),
|
| 204 |
+
'speaker_label': segment.get('speaker_label', ''),
|
| 205 |
+
'start_time': segment.get('start_time', 0.0),
|
| 206 |
+
'end_time': segment.get('end_time', 0.0),
|
| 207 |
+
'corresponding_alignments': []
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
# 找到对应的对齐项
|
| 211 |
+
segment_start = segment.get('start_time', 0.0)
|
| 212 |
+
segment_end = segment.get('end_time', 0.0)
|
| 213 |
+
|
| 214 |
+
for i, align_item in enumerate(alignments):
|
| 215 |
+
align_start = align_item.get('start', 0.0)
|
| 216 |
+
align_end = align_item.get('end', 0.0)
|
| 217 |
+
|
| 218 |
+
# 检查对齐项是否在当前段的时间范围内
|
| 219 |
+
if (align_start >= segment_start and align_end <= segment_end) or \
|
| 220 |
+
(align_start < segment_end and align_end > segment_start):
|
| 221 |
+
segment_mapping['corresponding_alignments'].append({
|
| 222 |
+
'alignment_index': i,
|
| 223 |
+
'transcript': align_item.get('transcript', ''),
|
| 224 |
+
'start': align_start,
|
| 225 |
+
'end': align_end,
|
| 226 |
+
'score': align_item.get('score', 0.0) if 'score' in align_item else None
|
| 227 |
+
})
|
| 228 |
+
|
| 229 |
+
alignment_data['segment_alignment_mapping'].append(segment_mapping)
|
| 230 |
+
|
| 231 |
+
return self.save_alignment_info(alignment_data, input_id, "detailed")
|
| 232 |
+
|
| 233 |
+
def remove_speaker_tags(self, text: str) -> str:
|
| 234 |
+
"""删除文本中的说话人标签[S1][S2]"""
|
| 235 |
+
return re.sub(r'\[S[12]\]', '', text).strip()
|
| 236 |
+
|
| 237 |
+
def extract_speaker_segments(self, text: str) -> List[Dict[str, Any]]:
|
| 238 |
+
"""提取文本中的说话人片段信息"""
|
| 239 |
+
segments = []
|
| 240 |
+
pattern = r'\[S([12])\]([^[]*)'
|
| 241 |
+
matches = re.findall(pattern, text)
|
| 242 |
+
|
| 243 |
+
for speaker_id, content in matches:
|
| 244 |
+
segments.append({
|
| 245 |
+
'speaker': f'S{speaker_id}',
|
| 246 |
+
'content': content.strip()
|
| 247 |
+
})
|
| 248 |
+
return segments
|
| 249 |
+
|
| 250 |
+
def replace_punctuation_with_comma(self, text: str, language: str = None) -> str:
|
| 251 |
+
"""将所有标点符号替换为逗号,连续逗号只保留一个,根据语言选择正确的逗号类型"""
|
| 252 |
+
# 如果未指定语言,使用类的默认语言设置或自动检测
|
| 253 |
+
if language is None:
|
| 254 |
+
if hasattr(self, 'language'):
|
| 255 |
+
language = self.language
|
| 256 |
+
else:
|
| 257 |
+
language = self._detect_language_from_text(text)
|
| 258 |
+
|
| 259 |
+
language = language.upper()
|
| 260 |
+
|
| 261 |
+
# 根据语言选择逗号类型和处理策略
|
| 262 |
+
if language == "EN" or (language == "AUTO" and self._is_english_text(text)):
|
| 263 |
+
# 英文处理:先删除撇号,再替换其他标点符号
|
| 264 |
+
text = re.sub(r"'", '', text) # 删除撇号(don't -> dont)
|
| 265 |
+
target_comma = ',' # 英文逗号
|
| 266 |
+
comma_pattern = r',+' # 匹配连续英文逗号
|
| 267 |
+
# 更新正则表达式,不包含撇号
|
| 268 |
+
text = re.sub(r'[.,!?;:()\[\]<>\"…·,。;:!?()【】《》""\\、]', target_comma, text)
|
| 269 |
+
else:
|
| 270 |
+
# 中文处理:包含撇号在替换范围内
|
| 271 |
+
target_comma = ',' # 中文逗号
|
| 272 |
+
comma_pattern = r',+' # 匹配连续中文逗号
|
| 273 |
+
# 更新正则表达式以匹配更多的标点符号
|
| 274 |
+
text = re.sub(r'[.,!?;:()\[\]<>\'\"…·,。;:!?()【】《》''""\\、]', target_comma, text)
|
| 275 |
+
|
| 276 |
+
text = re.sub(comma_pattern, target_comma, text)
|
| 277 |
+
return text.strip(target_comma)
|
| 278 |
+
|
| 279 |
+
def align_text_with_audio(self, text: str, audio_path: str, language=None) -> List[Dict[str, Any]]:
|
| 280 |
+
"""
|
| 281 |
+
文本和音频的词对齐
|
| 282 |
+
返回每个词对应的音频时间段
|
| 283 |
+
"""
|
| 284 |
+
# 确保模型已初始化
|
| 285 |
+
self._init_models_if_needed()
|
| 286 |
+
|
| 287 |
+
# 如果未指定语言,使用类的默认语言设置或自动检测
|
| 288 |
+
if language is None:
|
| 289 |
+
if hasattr(self, 'language'):
|
| 290 |
+
language = self.language
|
| 291 |
+
else:
|
| 292 |
+
language = self._detect_language_from_text(text)
|
| 293 |
+
else:
|
| 294 |
+
language = language.upper()
|
| 295 |
+
|
| 296 |
+
# 加载音频
|
| 297 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 298 |
+
|
| 299 |
+
# 重采样到模型要求的采样率
|
| 300 |
+
if sample_rate != self.alignment_model.bundle.sample_rate:
|
| 301 |
+
waveform = F.resample(waveform, sample_rate, self.alignment_model.bundle.sample_rate)
|
| 302 |
+
|
| 303 |
+
# 转换为单声道
|
| 304 |
+
if waveform.shape[0] > 1:
|
| 305 |
+
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 306 |
+
|
| 307 |
+
waveform = waveform.squeeze(0) # 移除批次维度
|
| 308 |
+
|
| 309 |
+
# 将音频移动到正确的设备
|
| 310 |
+
waveform = waveform.to(self.device)
|
| 311 |
+
|
| 312 |
+
# 执行对齐
|
| 313 |
+
try:
|
| 314 |
+
alignment_results = batch_get_alignment_result(
|
| 315 |
+
self.alignment_model,
|
| 316 |
+
[waveform],
|
| 317 |
+
[text],
|
| 318 |
+
[language]
|
| 319 |
+
)
|
| 320 |
+
if not alignment_results or not alignment_results[0]:
|
| 321 |
+
raise RuntimeError(f"对齐结果为空: {audio_path}")
|
| 322 |
+
return alignment_results[0]
|
| 323 |
+
except Exception as e:
|
| 324 |
+
self.logger.error(f"音频对齐失败: {audio_path}")
|
| 325 |
+
self.logger.error(f"错误详情: {e}")
|
| 326 |
+
raise RuntimeError(f"音频对齐失败,程序终止。文件: {audio_path},错误: {e}")
|
| 327 |
+
|
| 328 |
+
def split_audio_segment(self, audio_path: str, start_time: float, end_time: float, output_path: str):
|
| 329 |
+
"""分割音频片段"""
|
| 330 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 331 |
+
|
| 332 |
+
start_frame = int(start_time * sample_rate)
|
| 333 |
+
end_frame = int(end_time * sample_rate)
|
| 334 |
+
|
| 335 |
+
segment = waveform[:, start_frame:end_frame]
|
| 336 |
+
|
| 337 |
+
# 确保输出目录存在
|
| 338 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 339 |
+
|
| 340 |
+
torchaudio.save(output_path, segment, sample_rate)
|
| 341 |
+
return output_path
|
| 342 |
+
|
| 343 |
+
def concatenate_audio_files(self, audio_files: List[str], output_path: str):
|
| 344 |
+
"""拼接多个音频文件"""
|
| 345 |
+
if not audio_files:
|
| 346 |
+
return
|
| 347 |
+
|
| 348 |
+
waveforms = []
|
| 349 |
+
sample_rate = None
|
| 350 |
+
|
| 351 |
+
for audio_file in audio_files:
|
| 352 |
+
if os.path.exists(audio_file):
|
| 353 |
+
waveform, sr = torchaudio.load(audio_file)
|
| 354 |
+
if sample_rate is None:
|
| 355 |
+
sample_rate = sr
|
| 356 |
+
elif sr != sample_rate:
|
| 357 |
+
waveform = F.resample(waveform, sr, sample_rate)
|
| 358 |
+
waveforms.append(waveform)
|
| 359 |
+
|
| 360 |
+
if waveforms:
|
| 361 |
+
concatenated = torch.cat(waveforms, dim=1)
|
| 362 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 363 |
+
torchaudio.save(output_path, concatenated, sample_rate)
|
| 364 |
+
|
| 365 |
+
def split_audio_by_speaker(self, prompt_text: str, prompt_audio: str, audio_id: str) -> Tuple[str, str]:
|
| 366 |
+
"""
|
| 367 |
+
根据说话人标签分割prompt音频
|
| 368 |
+
返回S1和S2的音频片段路径
|
| 369 |
+
"""
|
| 370 |
+
# 1. 提取说话人片段
|
| 371 |
+
speaker_segments = self.extract_speaker_segments(prompt_text)
|
| 372 |
+
|
| 373 |
+
# 2. 删除标签后进行词对齐 - 如果失败则直接抛出异常
|
| 374 |
+
clean_text = self.remove_speaker_tags(prompt_text)
|
| 375 |
+
|
| 376 |
+
# 检测语言或使用设置的语言
|
| 377 |
+
alignment_language = self.language
|
| 378 |
+
if alignment_language == "AUTO":
|
| 379 |
+
alignment_language = self._detect_language_from_text(clean_text)
|
| 380 |
+
|
| 381 |
+
alignments = self.align_text_with_audio(clean_text, prompt_audio, alignment_language)
|
| 382 |
+
|
| 383 |
+
# 保存prompt对齐信息
|
| 384 |
+
prompt_alignment_data = {
|
| 385 |
+
'original_text': prompt_text,
|
| 386 |
+
'clean_text': clean_text,
|
| 387 |
+
'audio_path': prompt_audio,
|
| 388 |
+
'language': alignment_language,
|
| 389 |
+
'speaker_segments': speaker_segments,
|
| 390 |
+
'alignments': alignments
|
| 391 |
+
}
|
| 392 |
+
self.save_alignment_info(prompt_alignment_data, audio_id, "prompt")
|
| 393 |
+
|
| 394 |
+
# 3. 根据对齐结果分割音频
|
| 395 |
+
s1_segments = []
|
| 396 |
+
s2_segments = []
|
| 397 |
+
|
| 398 |
+
# 为每个说话人片段找到对应的时间段
|
| 399 |
+
text_pos = 0
|
| 400 |
+
for seg in speaker_segments:
|
| 401 |
+
seg_text = seg['content'].strip()
|
| 402 |
+
seg_length = len(seg_text)
|
| 403 |
+
|
| 404 |
+
# 找到这个片段在对齐结果中的起始和结束
|
| 405 |
+
start_time = None
|
| 406 |
+
end_time = None
|
| 407 |
+
|
| 408 |
+
current_pos = 0
|
| 409 |
+
for align_item in alignments:
|
| 410 |
+
item_text = align_item['transcript']
|
| 411 |
+
item_length = len(item_text)
|
| 412 |
+
|
| 413 |
+
if current_pos >= text_pos and current_pos < text_pos + seg_length:
|
| 414 |
+
if start_time is None:
|
| 415 |
+
start_time = align_item['start']
|
| 416 |
+
end_time = align_item['end']
|
| 417 |
+
|
| 418 |
+
current_pos += item_length
|
| 419 |
+
|
| 420 |
+
if start_time is not None and end_time is not None:
|
| 421 |
+
if seg['speaker'] == 'S1':
|
| 422 |
+
s1_segments.append((start_time, end_time))
|
| 423 |
+
else:
|
| 424 |
+
s2_segments.append((start_time, end_time))
|
| 425 |
+
|
| 426 |
+
text_pos += seg_length
|
| 427 |
+
|
| 428 |
+
# 4. 分割并拼接音频片段
|
| 429 |
+
safe_audio_id = self._get_safe_filename(audio_id)
|
| 430 |
+
prompts1_path = str(self.prompts_dir / f"{safe_audio_id}_s1.wav")
|
| 431 |
+
prompts2_path = str(self.prompts_dir / f"{safe_audio_id}_s2.wav")
|
| 432 |
+
|
| 433 |
+
# 分割S1的所有片段
|
| 434 |
+
if s1_segments:
|
| 435 |
+
s1_temp_segments = []
|
| 436 |
+
for i, (start, end) in enumerate(s1_segments):
|
| 437 |
+
temp_path = str(self.temp_dir / f"{safe_audio_id}_s1_temp_{i}.wav")
|
| 438 |
+
self.split_audio_segment(prompt_audio, start, end, temp_path)
|
| 439 |
+
s1_temp_segments.append(temp_path)
|
| 440 |
+
|
| 441 |
+
# 拼接S1片段
|
| 442 |
+
self.concatenate_audio_files(s1_temp_segments, prompts1_path)
|
| 443 |
+
|
| 444 |
+
# 分割S2的所有片段
|
| 445 |
+
if s2_segments:
|
| 446 |
+
s2_temp_segments = []
|
| 447 |
+
for i, (start, end) in enumerate(s2_segments):
|
| 448 |
+
temp_path = str(self.temp_dir / f"{safe_audio_id}_s2_temp_{i}.wav")
|
| 449 |
+
self.split_audio_segment(prompt_audio, start, end, temp_path)
|
| 450 |
+
s2_temp_segments.append(temp_path)
|
| 451 |
+
|
| 452 |
+
# 拼接S2片段
|
| 453 |
+
self.concatenate_audio_files(s2_temp_segments, prompts2_path)
|
| 454 |
+
|
| 455 |
+
return prompts1_path, prompts2_path
|
| 456 |
+
|
| 457 |
+
def map_text_segments_to_speakers(self, original_text: str) -> List[Dict[str, Any]]:
|
| 458 |
+
"""
|
| 459 |
+
将原始文本按说话人和标点符号同时分割,保持映射关系
|
| 460 |
+
支持英文单词级别的处理
|
| 461 |
+
"""
|
| 462 |
+
segments = []
|
| 463 |
+
pattern = r'\[S([12])\]([^[]*)'
|
| 464 |
+
matches = re.findall(pattern, original_text)
|
| 465 |
+
|
| 466 |
+
# 检测语言或使用设置的语言
|
| 467 |
+
alignment_language = self.language
|
| 468 |
+
if alignment_language == "AUTO":
|
| 469 |
+
alignment_language = self._detect_language_from_text(original_text)
|
| 470 |
+
|
| 471 |
+
segment_id = 0
|
| 472 |
+
for speaker_id, content in matches:
|
| 473 |
+
speaker = f'S{speaker_id}'
|
| 474 |
+
clean_content = content.strip()
|
| 475 |
+
comma_content = self.replace_punctuation_with_comma(clean_content, alignment_language)
|
| 476 |
+
|
| 477 |
+
# 根据语言选择正确的逗号分割
|
| 478 |
+
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_content)):
|
| 479 |
+
# 英文:按英文逗号分割,保持单词完整性
|
| 480 |
+
parts = [part.strip() for part in comma_content.split(',') if part.strip()]
|
| 481 |
+
else:
|
| 482 |
+
# 中文:按中文逗号分割
|
| 483 |
+
parts = [part.strip() for part in comma_content.split(',') if part.strip()]
|
| 484 |
+
|
| 485 |
+
for part in parts:
|
| 486 |
+
if part.strip():
|
| 487 |
+
segments.append({
|
| 488 |
+
'segment_id': segment_id,
|
| 489 |
+
'text': part.strip(),
|
| 490 |
+
'speaker_label': speaker,
|
| 491 |
+
'original_speaker_content': clean_content
|
| 492 |
+
})
|
| 493 |
+
segment_id += 1
|
| 494 |
+
|
| 495 |
+
return segments
|
| 496 |
+
|
| 497 |
+
def split_output_audio_by_comma(self, text: str, output_audio: str, audio_id: str) -> List[Dict[str, Any]]:
|
| 498 |
+
"""
|
| 499 |
+
根据逗号分割输出音频,返回每小段的信息 - 基于词对齐结果中的标点符号划分句子
|
| 500 |
+
"""
|
| 501 |
+
# 1. 获取文本片段和对应的说话人(用于获取speaker标签)
|
| 502 |
+
text_segments = self.map_text_segments_to_speakers(text)
|
| 503 |
+
|
| 504 |
+
# 2. 删除标签并替换标点符号
|
| 505 |
+
clean_text = self.remove_speaker_tags(text)
|
| 506 |
+
|
| 507 |
+
# 3. 检测语言或使用设置的语言
|
| 508 |
+
alignment_language = self.language
|
| 509 |
+
if alignment_language == "AUTO":
|
| 510 |
+
alignment_language = self._detect_language_from_text(clean_text)
|
| 511 |
+
|
| 512 |
+
# 使用检测到的语言替换标点符号
|
| 513 |
+
comma_text = self.replace_punctuation_with_comma(clean_text, alignment_language)
|
| 514 |
+
|
| 515 |
+
# 4. 词对齐 - 如果失败则直接抛出异常
|
| 516 |
+
alignments = self.align_text_with_audio(comma_text, output_audio, alignment_language)
|
| 517 |
+
|
| 518 |
+
# 5. 根据标点符号划分句子
|
| 519 |
+
segments = []
|
| 520 |
+
safe_audio_id = self._get_safe_filename(audio_id)
|
| 521 |
+
|
| 522 |
+
# 确定标点符号(根据语言选择,英文不包含撇号)
|
| 523 |
+
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
|
| 524 |
+
punctuation_chars = set([',', '.', '!', '?', ';', ':']) # 不包含撇号
|
| 525 |
+
else:
|
| 526 |
+
punctuation_chars = set([',', '。', '!', '?', ';', ':'])
|
| 527 |
+
|
| 528 |
+
# 顺序扫描对齐结果,根据标点符号划分句子
|
| 529 |
+
sentence_start_idx = 0
|
| 530 |
+
sentence_alignments = []
|
| 531 |
+
segment_id = 0
|
| 532 |
+
|
| 533 |
+
for i, align_item in enumerate(alignments):
|
| 534 |
+
transcript = align_item['transcript']
|
| 535 |
+
sentence_alignments.append(align_item)
|
| 536 |
+
|
| 537 |
+
# 检查是否包含标点符号(句子结束标志)
|
| 538 |
+
has_punctuation = any(punct in transcript for punct in punctuation_chars)
|
| 539 |
+
|
| 540 |
+
if has_punctuation or i == len(alignments) - 1: # 遇到标点符号或最后一个词
|
| 541 |
+
# 创建句子片段
|
| 542 |
+
if sentence_alignments:
|
| 543 |
+
# 获取句子的开始和结束时间
|
| 544 |
+
start_time = sentence_alignments[0]['start']
|
| 545 |
+
end_time = sentence_alignments[-1]['end']
|
| 546 |
+
|
| 547 |
+
# 构建句子文本(去除标点符号)
|
| 548 |
+
sentence_text_parts = []
|
| 549 |
+
for align in sentence_alignments:
|
| 550 |
+
# 根据语言选择不同的清理策略
|
| 551 |
+
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
|
| 552 |
+
# 英文:去除标点符号,但保留撇号已被删除的单词
|
| 553 |
+
clean_transcript = align['transcript'].rstrip(',.!?;:')
|
| 554 |
+
else:
|
| 555 |
+
# 中文:去除中文标点符号
|
| 556 |
+
clean_transcript = align['transcript'].rstrip(',。!?;:')
|
| 557 |
+
|
| 558 |
+
if clean_transcript.strip():
|
| 559 |
+
sentence_text_parts.append(clean_transcript)
|
| 560 |
+
|
| 561 |
+
# 根据语言选择连接方式
|
| 562 |
+
if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)):
|
| 563 |
+
sentence_text = ' '.join(sentence_text_parts).strip() # 英文用空格连接
|
| 564 |
+
else:
|
| 565 |
+
sentence_text = ''.join(sentence_text_parts).strip() # 中文直接连接
|
| 566 |
+
|
| 567 |
+
if sentence_text: # 只有非空句子才处理
|
| 568 |
+
# 确定说话人标签(从原始text_segments中获取,如果可能的话)
|
| 569 |
+
speaker_label = "S1" # 默认
|
| 570 |
+
if segment_id < len(text_segments):
|
| 571 |
+
speaker_label = text_segments[segment_id]['speaker_label']
|
| 572 |
+
elif text_segments:
|
| 573 |
+
# 如果超出范围,使用最后一个片段的speaker
|
| 574 |
+
speaker_label = text_segments[-1]['speaker_label']
|
| 575 |
+
|
| 576 |
+
# 生成音频文件路径
|
| 577 |
+
safe_text = self._get_safe_filename(sentence_text, 30)
|
| 578 |
+
audio_path = str(self.segments_dir / f"{safe_audio_id}_segment_{segment_id:03d}_{safe_text}.wav")
|
| 579 |
+
|
| 580 |
+
# 分割音频
|
| 581 |
+
try:
|
| 582 |
+
self.split_audio_segment(output_audio, start_time, end_time, audio_path)
|
| 583 |
+
except Exception as e:
|
| 584 |
+
self.logger.error(f"分割音频失败: {e}")
|
| 585 |
+
# 使用默认时间间隔
|
| 586 |
+
start_time = segment_id * 1.0
|
| 587 |
+
end_time = (segment_id + 1) * 1.0
|
| 588 |
+
self.split_audio_segment(output_audio, start_time, end_time, audio_path)
|
| 589 |
+
|
| 590 |
+
# 创建segment
|
| 591 |
+
segment = {
|
| 592 |
+
'segment_id': segment_id,
|
| 593 |
+
'text': sentence_text,
|
| 594 |
+
'speaker_label': speaker_label,
|
| 595 |
+
'original_speaker_content': sentence_text, # 这里简化处理
|
| 596 |
+
'audio_path': audio_path,
|
| 597 |
+
'start_time': start_time,
|
| 598 |
+
'end_time': end_time
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
segments.append(segment)
|
| 602 |
+
|
| 603 |
+
self.logger.info(f"句子 {segment_id}: '{sentence_text}' ({speaker_label}) -> {start_time:.3f}-{end_time:.3f}s")
|
| 604 |
+
segment_id += 1
|
| 605 |
+
|
| 606 |
+
# 重置为下一个句子
|
| 607 |
+
sentence_alignments = []
|
| 608 |
+
sentence_start_idx = i + 1
|
| 609 |
+
|
| 610 |
+
# 保存详细的对齐信息
|
| 611 |
+
self.save_detailed_alignment_info(
|
| 612 |
+
alignments, segments, audio_id, output_audio, text, comma_text
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
self.logger.info(f"总共分割出 {len(segments)} 个句子片段")
|
| 616 |
+
return segments
|
| 617 |
+
|
| 618 |
+
def _get_similarity_model_server(self):
|
| 619 |
+
"""获取线程局部的相似度模型实例(线程安全)"""
|
| 620 |
+
if not hasattr(self, 'similarity_model'):
|
| 621 |
+
# 为当前线程创建独立的模型实例
|
| 622 |
+
self.similarity_model = self._create_similarity_model()
|
| 623 |
+
return self.similarity_model
|
| 624 |
+
|
| 625 |
+
def _create_similarity_model(self):
|
| 626 |
+
"""创建新的相似度模型实例"""
|
| 627 |
+
try:
|
| 628 |
+
return TritonSpeakerClient(self.wespeaker_model_url)
|
| 629 |
+
except Exception as e:
|
| 630 |
+
self.logger.error(f"创建相似度模型失败: {e}")
|
| 631 |
+
raise
|
| 632 |
+
|
| 633 |
+
async def compute_similarity(self, processed_audio1, processed_audio2):
|
| 634 |
+
return await self.similarity_model.compute_similarity(processed_audio1, processed_audio2)
|
| 635 |
+
|
| 636 |
+
async def calculate_voice_similarity_thread_safe(self, audio1_path: str, audio2_path: str) -> float:
|
| 637 |
+
"""
|
| 638 |
+
线程安全的音色相似度计算
|
| 639 |
+
对于过短的音频片段,通过复制来达到最小长度要求
|
| 640 |
+
"""
|
| 641 |
+
try:
|
| 642 |
+
if not os.path.exists(audio1_path) or not os.path.exists(audio2_path):
|
| 643 |
+
self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}")
|
| 644 |
+
return None
|
| 645 |
+
|
| 646 |
+
# 获取线程局部的模型实例
|
| 647 |
+
_ = self._get_similarity_model_server()
|
| 648 |
+
|
| 649 |
+
# 计算相似度
|
| 650 |
+
similarity = await self.compute_similarity(audio1_path, audio2_path)
|
| 651 |
+
|
| 652 |
+
return float(similarity)
|
| 653 |
+
|
| 654 |
+
except Exception as e:
|
| 655 |
+
# 检查是否是窗口大小错误或其他计算错误
|
| 656 |
+
if "choose a window size" in str(e) or "window size" in str(e):
|
| 657 |
+
self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}")
|
| 658 |
+
return None
|
| 659 |
+
else:
|
| 660 |
+
self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}")
|
| 661 |
+
return None
|
| 662 |
+
|
| 663 |
+
async def calculate_segment_similarities_parallel(
|
| 664 |
+
self, output_segments: List[Dict[str, Any]], prompts1_path: str, prompts2_path: str
|
| 665 |
+
) -> List[Dict[str, Any]]:
|
| 666 |
+
"""
|
| 667 |
+
并行计算所有segments的相似度
|
| 668 |
+
Args:
|
| 669 |
+
output_segments: 音频segments列表
|
| 670 |
+
prompts1_path: S1 prompt音频路径
|
| 671 |
+
prompts2_path: S2 prompt音频路径
|
| 672 |
+
Returns:
|
| 673 |
+
包含相似度信息的segment列表
|
| 674 |
+
"""
|
| 675 |
+
|
| 676 |
+
async def calculate_single_segment_similarity(segment):
|
| 677 |
+
"""计算单个segment与两个prompts的相似度"""
|
| 678 |
+
try:
|
| 679 |
+
# 使用线程安全的相似度计算方法
|
| 680 |
+
sim1 = await self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts1_path)
|
| 681 |
+
sim2 = await self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts2_path)
|
| 682 |
+
|
| 683 |
+
return {
|
| 684 |
+
'segment': segment,
|
| 685 |
+
'sim1': sim1,
|
| 686 |
+
'sim2': sim2,
|
| 687 |
+
'success': True
|
| 688 |
+
}
|
| 689 |
+
except Exception as e:
|
| 690 |
+
self.logger.error(f"计算segment {segment['segment_id']} 相似度失败: {e}")
|
| 691 |
+
return {
|
| 692 |
+
'segment': segment,
|
| 693 |
+
'sim1': None,
|
| 694 |
+
'sim2': None,
|
| 695 |
+
'success': False
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
# 使用线程池并行处理所有segments
|
| 699 |
+
self.logger.info(f"开始异步计算 {len(output_segments)} 个segments的相似度")
|
| 700 |
+
|
| 701 |
+
# 创建任务并保留原始segment的顺序(gather会保持顺序)
|
| 702 |
+
tasks = [
|
| 703 |
+
asyncio.create_task(calculate_single_segment_similarity(segment))
|
| 704 |
+
for segment in output_segments
|
| 705 |
+
]
|
| 706 |
+
|
| 707 |
+
# 正确版本:使用asyncio.as_completed实时报告进度
|
| 708 |
+
return await self._run_tasks_with_progress(tasks)
|
| 709 |
+
|
| 710 |
+
# 新增辅助方法:带进度报告的任务执行
|
| 711 |
+
async def _run_tasks_with_progress(self, tasks):
|
| 712 |
+
"""执行任务集合并实时报告进度"""
|
| 713 |
+
completed_count = 0
|
| 714 |
+
total = len(tasks)
|
| 715 |
+
results = []
|
| 716 |
+
|
| 717 |
+
# 按完成顺序处理结果
|
| 718 |
+
for future in asyncio.as_completed(tasks):
|
| 719 |
+
result = await future
|
| 720 |
+
completed_count += 1
|
| 721 |
+
|
| 722 |
+
# 每完成10个segment报告一次进度
|
| 723 |
+
if completed_count % 10 == 0 or completed_count == total:
|
| 724 |
+
seg_id = result['segment']['segment_id']
|
| 725 |
+
self.logger.info(f"相似度计算进度: {completed_count}/{total} (最近完成: {seg_id})")
|
| 726 |
+
|
| 727 |
+
results.append(result)
|
| 728 |
+
|
| 729 |
+
# gather返回的就是按顺序的结果,无需额外排序
|
| 730 |
+
return results
|
| 731 |
+
|
| 732 |
+
async def evaluate_single_input(self, data: Dict[str, Any], input_id: str = None) -> Dict[str, Any]:
|
| 733 |
+
"""评估单个输入的音色相似度"""
|
| 734 |
+
|
| 735 |
+
# 生成输入ID
|
| 736 |
+
if input_id is None:
|
| 737 |
+
input_id = f"input_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 738 |
+
|
| 739 |
+
self.logger.info(f"开始评估输入: {input_id},使用语言: {self.language}")
|
| 740 |
+
|
| 741 |
+
# 1. 获取或分割prompt音频
|
| 742 |
+
prompts1_path, prompts2_path = self.get_or_split_prompt_audio(data, f"{input_id}_prompt")
|
| 743 |
+
|
| 744 |
+
# 2. 分割output音频(这里会保存详细对齐信息)
|
| 745 |
+
output_segments = self.split_output_audio_by_comma(data['text'], data['output_audio'], f"{input_id}_output")
|
| 746 |
+
|
| 747 |
+
# 3. 并行计算每小段的相似度
|
| 748 |
+
similarity_results = await self.calculate_segment_similarities_parallel(
|
| 749 |
+
output_segments, prompts1_path, prompts2_path
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
# 4. 处理相似度结果
|
| 753 |
+
segment_results = []
|
| 754 |
+
correct_predictions = 0
|
| 755 |
+
total_segments = 0 # 只计算有效段数
|
| 756 |
+
label_similarities = [] # 每小段与其标签的相似度
|
| 757 |
+
skipped_segments = 0 # 跳过的段数
|
| 758 |
+
|
| 759 |
+
for sim_result in similarity_results:
|
| 760 |
+
segment = sim_result['segment']
|
| 761 |
+
sim1 = sim_result['sim1']
|
| 762 |
+
sim2 = sim_result['sim2']
|
| 763 |
+
|
| 764 |
+
# 如果任一相似度为None(音频过短或计算失败),跳过该段
|
| 765 |
+
if sim1 is None or sim2 is None:
|
| 766 |
+
skipped_segments += 1
|
| 767 |
+
self.logger.info(f"跳过段 {segment['segment_id']}: 相似度计算失败")
|
| 768 |
+
continue
|
| 769 |
+
|
| 770 |
+
# 只有有效段才参与计算
|
| 771 |
+
total_segments += 1
|
| 772 |
+
|
| 773 |
+
# 判断实际音色
|
| 774 |
+
predicted_speaker = 'S1' if sim1 > sim2 else 'S2'
|
| 775 |
+
actual_speaker = segment['speaker_label']
|
| 776 |
+
is_correct = predicted_speaker == actual_speaker
|
| 777 |
+
|
| 778 |
+
if is_correct:
|
| 779 |
+
correct_predictions += 1
|
| 780 |
+
|
| 781 |
+
# 计算与标签的相似度
|
| 782 |
+
if actual_speaker == 'S1':
|
| 783 |
+
label_similarity = sim1
|
| 784 |
+
else:
|
| 785 |
+
label_similarity = sim2
|
| 786 |
+
label_similarities.append(label_similarity)
|
| 787 |
+
|
| 788 |
+
segment_result = {
|
| 789 |
+
'segment_id': segment['segment_id'],
|
| 790 |
+
'text': segment['text'],
|
| 791 |
+
'speaker_label': actual_speaker,
|
| 792 |
+
'predicted_speaker': predicted_speaker,
|
| 793 |
+
'sim1': sim1,
|
| 794 |
+
'sim2': sim2,
|
| 795 |
+
'label_similarity': label_similarity,
|
| 796 |
+
'is_correct': is_correct,
|
| 797 |
+
'audio_path': segment['audio_path'],
|
| 798 |
+
'start_time': segment.get('start_time', 0.0),
|
| 799 |
+
'end_time': segment.get('end_time', 1.0)
|
| 800 |
+
}
|
| 801 |
+
segment_results.append(segment_result)
|
| 802 |
+
|
| 803 |
+
# 4. 计算整体指标(只基于有效段)
|
| 804 |
+
accuracy = correct_predictions / total_segments if total_segments > 0 else 0.0
|
| 805 |
+
average_similarity = np.mean(label_similarities) if label_similarities else 0.0
|
| 806 |
+
|
| 807 |
+
# 5. 保存评估结果的对齐信息摘要
|
| 808 |
+
evaluation_alignment_summary = {
|
| 809 |
+
'input_id': input_id,
|
| 810 |
+
'language': self.language,
|
| 811 |
+
'prompt_alignment_files': [
|
| 812 |
+
f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json"
|
| 813 |
+
],
|
| 814 |
+
'output_alignment_file': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json",
|
| 815 |
+
'total_segments': total_segments,
|
| 816 |
+
'total_alignments_processed': len(output_segments),
|
| 817 |
+
'alignment_success_rate': total_segments / len(output_segments) if output_segments else 0.0
|
| 818 |
+
}
|
| 819 |
+
self.save_alignment_info(evaluation_alignment_summary, input_id, "summary")
|
| 820 |
+
|
| 821 |
+
result = {
|
| 822 |
+
'input_id': input_id,
|
| 823 |
+
'language': self.language,
|
| 824 |
+
'input_data': data, # 保存原始输入数据
|
| 825 |
+
'prompts1_path': prompts1_path,
|
| 826 |
+
'prompts2_path': prompts2_path,
|
| 827 |
+
'segments': segment_results,
|
| 828 |
+
'accuracy': accuracy,
|
| 829 |
+
'average_similarity': average_similarity,
|
| 830 |
+
'total_segments': total_segments, # 有效段数
|
| 831 |
+
'correct_predictions': correct_predictions,
|
| 832 |
+
'skipped_segments': skipped_segments, # 跳过的段数
|
| 833 |
+
'original_total_segments': len(output_segments), # 原始总段数
|
| 834 |
+
'alignment_files': {
|
| 835 |
+
'summary': f"{self._get_safe_filename(input_id)}_summary_alignment.json",
|
| 836 |
+
'output_detailed': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json",
|
| 837 |
+
'prompt': f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json"
|
| 838 |
+
},
|
| 839 |
+
'timestamp': datetime.now().isoformat()
|
| 840 |
+
}
|
| 841 |
+
|
| 842 |
+
self.logger.info(f"完成评估输入: {input_id}, 语言: {self.language}, 有效段: {total_segments}/{len(output_segments)}, 跳过: {skipped_segments}, 准确率: {accuracy:.3f}, 平均相似度: {average_similarity:.3f}")
|
| 843 |
+
|
| 844 |
+
return result
|
| 845 |
+
|
| 846 |
+
def save_results_to_jsonl(self, results: List[Dict[str, Any]], filename: str = None):
|
| 847 |
+
"""保存结果到JSONL文件"""
|
| 848 |
+
if filename is None:
|
| 849 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 850 |
+
filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
|
| 851 |
+
|
| 852 |
+
output_path = self.results_dir / filename
|
| 853 |
+
|
| 854 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 855 |
+
for result in results:
|
| 856 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
| 857 |
+
|
| 858 |
+
return str(output_path)
|
| 859 |
+
|
| 860 |
+
def save_summary_report(self, results: List[Dict[str, Any]], filename: str = None):
|
| 861 |
+
"""保存汇总报告"""
|
| 862 |
+
if filename is None:
|
| 863 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 864 |
+
filename = f"evaluation_summary_{self.language.lower()}_{timestamp}.json"
|
| 865 |
+
|
| 866 |
+
summary_path = self.results_dir / filename
|
| 867 |
+
|
| 868 |
+
# 计算总体统计
|
| 869 |
+
total_accuracy = np.mean([r['accuracy'] for r in results])
|
| 870 |
+
total_avg_similarity = np.mean([r['average_similarity'] for r in results])
|
| 871 |
+
total_segments = sum([r['total_segments'] for r in results])
|
| 872 |
+
total_correct = sum([r['correct_predictions'] for r in results])
|
| 873 |
+
|
| 874 |
+
summary = {
|
| 875 |
+
'evaluation_summary': {
|
| 876 |
+
'language': self.language,
|
| 877 |
+
'total_inputs': len(results),
|
| 878 |
+
'total_segments': total_segments,
|
| 879 |
+
'total_correct_predictions': total_correct,
|
| 880 |
+
'overall_accuracy': total_accuracy,
|
| 881 |
+
'overall_average_similarity': total_avg_similarity,
|
| 882 |
+
'evaluation_timestamp': datetime.now().isoformat(),
|
| 883 |
+
'output_directory': str(self.output_dir),
|
| 884 |
+
'alignment_directory': str(self.alignment_dir)
|
| 885 |
+
},
|
| 886 |
+
'per_input_results': [
|
| 887 |
+
{
|
| 888 |
+
'input_id': r['input_id'],
|
| 889 |
+
'language': r.get('language', self.language),
|
| 890 |
+
'accuracy': r['accuracy'],
|
| 891 |
+
'average_similarity': r['average_similarity'],
|
| 892 |
+
'total_segments': r['total_segments'],
|
| 893 |
+
'correct_predictions': r['correct_predictions'],
|
| 894 |
+
'output_audio_path': r['input_data']['output_audio'],
|
| 895 |
+
'alignment_files': r.get('alignment_files', {})
|
| 896 |
+
}
|
| 897 |
+
for r in results
|
| 898 |
+
]
|
| 899 |
+
}
|
| 900 |
+
|
| 901 |
+
with open(summary_path, 'w', encoding='utf-8') as f:
|
| 902 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 903 |
+
|
| 904 |
+
return str(summary_path)
|
| 905 |
+
|
| 906 |
+
def process_batch_from_jsonl_parallel(self, jsonl_path: str,
|
| 907 |
+
processes_per_gpu: int = 16,
|
| 908 |
+
results_filename: str = None,
|
| 909 |
+
shuffle_data: bool = True):
|
| 910 |
+
"""从JSONL文件并行批量处理输入数据"""
|
| 911 |
+
# 加载数据
|
| 912 |
+
input_data = self.load_data_from_jsonl(jsonl_path)
|
| 913 |
+
|
| 914 |
+
if not input_data:
|
| 915 |
+
self.logger.error("没有有效的输入数据")
|
| 916 |
+
return []
|
| 917 |
+
|
| 918 |
+
# 对数据进行shuffle,使分配更均匀
|
| 919 |
+
if shuffle_data:
|
| 920 |
+
random.shuffle(input_data)
|
| 921 |
+
self.logger.info(f"已对 {len(input_data)} 条数据进行随机shuffle")
|
| 922 |
+
|
| 923 |
+
return self.process_batch_parallel(input_data, processes_per_gpu, results_filename)
|
| 924 |
+
|
| 925 |
+
def process_batch_from_jsonl(self, jsonl_path: str, results_filename: str = None):
|
| 926 |
+
"""从JSONL文件批量处理输入数据(单进程版本)"""
|
| 927 |
+
# 加载数据
|
| 928 |
+
input_data = self.load_data_from_jsonl(jsonl_path)
|
| 929 |
+
|
| 930 |
+
if not input_data:
|
| 931 |
+
self.logger.error("没有有效的输入数据")
|
| 932 |
+
return []
|
| 933 |
+
|
| 934 |
+
return asyncio.run(self.process_batch_from_data(input_data, results_filename))
|
| 935 |
+
|
| 936 |
+
async def process_batch_from_data(self, input_data: List[Dict[str, Any]], results_filename: str = None):
|
| 937 |
+
"""处理数据列表(单进程版本,用于兼容),支持增量写入"""
|
| 938 |
+
# 准备结果文件
|
| 939 |
+
if results_filename is None:
|
| 940 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 941 |
+
results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
|
| 942 |
+
|
| 943 |
+
results_path = self.results_dir / results_filename
|
| 944 |
+
|
| 945 |
+
# 如果文件已存在,删除它(重新开始)
|
| 946 |
+
if results_path.exists():
|
| 947 |
+
results_path.unlink()
|
| 948 |
+
|
| 949 |
+
results = []
|
| 950 |
+
|
| 951 |
+
self.logger.info(f"开始处理 {len(input_data)} 个输入,使用语言: {self.language}...")
|
| 952 |
+
|
| 953 |
+
for i, data in enumerate(input_data):
|
| 954 |
+
input_id = f"input_{i+1:03d}"
|
| 955 |
+
print(f"处理第{i+1}/{len(input_data)}个输入: {input_id},语言: {self.language}")
|
| 956 |
+
|
| 957 |
+
try:
|
| 958 |
+
result = await self.evaluate_single_input(data, input_id=input_id)
|
| 959 |
+
results.append(result)
|
| 960 |
+
|
| 961 |
+
# 增量写入结果
|
| 962 |
+
self.append_result_to_jsonl(result, str(results_path))
|
| 963 |
+
|
| 964 |
+
except Exception as e:
|
| 965 |
+
self.logger.error(f"处理输入{input_id}时出错: {e}")
|
| 966 |
+
continue
|
| 967 |
+
|
| 968 |
+
if not results:
|
| 969 |
+
self.logger.error("没有成功处理的输入")
|
| 970 |
+
return []
|
| 971 |
+
|
| 972 |
+
# 保存汇总报告
|
| 973 |
+
summary_path = self.save_summary_report(results)
|
| 974 |
+
|
| 975 |
+
# 清理临时文件
|
| 976 |
+
self._clean_temp_files()
|
| 977 |
+
|
| 978 |
+
# 打印总体统计
|
| 979 |
+
total_accuracy = np.mean([r['accuracy'] for r in results])
|
| 980 |
+
total_avg_similarity = np.mean([r['average_similarity'] for r in results])
|
| 981 |
+
|
| 982 |
+
print(f"\n=== 评估完成 ===")
|
| 983 |
+
print(f"使用语言: {self.language}")
|
| 984 |
+
print(f"总体准确率: {total_accuracy:.3f}")
|
| 985 |
+
print(f"总体平均相似度: {total_avg_similarity:.3f}")
|
| 986 |
+
print(f"详细结果已保存到: {results_path}")
|
| 987 |
+
print(f"汇总报告已保存到: {summary_path}")
|
| 988 |
+
print(f"对齐信息已保存到: {self.alignment_dir}")
|
| 989 |
+
print(f"所有中间文件保存在: {self.output_dir}")
|
| 990 |
+
|
| 991 |
+
return results
|
| 992 |
+
|
| 993 |
+
def _load_wespeaker_model(self, wespeaker_model_url):
|
| 994 |
+
"""加载wespeaker模型"""
|
| 995 |
+
try:
|
| 996 |
+
self.similarity_model = TritonSpeakerClient(wespeaker_model_url)
|
| 997 |
+
except ImportError:
|
| 998 |
+
raise ImportError("请安装wespeaker: pip install git+https://github.com/wenet-e2e/wespeaker.git")
|
| 999 |
+
except Exception as e:
|
| 1000 |
+
self.logger.error(f"加载wespeaker模型失败: {e}")
|
| 1001 |
+
raise
|
| 1002 |
+
|
| 1003 |
+
def load_data_from_jsonl(self, jsonl_path: str) -> List[Dict[str, Any]]:
|
| 1004 |
+
"""从JSONL文件加载数据"""
|
| 1005 |
+
data = []
|
| 1006 |
+
try:
|
| 1007 |
+
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
| 1008 |
+
for line_num, line in enumerate(f, 1):
|
| 1009 |
+
line = line.strip()
|
| 1010 |
+
if line:
|
| 1011 |
+
try:
|
| 1012 |
+
item = json.loads(line)
|
| 1013 |
+
# 验证必要字段
|
| 1014 |
+
required_fields = ['text', 'output_audio']
|
| 1015 |
+
for field in required_fields:
|
| 1016 |
+
if field not in item:
|
| 1017 |
+
self.logger.error(f"第{line_num}行缺少必要字段: {field}")
|
| 1018 |
+
continue
|
| 1019 |
+
|
| 1020 |
+
# 验证音频路径模式:要么有prompt_audio和prompt_text,要么有分别的speaker音频文件
|
| 1021 |
+
has_combined_prompt = 'prompt_audio' in item and 'prompt_text' in item
|
| 1022 |
+
has_separate_prompts = ('prompt_audio_speaker1' in item and
|
| 1023 |
+
'prompt_text_speaker1' in item and
|
| 1024 |
+
'prompt_audio_speaker2' in item and
|
| 1025 |
+
'prompt_text_speaker2' in item)
|
| 1026 |
+
|
| 1027 |
+
if not (has_combined_prompt or has_separate_prompts):
|
| 1028 |
+
self.logger.error(f"第{line_num}行:需要提供prompt_audio+prompt_text或者分别的speaker音频文件")
|
| 1029 |
+
continue
|
| 1030 |
+
|
| 1031 |
+
data.append(item)
|
| 1032 |
+
|
| 1033 |
+
except json.JSONDecodeError as e:
|
| 1034 |
+
self.logger.error(f"第{line_num}行JSON解析错误: {e}")
|
| 1035 |
+
continue
|
| 1036 |
+
|
| 1037 |
+
self.logger.info(f"从{jsonl_path}成功加载{len(data)}条数据")
|
| 1038 |
+
return data
|
| 1039 |
+
|
| 1040 |
+
except FileNotFoundError:
|
| 1041 |
+
self.logger.error(f"JSONL文件不存在: {jsonl_path}")
|
| 1042 |
+
return []
|
| 1043 |
+
except Exception as e:
|
| 1044 |
+
self.logger.error(f"读取JSONL文件失败: {e}")
|
| 1045 |
+
return []
|
| 1046 |
+
|
| 1047 |
+
@staticmethod
|
| 1048 |
+
def get_gpu_count():
|
| 1049 |
+
"""获取可用GPU数量"""
|
| 1050 |
+
if torch.cuda.is_available():
|
| 1051 |
+
return torch.cuda.device_count()
|
| 1052 |
+
return 0
|
| 1053 |
+
|
| 1054 |
+
@staticmethod
|
| 1055 |
+
def split_data_by_gpu(data: List[Dict[str, Any]], num_gpus: int) -> List[List[Dict[str, Any]]]:
|
| 1056 |
+
"""根据GPU数量分割数据"""
|
| 1057 |
+
if num_gpus == 0:
|
| 1058 |
+
return [data]
|
| 1059 |
+
|
| 1060 |
+
chunk_size = math.ceil(len(data) / num_gpus)
|
| 1061 |
+
gpu_chunks = []
|
| 1062 |
+
|
| 1063 |
+
for i in range(num_gpus):
|
| 1064 |
+
start_idx = i * chunk_size
|
| 1065 |
+
end_idx = min((i + 1) * chunk_size, len(data))
|
| 1066 |
+
if start_idx < len(data):
|
| 1067 |
+
gpu_chunks.append(data[start_idx:end_idx])
|
| 1068 |
+
|
| 1069 |
+
return gpu_chunks
|
| 1070 |
+
|
| 1071 |
+
@staticmethod
|
| 1072 |
+
def split_data_by_processes(data: List[Dict[str, Any]], num_processes: int) -> List[List[Dict[str, Any]]]:
|
| 1073 |
+
"""根据进程数量分割数据"""
|
| 1074 |
+
if num_processes <= 1:
|
| 1075 |
+
return [data]
|
| 1076 |
+
|
| 1077 |
+
chunk_size = math.ceil(len(data) / num_processes)
|
| 1078 |
+
process_chunks = []
|
| 1079 |
+
|
| 1080 |
+
for i in range(num_processes):
|
| 1081 |
+
start_idx = i * chunk_size
|
| 1082 |
+
end_idx = min((i + 1) * chunk_size, len(data))
|
| 1083 |
+
if start_idx < len(data):
|
| 1084 |
+
process_chunks.append(data[start_idx:end_idx])
|
| 1085 |
+
|
| 1086 |
+
return process_chunks
|
| 1087 |
+
|
| 1088 |
+
def append_result_to_jsonl(self, result: Dict[str, Any], filepath: str):
|
| 1089 |
+
"""增量写入结果到JSONL文件"""
|
| 1090 |
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
| 1091 |
+
with open(filepath, 'a', encoding='utf-8') as f:
|
| 1092 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
| 1093 |
+
f.flush() # 强制刷新缓冲区
|
| 1094 |
+
|
| 1095 |
+
def merge_temp_results(self, temp_files: List[str], final_path: str):
|
| 1096 |
+
"""合并临时结果文件"""
|
| 1097 |
+
all_results = []
|
| 1098 |
+
|
| 1099 |
+
for temp_file in temp_files:
|
| 1100 |
+
if os.path.exists(temp_file):
|
| 1101 |
+
try:
|
| 1102 |
+
with open(temp_file, 'r', encoding='utf-8') as f:
|
| 1103 |
+
for line in f:
|
| 1104 |
+
line = line.strip()
|
| 1105 |
+
if line:
|
| 1106 |
+
result = json.loads(line)
|
| 1107 |
+
all_results.append(result)
|
| 1108 |
+
except Exception as e:
|
| 1109 |
+
self.logger.error(f"读取临时文件失败: {temp_file}, 错误: {e}")
|
| 1110 |
+
|
| 1111 |
+
# 写入最终文件
|
| 1112 |
+
with open(final_path, 'w', encoding='utf-8') as f:
|
| 1113 |
+
for result in all_results:
|
| 1114 |
+
f.write(json.dumps(result, ensure_ascii=False) + '\n')
|
| 1115 |
+
|
| 1116 |
+
return all_results
|
| 1117 |
+
|
| 1118 |
+
def process_batch_parallel(self, input_data: List[Dict[str, Any]],
|
| 1119 |
+
processes_per_gpu: int = 8, # 降低进程数
|
| 1120 |
+
results_filename: str = None,
|
| 1121 |
+
shuffle_data: bool = True):
|
| 1122 |
+
"""并行批量处理输入数据"""
|
| 1123 |
+
# 1. 检查GPU数量
|
| 1124 |
+
num_gpus = self.get_gpu_count()
|
| 1125 |
+
if num_gpus == 0:
|
| 1126 |
+
self.logger.warning("未检测到GPU,将使用CPU单进程处理")
|
| 1127 |
+
return asyncio.run(self.process_batch_from_data(input_data, results_filename))
|
| 1128 |
+
|
| 1129 |
+
# 限制每个GPU的进程数,避免CUDA内存冲突
|
| 1130 |
+
max_processes_per_gpu = min(processes_per_gpu, 16)
|
| 1131 |
+
self.logger.info(f"检测到 {num_gpus} 个GPU,每个GPU将使用 {max_processes_per_gpu} 个进程")
|
| 1132 |
+
|
| 1133 |
+
# 2. 对数据进行shuffle(如果还没有shuffle过)
|
| 1134 |
+
shuffled_data = input_data.copy()
|
| 1135 |
+
if shuffle_data:
|
| 1136 |
+
random.shuffle(shuffled_data)
|
| 1137 |
+
self.logger.info(f"已对 {len(shuffled_data)} 条数据进行随机shuffle以平衡GPU负载")
|
| 1138 |
+
|
| 1139 |
+
# 3. 按GPU分割数据
|
| 1140 |
+
gpu_chunks = self.split_data_by_gpu(shuffled_data, num_gpus)
|
| 1141 |
+
|
| 1142 |
+
# 打印每个GPU分配到的数据量
|
| 1143 |
+
for gpu_id, gpu_data in enumerate(gpu_chunks):
|
| 1144 |
+
if gpu_data:
|
| 1145 |
+
self.logger.info(f"GPU {gpu_id}: 分配到 {len(gpu_data)} 条数据")
|
| 1146 |
+
|
| 1147 |
+
# 4. 准备结果文件路径
|
| 1148 |
+
if results_filename is None:
|
| 1149 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 1150 |
+
results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl"
|
| 1151 |
+
|
| 1152 |
+
final_results_path = self.results_dir / results_filename
|
| 1153 |
+
|
| 1154 |
+
# 5. 为所有GPU准备进程参数
|
| 1155 |
+
all_temp_files = []
|
| 1156 |
+
all_gpu_tasks = []
|
| 1157 |
+
|
| 1158 |
+
for gpu_id, gpu_data in enumerate(gpu_chunks):
|
| 1159 |
+
if not gpu_data:
|
| 1160 |
+
continue
|
| 1161 |
+
|
| 1162 |
+
self.logger.info(f"GPU {gpu_id}: 准备处理 {len(gpu_data)} 条数据")
|
| 1163 |
+
|
| 1164 |
+
# 按进程数分割当前GPU的数据
|
| 1165 |
+
process_chunks = self.split_data_by_processes(gpu_data, max_processes_per_gpu)
|
| 1166 |
+
|
| 1167 |
+
# 为当前GPU准备所有进程参数
|
| 1168 |
+
gpu_process_args = []
|
| 1169 |
+
for proc_id, proc_data in enumerate(process_chunks):
|
| 1170 |
+
if proc_data:
|
| 1171 |
+
temp_result_file = str(self.temp_results_dir / f"gpu{gpu_id}_proc{proc_id}_results.jsonl")
|
| 1172 |
+
all_temp_files.append(temp_result_file)
|
| 1173 |
+
|
| 1174 |
+
# 子进程输出目录在主输出目录内部
|
| 1175 |
+
subprocess_output_dir = str(self.output_dir / f"gpu{gpu_id}_proc{proc_id}")
|
| 1176 |
+
|
| 1177 |
+
gpu_process_args.append((
|
| 1178 |
+
proc_data,
|
| 1179 |
+
gpu_id,
|
| 1180 |
+
proc_id,
|
| 1181 |
+
subprocess_output_dir,
|
| 1182 |
+
temp_result_file,
|
| 1183 |
+
self.alignment_model_dir,
|
| 1184 |
+
self.wespeaker_model_url,
|
| 1185 |
+
self.language, # 语言参数
|
| 1186 |
+
self.similarity_max_workers # 添加相似度计算线程数参数
|
| 1187 |
+
))
|
| 1188 |
+
|
| 1189 |
+
if gpu_process_args:
|
| 1190 |
+
all_gpu_tasks.append((gpu_id, gpu_process_args, max_processes_per_gpu))
|
| 1191 |
+
|
| 1192 |
+
# 6. 使用ThreadPoolExecutor并行处理所有GPU
|
| 1193 |
+
def process_gpu_tasks(gpu_task):
|
| 1194 |
+
gpu_id, process_args, actual_processes = gpu_task
|
| 1195 |
+
self.logger.info(f"GPU {gpu_id}: 开始并行处理 {len(process_args)} 个进程")
|
| 1196 |
+
|
| 1197 |
+
# 为每个GPU使用独立的进程池,避免进程间冲突
|
| 1198 |
+
with mp.Pool(processes=actual_processes) as pool:
|
| 1199 |
+
# 调用同步包装器 run_async_worker,在每个子进程内部运行异步函数。
|
| 1200 |
+
pool.map(run_async_worker, process_args)
|
| 1201 |
+
|
| 1202 |
+
self.logger.info(f"GPU {gpu_id}: 所有进程处理完成")
|
| 1203 |
+
return gpu_id
|
| 1204 |
+
|
| 1205 |
+
# 使用线程池同时处理所有GPU
|
| 1206 |
+
with ThreadPoolExecutor(max_workers=num_gpus) as executor:
|
| 1207 |
+
# 提交所有GPU任务
|
| 1208 |
+
future_to_gpu = {executor.submit(process_gpu_tasks, gpu_task): gpu_task[0]
|
| 1209 |
+
for gpu_task in all_gpu_tasks}
|
| 1210 |
+
|
| 1211 |
+
# 等待所有GPU完成
|
| 1212 |
+
completed_gpus = []
|
| 1213 |
+
for future in as_completed(future_to_gpu):
|
| 1214 |
+
gpu_id = future_to_gpu[future]
|
| 1215 |
+
try:
|
| 1216 |
+
result_gpu_id = future.result()
|
| 1217 |
+
completed_gpus.append(result_gpu_id)
|
| 1218 |
+
self.logger.info(f"GPU {result_gpu_id} 完成处理")
|
| 1219 |
+
except Exception as exc:
|
| 1220 |
+
self.logger.error(f"GPU {gpu_id} 处理时发生异常: {exc}")
|
| 1221 |
+
|
| 1222 |
+
self.logger.info(f"所有GPU处理完成: {completed_gpus}")
|
| 1223 |
+
|
| 1224 |
+
# 7. 合并所有临时结果文件
|
| 1225 |
+
self.logger.info("合并所有临时结果文件...")
|
| 1226 |
+
all_results = self.merge_temp_results(all_temp_files, str(final_results_path))
|
| 1227 |
+
|
| 1228 |
+
if not all_results:
|
| 1229 |
+
self.logger.error("没有成功处理的数据")
|
| 1230 |
+
return []
|
| 1231 |
+
|
| 1232 |
+
# 8. 生成汇总报告
|
| 1233 |
+
summary_path = self.save_summary_report(all_results)
|
| 1234 |
+
|
| 1235 |
+
# 9. 清理临时文件
|
| 1236 |
+
for temp_file in all_temp_files:
|
| 1237 |
+
if os.path.exists(temp_file):
|
| 1238 |
+
os.remove(temp_file)
|
| 1239 |
+
|
| 1240 |
+
# 10. 打印总体统计
|
| 1241 |
+
total_accuracy = np.mean([r['accuracy'] for r in all_results])
|
| 1242 |
+
total_avg_similarity = np.mean([r['average_similarity'] for r in all_results])
|
| 1243 |
+
|
| 1244 |
+
print(f"\n=== 并行评估完成 ===")
|
| 1245 |
+
print(f"使用语言: {self.language}")
|
| 1246 |
+
print(f"使用 {num_gpus} 个GPU,每GPU {max_processes_per_gpu} 个进程")
|
| 1247 |
+
print(f"总处理数据: {len(input_data)} 条")
|
| 1248 |
+
print(f"成功处理: {len(all_results)} 条")
|
| 1249 |
+
print(f"总体准确率: {total_accuracy:.3f}")
|
| 1250 |
+
print(f"总体平均相似度: {total_avg_similarity:.3f}")
|
| 1251 |
+
print(f"详细结果已保存到: {final_results_path}")
|
| 1252 |
+
print(f"汇总报告已保存到: {summary_path}")
|
| 1253 |
+
print(f"对齐信息已保存到: {self.alignment_dir}")
|
| 1254 |
+
|
| 1255 |
+
return all_results
|
| 1256 |
+
|
| 1257 |
+
def get_or_split_prompt_audio(self, data: Dict[str, Any], audio_id: str) -> Tuple[str, str]:
|
| 1258 |
+
"""
|
| 1259 |
+
获取或分割prompt音频
|
| 1260 |
+
如果提供了分别的speaker音频文件则直接使用,否则从combined prompt分割
|
| 1261 |
+
"""
|
| 1262 |
+
# 检查是否有分别的speaker音频文件
|
| 1263 |
+
if ('prompt_audio_speaker1' in data and 'prompt_audio_speaker2' in data and
|
| 1264 |
+
'prompt_text_speaker1' in data and 'prompt_text_speaker2' in data):
|
| 1265 |
+
|
| 1266 |
+
self.logger.info(f"使用预分割的speaker音频文件")
|
| 1267 |
+
|
| 1268 |
+
# 即使使用预分割的音频,也保存对齐信息
|
| 1269 |
+
try:
|
| 1270 |
+
# 检测语言或使用设置的语言
|
| 1271 |
+
alignment_language = self.language
|
| 1272 |
+
if alignment_language == "AUTO":
|
| 1273 |
+
alignment_language = self._detect_language_from_text(data['prompt_text_speaker1'])
|
| 1274 |
+
|
| 1275 |
+
# 对S1音频进行对齐
|
| 1276 |
+
s1_alignments = self.align_text_with_audio(
|
| 1277 |
+
data['prompt_text_speaker1'], data['prompt_audio_speaker1'], alignment_language
|
| 1278 |
+
)
|
| 1279 |
+
s1_alignment_data = {
|
| 1280 |
+
'speaker': 'S1',
|
| 1281 |
+
'text': data['prompt_text_speaker1'],
|
| 1282 |
+
'audio_path': data['prompt_audio_speaker1'],
|
| 1283 |
+
'language': alignment_language,
|
| 1284 |
+
'alignments': s1_alignments
|
| 1285 |
+
}
|
| 1286 |
+
self.save_alignment_info(s1_alignment_data, audio_id, "prompt_s1")
|
| 1287 |
+
|
| 1288 |
+
# 对S2音频进行对齐
|
| 1289 |
+
s2_alignments = self.align_text_with_audio(
|
| 1290 |
+
data['prompt_text_speaker2'], data['prompt_audio_speaker2'], alignment_language
|
| 1291 |
+
)
|
| 1292 |
+
s2_alignment_data = {
|
| 1293 |
+
'speaker': 'S2',
|
| 1294 |
+
'text': data['prompt_text_speaker2'],
|
| 1295 |
+
'audio_path': data['prompt_audio_speaker2'],
|
| 1296 |
+
'language': alignment_language,
|
| 1297 |
+
'alignments': s2_alignments
|
| 1298 |
+
}
|
| 1299 |
+
self.save_alignment_info(s2_alignment_data, audio_id, "prompt_s2")
|
| 1300 |
+
|
| 1301 |
+
except Exception as e:
|
| 1302 |
+
self.logger.warning(f"保存预分割音频对齐信息失败: {e}")
|
| 1303 |
+
|
| 1304 |
+
return data['prompt_audio_speaker1'], data['prompt_audio_speaker2']
|
| 1305 |
+
|
| 1306 |
+
# 否则从combined prompt分割
|
| 1307 |
+
elif 'prompt_audio' in data and 'prompt_text' in data:
|
| 1308 |
+
self.logger.info(f"从combined prompt音频分割speaker片段")
|
| 1309 |
+
return self.split_audio_by_speaker(data['prompt_text'], data['prompt_audio'], audio_id)
|
| 1310 |
+
|
| 1311 |
+
else:
|
| 1312 |
+
raise ValueError("必须提供prompt_audio+prompt_text或者分别的speaker音频文件")
|
| 1313 |
+
|
| 1314 |
+
def calculate_voice_similarity(self, audio1_path: str, audio2_path: str) -> float:
|
| 1315 |
+
"""
|
| 1316 |
+
计算两个音频的音色相似度(向后兼容版本)
|
| 1317 |
+
对于过短的音频片段,通过复制来达到最小长度要求
|
| 1318 |
+
"""
|
| 1319 |
+
# 如果在多线程环境中,使用线程安全版本
|
| 1320 |
+
if threading.current_thread() != threading.main_thread():
|
| 1321 |
+
return self.calculate_voice_similarity_thread_safe(audio1_path, audio2_path)
|
| 1322 |
+
|
| 1323 |
+
# 确保模型已初始化
|
| 1324 |
+
self._init_models_if_needed()
|
| 1325 |
+
|
| 1326 |
+
try:
|
| 1327 |
+
if not os.path.exists(audio1_path) or not os.path.exists(audio2_path):
|
| 1328 |
+
self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}")
|
| 1329 |
+
return None
|
| 1330 |
+
|
| 1331 |
+
# 检查并处理音频文件长度
|
| 1332 |
+
def process_audio_for_similarity(audio_path, min_duration=0.1):
|
| 1333 |
+
"""
|
| 1334 |
+
处理音频文件,如果过短则复制到满足最小长度要求
|
| 1335 |
+
返回处理后的音频路径和是否为临时文件的标志
|
| 1336 |
+
"""
|
| 1337 |
+
try:
|
| 1338 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 1339 |
+
duration = waveform.shape[1] / sample_rate
|
| 1340 |
+
|
| 1341 |
+
if duration >= min_duration:
|
| 1342 |
+
# 音频长度足够,直接返回原路径
|
| 1343 |
+
return audio_path, False
|
| 1344 |
+
|
| 1345 |
+
# 音频过短,需要复制
|
| 1346 |
+
repeat_times = math.ceil(min_duration / duration)
|
| 1347 |
+
self.logger.info(f"音频过短 ({duration:.3f}s),复制 {repeat_times} 次达到 {min_duration}s 要求: {audio_path}")
|
| 1348 |
+
|
| 1349 |
+
# 复制音频
|
| 1350 |
+
repeated_waveform = waveform.repeat(1, repeat_times)
|
| 1351 |
+
|
| 1352 |
+
# 生成临时文件路径
|
| 1353 |
+
temp_filename = f"temp_{os.path.basename(audio_path)}"
|
| 1354 |
+
temp_path = str(self.temp_dir / temp_filename)
|
| 1355 |
+
|
| 1356 |
+
# 保存复制后的音频
|
| 1357 |
+
torchaudio.save(temp_path, repeated_waveform, sample_rate)
|
| 1358 |
+
|
| 1359 |
+
return temp_path, True
|
| 1360 |
+
|
| 1361 |
+
except Exception as e:
|
| 1362 |
+
self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}")
|
| 1363 |
+
return audio_path, False
|
| 1364 |
+
|
| 1365 |
+
# 处理两个音频文件
|
| 1366 |
+
processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path)
|
| 1367 |
+
processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path)
|
| 1368 |
+
|
| 1369 |
+
# 计算相似度
|
| 1370 |
+
similarity = self.similarity_model.compute_similarity(processed_audio1, processed_audio2)
|
| 1371 |
+
|
| 1372 |
+
# 清理临时文件
|
| 1373 |
+
if is_temp1 and os.path.exists(processed_audio1):
|
| 1374 |
+
try:
|
| 1375 |
+
os.remove(processed_audio1)
|
| 1376 |
+
except Exception as e:
|
| 1377 |
+
self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}")
|
| 1378 |
+
|
| 1379 |
+
if is_temp2 and os.path.exists(processed_audio2):
|
| 1380 |
+
try:
|
| 1381 |
+
os.remove(processed_audio2)
|
| 1382 |
+
except Exception as e:
|
| 1383 |
+
self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}")
|
| 1384 |
+
|
| 1385 |
+
return float(similarity)
|
| 1386 |
+
|
| 1387 |
+
except Exception as e:
|
| 1388 |
+
# 检查是否是窗口大小错误或其他计算错误
|
| 1389 |
+
if "choose a window size" in str(e) or "window size" in str(e):
|
| 1390 |
+
self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}")
|
| 1391 |
+
return None
|
| 1392 |
+
else:
|
| 1393 |
+
self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}")
|
| 1394 |
+
return None
|
| 1395 |
+
|
| 1396 |
+
# 全局函数,用于多进程处理(支持增量写入)
|
| 1397 |
+
async def process_data_chunk_incremental(args):
|
| 1398 |
+
"""处理数据块的工作函数(增量写入版本)"""
|
| 1399 |
+
data_chunk, gpu_id, proc_id, output_dir, temp_result_file, alignment_model_dir, wespeaker_model_url, language, similarity_max_workers = args
|
| 1400 |
+
|
| 1401 |
+
# 设置当前进程使用的GPU
|
| 1402 |
+
device = f"cuda:{gpu_id}" if torch.cuda.is_available() and gpu_id < torch.cuda.device_count() else "cpu"
|
| 1403 |
+
|
| 1404 |
+
try:
|
| 1405 |
+
# 清理CUDA状态,避免进程间冲突
|
| 1406 |
+
if torch.cuda.is_available():
|
| 1407 |
+
torch.cuda.empty_cache()
|
| 1408 |
+
# 设置当前进程的GPU设备
|
| 1409 |
+
torch.cuda.set_device(gpu_id)
|
| 1410 |
+
# 添加小延迟,避免同时初始化冲突
|
| 1411 |
+
time.sleep(proc_id * 0.5)
|
| 1412 |
+
|
| 1413 |
+
# 创建评估器实例,传入模型路径、语言参数和相似度计算线程数
|
| 1414 |
+
evaluator = SpeakerSimilarityEvaluator(
|
| 1415 |
+
device=device,
|
| 1416 |
+
alignment_model_dir=alignment_model_dir,
|
| 1417 |
+
wespeaker_model_url=wespeaker_model_url,
|
| 1418 |
+
output_dir=output_dir,
|
| 1419 |
+
language=language, # 传入语言参数
|
| 1420 |
+
similarity_max_workers=similarity_max_workers # 传入相似度计算线程数
|
| 1421 |
+
)
|
| 1422 |
+
|
| 1423 |
+
# 延迟初始化模型
|
| 1424 |
+
evaluator._init_models_if_needed()
|
| 1425 |
+
|
| 1426 |
+
# 清空临时结果文件(如果存在)
|
| 1427 |
+
if os.path.exists(temp_result_file):
|
| 1428 |
+
os.remove(temp_result_file)
|
| 1429 |
+
|
| 1430 |
+
# 处理数据块
|
| 1431 |
+
for i, data in enumerate(data_chunk):
|
| 1432 |
+
input_id = f"gpu{gpu_id}_proc{proc_id}_input_{i+1:03d}"
|
| 1433 |
+
|
| 1434 |
+
try:
|
| 1435 |
+
result = await evaluator.evaluate_single_input(data, input_id=input_id)
|
| 1436 |
+
|
| 1437 |
+
# 立即写入结果到临时文件
|
| 1438 |
+
evaluator.append_result_to_jsonl(result, temp_result_file)
|
| 1439 |
+
|
| 1440 |
+
print(f"GPU{gpu_id}-进程{proc_id}: 完成 {input_id} (语言: {language}, 相似度线程: {similarity_max_workers})")
|
| 1441 |
+
|
| 1442 |
+
# 每处理完一个数据项,清理CUDA缓存
|
| 1443 |
+
if torch.cuda.is_available():
|
| 1444 |
+
torch.cuda.empty_cache()
|
| 1445 |
+
|
| 1446 |
+
except Exception as e:
|
| 1447 |
+
print(f"GPU{gpu_id}-进程{proc_id}: 处理 {input_id} 失败: {e}")
|
| 1448 |
+
# 出错时也清理CUDA缓存
|
| 1449 |
+
if torch.cuda.is_available():
|
| 1450 |
+
torch.cuda.empty_cache()
|
| 1451 |
+
continue
|
| 1452 |
+
|
| 1453 |
+
print(f"GPU{gpu_id}-进程{proc_id}: 所有数据处理完成,结果已写入 {temp_result_file}")
|
| 1454 |
+
|
| 1455 |
+
except Exception as e:
|
| 1456 |
+
print(f"GPU{gpu_id}-进程{proc_id}: 初始化失败: {e}")
|
| 1457 |
+
# 出错时清理CUDA缓存
|
| 1458 |
+
if torch.cuda.is_available():
|
| 1459 |
+
torch.cuda.empty_cache()
|
| 1460 |
+
|
| 1461 |
+
|
| 1462 |
+
def run_async_worker(args):
|
| 1463 |
+
"""
|
| 1464 |
+
一个同步包装器,为我们的异步工作函数设置并运行asyncio事件循环。
|
| 1465 |
+
这是必需的,因为 multiprocessing.Pool 不能直接调用异步函数。
|
| 1466 |
+
"""
|
| 1467 |
+
# asyncio.run() 是在每个子进程中启动和运行协程最简单、最安全的方式。
|
| 1468 |
+
# 它会创建一个新的事件循环,运行协程直到完成,然后关闭事件循环。
|
| 1469 |
+
return asyncio.run(process_data_chunk_incremental(args))
|
| 1470 |
+
|
| 1471 |
+
|
| 1472 |
+
def main():
|
| 1473 |
+
"""主函数示例"""
|
| 1474 |
+
import argparse
|
| 1475 |
+
|
| 1476 |
+
parser = argparse.ArgumentParser(description='Speaker Similarity Evaluator')
|
| 1477 |
+
parser.add_argument('--jsonl_path', type=str, help='JSONL文件路径')
|
| 1478 |
+
parser.add_argument('--output_dir', type=str,
|
| 1479 |
+
default=f"/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/eval_res/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
| 1480 |
+
help='结果保存目录')
|
| 1481 |
+
parser.add_argument('--language', type=str, choices=['zh', 'en', 'auto'], default='zh',
|
| 1482 |
+
help='指定语言: zh=中文, en=英文, auto=自动检测 (默认: zh)')
|
| 1483 |
+
parser.add_argument('--no_parallel', action='store_true', help='禁用并行处理(默认启用并行)')
|
| 1484 |
+
parser.add_argument('--processes_per_gpu', type=int, default=4, help='每个GPU的进程数(建议不超过4)')
|
| 1485 |
+
parser.add_argument('--similarity_workers', type=int, default=16, help='相似度计算的线程数(默认: 8)')
|
| 1486 |
+
parser.add_argument('--no_shuffle', action='store_true', help='禁用数据shuffle(默认启用shuffle)')
|
| 1487 |
+
parser.add_argument('--random_seed', type=int, default=None, help='随机种子(可选,用于结果复现)')
|
| 1488 |
+
|
| 1489 |
+
args = parser.parse_args()
|
| 1490 |
+
|
| 1491 |
+
# 设置随机种子(如果指定)
|
| 1492 |
+
if args.random_seed is not None:
|
| 1493 |
+
random.seed(args.random_seed)
|
| 1494 |
+
np.random.seed(args.random_seed)
|
| 1495 |
+
torch.manual_seed(args.random_seed)
|
| 1496 |
+
print(f"设置随机种子: {args.random_seed}")
|
| 1497 |
+
|
| 1498 |
+
# 语言参数处理
|
| 1499 |
+
language = args.language.upper()
|
| 1500 |
+
if language == 'AUTO':
|
| 1501 |
+
language = 'AUTO'
|
| 1502 |
+
elif language == 'EN':
|
| 1503 |
+
language = 'EN'
|
| 1504 |
+
else:
|
| 1505 |
+
language = 'ZH' # 默认中文
|
| 1506 |
+
|
| 1507 |
+
# 创建评估器,指定结果保存目录、语言和相似度计算线程数
|
| 1508 |
+
evaluator = SpeakerSimilarityEvaluator(
|
| 1509 |
+
output_dir=args.output_dir,
|
| 1510 |
+
language=language,
|
| 1511 |
+
similarity_max_workers=args.similarity_workers
|
| 1512 |
+
)
|
| 1513 |
+
|
| 1514 |
+
# 默认使用并行处理,除非明确禁用
|
| 1515 |
+
use_parallel = not args.no_parallel
|
| 1516 |
+
use_shuffle = not args.no_shuffle
|
| 1517 |
+
|
| 1518 |
+
print(f"使用语言设置: {language}")
|
| 1519 |
+
print(f"相似度计算线程数: {args.similarity_workers}")
|
| 1520 |
+
|
| 1521 |
+
if args.jsonl_path:
|
| 1522 |
+
# 从JSONL文件处理数据
|
| 1523 |
+
if use_parallel:
|
| 1524 |
+
evaluator.process_batch_from_jsonl_parallel(
|
| 1525 |
+
args.jsonl_path,
|
| 1526 |
+
processes_per_gpu=args.processes_per_gpu,
|
| 1527 |
+
shuffle_data=use_shuffle
|
| 1528 |
+
)
|
| 1529 |
+
else:
|
| 1530 |
+
asyncio.run(evaluator.process_batch_from_jsonl(args.jsonl_path))
|
| 1531 |
+
else:
|
| 1532 |
+
# 使用示例数据(兼容性)
|
| 1533 |
+
input_data = [
|
| 1534 |
+
{
|
| 1535 |
+
'prompt_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/zhouxingchi/zxc_enhanced.wav",
|
| 1536 |
+
'prompt_text': "[S1]你再往前半步我就把你给杀了。[S2]你应该这么做,我也应该死。",
|
| 1537 |
+
'text': "[S1]至尊宝,如果有一天我不再是紫霞仙子,只是一个普通的凡人,你还会像现在这样陪着我吗?[S2]这个嘛,那我得先问问月老,看看他给不给我打折!毕竟追仙子要花好多力气的![S1]哼!油嘴滑舌!我是认真的![S2]紫霞,不管你是仙子还是凡人,哪怕变成一根香蕉,我都认得出你。不过……你最好别真变成香蕉,我怕我会忍不住吃掉……[S1]讨厌!谁要变成香蕉啊!那……如果有一天,我们不得不分开呢?[S2]哇!你这话比牛魔王的斧头还狠!不行不行,你得赔我精神损失费![S1]怎么赔?[S2]很简单,让我亲一下,就当是定金![S1]想得美!那如果有一天,你真的忘了我呢?[S2]那我就算翻遍三界,打烂阎王殿,也要把记忆找回来。紫霞,我至尊宝这辈子,赖定你了![S1]傻瓜。",
|
| 1538 |
+
'output_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step145000/test_set/output_7.wav"
|
| 1539 |
+
}
|
| 1540 |
+
]
|
| 1541 |
+
|
| 1542 |
+
# 处理数据
|
| 1543 |
+
if use_parallel:
|
| 1544 |
+
evaluator.process_batch_parallel(input_data, processes_per_gpu=args.processes_per_gpu)
|
| 1545 |
+
else:
|
| 1546 |
+
asyncio.run(evaluator.process_batch_from_data(input_data))
|
| 1547 |
+
|
| 1548 |
+
|
| 1549 |
+
if __name__ == "__main__":
|
| 1550 |
+
main()
|
test_online.sh
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# 设置CUDA环境变量
|
| 4 |
+
export CUDA_LAUNCH_BLOCKING=1
|
| 5 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
| 6 |
+
|
| 7 |
+
# Color variables
|
| 8 |
+
RESET='\033[0m'
|
| 9 |
+
RED='\033[0;31m'
|
| 10 |
+
GREEN='\033[0;32m'
|
| 11 |
+
YELLOW='\033[0;33m'
|
| 12 |
+
BLUE='\033[0;34m'
|
| 13 |
+
MAGENTA='\033[0;35m'
|
| 14 |
+
CYAN='\033[0;36m'
|
| 15 |
+
WHITE='\033[0;37m'
|
| 16 |
+
|
| 17 |
+
# 创建日志目录和文件名
|
| 18 |
+
LOG_DIR="./logs"
|
| 19 |
+
mkdir -p "$LOG_DIR"
|
| 20 |
+
# 记录开始时间
|
| 21 |
+
START_TIME=$(date +%s)
|
| 22 |
+
START_TIME_READABLE=$(date '+%Y-%m-%d %H:%M:%S')
|
| 23 |
+
LOG_TIME=$(date +%Y%m%d_%H%M%S)
|
| 24 |
+
LOG_FILE="$LOG_DIR/evaluation_$LOG_TIME.log"
|
| 25 |
+
SERVER_FILE="$LOG_DIR/server_$LOG_TIME.log"
|
| 26 |
+
|
| 27 |
+
# A function to ensure the server is killed, which we'll call on exit.
|
| 28 |
+
cleanup() {
|
| 29 |
+
echo "--- Cleanup ---"
|
| 30 |
+
# Check if the server process is still running
|
| 31 |
+
if kill -0 $SERVER_PID 2>/dev/null; then
|
| 32 |
+
echo "Client has finished. Sending SIGTERM to shut down the server (PID: $SERVER_PID)..."
|
| 33 |
+
# Send the SIGTERM signal, allowing the server to shut down gracefully if it handles the signal.
|
| 34 |
+
kill $SERVER_PID
|
| 35 |
+
# Wait a moment for it to terminate
|
| 36 |
+
wait $SERVER_PID 2>/dev/null
|
| 37 |
+
echo "Server has been shut down."
|
| 38 |
+
else
|
| 39 |
+
echo "Server (PID: $SERVER_PID) was already stopped."
|
| 40 |
+
fi
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Use 'trap' to register the 'cleanup' function to be called when the script exits.
|
| 44 |
+
# This works for normal exit, Ctrl+C (SIGINT), or termination (SIGTERM).
|
| 45 |
+
trap cleanup EXIT
|
| 46 |
+
|
| 47 |
+
# 1. Start the server in the background
|
| 48 |
+
echo "Starting alignment models' remote_server.py in the background..."
|
| 49 |
+
/opt/tritonserver/bin/tritonserver --model-repository=./model_repo 2>&1 > $SERVER_FILE &
|
| 50 |
+
|
| 51 |
+
# 2. Capture the Process ID (PID) of the server
|
| 52 |
+
SERVER_PID=$!
|
| 53 |
+
echo "Server started with PID: $SERVER_PID"
|
| 54 |
+
|
| 55 |
+
# Give the server a moment to initialize and start listening on its port.
|
| 56 |
+
# This is crucial, otherwise the client might try to connect before the server is ready.
|
| 57 |
+
echo "Waited 3 seconds for server to initialize."
|
| 58 |
+
echo "------------------------------------------"
|
| 59 |
+
sleep 3
|
| 60 |
+
|
| 61 |
+
echo "${GREEN}========================================="
|
| 62 |
+
echo "音色相似度评估开始"
|
| 63 |
+
echo "开始时间: $START_TIME_READABLE"
|
| 64 |
+
echo "日志文件: $LOG_FILE"
|
| 65 |
+
echo "========================================="
|
| 66 |
+
echo "可以使用以下命令实时查看日志:"
|
| 67 |
+
echo "tail -f $LOG_FILE${RESET}"
|
| 68 |
+
echo ""
|
| 69 |
+
|
| 70 |
+
# 将开始时间信息也写入日志文件
|
| 71 |
+
{
|
| 72 |
+
echo "${GREEN}========================================="
|
| 73 |
+
echo "音色相似度评估开始"
|
| 74 |
+
echo "开始时间: $START_TIME_READABLE"
|
| 75 |
+
echo "进程配置: 每GPU 8个进程"
|
| 76 |
+
echo "语言设置: zh (中文)"
|
| 77 |
+
echo "=========================================${RESET}"
|
| 78 |
+
echo ""
|
| 79 |
+
} | tee "$LOG_FILE"
|
| 80 |
+
|
| 81 |
+
# 3. Run the client in the foreground
|
| 82 |
+
echo "Starting similarity test client test.py in the foreground..."
|
| 83 |
+
# The script will pause here and wait for client.py to complete.
|
| 84 |
+
# We wrap this in a block to capture the exit code.
|
| 85 |
+
{
|
| 86 |
+
# 使用更保守的进程数
|
| 87 |
+
python -u ./test_online.py \
|
| 88 |
+
--jsonl_path /data-mnt/data/yqzhang/testset_ttsd/test_set_zh_304/output_new.jsonl \
|
| 89 |
+
--output_dir ./eval_res/new_test_online \
|
| 90 |
+
--processes_per_gpu 8 \
|
| 91 |
+
--language zh \
|
| 92 |
+
2>&1 | tee -a "$LOG_FILE"
|
| 93 |
+
CLIENT_EXIT_CODE=$?
|
| 94 |
+
}
|
| 95 |
+
echo "------------------------------------------"
|
| 96 |
+
echo "${YELLOW}Client.py has finished with exit code: $CLIENT_EXIT_CODE${RESET}"
|
| 97 |
+
# 记录结束时间
|
| 98 |
+
END_TIME=$(date +%s)
|
| 99 |
+
END_TIME_READABLE=$(date '+%Y-%m-%d %H:%M:%S')
|
| 100 |
+
|
| 101 |
+
# 计算耗时
|
| 102 |
+
DURATION=$((END_TIME - START_TIME))
|
| 103 |
+
HOURS=$((DURATION / 3600))
|
| 104 |
+
MINUTES=$(((DURATION % 3600) / 60))
|
| 105 |
+
SECONDS=$((DURATION % 60))
|
| 106 |
+
|
| 107 |
+
# 输出结束信息
|
| 108 |
+
{
|
| 109 |
+
echo "${GREEN}"
|
| 110 |
+
echo "========================================="
|
| 111 |
+
echo "音色相似度评估完成!"
|
| 112 |
+
echo "结束时间: $END_TIME_READABLE"
|
| 113 |
+
echo "总耗时: ${HOURS}小时${MINUTES}分钟${SECONDS}秒 (共${DURATION}秒)"
|
| 114 |
+
echo "日志文件: $LOG_FILE"
|
| 115 |
+
echo "========================================="
|
| 116 |
+
echo "${RESET}"
|
| 117 |
+
} | tee -a "$LOG_FILE"
|
| 118 |
+
|
| 119 |
+
# 显示在终端
|
| 120 |
+
echo "${GREEN}"
|
| 121 |
+
echo "评估完成!"
|
| 122 |
+
echo "开始时间: $START_TIME_READABLE"
|
| 123 |
+
echo "结束时间: $END_TIME_READABLE"
|
| 124 |
+
echo "总耗时: ${HOURS}小时${MINUTES}分钟${SECONDS}秒"
|
| 125 |
+
echo "日志已保存到: $LOG_FILE"
|
| 126 |
+
echo "${RESET}"
|
| 127 |
+
|
| 128 |
+
# 如果耗时超过1小时,发送额外提醒
|
| 129 |
+
if [ $DURATION -gt 3600 ]; then
|
| 130 |
+
echo "${RED}"
|
| 131 |
+
echo "⏰ 注意:本次评估耗时较长,超过1小时"
|
| 132 |
+
echo " 建议检查性能优化效果"
|
| 133 |
+
echo "${RESET}"
|
| 134 |
+
fi
|
| 135 |
+
|
| 136 |
+
# The 'trap' will automatically call the 'cleanup' function now that the script is exiting.
|
| 137 |
+
# The exit is triggered because the client process (the last foreground command) has finished.
|
| 138 |
+
|
| 139 |
+
# You can add logic based on the client's exit code if needed.
|
| 140 |
+
if [ $CLIENT_EXIT_CODE -ne 0 ]; then
|
| 141 |
+
echo "Warning: Client exited with an error."
|
| 142 |
+
exit 1 # Exit the main script with an error code as well
|
| 143 |
+
fi
|
| 144 |
+
|
| 145 |
+
echo "Script finished successfully."
|
| 146 |
+
exit 0
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|