File size: 6,613 Bytes
29c0409 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# new_client.py
import argparse
import asyncio
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import tritonclient.grpc.aio as grpcclient
import sys
import time
import math
class TritonSpeakerClient:
def __init__(self, url, model_name="speaker_model", verbose=False):
try:
self.triton_client = grpcclient.InferenceServerClient(url=url, verbose=verbose)
except Exception as e:
print(f"Channel creation failed: {e}", file=sys.stderr)
sys.exit(1)
self.model_name = model_name
# --- 从旧的 similarity_model 迁移过来的预处理参数 ---
self.sample_rate = 16000
self.feature_dim = 80
self.min_duration = 0.1
# ----------------------------------------------------
def _preprocess_audio(self, audio_path: str):
"""
从音频文件路径加载并预处理音频,生成Fbank特征。
这段逻辑完全复制自旧的 similarity_model.py 中的 preprocess 方法。
"""
try:
waveform, sample_rate = torchaudio.load(audio_path)
# 如果采样率不匹配,则重采样
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
waveform = resampler(waveform)
# 如果音频太短,则重复填充以满足最小长度
duration = waveform.shape[1] / self.sample_rate
if duration < self.min_duration:
repeat_times = math.ceil(self.min_duration / duration)
waveform = waveform.repeat(1, repeat_times)
# 计算80维Fbank特征
# waveform 需要是 [batch, time] 格式,所以我们移除通道维度
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True) # 转为单声道
features = kaldi.fbank(
waveform,
num_mel_bins=self.feature_dim,
sample_frequency=self.sample_rate,
frame_length=25,
frame_shift=10
)
# fbank 输出 shape [1, num_frames, num_bins], 我们需要 [num_frames, 80]
return features.squeeze(0).numpy().astype(np.float32)
except Exception as e:
raise RuntimeError(f"Failed during audio preprocessing for {audio_path}: {e}")
def _calculate_cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray):
"""在客户端计算余弦相似度。"""
e1 = torch.from_numpy(emb1).flatten()
e2 = torch.from_numpy(emb2).flatten()
similarity = torch.nn.functional.cosine_similarity(e1, e2, dim=0)
# 将相似度从 [-1, 1] 范围归一化到 [0, 1]
return (similarity.item() + 1) / 2
async def compute_similarity(self, audio1_path: str, audio2_path: str):
"""
计算两个音频文件的相似度。
此函数现在包含完整的处理流程:预处理 -> 批处理 -> 推理 -> 后处理。
"""
# 1. 在客户端对两个音频文件进行预处理
feats1 = self._preprocess_audio(audio1_path)
feats2 = self._preprocess_audio(audio2_path)
# 2. 批处理:为了使用Triton的动态批处理,我们将两个特征打包成一个请求。
# 由于它们的长度(帧数)可能不同,我们需要将它们填充到相同的长度。
max_len = max(feats1.shape[0], feats2.shape[0])
# 使用np.pad进行填充
padded_feats1 = np.pad(feats1, ((0, max_len - feats1.shape[0]), (0, 0)), 'constant', constant_values=0)
padded_feats2 = np.pad(feats2, ((0, max_len - feats2.shape[0]), (0, 0)), 'constant', constant_values=0)
# 将填充后的特征堆叠成一个批次
input_batch = np.stack([padded_feats1, padded_feats2]) # Shape: [2, max_len, 80]
# 3. 创建Triton输入张量
# 输入名称 "feats" 必须与 speaker_model 的 config.pbtxt 中的输入名匹配
inputs = [
grpcclient.InferInput("feats", input_batch.shape, "FP32")
]
inputs[0].set_data_from_numpy(input_batch)
# 4. 设置请求的输出
# 输出名称 "embs" 必须与 speaker_model 的 config.pbtxt 中的输出名匹配
outputs = [grpcclient.InferRequestedOutput("embs")]
# 5. 发送推理请求
response = await self.triton_client.infer(
model_name=self.model_name,
inputs=inputs,
outputs=outputs
)
# 6. 解析结果
embeddings_batch = response.as_numpy("embs") # Shape: [2, embedding_dim]
emb1 = embeddings_batch[0]
emb2 = embeddings_batch[1]
# 7. 在客户端计算相似度
similarity = self._calculate_cosine_similarity(emb1, emb2)
return similarity
async def main():
parser = argparse.ArgumentParser(description="Triton client for speaker model (direct call).")
parser.add_argument('-v', '--verbose', action="store_true", default=False, help='Enable verbose output')
parser.add_argument('-u', '--url', type=str, default='localhost:8001', help='Inference server URL.')
# 注意:这里的 model_name 应该是 speaker_model
parser.add_argument('--model_name', default='speaker_model', help='The name of the speaker embedding model on Triton.')
parser.add_argument('--audio_file1', type=str, required=True, help='Path to first audio file')
parser.add_argument('--audio_file2', type=str, required=True, help='Path to second audio file')
FLAGS = parser.parse_args()
client = TritonSpeakerClient(FLAGS.url, FLAGS.model_name, verbose=FLAGS.verbose)
start_time = time.time()
try:
similarity = await client.compute_similarity(FLAGS.audio_file1, FLAGS.audio_file2)
elapsed = time.time() - start_time
print(f"Similarity: {similarity:.4f}, Time: {elapsed:.3f}s")
except Exception as e:
print(f"Error computing similarity: {e}", file=sys.stderr)
sys.exit(1)
# 使用示例:
# python speaker_client.py --audio_file1=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi1.wav --audio_file2=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi2.wav
if __name__ == '__main__':
asyncio.run(main())
|