voxblink2_samresnet100_ft / alignment_online.py
MCplayer's picture
speech similarity model
29c0409
import base64
import httpx
import re
import requests
import torch
import torchaudio.functional as F
import torchaudio
import uroman as ur
import logging
import traceback
def convert_to_list_with_punctuation_mixed(text):
"""处理中文文本(可能包含英文单词) - 中文按字符分割,英文单词保持完整"""
result = []
text = text.strip()
if not text:
return result
def is_chinese(char):
"""检查是否是汉字"""
return '\u4e00' <= char <= '\u9fff'
# 使用更精确的正则表达式来分割文本
# 匹配:英文单词(含数字)、单个汉字、标点符号
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
tokens = re.findall(pattern, text)
for token in tokens:
if not token.strip(): # 跳过空字符
continue
if re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词(可能包含数字)
result.append(token)
elif is_chinese(token): # 单个汉字
result.append(token)
else: # 标点符号等其他字符
# 标点符号加到前一个词后面
if result:
result[-1] += token
else:
# 如果是文本开头的标点,单独作为一项
result.append(token)
return result
def split_and_merge_punctuation(text):
"""处理英文 - 按单词分割,保持单词完整性"""
# 先按空格拆分文本
elements = text.split()
# 用于保存最终的结果
result = []
# 遍历每个拆分后的元素
for ele in elements:
# 使用正则表达式提取连续字母、数字和标点
parts = re.findall(r'[a-zA-Z0-9]+|[^\w\s]+', ele)
# 用于保存拆分后的部分
merged_parts = []
for i in range(len(parts)):
if i % 2 == 0: # 如果是字母或数字部分
# 将字母或数字部分添加到结果中
merged_parts.append(parts[i])
else: # 如果是标点或其他符号部分
# 将标点部分与前面的字母或数字部分合并
if merged_parts:
merged_parts[-1] += parts[i]
else:
merged_parts.append(parts[i])
# 将合并后的部分加入最终结果
result.extend(merged_parts)
return result
def get_aligned_result_text_with_punctuation(alignment_result, text, language):
"""
将对齐结果转换为正确的文本tokens,英文保持单词级别,中文保持字符级别(但英文单词完整)
"""
logging.info("start change text to text_tokens")
if language == "EN":
text_tokens = split_and_merge_punctuation(text) # 英文按单词分词
elif language == "ZH":
text_tokens = convert_to_list_with_punctuation_mixed(text) # 中文按字符分割,但英文单词保持完整
else:
raise ValueError(f"Unsupported language: {language}")
logging.info(f"Text tokens count: {len(text_tokens)}, Alignment result count: {len(alignment_result)}")
punctuations = set(',.!?;:()[]<>\'\"…·,。;:!?()【】《》''""\、')
logging.info("start get align result text with punctuation")
updated_alignment_result = []
token_idx = 0
for index, align_item in enumerate(alignment_result):
if token_idx >= len(text_tokens):
# 如果text_tokens用完了但还有对齐结果,跳出循环
logging.warning(f"Text tokens exhausted at index {token_idx}, but alignment has more items")
break
start = align_item["start"]
end = align_item["end"]
text_token = text_tokens[token_idx]
# 检查该 token 后是否有连续标点(仅对中文)
if language == "ZH":
while token_idx + 1 < len(text_tokens) and text_tokens[token_idx + 1] in punctuations:
assert False, "???" # 这里理论上应该进不去??
text_token += text_tokens[token_idx + 1] # 将标点加入
token_idx += 1
else:
# 英文不需要特殊的标点处理,因为标点已经在split_and_merge_punctuation中处理了
pass
# 更新对齐结果
updated_item = {
"start": start,
"end": end,
"transcript": text_token
}
updated_item.update({key: align_item[key] for key in align_item if key not in ["start", "end", "transcript"]})
updated_alignment_result.append(updated_item)
token_idx += 1
logging.info("end get align result text with punctuation")
return updated_alignment_result
class AlignmentModel:
def __init__(self, device, model_dir='/data-mnt/data/wy/X-Codec-2.0/checkpoints'):
"""
初始化对齐模型并加载必要的资源
"""
self.device = torch.device(device)
self.bundle = torchaudio.pipelines.MMS_FA
model = self.bundle.get_model(with_star=False, dl_kwargs={'model_dir': model_dir}).to(self.device)
# --- 核心优化 ---
# 使用 torch.compile 对模型进行 JIT 编译
# mode="max-autotune" 会花费更长时间编译,但能达到最佳性能
print("Compiling the model... This may take a moment.")
self.align_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
print("Model compiled successfully.")
self.uroman = ur.Uroman()
self.DICTIONARY = self.bundle.get_dict()
def align(self, emission, tokens):
"""
执行强对齐
:param emission: 模型的输出
:param tokens: 目标 tokens
:return: 对齐的 tokens 和分数
"""
alignments, scores = F.forced_align(
log_probs=emission,
targets=tokens,
blank=0
)
alignments, scores = alignments[0], scores[0]
scores = scores.exp()
return alignments, scores
def unflatten(self, list_, lengths):
"""
将一个长列表按照长度拆分成子列表
:param list_: 长列表
:param lengths: 各子列表的长度
:return: 拆分后的子列表
"""
assert len(list_) == sum(lengths)
i = 0
ret = []
for l in lengths:
ret.append(list_[i:i + l])
i += l
return ret
def preview_word(self, waveform, spans, num_frames, transcript, sample_rate):
"""
预览每个单词的开始时间和结束时间
:param waveform: 音频波形
:param spans: 单词的跨度
:param num_frames: 帧数
:param transcript: 转录文本
:param sample_rate: 采样率
:return: 单词的对齐信息
"""
end = 0
alignment_result = []
for span, trans in zip(spans, transcript):
ratio = waveform.size(1) / num_frames
x0 = int(ratio * span[0].start)
x1 = int(ratio * span[-1].end)
align_info = {
"transcript": trans,
"start": round(x0 / sample_rate, 3),
"end": round(x1 / sample_rate, 3)
}
align_info["pause"] = round(align_info["start"] - end, 3)
align_info["duration"] = round(align_info["end"] - align_info["start"], 3)
end = align_info["end"]
alignment_result.append(align_info)
return alignment_result
def make_wav_batch(self, wav_list):
"""
将 wav_list 中的每个 wav 张量填充为相同的长度,返回填充后的张量和每个张量的原始长度。
:param wav_list: wav 文件列表
:return: 填充后的音频张量和原始长度
"""
wav_lengths = torch.tensor([wav.size(0) for wav in wav_list], dtype=torch.long)
max_length = max(wav_lengths)
# 确保张量在正确的设备上
wavs_tensors = torch.zeros(len(wav_list), max_length, device=self.device)
for i, wav in enumerate(wav_list):
wav = wav.to(self.device) # 确保wav在正确的设备上
wavs_tensors[i, :wav_lengths[i]] = wav
return wavs_tensors, wav_lengths.to(self.device)
def get_target(self, transcript, language):
"""
获取给定转录文本的目标 tokens - 修正版本,保持英文单词完整性
"""
original_transcript = transcript # 保存原始文本用于调试
if language == "ZH":
# 中文处理:保持英文单词完整,只对中文字符进行romanization
# 使用相同的分词逻辑
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
tokens = re.findall(pattern, transcript)
# 分别处理中文字符和英文单词
processed_parts = []
for token in tokens:
if not token.strip():
continue
elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词
# 英文单词保持原样,不进行romanization
processed_parts.append(token.lower())
elif '\u4e00' <= token <= '\u9fff': # 中文字符
# 只对中文字符进行romanization
romanized = self.uroman.romanize_string(token)
processed_parts.append(romanized)
else: # 标点符号等
# 标点符号直接添加,但会在后续步骤中被过滤掉
processed_parts.append(token)
# 用空格连接所有部分
transcript = ' '.join(processed_parts)
elif language == "EN":
# 英文处理:保持单词结构,只是清理标点
pass
else:
assert False, f"Unsupported language: {language}"
# 清理标点符号
transcript = re.sub(r'[^\w\s]', r' ', transcript)
TRANSCRIPT = transcript.lower().split()
# 提前获取字典中的特殊符号 token
star_token = self.DICTIONARY['*']
tokenized_transcript = []
# 统一的tokenization逻辑
for word in TRANSCRIPT:
# 对每个word中的字符进行token化
word_tokens = []
for c in word:
if c in self.DICTIONARY and c != '-':
word_tokens.append(self.DICTIONARY[c])
else:
word_tokens.append(star_token)
tokenized_transcript.extend(word_tokens)
logging.info(f"Original transcript: {original_transcript}")
logging.info(f"Processed transcript: {transcript}")
logging.info(f"Final TRANSCRIPT: {TRANSCRIPT}")
return torch.tensor([tokenized_transcript], dtype=torch.int32, device=self.device)
def get_alignment_result(self, emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language):
"""
根据给定的 emission 和对齐信息生成对齐结果 - 修正版本
"""
original_transcript = transcript # 保存原始文本
if language == "ZH":
# 使用与get_target相同的处理逻辑
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]'
tokens = re.findall(pattern, transcript)
processed_parts = []
for token in tokens:
if not token.strip():
continue
elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词
processed_parts.append(token.lower())
elif '\u4e00' <= token <= '\u9fff': # 中文字符
romanized = self.uroman.romanize_string(token)
processed_parts.append(romanized)
else: # 标点符号等
processed_parts.append(token)
transcript = ' '.join(processed_parts)
elif language == "EN":
pass
else:
assert False, f"Unsupported language: {language}"
transcript = re.sub(r'[^\w\s]', r' ', transcript)
emission = emission_padded[:emission_length, :].unsqueeze(0)
TRANSCRIPT = transcript.lower().split()
token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
# 统一的分组逻辑
word_spans = self.unflatten(token_spans, [len(word) for word in TRANSCRIPT])
num_frames = emission.size(1)
logging.info(f"Original transcript for alignment: {original_transcript}")
logging.info(f"Processed TRANSCRIPT: {TRANSCRIPT}")
return self.preview_word(waveform.unsqueeze(0), word_spans, num_frames, TRANSCRIPT, self.bundle.sample_rate)
def batch_alignment(self, wav_list, transcript_list, language_list):
"""
批量对齐
:param wav_list: wav 文件列表
:param transcript_list: 转录文本列表
:param language_list: 语言类型列表
:return: 对齐结果列表
"""
wavs_tensors, wavs_lengths_tensor = self.make_wav_batch(wav_list)
logging.info("start alignment model forward")
with torch.inference_mode():
emission, emission_lengths = self.align_model(wavs_tensors.to(self.device), wavs_lengths_tensor)
star_dim = torch.zeros((emission.shape[0], emission.size(1), 1), dtype=emission.dtype, device=self.device)
emission = torch.cat((emission, star_dim), dim=-1)
logging.info("end alignment model forward")
target_list = [self.get_target(transcript, language) for transcript, language in zip(transcript_list, language_list)]
logging.info("align success")
align_results = [
self.align(emission_padded[:emission_length, :].unsqueeze(0), target)
for emission_padded, emission_length, target in zip(emission, emission_lengths, target_list)
]
logging.info("get align result")
batch_aligned_tokens = [align_result[0] for align_result in align_results]
batch_alignment_scores = [align_result[1] for align_result in align_results]
alignment_result_list = [
self.get_alignment_result(emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language)
for emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language
in zip(emission, emission_lengths, batch_aligned_tokens, batch_alignment_scores, transcript_list, wav_list, language_list)
]
logging.info("get align result success")
return alignment_result_list
async def batch_get_alignment_result_remote(alignment_url, audio_path, transcript, language):
"""
通过调用远程对齐服务来批量获取对齐结果。
"""
payload = {
"audio_path": audio_path,
"transcript": transcript,
"language": language,
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(alignment_url, json=payload, timeout=300) # 设置较长的超时
response.raise_for_status() # 如果状态码不是 2xx,则抛出异常
data = response.json()
return data['results']
except requests.exceptions.RequestException as e:
logging.error(f"Failed to connect to alignment service: {e}")
traceback.print_exc()
# 根据需求可以返回空列表或抛出异常
except Exception as e:
logging.error(f"An error occurred in remote alignment: {e}")
traceback.print_exc()