# new_client.py import argparse import asyncio import numpy as np import torch import torchaudio import torchaudio.compliance.kaldi as kaldi import tritonclient.grpc.aio as grpcclient import sys import time import math class TritonSpeakerClient: def __init__(self, url, model_name="speaker_model", verbose=False): try: self.triton_client = grpcclient.InferenceServerClient(url=url, verbose=verbose) except Exception as e: print(f"Channel creation failed: {e}", file=sys.stderr) sys.exit(1) self.model_name = model_name # --- 从旧的 similarity_model 迁移过来的预处理参数 --- self.sample_rate = 16000 self.feature_dim = 80 self.min_duration = 0.1 # ---------------------------------------------------- def _preprocess_audio(self, audio_path: str): """ 从音频文件路径加载并预处理音频,生成Fbank特征。 这段逻辑完全复制自旧的 similarity_model.py 中的 preprocess 方法。 """ try: waveform, sample_rate = torchaudio.load(audio_path) # 如果采样率不匹配,则重采样 if sample_rate != self.sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate) waveform = resampler(waveform) # 如果音频太短,则重复填充以满足最小长度 duration = waveform.shape[1] / self.sample_rate if duration < self.min_duration: repeat_times = math.ceil(self.min_duration / duration) waveform = waveform.repeat(1, repeat_times) # 计算80维Fbank特征 # waveform 需要是 [batch, time] 格式,所以我们移除通道维度 if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) # 转为单声道 features = kaldi.fbank( waveform, num_mel_bins=self.feature_dim, sample_frequency=self.sample_rate, frame_length=25, frame_shift=10 ) # fbank 输出 shape [1, num_frames, num_bins], 我们需要 [num_frames, 80] return features.squeeze(0).numpy().astype(np.float32) except Exception as e: raise RuntimeError(f"Failed during audio preprocessing for {audio_path}: {e}") def _calculate_cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray): """在客户端计算余弦相似度。""" e1 = torch.from_numpy(emb1).flatten() e2 = torch.from_numpy(emb2).flatten() similarity = torch.nn.functional.cosine_similarity(e1, e2, dim=0) # 将相似度从 [-1, 1] 范围归一化到 [0, 1] return (similarity.item() + 1) / 2 async def compute_similarity(self, audio1_path: str, audio2_path: str): """ 计算两个音频文件的相似度。 此函数现在包含完整的处理流程:预处理 -> 批处理 -> 推理 -> 后处理。 """ # 1. 在客户端对两个音频文件进行预处理 feats1 = self._preprocess_audio(audio1_path) feats2 = self._preprocess_audio(audio2_path) # 2. 批处理:为了使用Triton的动态批处理,我们将两个特征打包成一个请求。 # 由于它们的长度(帧数)可能不同,我们需要将它们填充到相同的长度。 max_len = max(feats1.shape[0], feats2.shape[0]) # 使用np.pad进行填充 padded_feats1 = np.pad(feats1, ((0, max_len - feats1.shape[0]), (0, 0)), 'constant', constant_values=0) padded_feats2 = np.pad(feats2, ((0, max_len - feats2.shape[0]), (0, 0)), 'constant', constant_values=0) # 将填充后的特征堆叠成一个批次 input_batch = np.stack([padded_feats1, padded_feats2]) # Shape: [2, max_len, 80] # 3. 创建Triton输入张量 # 输入名称 "feats" 必须与 speaker_model 的 config.pbtxt 中的输入名匹配 inputs = [ grpcclient.InferInput("feats", input_batch.shape, "FP32") ] inputs[0].set_data_from_numpy(input_batch) # 4. 设置请求的输出 # 输出名称 "embs" 必须与 speaker_model 的 config.pbtxt 中的输出名匹配 outputs = [grpcclient.InferRequestedOutput("embs")] # 5. 发送推理请求 response = await self.triton_client.infer( model_name=self.model_name, inputs=inputs, outputs=outputs ) # 6. 解析结果 embeddings_batch = response.as_numpy("embs") # Shape: [2, embedding_dim] emb1 = embeddings_batch[0] emb2 = embeddings_batch[1] # 7. 在客户端计算相似度 similarity = self._calculate_cosine_similarity(emb1, emb2) return similarity async def main(): parser = argparse.ArgumentParser(description="Triton client for speaker model (direct call).") parser.add_argument('-v', '--verbose', action="store_true", default=False, help='Enable verbose output') parser.add_argument('-u', '--url', type=str, default='localhost:8001', help='Inference server URL.') # 注意:这里的 model_name 应该是 speaker_model parser.add_argument('--model_name', default='speaker_model', help='The name of the speaker embedding model on Triton.') parser.add_argument('--audio_file1', type=str, required=True, help='Path to first audio file') parser.add_argument('--audio_file2', type=str, required=True, help='Path to second audio file') FLAGS = parser.parse_args() client = TritonSpeakerClient(FLAGS.url, FLAGS.model_name, verbose=FLAGS.verbose) start_time = time.time() try: similarity = await client.compute_similarity(FLAGS.audio_file1, FLAGS.audio_file2) elapsed = time.time() - start_time print(f"Similarity: {similarity:.4f}, Time: {elapsed:.3f}s") except Exception as e: print(f"Error computing similarity: {e}", file=sys.stderr) sys.exit(1) # 使用示例: # python speaker_client.py --audio_file1=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi1.wav --audio_file2=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi2.wav if __name__ == '__main__': asyncio.run(main())