|
|
import base64 |
|
|
import httpx |
|
|
import re |
|
|
import requests |
|
|
import torch |
|
|
import torchaudio.functional as F |
|
|
import torchaudio |
|
|
import uroman as ur |
|
|
import logging |
|
|
import traceback |
|
|
|
|
|
|
|
|
def convert_to_list_with_punctuation_mixed(text): |
|
|
"""处理中文文本(可能包含英文单词) - 中文按字符分割,英文单词保持完整""" |
|
|
result = [] |
|
|
text = text.strip() |
|
|
|
|
|
if not text: |
|
|
return result |
|
|
|
|
|
def is_chinese(char): |
|
|
"""检查是否是汉字""" |
|
|
return '\u4e00' <= char <= '\u9fff' |
|
|
|
|
|
|
|
|
|
|
|
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]' |
|
|
tokens = re.findall(pattern, text) |
|
|
|
|
|
for token in tokens: |
|
|
if not token.strip(): |
|
|
continue |
|
|
|
|
|
if re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): |
|
|
result.append(token) |
|
|
elif is_chinese(token): |
|
|
result.append(token) |
|
|
else: |
|
|
|
|
|
if result: |
|
|
result[-1] += token |
|
|
else: |
|
|
|
|
|
result.append(token) |
|
|
|
|
|
return result |
|
|
|
|
|
def split_and_merge_punctuation(text): |
|
|
"""处理英文 - 按单词分割,保持单词完整性""" |
|
|
|
|
|
elements = text.split() |
|
|
|
|
|
|
|
|
result = [] |
|
|
|
|
|
|
|
|
for ele in elements: |
|
|
|
|
|
parts = re.findall(r'[a-zA-Z0-9]+|[^\w\s]+', ele) |
|
|
|
|
|
|
|
|
merged_parts = [] |
|
|
|
|
|
for i in range(len(parts)): |
|
|
if i % 2 == 0: |
|
|
|
|
|
merged_parts.append(parts[i]) |
|
|
else: |
|
|
|
|
|
if merged_parts: |
|
|
merged_parts[-1] += parts[i] |
|
|
else: |
|
|
merged_parts.append(parts[i]) |
|
|
|
|
|
|
|
|
result.extend(merged_parts) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def get_aligned_result_text_with_punctuation(alignment_result, text, language): |
|
|
""" |
|
|
将对齐结果转换为正确的文本tokens,英文保持单词级别,中文保持字符级别(但英文单词完整) |
|
|
""" |
|
|
logging.info("start change text to text_tokens") |
|
|
|
|
|
if language == "EN": |
|
|
text_tokens = split_and_merge_punctuation(text) |
|
|
elif language == "ZH": |
|
|
text_tokens = convert_to_list_with_punctuation_mixed(text) |
|
|
else: |
|
|
raise ValueError(f"Unsupported language: {language}") |
|
|
|
|
|
logging.info(f"Text tokens count: {len(text_tokens)}, Alignment result count: {len(alignment_result)}") |
|
|
|
|
|
punctuations = set(',.!?;:()[]<>\'\"…·,。;:!?()【】《》''""\、') |
|
|
|
|
|
logging.info("start get align result text with punctuation") |
|
|
updated_alignment_result = [] |
|
|
token_idx = 0 |
|
|
|
|
|
for index, align_item in enumerate(alignment_result): |
|
|
if token_idx >= len(text_tokens): |
|
|
|
|
|
logging.warning(f"Text tokens exhausted at index {token_idx}, but alignment has more items") |
|
|
break |
|
|
|
|
|
start = align_item["start"] |
|
|
end = align_item["end"] |
|
|
text_token = text_tokens[token_idx] |
|
|
|
|
|
|
|
|
if language == "ZH": |
|
|
while token_idx + 1 < len(text_tokens) and text_tokens[token_idx + 1] in punctuations: |
|
|
assert False, "???" |
|
|
text_token += text_tokens[token_idx + 1] |
|
|
token_idx += 1 |
|
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
updated_item = { |
|
|
"start": start, |
|
|
"end": end, |
|
|
"transcript": text_token |
|
|
} |
|
|
updated_item.update({key: align_item[key] for key in align_item if key not in ["start", "end", "transcript"]}) |
|
|
|
|
|
updated_alignment_result.append(updated_item) |
|
|
token_idx += 1 |
|
|
|
|
|
logging.info("end get align result text with punctuation") |
|
|
return updated_alignment_result |
|
|
|
|
|
|
|
|
class AlignmentModel: |
|
|
def __init__(self, device, model_dir='/data-mnt/data/wy/X-Codec-2.0/checkpoints'): |
|
|
""" |
|
|
初始化对齐模型并加载必要的资源 |
|
|
""" |
|
|
self.device = torch.device(device) |
|
|
self.bundle = torchaudio.pipelines.MMS_FA |
|
|
model = self.bundle.get_model(with_star=False, dl_kwargs={'model_dir': model_dir}).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Compiling the model... This may take a moment.") |
|
|
self.align_model = torch.compile(model, mode="reduce-overhead", fullgraph=True) |
|
|
print("Model compiled successfully.") |
|
|
|
|
|
self.uroman = ur.Uroman() |
|
|
self.DICTIONARY = self.bundle.get_dict() |
|
|
|
|
|
def align(self, emission, tokens): |
|
|
""" |
|
|
执行强对齐 |
|
|
:param emission: 模型的输出 |
|
|
:param tokens: 目标 tokens |
|
|
:return: 对齐的 tokens 和分数 |
|
|
""" |
|
|
alignments, scores = F.forced_align( |
|
|
log_probs=emission, |
|
|
targets=tokens, |
|
|
blank=0 |
|
|
) |
|
|
alignments, scores = alignments[0], scores[0] |
|
|
scores = scores.exp() |
|
|
return alignments, scores |
|
|
|
|
|
def unflatten(self, list_, lengths): |
|
|
""" |
|
|
将一个长列表按照长度拆分成子列表 |
|
|
:param list_: 长列表 |
|
|
:param lengths: 各子列表的长度 |
|
|
:return: 拆分后的子列表 |
|
|
""" |
|
|
assert len(list_) == sum(lengths) |
|
|
i = 0 |
|
|
ret = [] |
|
|
for l in lengths: |
|
|
ret.append(list_[i:i + l]) |
|
|
i += l |
|
|
return ret |
|
|
|
|
|
def preview_word(self, waveform, spans, num_frames, transcript, sample_rate): |
|
|
""" |
|
|
预览每个单词的开始时间和结束时间 |
|
|
:param waveform: 音频波形 |
|
|
:param spans: 单词的跨度 |
|
|
:param num_frames: 帧数 |
|
|
:param transcript: 转录文本 |
|
|
:param sample_rate: 采样率 |
|
|
:return: 单词的对齐信息 |
|
|
""" |
|
|
end = 0 |
|
|
alignment_result = [] |
|
|
for span, trans in zip(spans, transcript): |
|
|
ratio = waveform.size(1) / num_frames |
|
|
x0 = int(ratio * span[0].start) |
|
|
x1 = int(ratio * span[-1].end) |
|
|
align_info = { |
|
|
"transcript": trans, |
|
|
"start": round(x0 / sample_rate, 3), |
|
|
"end": round(x1 / sample_rate, 3) |
|
|
} |
|
|
align_info["pause"] = round(align_info["start"] - end, 3) |
|
|
align_info["duration"] = round(align_info["end"] - align_info["start"], 3) |
|
|
end = align_info["end"] |
|
|
alignment_result.append(align_info) |
|
|
return alignment_result |
|
|
|
|
|
def make_wav_batch(self, wav_list): |
|
|
""" |
|
|
将 wav_list 中的每个 wav 张量填充为相同的长度,返回填充后的张量和每个张量的原始长度。 |
|
|
:param wav_list: wav 文件列表 |
|
|
:return: 填充后的音频张量和原始长度 |
|
|
""" |
|
|
wav_lengths = torch.tensor([wav.size(0) for wav in wav_list], dtype=torch.long) |
|
|
max_length = max(wav_lengths) |
|
|
|
|
|
wavs_tensors = torch.zeros(len(wav_list), max_length, device=self.device) |
|
|
for i, wav in enumerate(wav_list): |
|
|
wav = wav.to(self.device) |
|
|
wavs_tensors[i, :wav_lengths[i]] = wav |
|
|
return wavs_tensors, wav_lengths.to(self.device) |
|
|
|
|
|
def get_target(self, transcript, language): |
|
|
""" |
|
|
获取给定转录文本的目标 tokens - 修正版本,保持英文单词完整性 |
|
|
""" |
|
|
original_transcript = transcript |
|
|
|
|
|
if language == "ZH": |
|
|
|
|
|
|
|
|
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]' |
|
|
tokens = re.findall(pattern, transcript) |
|
|
|
|
|
|
|
|
processed_parts = [] |
|
|
for token in tokens: |
|
|
if not token.strip(): |
|
|
continue |
|
|
elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): |
|
|
|
|
|
processed_parts.append(token.lower()) |
|
|
elif '\u4e00' <= token <= '\u9fff': |
|
|
|
|
|
romanized = self.uroman.romanize_string(token) |
|
|
processed_parts.append(romanized) |
|
|
else: |
|
|
|
|
|
processed_parts.append(token) |
|
|
|
|
|
|
|
|
transcript = ' '.join(processed_parts) |
|
|
|
|
|
elif language == "EN": |
|
|
|
|
|
pass |
|
|
else: |
|
|
assert False, f"Unsupported language: {language}" |
|
|
|
|
|
|
|
|
transcript = re.sub(r'[^\w\s]', r' ', transcript) |
|
|
TRANSCRIPT = transcript.lower().split() |
|
|
|
|
|
|
|
|
star_token = self.DICTIONARY['*'] |
|
|
tokenized_transcript = [] |
|
|
|
|
|
|
|
|
for word in TRANSCRIPT: |
|
|
|
|
|
word_tokens = [] |
|
|
for c in word: |
|
|
if c in self.DICTIONARY and c != '-': |
|
|
word_tokens.append(self.DICTIONARY[c]) |
|
|
else: |
|
|
word_tokens.append(star_token) |
|
|
tokenized_transcript.extend(word_tokens) |
|
|
|
|
|
logging.info(f"Original transcript: {original_transcript}") |
|
|
logging.info(f"Processed transcript: {transcript}") |
|
|
logging.info(f"Final TRANSCRIPT: {TRANSCRIPT}") |
|
|
|
|
|
return torch.tensor([tokenized_transcript], dtype=torch.int32, device=self.device) |
|
|
|
|
|
def get_alignment_result(self, emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language): |
|
|
""" |
|
|
根据给定的 emission 和对齐信息生成对齐结果 - 修正版本 |
|
|
""" |
|
|
original_transcript = transcript |
|
|
|
|
|
if language == "ZH": |
|
|
|
|
|
pattern = r'[a-zA-Z]+[a-zA-Z0-9]*|[\u4e00-\u9fff]|[^\w\s\u4e00-\u9fff]' |
|
|
tokens = re.findall(pattern, transcript) |
|
|
|
|
|
processed_parts = [] |
|
|
for token in tokens: |
|
|
if not token.strip(): |
|
|
continue |
|
|
elif re.match(r'^[a-zA-Z]+[a-zA-Z0-9]*$', token): |
|
|
processed_parts.append(token.lower()) |
|
|
elif '\u4e00' <= token <= '\u9fff': |
|
|
romanized = self.uroman.romanize_string(token) |
|
|
processed_parts.append(romanized) |
|
|
else: |
|
|
processed_parts.append(token) |
|
|
|
|
|
transcript = ' '.join(processed_parts) |
|
|
elif language == "EN": |
|
|
pass |
|
|
else: |
|
|
assert False, f"Unsupported language: {language}" |
|
|
|
|
|
transcript = re.sub(r'[^\w\s]', r' ', transcript) |
|
|
emission = emission_padded[:emission_length, :].unsqueeze(0) |
|
|
TRANSCRIPT = transcript.lower().split() |
|
|
|
|
|
token_spans = F.merge_tokens(aligned_tokens, alignment_scores) |
|
|
|
|
|
|
|
|
word_spans = self.unflatten(token_spans, [len(word) for word in TRANSCRIPT]) |
|
|
|
|
|
num_frames = emission.size(1) |
|
|
|
|
|
logging.info(f"Original transcript for alignment: {original_transcript}") |
|
|
logging.info(f"Processed TRANSCRIPT: {TRANSCRIPT}") |
|
|
|
|
|
return self.preview_word(waveform.unsqueeze(0), word_spans, num_frames, TRANSCRIPT, self.bundle.sample_rate) |
|
|
|
|
|
def batch_alignment(self, wav_list, transcript_list, language_list): |
|
|
""" |
|
|
批量对齐 |
|
|
:param wav_list: wav 文件列表 |
|
|
:param transcript_list: 转录文本列表 |
|
|
:param language_list: 语言类型列表 |
|
|
:return: 对齐结果列表 |
|
|
""" |
|
|
wavs_tensors, wavs_lengths_tensor = self.make_wav_batch(wav_list) |
|
|
logging.info("start alignment model forward") |
|
|
with torch.inference_mode(): |
|
|
emission, emission_lengths = self.align_model(wavs_tensors.to(self.device), wavs_lengths_tensor) |
|
|
star_dim = torch.zeros((emission.shape[0], emission.size(1), 1), dtype=emission.dtype, device=self.device) |
|
|
emission = torch.cat((emission, star_dim), dim=-1) |
|
|
|
|
|
logging.info("end alignment model forward") |
|
|
|
|
|
target_list = [self.get_target(transcript, language) for transcript, language in zip(transcript_list, language_list)] |
|
|
|
|
|
logging.info("align success") |
|
|
align_results = [ |
|
|
self.align(emission_padded[:emission_length, :].unsqueeze(0), target) |
|
|
for emission_padded, emission_length, target in zip(emission, emission_lengths, target_list) |
|
|
] |
|
|
|
|
|
logging.info("get align result") |
|
|
batch_aligned_tokens = [align_result[0] for align_result in align_results] |
|
|
batch_alignment_scores = [align_result[1] for align_result in align_results] |
|
|
|
|
|
alignment_result_list = [ |
|
|
self.get_alignment_result(emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language) |
|
|
for emission_padded, emission_length, aligned_tokens, alignment_scores, transcript, waveform, language |
|
|
in zip(emission, emission_lengths, batch_aligned_tokens, batch_alignment_scores, transcript_list, wav_list, language_list) |
|
|
] |
|
|
logging.info("get align result success") |
|
|
return alignment_result_list |
|
|
|
|
|
|
|
|
async def batch_get_alignment_result_remote(alignment_url, audio_path, transcript, language): |
|
|
""" |
|
|
通过调用远程对齐服务来批量获取对齐结果。 |
|
|
""" |
|
|
payload = { |
|
|
"audio_path": audio_path, |
|
|
"transcript": transcript, |
|
|
"language": language, |
|
|
} |
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient() as client: |
|
|
response = await client.post(alignment_url, json=payload, timeout=300) |
|
|
response.raise_for_status() |
|
|
data = response.json() |
|
|
return data['results'] |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
logging.error(f"Failed to connect to alignment service: {e}") |
|
|
traceback.print_exc() |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"An error occurred in remote alignment: {e}") |
|
|
traceback.print_exc() |
|
|
|
|
|
|