import asyncio import json import re import os from typing import List, Dict, Tuple, Any import numpy as np from pathlib import Path import torch import torchaudio import torchaudio.functional as F import tritonclient.grpc as grpcclient from tritonclient.utils import * import logging import wespeaker import shutil from datetime import datetime import multiprocessing as mp from functools import partial import math import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed import random # 添加random模块用于shuffle # 设置multiprocessing启动方式为spawn(CUDA兼容) mp.set_start_method('spawn', force=True) # 引用词对齐模块 from alignment import AlignmentModel, batch_get_alignment_result # from tensorrt_client import TritonSimilarityClient from speaker_client import TritonSpeakerClient class SpeakerSimilarityEvaluator: """音色相似度评估器""" def __init__(self, device="cuda", alignment_model_dir='./models/mms_fa', wespeaker_model_url='localhost:8001', output_dir="./evaluation_results", language="ZH", similarity_max_workers=8): """初始化评估器""" self.device = device self.alignment_model_dir = alignment_model_dir self.wespeaker_model_url = wespeaker_model_url self.language = language.upper() # 添加语言参数 self.similarity_max_workers = similarity_max_workers # 相似度计算线程数,已无效 # 先设置日志系统 logging.basicConfig(level=logging.INFO) self.logger = logging.getLogger(__name__) # 设置输出目录结构 self.output_dir = Path(output_dir) self.segments_dir = self.output_dir / "segments" # 分割后的音频片段 self.prompts_dir = self.output_dir / "prompts" # prompt音频的S1和S2片段 self.temp_dir = self.output_dir / "temp" # 临时文件 self.results_dir = self.output_dir / "results" # 评估结果 self.temp_results_dir = self.output_dir / "temp_results" # 临时结果文件 self.alignment_dir = self.output_dir / "alignments" # 对齐信息保存目录 # 创建所有必要的目录 self._create_output_directories() # 在多进程环境中延迟模型初始化 self.alignment_model = None self.similarity_model = None # 线程局部存储,用于线程安全的模型访问 self._thread_local = threading.local() # 记录运行信息 self.logger.info(f"评估结果将保存到: {self.output_dir}") self.logger.info(f"对齐信息将保存到: {self.alignment_dir}") self.logger.info(f"使用语言: {self.language}") def _create_output_directories(self): """创建输出目录结构""" for dir_path in [self.segments_dir, self.prompts_dir, self.temp_dir, self.results_dir, self.temp_results_dir, self.alignment_dir]: dir_path.mkdir(parents=True, exist_ok=True) def _get_safe_filename(self, text: str, max_length: int = 50) -> str: """生成安全的文件名""" # 移除特殊字符,只保留中文、英文、数字和基本符号 safe_text = re.sub(r'[^\u4e00-\u9fff\w\s]', '', text) # 限制长度 if len(safe_text) > max_length: safe_text = safe_text[:max_length] # 替换空格为下划线 safe_text = safe_text.replace(' ', '_') return safe_text if safe_text else "unnamed" def _clean_temp_files(self): """清理临时文件,但保留临时目录""" if self.temp_dir.exists(): # 只删除临时目录中的文件,不删除目录本身 for file_path in self.temp_dir.iterdir(): if file_path.is_file(): try: file_path.unlink() except Exception as e: self.logger.warning(f"删除临时文件失败: {file_path}, 错误: {e}") else: # 如果临时目录不存在,重新创建 self.temp_dir.mkdir(parents=True, exist_ok=True) def _init_models_if_needed(self): """延迟初始化模型(用于多进程环境)""" # 初始化对齐模型 - 修正参数顺序 if self.alignment_model is None: # 根据AlignmentModel的构造函数,应该是(device, model_dir)而不是(model_dir, device) self.alignment_model = AlignmentModel(self.device, self.alignment_model_dir) # 初始化相似度模型 if self.similarity_model is None: self._load_wespeaker_model(self.wespeaker_model_url) def _is_english_text(self, text: str) -> bool: """简单判断文本是否主要是英文""" # 计算英文字符的比例 english_chars = sum(1 for c in text if c.isascii() and c.isalpha()) total_chars = sum(1 for c in text if c.isalpha()) if total_chars == 0: return False return english_chars / total_chars > 0.8 # 如果80%以上是英文字符,认为是英文 def _detect_language_from_text(self, text: str) -> str: """从文本内容检测语言""" clean_text = self.remove_speaker_tags(text) if self._is_english_text(clean_text): return "EN" else: return "ZH" def save_alignment_info(self, alignment_data: Dict[str, Any], input_id: str, file_type: str = "output"): """ 保存对齐信息到单独的JSON文件 Args: alignment_data: 对齐信息数据 input_id: 输入ID file_type: 文件类型 ("output", "prompt", "segment") """ try: safe_input_id = self._get_safe_filename(input_id) alignment_filename = f"{safe_input_id}_{file_type}_alignment.json" alignment_path = self.alignment_dir / alignment_filename # 添加元数据 alignment_info = { 'input_id': input_id, 'file_type': file_type, 'language': self.language, 'timestamp': datetime.now().isoformat(), 'alignment_data': alignment_data } with open(alignment_path, 'w', encoding='utf-8') as f: json.dump(alignment_info, f, ensure_ascii=False, indent=2) self.logger.info(f"对齐信息已保存: {alignment_path}") return str(alignment_path) except Exception as e: self.logger.error(f"保存对齐信息失败: {e}") return None def save_detailed_alignment_info(self, alignments: List[Dict[str, Any]], text_segments: List[Dict[str, Any]], input_id: str, audio_path: str, original_text: str, processed_text: str): """ 保存详细的对齐信息,包括分段信息 Args: alignments: 对齐结果列表 text_segments: 文本分段信息 input_id: 输入ID audio_path: 音频文件路径 original_text: 原始文本 processed_text: 处理后的文本 """ alignment_data = { 'original_text': original_text, 'processed_text': processed_text, 'audio_path': audio_path, 'language': self.language, 'total_alignments': len(alignments), 'total_segments': len(text_segments), 'alignments': alignments, 'text_segments': text_segments, 'segment_alignment_mapping': [] } # 建立文本段和对齐结果的映射关系 for segment in text_segments: segment_mapping = { 'segment_id': segment.get('segment_id', 0), 'segment_text': segment.get('text', ''), 'speaker_label': segment.get('speaker_label', ''), 'start_time': segment.get('start_time', 0.0), 'end_time': segment.get('end_time', 0.0), 'corresponding_alignments': [] } # 找到对应的对齐项 segment_start = segment.get('start_time', 0.0) segment_end = segment.get('end_time', 0.0) for i, align_item in enumerate(alignments): align_start = align_item.get('start', 0.0) align_end = align_item.get('end', 0.0) # 检查对齐项是否在当前段的时间范围内 if (align_start >= segment_start and align_end <= segment_end) or \ (align_start < segment_end and align_end > segment_start): segment_mapping['corresponding_alignments'].append({ 'alignment_index': i, 'transcript': align_item.get('transcript', ''), 'start': align_start, 'end': align_end, 'score': align_item.get('score', 0.0) if 'score' in align_item else None }) alignment_data['segment_alignment_mapping'].append(segment_mapping) return self.save_alignment_info(alignment_data, input_id, "detailed") def remove_speaker_tags(self, text: str) -> str: """删除文本中的说话人标签[S1][S2]""" return re.sub(r'\[S[12]\]', '', text).strip() def extract_speaker_segments(self, text: str) -> List[Dict[str, Any]]: """提取文本中的说话人片段信息""" segments = [] pattern = r'\[S([12])\]([^[]*)' matches = re.findall(pattern, text) for speaker_id, content in matches: segments.append({ 'speaker': f'S{speaker_id}', 'content': content.strip() }) return segments def replace_punctuation_with_comma(self, text: str, language: str = None) -> str: """将所有标点符号替换为逗号,连续逗号只保留一个,根据语言选择正确的逗号类型""" # 如果未指定语言,使用类的默认语言设置或自动检测 if language is None: if hasattr(self, 'language'): language = self.language else: language = self._detect_language_from_text(text) language = language.upper() # 根据语言选择逗号类型和处理策略 if language == "EN" or (language == "AUTO" and self._is_english_text(text)): # 英文处理:先删除撇号,再替换其他标点符号 text = re.sub(r"'", '', text) # 删除撇号(don't -> dont) target_comma = ',' # 英文逗号 comma_pattern = r',+' # 匹配连续英文逗号 # 更新正则表达式,不包含撇号 text = re.sub(r'[.,!?;:()\[\]<>\"…·,。;:!?()【】《》""\\、]', target_comma, text) else: # 中文处理:包含撇号在替换范围内 target_comma = ',' # 中文逗号 comma_pattern = r',+' # 匹配连续中文逗号 # 更新正则表达式以匹配更多的标点符号 text = re.sub(r'[.,!?;:()\[\]<>\'\"…·,。;:!?()【】《》''""\\、]', target_comma, text) text = re.sub(comma_pattern, target_comma, text) return text.strip(target_comma) def align_text_with_audio(self, text: str, audio_path: str, language=None) -> List[Dict[str, Any]]: """ 文本和音频的词对齐 返回每个词对应的音频时间段 """ # 确保模型已初始化 self._init_models_if_needed() # 如果未指定语言,使用类的默认语言设置或自动检测 if language is None: if hasattr(self, 'language'): language = self.language else: language = self._detect_language_from_text(text) else: language = language.upper() # 加载音频 waveform, sample_rate = torchaudio.load(audio_path) # 重采样到模型要求的采样率 if sample_rate != self.alignment_model.bundle.sample_rate: waveform = F.resample(waveform, sample_rate, self.alignment_model.bundle.sample_rate) # 转换为单声道 if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) waveform = waveform.squeeze(0) # 移除批次维度 # 将音频移动到正确的设备 waveform = waveform.to(self.device) # 执行对齐 try: alignment_results = batch_get_alignment_result( self.alignment_model, [waveform], [text], [language] ) if not alignment_results or not alignment_results[0]: raise RuntimeError(f"对齐结果为空: {audio_path}") return alignment_results[0] except Exception as e: self.logger.error(f"音频对齐失败: {audio_path}") self.logger.error(f"错误详情: {e}") raise RuntimeError(f"音频对齐失败,程序终止。文件: {audio_path},错误: {e}") def split_audio_segment(self, audio_path: str, start_time: float, end_time: float, output_path: str): """分割音频片段""" waveform, sample_rate = torchaudio.load(audio_path) start_frame = int(start_time * sample_rate) end_frame = int(end_time * sample_rate) segment = waveform[:, start_frame:end_frame] # 确保输出目录存在 os.makedirs(os.path.dirname(output_path), exist_ok=True) torchaudio.save(output_path, segment, sample_rate) return output_path def concatenate_audio_files(self, audio_files: List[str], output_path: str): """拼接多个音频文件""" if not audio_files: return waveforms = [] sample_rate = None for audio_file in audio_files: if os.path.exists(audio_file): waveform, sr = torchaudio.load(audio_file) if sample_rate is None: sample_rate = sr elif sr != sample_rate: waveform = F.resample(waveform, sr, sample_rate) waveforms.append(waveform) if waveforms: concatenated = torch.cat(waveforms, dim=1) os.makedirs(os.path.dirname(output_path), exist_ok=True) torchaudio.save(output_path, concatenated, sample_rate) def split_audio_by_speaker(self, prompt_text: str, prompt_audio: str, audio_id: str) -> Tuple[str, str]: """ 根据说话人标签分割prompt音频 返回S1和S2的音频片段路径 """ # 1. 提取说话人片段 speaker_segments = self.extract_speaker_segments(prompt_text) # 2. 删除标签后进行词对齐 - 如果失败则直接抛出异常 clean_text = self.remove_speaker_tags(prompt_text) # 检测语言或使用设置的语言 alignment_language = self.language if alignment_language == "AUTO": alignment_language = self._detect_language_from_text(clean_text) alignments = self.align_text_with_audio(clean_text, prompt_audio, alignment_language) # 保存prompt对齐信息 prompt_alignment_data = { 'original_text': prompt_text, 'clean_text': clean_text, 'audio_path': prompt_audio, 'language': alignment_language, 'speaker_segments': speaker_segments, 'alignments': alignments } self.save_alignment_info(prompt_alignment_data, audio_id, "prompt") # 3. 根据对齐结果分割音频 s1_segments = [] s2_segments = [] # 为每个说话人片段找到对应的时间段 text_pos = 0 for seg in speaker_segments: seg_text = seg['content'].strip() seg_length = len(seg_text) # 找到这个片段在对齐结果中的起始和结束 start_time = None end_time = None current_pos = 0 for align_item in alignments: item_text = align_item['transcript'] item_length = len(item_text) if current_pos >= text_pos and current_pos < text_pos + seg_length: if start_time is None: start_time = align_item['start'] end_time = align_item['end'] current_pos += item_length if start_time is not None and end_time is not None: if seg['speaker'] == 'S1': s1_segments.append((start_time, end_time)) else: s2_segments.append((start_time, end_time)) text_pos += seg_length # 4. 分割并拼接音频片段 safe_audio_id = self._get_safe_filename(audio_id) prompts1_path = str(self.prompts_dir / f"{safe_audio_id}_s1.wav") prompts2_path = str(self.prompts_dir / f"{safe_audio_id}_s2.wav") # 分割S1的所有片段 if s1_segments: s1_temp_segments = [] for i, (start, end) in enumerate(s1_segments): temp_path = str(self.temp_dir / f"{safe_audio_id}_s1_temp_{i}.wav") self.split_audio_segment(prompt_audio, start, end, temp_path) s1_temp_segments.append(temp_path) # 拼接S1片段 self.concatenate_audio_files(s1_temp_segments, prompts1_path) # 分割S2的所有片段 if s2_segments: s2_temp_segments = [] for i, (start, end) in enumerate(s2_segments): temp_path = str(self.temp_dir / f"{safe_audio_id}_s2_temp_{i}.wav") self.split_audio_segment(prompt_audio, start, end, temp_path) s2_temp_segments.append(temp_path) # 拼接S2片段 self.concatenate_audio_files(s2_temp_segments, prompts2_path) return prompts1_path, prompts2_path def map_text_segments_to_speakers(self, original_text: str) -> List[Dict[str, Any]]: """ 将原始文本按说话人和标点符号同时分割,保持映射关系 支持英文单词级别的处理 """ segments = [] pattern = r'\[S([12])\]([^[]*)' matches = re.findall(pattern, original_text) # 检测语言或使用设置的语言 alignment_language = self.language if alignment_language == "AUTO": alignment_language = self._detect_language_from_text(original_text) segment_id = 0 for speaker_id, content in matches: speaker = f'S{speaker_id}' clean_content = content.strip() comma_content = self.replace_punctuation_with_comma(clean_content, alignment_language) # 根据语言选择正确的逗号分割 if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_content)): # 英文:按英文逗号分割,保持单词完整性 parts = [part.strip() for part in comma_content.split(',') if part.strip()] else: # 中文:按中文逗号分割 parts = [part.strip() for part in comma_content.split(',') if part.strip()] for part in parts: if part.strip(): segments.append({ 'segment_id': segment_id, 'text': part.strip(), 'speaker_label': speaker, 'original_speaker_content': clean_content }) segment_id += 1 return segments def split_output_audio_by_comma(self, text: str, output_audio: str, audio_id: str) -> List[Dict[str, Any]]: """ 根据逗号分割输出音频,返回每小段的信息 - 基于词对齐结果中的标点符号划分句子 """ # 1. 获取文本片段和对应的说话人(用于获取speaker标签) text_segments = self.map_text_segments_to_speakers(text) # 2. 删除标签并替换标点符号 clean_text = self.remove_speaker_tags(text) # 3. 检测语言或使用设置的语言 alignment_language = self.language if alignment_language == "AUTO": alignment_language = self._detect_language_from_text(clean_text) # 使用检测到的语言替换标点符号 comma_text = self.replace_punctuation_with_comma(clean_text, alignment_language) # 4. 词对齐 - 如果失败则直接抛出异常 alignments = self.align_text_with_audio(comma_text, output_audio, alignment_language) # 5. 根据标点符号划分句子 segments = [] safe_audio_id = self._get_safe_filename(audio_id) # 确定标点符号(根据语言选择,英文不包含撇号) if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)): punctuation_chars = set([',', '.', '!', '?', ';', ':']) # 不包含撇号 else: punctuation_chars = set([',', '。', '!', '?', ';', ':']) # 顺序扫描对齐结果,根据标点符号划分句子 sentence_start_idx = 0 sentence_alignments = [] segment_id = 0 for i, align_item in enumerate(alignments): transcript = align_item['transcript'] sentence_alignments.append(align_item) # 检查是否包含标点符号(句子结束标志) has_punctuation = any(punct in transcript for punct in punctuation_chars) if has_punctuation or i == len(alignments) - 1: # 遇到标点符号或最后一个词 # 创建句子片段 if sentence_alignments: # 获取句子的开始和结束时间 start_time = sentence_alignments[0]['start'] end_time = sentence_alignments[-1]['end'] # 构建句子文本(去除标点符号) sentence_text_parts = [] for align in sentence_alignments: # 根据语言选择不同的清理策略 if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)): # 英文:去除标点符号,但保留撇号已被删除的单词 clean_transcript = align['transcript'].rstrip(',.!?;:') else: # 中文:去除中文标点符号 clean_transcript = align['transcript'].rstrip(',。!?;:') if clean_transcript.strip(): sentence_text_parts.append(clean_transcript) # 根据语言选择连接方式 if alignment_language == "EN" or (alignment_language == "AUTO" and self._is_english_text(clean_text)): sentence_text = ' '.join(sentence_text_parts).strip() # 英文用空格连接 else: sentence_text = ''.join(sentence_text_parts).strip() # 中文直接连接 if sentence_text: # 只有非空句子才处理 # 确定说话人标签(从原始text_segments中获取,如果可能的话) speaker_label = "S1" # 默认 if segment_id < len(text_segments): speaker_label = text_segments[segment_id]['speaker_label'] elif text_segments: # 如果超出范围,使用最后一个片段的speaker speaker_label = text_segments[-1]['speaker_label'] # 生成音频文件路径 safe_text = self._get_safe_filename(sentence_text, 30) audio_path = str(self.segments_dir / f"{safe_audio_id}_segment_{segment_id:03d}_{safe_text}.wav") # 分割音频 try: self.split_audio_segment(output_audio, start_time, end_time, audio_path) except Exception as e: self.logger.error(f"分割音频失败: {e}") # 使用默认时间间隔 start_time = segment_id * 1.0 end_time = (segment_id + 1) * 1.0 self.split_audio_segment(output_audio, start_time, end_time, audio_path) # 创建segment segment = { 'segment_id': segment_id, 'text': sentence_text, 'speaker_label': speaker_label, 'original_speaker_content': sentence_text, # 这里简化处理 'audio_path': audio_path, 'start_time': start_time, 'end_time': end_time } segments.append(segment) self.logger.info(f"句子 {segment_id}: '{sentence_text}' ({speaker_label}) -> {start_time:.3f}-{end_time:.3f}s") segment_id += 1 # 重置为下一个句子 sentence_alignments = [] sentence_start_idx = i + 1 # 保存详细的对齐信息 self.save_detailed_alignment_info( alignments, segments, audio_id, output_audio, text, comma_text ) self.logger.info(f"总共分割出 {len(segments)} 个句子片段") return segments def _get_similarity_model_server(self): """获取线程局部的相似度模型实例(线程安全)""" if not hasattr(self, 'similarity_model'): # 为当前线程创建独立的模型实例 self.similarity_model = self._create_similarity_model() return self.similarity_model def _create_similarity_model(self): """创建新的相似度模型实例""" try: return TritonSpeakerClient(self.wespeaker_model_url) except Exception as e: self.logger.error(f"创建相似度模型失败: {e}") raise async def compute_similarity(self, processed_audio1, processed_audio2): return await self.similarity_model.compute_similarity(processed_audio1, processed_audio2) async def calculate_voice_similarity_thread_safe(self, audio1_path: str, audio2_path: str) -> float: """ 线程安全的音色相似度计算 对于过短的音频片段,通过复制来达到最小长度要求 """ try: if not os.path.exists(audio1_path) or not os.path.exists(audio2_path): self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}") return None # 获取线程局部的模型实例 _ = self._get_similarity_model_server() # 计算相似度 similarity = await self.compute_similarity(audio1_path, audio2_path) return float(similarity) except Exception as e: # 检查是否是窗口大小错误或其他计算错误 if "choose a window size" in str(e) or "window size" in str(e): self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}") return None else: self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}") return None async def calculate_segment_similarities_parallel( self, output_segments: List[Dict[str, Any]], prompts1_path: str, prompts2_path: str ) -> List[Dict[str, Any]]: """ 并行计算所有segments的相似度 Args: output_segments: 音频segments列表 prompts1_path: S1 prompt音频路径 prompts2_path: S2 prompt音频路径 Returns: 包含相似度信息的segment列表 """ async def calculate_single_segment_similarity(segment): """计算单个segment与两个prompts的相似度""" try: # 使用线程安全的相似度计算方法 sim1 = await self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts1_path) sim2 = await self.calculate_voice_similarity_thread_safe(segment['audio_path'], prompts2_path) return { 'segment': segment, 'sim1': sim1, 'sim2': sim2, 'success': True } except Exception as e: self.logger.error(f"计算segment {segment['segment_id']} 相似度失败: {e}") return { 'segment': segment, 'sim1': None, 'sim2': None, 'success': False } # 使用线程池并行处理所有segments self.logger.info(f"开始异步计算 {len(output_segments)} 个segments的相似度") # 创建任务并保留原始segment的顺序(gather会保持顺序) tasks = [ asyncio.create_task(calculate_single_segment_similarity(segment)) for segment in output_segments ] # 正确版本:使用asyncio.as_completed实时报告进度 return await self._run_tasks_with_progress(tasks) # 新增辅助方法:带进度报告的任务执行 async def _run_tasks_with_progress(self, tasks): """执行任务集合并实时报告进度""" completed_count = 0 total = len(tasks) results = [] # 按完成顺序处理结果 for future in asyncio.as_completed(tasks): result = await future completed_count += 1 # 每完成10个segment报告一次进度 if completed_count % 10 == 0 or completed_count == total: seg_id = result['segment']['segment_id'] self.logger.info(f"相似度计算进度: {completed_count}/{total} (最近完成: {seg_id})") results.append(result) # gather返回的就是按顺序的结果,无需额外排序 return results async def evaluate_single_input(self, data: Dict[str, Any], input_id: str = None) -> Dict[str, Any]: """评估单个输入的音色相似度""" # 生成输入ID if input_id is None: input_id = f"input_{datetime.now().strftime('%Y%m%d_%H%M%S')}" self.logger.info(f"开始评估输入: {input_id},使用语言: {self.language}") # 1. 获取或分割prompt音频 prompts1_path, prompts2_path = self.get_or_split_prompt_audio(data, f"{input_id}_prompt") # 2. 分割output音频(这里会保存详细对齐信息) output_segments = self.split_output_audio_by_comma(data['text'], data['output_audio'], f"{input_id}_output") # 3. 并行计算每小段的相似度 similarity_results = await self.calculate_segment_similarities_parallel( output_segments, prompts1_path, prompts2_path ) # 4. 处理相似度结果 segment_results = [] correct_predictions = 0 total_segments = 0 # 只计算有效段数 label_similarities = [] # 每小段与其标签的相似度 skipped_segments = 0 # 跳过的段数 for sim_result in similarity_results: segment = sim_result['segment'] sim1 = sim_result['sim1'] sim2 = sim_result['sim2'] # 如果任一相似度为None(音频过短或计算失败),跳过该段 if sim1 is None or sim2 is None: skipped_segments += 1 self.logger.info(f"跳过段 {segment['segment_id']}: 相似度计算失败") continue # 只有有效段才参与计算 total_segments += 1 # 判断实际音色 predicted_speaker = 'S1' if sim1 > sim2 else 'S2' actual_speaker = segment['speaker_label'] is_correct = predicted_speaker == actual_speaker if is_correct: correct_predictions += 1 # 计算与标签的相似度 if actual_speaker == 'S1': label_similarity = sim1 else: label_similarity = sim2 label_similarities.append(label_similarity) segment_result = { 'segment_id': segment['segment_id'], 'text': segment['text'], 'speaker_label': actual_speaker, 'predicted_speaker': predicted_speaker, 'sim1': sim1, 'sim2': sim2, 'label_similarity': label_similarity, 'is_correct': is_correct, 'audio_path': segment['audio_path'], 'start_time': segment.get('start_time', 0.0), 'end_time': segment.get('end_time', 1.0) } segment_results.append(segment_result) # 4. 计算整体指标(只基于有效段) accuracy = correct_predictions / total_segments if total_segments > 0 else 0.0 average_similarity = np.mean(label_similarities) if label_similarities else 0.0 # 5. 保存评估结果的对齐信息摘要 evaluation_alignment_summary = { 'input_id': input_id, 'language': self.language, 'prompt_alignment_files': [ f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json" ], 'output_alignment_file': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json", 'total_segments': total_segments, 'total_alignments_processed': len(output_segments), 'alignment_success_rate': total_segments / len(output_segments) if output_segments else 0.0 } self.save_alignment_info(evaluation_alignment_summary, input_id, "summary") result = { 'input_id': input_id, 'language': self.language, 'input_data': data, # 保存原始输入数据 'prompts1_path': prompts1_path, 'prompts2_path': prompts2_path, 'segments': segment_results, 'accuracy': accuracy, 'average_similarity': average_similarity, 'total_segments': total_segments, # 有效段数 'correct_predictions': correct_predictions, 'skipped_segments': skipped_segments, # 跳过的段数 'original_total_segments': len(output_segments), # 原始总段数 'alignment_files': { 'summary': f"{self._get_safe_filename(input_id)}_summary_alignment.json", 'output_detailed': f"{self._get_safe_filename(f'{input_id}_output')}_detailed_alignment.json", 'prompt': f"{self._get_safe_filename(f'{input_id}_prompt')}_prompt_alignment.json" }, 'timestamp': datetime.now().isoformat() } self.logger.info(f"完成评估输入: {input_id}, 语言: {self.language}, 有效段: {total_segments}/{len(output_segments)}, 跳过: {skipped_segments}, 准确率: {accuracy:.3f}, 平均相似度: {average_similarity:.3f}") return result def save_results_to_jsonl(self, results: List[Dict[str, Any]], filename: str = None): """保存结果到JSONL文件""" if filename is None: timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl" output_path = self.results_dir / filename with open(output_path, 'w', encoding='utf-8') as f: for result in results: f.write(json.dumps(result, ensure_ascii=False) + '\n') return str(output_path) def save_summary_report(self, results: List[Dict[str, Any]], filename: str = None): """保存汇总报告""" if filename is None: timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') filename = f"evaluation_summary_{self.language.lower()}_{timestamp}.json" summary_path = self.results_dir / filename # 计算总体统计 total_accuracy = np.mean([r['accuracy'] for r in results]) total_avg_similarity = np.mean([r['average_similarity'] for r in results]) total_segments = sum([r['total_segments'] for r in results]) total_correct = sum([r['correct_predictions'] for r in results]) summary = { 'evaluation_summary': { 'language': self.language, 'total_inputs': len(results), 'total_segments': total_segments, 'total_correct_predictions': total_correct, 'overall_accuracy': total_accuracy, 'overall_average_similarity': total_avg_similarity, 'evaluation_timestamp': datetime.now().isoformat(), 'output_directory': str(self.output_dir), 'alignment_directory': str(self.alignment_dir) }, 'per_input_results': [ { 'input_id': r['input_id'], 'language': r.get('language', self.language), 'accuracy': r['accuracy'], 'average_similarity': r['average_similarity'], 'total_segments': r['total_segments'], 'correct_predictions': r['correct_predictions'], 'output_audio_path': r['input_data']['output_audio'], 'alignment_files': r.get('alignment_files', {}) } for r in results ] } with open(summary_path, 'w', encoding='utf-8') as f: json.dump(summary, f, ensure_ascii=False, indent=2) return str(summary_path) def process_batch_from_jsonl_parallel(self, jsonl_path: str, processes_per_gpu: int = 16, results_filename: str = None, shuffle_data: bool = True): """从JSONL文件并行批量处理输入数据""" # 加载数据 input_data = self.load_data_from_jsonl(jsonl_path) if not input_data: self.logger.error("没有有效的输入数据") return [] # 对数据进行shuffle,使分配更均匀 if shuffle_data: random.shuffle(input_data) self.logger.info(f"已对 {len(input_data)} 条数据进行随机shuffle") return self.process_batch_parallel(input_data, processes_per_gpu, results_filename) def process_batch_from_jsonl(self, jsonl_path: str, results_filename: str = None): """从JSONL文件批量处理输入数据(单进程版本)""" # 加载数据 input_data = self.load_data_from_jsonl(jsonl_path) if not input_data: self.logger.error("没有有效的输入数据") return [] return asyncio.run(self.process_batch_from_data(input_data, results_filename)) async def process_batch_from_data(self, input_data: List[Dict[str, Any]], results_filename: str = None): """处理数据列表(单进程版本,用于兼容),支持增量写入""" # 准备结果文件 if results_filename is None: timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl" results_path = self.results_dir / results_filename # 如果文件已存在,删除它(重新开始) if results_path.exists(): results_path.unlink() results = [] self.logger.info(f"开始处理 {len(input_data)} 个输入,使用语言: {self.language}...") for i, data in enumerate(input_data): input_id = f"input_{i+1:03d}" print(f"处理第{i+1}/{len(input_data)}个输入: {input_id},语言: {self.language}") try: result = await self.evaluate_single_input(data, input_id=input_id) results.append(result) # 增量写入结果 self.append_result_to_jsonl(result, str(results_path)) except Exception as e: self.logger.error(f"处理输入{input_id}时出错: {e}") continue if not results: self.logger.error("没有成功处理的输入") return [] # 保存汇总报告 summary_path = self.save_summary_report(results) # 清理临时文件 self._clean_temp_files() # 打印总体统计 total_accuracy = np.mean([r['accuracy'] for r in results]) total_avg_similarity = np.mean([r['average_similarity'] for r in results]) print(f"\n=== 评估完成 ===") print(f"使用语言: {self.language}") print(f"总体准确率: {total_accuracy:.3f}") print(f"总体平均相似度: {total_avg_similarity:.3f}") print(f"详细结果已保存到: {results_path}") print(f"汇总报告已保存到: {summary_path}") print(f"对齐信息已保存到: {self.alignment_dir}") print(f"所有中间文件保存在: {self.output_dir}") return results def _load_wespeaker_model(self, wespeaker_model_url): """加载wespeaker模型""" try: self.similarity_model = TritonSpeakerClient(wespeaker_model_url) except ImportError: raise ImportError("请安装wespeaker: pip install git+https://github.com/wenet-e2e/wespeaker.git") except Exception as e: self.logger.error(f"加载wespeaker模型失败: {e}") raise def load_data_from_jsonl(self, jsonl_path: str) -> List[Dict[str, Any]]: """从JSONL文件加载数据""" data = [] try: with open(jsonl_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() if line: try: item = json.loads(line) # 验证必要字段 required_fields = ['text', 'output_audio'] for field in required_fields: if field not in item: self.logger.error(f"第{line_num}行缺少必要字段: {field}") continue # 验证音频路径模式:要么有prompt_audio和prompt_text,要么有分别的speaker音频文件 has_combined_prompt = 'prompt_audio' in item and 'prompt_text' in item has_separate_prompts = ('prompt_audio_speaker1' in item and 'prompt_text_speaker1' in item and 'prompt_audio_speaker2' in item and 'prompt_text_speaker2' in item) if not (has_combined_prompt or has_separate_prompts): self.logger.error(f"第{line_num}行:需要提供prompt_audio+prompt_text或者分别的speaker音频文件") continue data.append(item) except json.JSONDecodeError as e: self.logger.error(f"第{line_num}行JSON解析错误: {e}") continue self.logger.info(f"从{jsonl_path}成功加载{len(data)}条数据") return data except FileNotFoundError: self.logger.error(f"JSONL文件不存在: {jsonl_path}") return [] except Exception as e: self.logger.error(f"读取JSONL文件失败: {e}") return [] @staticmethod def get_gpu_count(): """获取可用GPU数量""" if torch.cuda.is_available(): return torch.cuda.device_count() return 0 @staticmethod def split_data_by_gpu(data: List[Dict[str, Any]], num_gpus: int) -> List[List[Dict[str, Any]]]: """根据GPU数量分割数据""" if num_gpus == 0: return [data] chunk_size = math.ceil(len(data) / num_gpus) gpu_chunks = [] for i in range(num_gpus): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, len(data)) if start_idx < len(data): gpu_chunks.append(data[start_idx:end_idx]) return gpu_chunks @staticmethod def split_data_by_processes(data: List[Dict[str, Any]], num_processes: int) -> List[List[Dict[str, Any]]]: """根据进程数量分割数据""" if num_processes <= 1: return [data] chunk_size = math.ceil(len(data) / num_processes) process_chunks = [] for i in range(num_processes): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, len(data)) if start_idx < len(data): process_chunks.append(data[start_idx:end_idx]) return process_chunks def append_result_to_jsonl(self, result: Dict[str, Any], filepath: str): """增量写入结果到JSONL文件""" os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, 'a', encoding='utf-8') as f: f.write(json.dumps(result, ensure_ascii=False) + '\n') f.flush() # 强制刷新缓冲区 def merge_temp_results(self, temp_files: List[str], final_path: str): """合并临时结果文件""" all_results = [] for temp_file in temp_files: if os.path.exists(temp_file): try: with open(temp_file, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line: result = json.loads(line) all_results.append(result) except Exception as e: self.logger.error(f"读取临时文件失败: {temp_file}, 错误: {e}") # 写入最终文件 with open(final_path, 'w', encoding='utf-8') as f: for result in all_results: f.write(json.dumps(result, ensure_ascii=False) + '\n') return all_results def process_batch_parallel(self, input_data: List[Dict[str, Any]], processes_per_gpu: int = 8, # 降低进程数 results_filename: str = None, shuffle_data: bool = True): """并行批量处理输入数据""" # 1. 检查GPU数量 num_gpus = self.get_gpu_count() if num_gpus == 0: self.logger.warning("未检测到GPU,将使用CPU单进程处理") return asyncio.run(self.process_batch_from_data(input_data, results_filename)) # 限制每个GPU的进程数,避免CUDA内存冲突 max_processes_per_gpu = min(processes_per_gpu, 16) self.logger.info(f"检测到 {num_gpus} 个GPU,每个GPU将使用 {max_processes_per_gpu} 个进程") # 2. 对数据进行shuffle(如果还没有shuffle过) shuffled_data = input_data.copy() if shuffle_data: random.shuffle(shuffled_data) self.logger.info(f"已对 {len(shuffled_data)} 条数据进行随机shuffle以平衡GPU负载") # 3. 按GPU分割数据 gpu_chunks = self.split_data_by_gpu(shuffled_data, num_gpus) # 打印每个GPU分配到的数据量 for gpu_id, gpu_data in enumerate(gpu_chunks): if gpu_data: self.logger.info(f"GPU {gpu_id}: 分配到 {len(gpu_data)} 条数据") # 4. 准备结果文件路径 if results_filename is None: timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') results_filename = f"speaker_similarity_results_{self.language.lower()}_{timestamp}.jsonl" final_results_path = self.results_dir / results_filename # 5. 为所有GPU准备进程参数 all_temp_files = [] all_gpu_tasks = [] for gpu_id, gpu_data in enumerate(gpu_chunks): if not gpu_data: continue self.logger.info(f"GPU {gpu_id}: 准备处理 {len(gpu_data)} 条数据") # 按进程数分割当前GPU的数据 process_chunks = self.split_data_by_processes(gpu_data, max_processes_per_gpu) # 为当前GPU准备所有进程参数 gpu_process_args = [] for proc_id, proc_data in enumerate(process_chunks): if proc_data: temp_result_file = str(self.temp_results_dir / f"gpu{gpu_id}_proc{proc_id}_results.jsonl") all_temp_files.append(temp_result_file) # 子进程输出目录在主输出目录内部 subprocess_output_dir = str(self.output_dir / f"gpu{gpu_id}_proc{proc_id}") gpu_process_args.append(( proc_data, gpu_id, proc_id, subprocess_output_dir, temp_result_file, self.alignment_model_dir, self.wespeaker_model_url, self.language, # 语言参数 self.similarity_max_workers # 添加相似度计算线程数参数 )) if gpu_process_args: all_gpu_tasks.append((gpu_id, gpu_process_args, max_processes_per_gpu)) # 6. 使用ThreadPoolExecutor并行处理所有GPU def process_gpu_tasks(gpu_task): gpu_id, process_args, actual_processes = gpu_task self.logger.info(f"GPU {gpu_id}: 开始并行处理 {len(process_args)} 个进程") # 为每个GPU使用独立的进程池,避免进程间冲突 with mp.Pool(processes=actual_processes) as pool: # 调用同步包装器 run_async_worker,在每个子进程内部运行异步函数。 pool.map(run_async_worker, process_args) self.logger.info(f"GPU {gpu_id}: 所有进程处理完成") return gpu_id # 使用线程池同时处理所有GPU with ThreadPoolExecutor(max_workers=num_gpus) as executor: # 提交所有GPU任务 future_to_gpu = {executor.submit(process_gpu_tasks, gpu_task): gpu_task[0] for gpu_task in all_gpu_tasks} # 等待所有GPU完成 completed_gpus = [] for future in as_completed(future_to_gpu): gpu_id = future_to_gpu[future] try: result_gpu_id = future.result() completed_gpus.append(result_gpu_id) self.logger.info(f"GPU {result_gpu_id} 完成处理") except Exception as exc: self.logger.error(f"GPU {gpu_id} 处理时发生异常: {exc}") self.logger.info(f"所有GPU处理完成: {completed_gpus}") # 7. 合并所有临时结果文件 self.logger.info("合并所有临时结果文件...") all_results = self.merge_temp_results(all_temp_files, str(final_results_path)) if not all_results: self.logger.error("没有成功处理的数据") return [] # 8. 生成汇总报告 summary_path = self.save_summary_report(all_results) # 9. 清理临时文件 for temp_file in all_temp_files: if os.path.exists(temp_file): os.remove(temp_file) # 10. 打印总体统计 total_accuracy = np.mean([r['accuracy'] for r in all_results]) total_avg_similarity = np.mean([r['average_similarity'] for r in all_results]) print(f"\n=== 并行评估完成 ===") print(f"使用语言: {self.language}") print(f"使用 {num_gpus} 个GPU,每GPU {max_processes_per_gpu} 个进程") print(f"总处理数据: {len(input_data)} 条") print(f"成功处理: {len(all_results)} 条") print(f"总体准确率: {total_accuracy:.3f}") print(f"总体平均相似度: {total_avg_similarity:.3f}") print(f"详细结果已保存到: {final_results_path}") print(f"汇总报告已保存到: {summary_path}") print(f"对齐信息已保存到: {self.alignment_dir}") return all_results def get_or_split_prompt_audio(self, data: Dict[str, Any], audio_id: str) -> Tuple[str, str]: """ 获取或分割prompt音频 如果提供了分别的speaker音频文件则直接使用,否则从combined prompt分割 """ # 检查是否有分别的speaker音频文件 if ('prompt_audio_speaker1' in data and 'prompt_audio_speaker2' in data and 'prompt_text_speaker1' in data and 'prompt_text_speaker2' in data): self.logger.info(f"使用预分割的speaker音频文件") # 即使使用预分割的音频,也保存对齐信息 try: # 检测语言或使用设置的语言 alignment_language = self.language if alignment_language == "AUTO": alignment_language = self._detect_language_from_text(data['prompt_text_speaker1']) # 对S1音频进行对齐 s1_alignments = self.align_text_with_audio( data['prompt_text_speaker1'], data['prompt_audio_speaker1'], alignment_language ) s1_alignment_data = { 'speaker': 'S1', 'text': data['prompt_text_speaker1'], 'audio_path': data['prompt_audio_speaker1'], 'language': alignment_language, 'alignments': s1_alignments } self.save_alignment_info(s1_alignment_data, audio_id, "prompt_s1") # 对S2音频进行对齐 s2_alignments = self.align_text_with_audio( data['prompt_text_speaker2'], data['prompt_audio_speaker2'], alignment_language ) s2_alignment_data = { 'speaker': 'S2', 'text': data['prompt_text_speaker2'], 'audio_path': data['prompt_audio_speaker2'], 'language': alignment_language, 'alignments': s2_alignments } self.save_alignment_info(s2_alignment_data, audio_id, "prompt_s2") except Exception as e: self.logger.warning(f"保存预分割音频对齐信息失败: {e}") return data['prompt_audio_speaker1'], data['prompt_audio_speaker2'] # 否则从combined prompt分割 elif 'prompt_audio' in data and 'prompt_text' in data: self.logger.info(f"从combined prompt音频分割speaker片段") return self.split_audio_by_speaker(data['prompt_text'], data['prompt_audio'], audio_id) else: raise ValueError("必须提供prompt_audio+prompt_text或者分别的speaker音频文件") def calculate_voice_similarity(self, audio1_path: str, audio2_path: str) -> float: """ 计算两个音频的音色相似度(向后兼容版本) 对于过短的音频片段,通过复制来达到最小长度要求 """ # 如果在多线程环境中,使用线程安全版本 if threading.current_thread() != threading.main_thread(): return self.calculate_voice_similarity_thread_safe(audio1_path, audio2_path) # 确保模型已初始化 self._init_models_if_needed() try: if not os.path.exists(audio1_path) or not os.path.exists(audio2_path): self.logger.warning(f"Audio file not found: {audio1_path} or {audio2_path}") return None # 检查并处理音频文件长度 def process_audio_for_similarity(audio_path, min_duration=0.1): """ 处理音频文件,如果过短则复制到满足最小长度要求 返回处理后的音频路径和是否为临时文件的标志 """ try: waveform, sample_rate = torchaudio.load(audio_path) duration = waveform.shape[1] / sample_rate if duration >= min_duration: # 音频长度足够,直接返回原路径 return audio_path, False # 音频过短,需要复制 repeat_times = math.ceil(min_duration / duration) self.logger.info(f"音频过短 ({duration:.3f}s),复制 {repeat_times} 次达到 {min_duration}s 要求: {audio_path}") # 复制音频 repeated_waveform = waveform.repeat(1, repeat_times) # 生成临时文件路径 temp_filename = f"temp_{os.path.basename(audio_path)}" temp_path = str(self.temp_dir / temp_filename) # 保存复制后的音频 torchaudio.save(temp_path, repeated_waveform, sample_rate) return temp_path, True except Exception as e: self.logger.error(f"处理音频文件失败: {audio_path}, 错误: {e}") return audio_path, False # 处理两个音频文件 processed_audio1, is_temp1 = process_audio_for_similarity(audio1_path) processed_audio2, is_temp2 = process_audio_for_similarity(audio2_path) # 计算相似度 similarity = self.similarity_model.compute_similarity(processed_audio1, processed_audio2) # 清理临时文件 if is_temp1 and os.path.exists(processed_audio1): try: os.remove(processed_audio1) except Exception as e: self.logger.warning(f"删除临时文件失败: {processed_audio1}, 错误: {e}") if is_temp2 and os.path.exists(processed_audio2): try: os.remove(processed_audio2) except Exception as e: self.logger.warning(f"删除临时文件失败: {processed_audio2}, 错误: {e}") return float(similarity) except Exception as e: # 检查是否是窗口大小错误或其他计算错误 if "choose a window size" in str(e) or "window size" in str(e): self.logger.warning(f"音频片段仍然过短,无法计算相似度: {audio1_path} vs {audio2_path}") return None else: self.logger.error(f"Failed to compute similarity between {audio1_path} and {audio2_path}: {e}") return None # 全局函数,用于多进程处理(支持增量写入) async def process_data_chunk_incremental(args): """处理数据块的工作函数(增量写入版本)""" data_chunk, gpu_id, proc_id, output_dir, temp_result_file, alignment_model_dir, wespeaker_model_url, language, similarity_max_workers = args # 设置当前进程使用的GPU device = f"cuda:{gpu_id}" if torch.cuda.is_available() and gpu_id < torch.cuda.device_count() else "cpu" try: # 清理CUDA状态,避免进程间冲突 if torch.cuda.is_available(): torch.cuda.empty_cache() # 设置当前进程的GPU设备 torch.cuda.set_device(gpu_id) # 添加小延迟,避免同时初始化冲突 time.sleep(proc_id * 0.5) # 创建评估器实例,传入模型路径、语言参数和相似度计算线程数 evaluator = SpeakerSimilarityEvaluator( device=device, alignment_model_dir=alignment_model_dir, wespeaker_model_url=wespeaker_model_url, output_dir=output_dir, language=language, # 传入语言参数 similarity_max_workers=similarity_max_workers # 传入相似度计算线程数 ) # 延迟初始化模型 evaluator._init_models_if_needed() # 清空临时结果文件(如果存在) if os.path.exists(temp_result_file): os.remove(temp_result_file) # 处理数据块 for i, data in enumerate(data_chunk): input_id = f"gpu{gpu_id}_proc{proc_id}_input_{i+1:03d}" try: result = await evaluator.evaluate_single_input(data, input_id=input_id) # 立即写入结果到临时文件 evaluator.append_result_to_jsonl(result, temp_result_file) print(f"GPU{gpu_id}-进程{proc_id}: 完成 {input_id} (语言: {language}, 相似度线程: {similarity_max_workers})") # 每处理完一个数据项,清理CUDA缓存 if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: print(f"GPU{gpu_id}-进程{proc_id}: 处理 {input_id} 失败: {e}") # 出错时也清理CUDA缓存 if torch.cuda.is_available(): torch.cuda.empty_cache() continue print(f"GPU{gpu_id}-进程{proc_id}: 所有数据处理完成,结果已写入 {temp_result_file}") except Exception as e: print(f"GPU{gpu_id}-进程{proc_id}: 初始化失败: {e}") # 出错时清理CUDA缓存 if torch.cuda.is_available(): torch.cuda.empty_cache() def run_async_worker(args): """ 一个同步包装器,为我们的异步工作函数设置并运行asyncio事件循环。 这是必需的,因为 multiprocessing.Pool 不能直接调用异步函数。 """ # asyncio.run() 是在每个子进程中启动和运行协程最简单、最安全的方式。 # 它会创建一个新的事件循环,运行协程直到完成,然后关闭事件循环。 return asyncio.run(process_data_chunk_incremental(args)) def main(): """主函数示例""" import argparse parser = argparse.ArgumentParser(description='Speaker Similarity Evaluator') parser.add_argument('--jsonl_path', type=str, help='JSONL文件路径') parser.add_argument('--output_dir', type=str, default=f"/inspire/hdd/project/embodied-multimodality/public/yqzhang/auto_evaluation_new/eval_res/results_{datetime.now().strftime('%Y%m%d_%H%M%S')}", help='结果保存目录') parser.add_argument('--language', type=str, choices=['zh', 'en', 'auto'], default='zh', help='指定语言: zh=中文, en=英文, auto=自动检测 (默认: zh)') parser.add_argument('--no_parallel', action='store_true', help='禁用并行处理(默认启用并行)') parser.add_argument('--processes_per_gpu', type=int, default=4, help='每个GPU的进程数(建议不超过4)') parser.add_argument('--similarity_workers', type=int, default=16, help='相似度计算的线程数(默认: 8)') parser.add_argument('--no_shuffle', action='store_true', help='禁用数据shuffle(默认启用shuffle)') parser.add_argument('--random_seed', type=int, default=None, help='随机种子(可选,用于结果复现)') args = parser.parse_args() # 设置随机种子(如果指定) if args.random_seed is not None: random.seed(args.random_seed) np.random.seed(args.random_seed) torch.manual_seed(args.random_seed) print(f"设置随机种子: {args.random_seed}") # 语言参数处理 language = args.language.upper() if language == 'AUTO': language = 'AUTO' elif language == 'EN': language = 'EN' else: language = 'ZH' # 默认中文 # 创建评估器,指定结果保存目录、语言和相似度计算线程数 evaluator = SpeakerSimilarityEvaluator( output_dir=args.output_dir, language=language, similarity_max_workers=args.similarity_workers ) # 默认使用并行处理,除非明确禁用 use_parallel = not args.no_parallel use_shuffle = not args.no_shuffle print(f"使用语言设置: {language}") print(f"相似度计算线程数: {args.similarity_workers}") if args.jsonl_path: # 从JSONL文件处理数据 if use_parallel: evaluator.process_batch_from_jsonl_parallel( args.jsonl_path, processes_per_gpu=args.processes_per_gpu, shuffle_data=use_shuffle ) else: asyncio.run(evaluator.process_batch_from_jsonl(args.jsonl_path)) else: # 使用示例数据(兼容性) input_data = [ { 'prompt_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/zhouxingchi/zxc_enhanced.wav", 'prompt_text': "[S1]你再往前半步我就把你给杀了。[S2]你应该这么做,我也应该死。", 'text': "[S1]至尊宝,如果有一天我不再是紫霞仙子,只是一个普通的凡人,你还会像现在这样陪着我吗?[S2]这个嘛,那我得先问问月老,看看他给不给我打折!毕竟追仙子要花好多力气的![S1]哼!油嘴滑舌!我是认真的![S2]紫霞,不管你是仙子还是凡人,哪怕变成一根香蕉,我都认得出你。不过……你最好别真变成香蕉,我怕我会忍不住吃掉……[S1]讨厌!谁要变成香蕉啊!那……如果有一天,我们不得不分开呢?[S2]哇!你这话比牛魔王的斧头还狠!不行不行,你得赔我精神损失费![S1]怎么赔?[S2]很简单,让我亲一下,就当是定金![S1]想得美!那如果有一天,你真的忘了我呢?[S2]那我就算翻遍三界,打烂阎王殿,也要把记忆找回来。紫霞,我至尊宝这辈子,赖定你了![S1]傻瓜。", 'output_audio': "/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_res/from_newckpt_step145000/test_set/output_7.wav" } ] # 处理数据 if use_parallel: evaluator.process_batch_parallel(input_data, processes_per_gpu=args.processes_per_gpu) else: asyncio.run(evaluator.process_batch_from_data(input_data)) if __name__ == "__main__": main()