import base64 import httpx import re import requests import torch import torchaudio.functional as F import torchaudio import uroman as ur import logging import traceback def convert_to_list_with_punctuation_mixed(text): """处理中文文本(可能包含英文单词) - 中文按字符分割,英文单词保持完整""" result = [] text = text.strip() if not text: return result def is_chinese(char): """检查是否是汉字""" return '\u4e00' <= char <= '\u9fff' # 使用更精确的正则表达式来分割文本 # 匹配:英文单词(含数字)、单个汉字、标点符号 pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]' tokens = re.findall(pattern, text) for token in tokens: if not token.strip(): # 跳过空字符 continue if re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词(可能包含数字) result.append(token) elif is_chinese(token): # 单个汉字 result.append(token) else: # 标点符号等其他字符 # 标点符号加到前一个词后面 if result: result[-1] += token else: # 如果是文本开头的标点,单独作为一项 result.append(token) return result def split_and_merge_punctuation(text): """处理英文 - 按单词分割,保持单词完整性""" # 先按空格拆分文本 elements = text.split() # 用于保存最终的结果 result = [] # 遍历每个拆分后的元素 for ele in elements: # 使用正则表达式提取连续字母、数字和标点 parts = re.findall(r'[a-zA-Z0-9]+|[^\w\s]+', ele) # 用于保存拆分后的部分 merged_parts = [] for i in range(len(parts)): if i % 2 == 0: # 如果是字母或数字部分 # 将字母或数字部分添加到结果中 merged_parts.append(parts[i]) else: # 如果是标点或其他符号部分 # 将标点部分与前面的字母或数字部分合并 if merged_parts: merged_parts[-1] += parts[i] else: merged_parts.append(parts[i]) # 将合并后的部分加入最终结果 result.extend(merged_parts) return result def get_aligned_result_text_with_punctuation(alignment_result, text, language): """ 将对齐结果转换为正确的文本tokens,英文保持单词级别,中文保持字符级别(但英文单词完整) """ logging.info("start change text to text_tokens") if language == "EN": text_tokens = split_and_merge_punctuation(text) # 英文按单词分词 elif language == "ZH": text_tokens = convert_to_list_with_punctuation_mixed(text) # 中文按字符分割,但英文单词保持完整 else: raise ValueError(f"Unsupported language: {language}") logging.info(f"Text tokens count: {len(text_tokens)}, Alignment result count: {len(alignment_result)}") punctuations = set(',.!?;:()[]<>\'\"…·,。;:!?()【】《》''""\、') logging.info("start get align result text with punctuation") updated_alignment_result = [] token_idx = 0 for index, align_item in enumerate(alignment_result): if token_idx >= len(text_tokens): # 如果text_tokens用完了但还有对齐结果,跳出循环 logging.warning(f"Text tokens exhausted at index {token_idx}, but alignment has more items") break start = align_item["start"] end = align_item["end"] text_token = text_tokens[token_idx] # 检查该 token 后是否有连续标点(仅对中文) if language == "ZH": while token_idx + 1 < len(text_tokens) and text_tokens[token_idx + 1] in punctuations: assert False, "???" # 这里理论上应该进不去?? text_token += text_tokens[token_idx + 1] # 将标点加入 token_idx += 1 else: # 英文不需要特殊的标点处理,因为标点已经在split_and_merge_punctuation中处理了 pass # 更新对齐结果 updated_item = { "start": start, "end": end, "transcript": text_token } updated_item.update({key: align_item[key] for key in align_item if key not in ["start", "end", "transcript"]}) updated_alignment_result.append(updated_item) token_idx += 1 logging.info("end get align result text with punctuation") return updated_alignment_result class AlignmentModel: def __init__(self, device, model_dir='/data-mnt/data/wy/X-Codec-2.0/checkpoints'): """ 初始化对齐模型并加载必要的资源 """ self.device = torch.device(device) self.bundle = torchaudio.pipelines.MMS_FA model = self.bundle.get_model(with_star=False, dl_kwargs={'model_dir': model_dir}).to(self.device) # --- 核心优化 --- # 使用 torch.compile 对模型进行 JIT 编译 # mode="max-autotune" 会花费更长时间编译,但能达到最佳性能 print("Compiling the model... This may take a moment.") self.align_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) print("Model compiled successfully.") self.uroman = ur.Uroman() self.DICTIONARY = self.bundle.get_dict() def align(self, emission, tokens): """ 执行强对齐 :param emission: 模型的输出 :param tokens: 目标 tokens :return: 对齐的 tokens 和分数 """ alignments, scores = F.forced_align( log_probs=emission, targets=tokens, blank=0 ) alignments, scores = alignments[0], scores[0] scores = scores.exp() return alignments, scores def unflatten(self, list_, lengths): """ 将一个长列表按照长度拆分成子列表 :param list_: 长列表 :param lengths: 各子列表的长度 :return: 拆分后的子列表 """ assert len(list_) == sum(lengths) i = 0 ret = [] for l in lengths: ret.append(list_[i:i + l]) i += l return ret def preview_word(self, waveform, spans, num_frames, transcript, sample_rate): """ 预览每个单词的开始时间和结束时间 :param waveform: 音频波形 :param spans: 单词的跨度 :param num_frames: 帧数 :param transcript: 转录文本 :param sample_rate: 采样率 :return: 单词的对齐信息 """ end = 0 alignment_result = [] for span, trans in zip(spans, transcript): ratio = waveform.size(1) / num_frames x0 = int(ratio * span[0].start) x1 = int(ratio * span[-1].end) align_info = { "transcript": trans, "start": round(x0 / sample_rate, 3), "end": round(x1 / sample_rate, 3) } align_info["pause"] = round(align_info["start"] - end, 3) align_info["duration"] = round(align_info["end"] - align_info["start"], 3) end = align_info["end"] alignment_result.append(align_info) return alignment_result def make_wav_batch(self, wav_list): """ 将 wav_list 中的每个 wav 张量填充为相同的长度,返回填充后的张量和每个张量的原始长度。 :param wav_list: wav 文件列表 :return: 填充后的音频张量和原始长度 """ wav_lengths = torch.tensor([wav.size(0) for wav in wav_list], dtype=torch.long) max_length = max(wav_lengths) # 确保张量在正确的设备上 wavs_tensors = torch.zeros(len(wav_list), max_length, device=self.device) for i, wav in enumerate(wav_list): wav = wav.to(self.device) # 确保wav在正确的设备上 wavs_tensors[i, :wav_lengths[i]] = wav return wavs_tensors, wav_lengths.to(self.device) def get_target(self, transcript, language): """ 获取给定转录文本的目标 tokens - 修正版本,保持英文单词完整性 """ original_transcript = transcript # 保存原始文本用于调试 if language == "ZH": # 中文处理:保持英文单词完整,只对中文字符进行romanization # 使用相同的分词逻辑 pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]' tokens = re.findall(pattern, transcript) # 分别处理中文字符和英文单词 processed_parts = [] for token in tokens: if not token.strip(): continue elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词 # 英文单词保持原样,不进行romanization processed_parts.append(token.lower()) elif '\u4e00' <= token <= '\u9fff': # 中文字符 # 只对中文字符进行romanization romanized = self.uroman.romanize_string(token) processed_parts.append(romanized) else: # 标点符号等 # 标点符号直接添加,但会在后续步骤中被过滤掉 processed_parts.append(token) # 用空格连接所有部分 transcript = ' '.join(processed_parts) elif language == "EN": # 英文处理:保持单词结构,只是清理标点 pass else: assert False, f"Unsupported language: {language}" # 清理标点符号 transcript = re.sub(r'[^\w\s]', r' ', transcript) TRANSCRIPT = transcript.lower().split() # 提前获取字典中的特殊符号 token star_token = self.DICTIONARY['*'] tokenized_transcript = [] # 统一的tokenization逻辑 for word in TRANSCRIPT: # 对每个word中的字符进行token化 word_tokens = [] for c in word: if c in self.DICTIONARY and c != '-': word_tokens.append(self.DICTIONARY[c]) else: word_tokens.append(star_token) tokenized_transcript.extend(word_tokens) logging.info(f"Original transcript: {original_transcript}") logging.info(f"Processed transcript: {transcript}") logging.info(f"Final TRANSCRIPT: {TRANSCRIPT}") return torch.tensor([tokenized_transcript], dtype=torch.int32, device=self.device) def get_alignment_result(self, emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language): """ 根据给定的 emission 和对齐信息生成对齐结果 - 修正版本 """ original_transcript = transcript # 保存原始文本 if language == "ZH": # 使用与get_target相同的处理逻辑 pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]' tokens = re.findall(pattern, transcript) processed_parts = [] for token in tokens: if not token.strip(): continue elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): # 英文单词 processed_parts.append(token.lower()) elif '\u4e00' <= token <= '\u9fff': # 中文字符 romanized = self.uroman.romanize_string(token) processed_parts.append(romanized) else: # 标点符号等 processed_parts.append(token) transcript = ' '.join(processed_parts) elif language == "EN": pass else: assert False, f"Unsupported language: {language}" transcript = re.sub(r'[^\w\s]', r' ', transcript) emission = emission_padded[:emission_length, :].unsqueeze(0) TRANSCRIPT = transcript.lower().split() token_spans = F.merge_tokens(aligned_tokens, alignment_scores) # 统一的分组逻辑 word_spans = self.unflatten(token_spans, [len(word) for word in TRANSCRIPT]) num_frames = emission.size(1) logging.info(f"Original transcript for alignment: {original_transcript}") logging.info(f"Processed TRANSCRIPT: {TRANSCRIPT}") return self.preview_word(waveform.unsqueeze(0), word_spans, num_frames, TRANSCRIPT, self.bundle.sample_rate) def batch_alignment(self, wav_list, transcript_list, language_list): """ 批量对齐 :param wav_list: wav 文件列表 :param transcript_list: 转录文本列表 :param language_list: 语言类型列表 :return: 对齐结果列表 """ wavs_tensors, wavs_lengths_tensor = self.make_wav_batch(wav_list) logging.info("start alignment model forward") with torch.inference_mode(): emission, emission_lengths = self.align_model(wavs_tensors.to(self.device), wavs_lengths_tensor) star_dim = torch.zeros((emission.shape[0], emission.size(1), 1), dtype=emission.dtype, device=self.device) emission = torch.cat((emission, star_dim), dim=-1) logging.info("end alignment model forward") target_list = [self.get_target(transcript, language) for transcript, language in zip(transcript_list, language_list)] logging.info("align success") align_results = [ self.align(emission_padded[:emission_length, :].unsqueeze(0), target) for emission_padded, emission_length, target in zip(emission, emission_lengths, target_list) ] logging.info("get align result") batch_aligned_tokens = [align_result[0] for align_result in align_results] batch_alignment_scores = [align_result[1] for align_result in align_results] alignment_result_list = [ self.get_alignment_result(emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language) for emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language in zip(emission, emission_lengths, batch_aligned_tokens, batch_alignment_scores, transcript_list, wav_list, language_list) ] logging.info("get align result success") return alignment_result_list async def batch_get_alignment_result_remote(alignment_url, audio_path, transcript, language): """ 通过调用远程对齐服务来批量获取对齐结果。 """ payload = { "audio_path": audio_path, "transcript": transcript, "language": language, } try: async with httpx.AsyncClient() as client: response = await client.post(alignment_url, json=payload, timeout=300) # 设置较长的超时 response.raise_for_status() # 如果状态码不是 2xx,则抛出异常 data = response.json() return data['results'] except requests.exceptions.RequestException as e: logging.error(f"Failed to connect to alignment service: {e}") traceback.print_exc() # 根据需求可以返回空列表或抛出异常 except Exception as e: logging.error(f"An error occurred in remote alignment: {e}") traceback.print_exc()