voxblink2_samresnet100_ft / speaker_client.py
MCplayer's picture
speech similarity model
29c0409
raw
history blame
6.61 kB
# new_client.py
import argparse
import asyncio
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import tritonclient.grpc.aio as grpcclient
import sys
import time
import math
class TritonSpeakerClient:
def __init__(self, url, model_name="speaker_model", verbose=False):
try:
self.triton_client = grpcclient.InferenceServerClient(url=url, verbose=verbose)
except Exception as e:
print(f"Channel creation failed: {e}", file=sys.stderr)
sys.exit(1)
self.model_name = model_name
# --- 从旧的 similarity_model 迁移过来的预处理参数 ---
self.sample_rate = 16000
self.feature_dim = 80
self.min_duration = 0.1
# ----------------------------------------------------
def _preprocess_audio(self, audio_path: str):
"""
从音频文件路径加载并预处理音频,生成Fbank特征。
这段逻辑完全复制自旧的 similarity_model.py 中的 preprocess 方法。
"""
try:
waveform, sample_rate = torchaudio.load(audio_path)
# 如果采样率不匹配,则重采样
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
waveform = resampler(waveform)
# 如果音频太短,则重复填充以满足最小长度
duration = waveform.shape[1] / self.sample_rate
if duration < self.min_duration:
repeat_times = math.ceil(self.min_duration / duration)
waveform = waveform.repeat(1, repeat_times)
# 计算80维Fbank特征
# waveform 需要是 [batch, time] 格式,所以我们移除通道维度
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True) # 转为单声道
features = kaldi.fbank(
waveform,
num_mel_bins=self.feature_dim,
sample_frequency=self.sample_rate,
frame_length=25,
frame_shift=10
)
# fbank 输出 shape [1, num_frames, num_bins], 我们需要 [num_frames, 80]
return features.squeeze(0).numpy().astype(np.float32)
except Exception as e:
raise RuntimeError(f"Failed during audio preprocessing for {audio_path}: {e}")
def _calculate_cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray):
"""在客户端计算余弦相似度。"""
e1 = torch.from_numpy(emb1).flatten()
e2 = torch.from_numpy(emb2).flatten()
similarity = torch.nn.functional.cosine_similarity(e1, e2, dim=0)
# 将相似度从 [-1, 1] 范围归一化到 [0, 1]
return (similarity.item() + 1) / 2
async def compute_similarity(self, audio1_path: str, audio2_path: str):
"""
计算两个音频文件的相似度。
此函数现在包含完整的处理流程:预处理 -> 批处理 -> 推理 -> 后处理。
"""
# 1. 在客户端对两个音频文件进行预处理
feats1 = self._preprocess_audio(audio1_path)
feats2 = self._preprocess_audio(audio2_path)
# 2. 批处理:为了使用Triton的动态批处理,我们将两个特征打包成一个请求。
# 由于它们的长度(帧数)可能不同,我们需要将它们填充到相同的长度。
max_len = max(feats1.shape[0], feats2.shape[0])
# 使用np.pad进行填充
padded_feats1 = np.pad(feats1, ((0, max_len - feats1.shape[0]), (0, 0)), 'constant', constant_values=0)
padded_feats2 = np.pad(feats2, ((0, max_len - feats2.shape[0]), (0, 0)), 'constant', constant_values=0)
# 将填充后的特征堆叠成一个批次
input_batch = np.stack([padded_feats1, padded_feats2]) # Shape: [2, max_len, 80]
# 3. 创建Triton输入张量
# 输入名称 "feats" 必须与 speaker_model 的 config.pbtxt 中的输入名匹配
inputs = [
grpcclient.InferInput("feats", input_batch.shape, "FP32")
]
inputs[0].set_data_from_numpy(input_batch)
# 4. 设置请求的输出
# 输出名称 "embs" 必须与 speaker_model 的 config.pbtxt 中的输出名匹配
outputs = [grpcclient.InferRequestedOutput("embs")]
# 5. 发送推理请求
response = await self.triton_client.infer(
model_name=self.model_name,
inputs=inputs,
outputs=outputs
)
# 6. 解析结果
embeddings_batch = response.as_numpy("embs") # Shape: [2, embedding_dim]
emb1 = embeddings_batch[0]
emb2 = embeddings_batch[1]
# 7. 在客户端计算相似度
similarity = self._calculate_cosine_similarity(emb1, emb2)
return similarity
async def main():
parser = argparse.ArgumentParser(description="Triton client for speaker model (direct call).")
parser.add_argument('-v', '--verbose', action="store_true", default=False, help='Enable verbose output')
parser.add_argument('-u', '--url', type=str, default='localhost:8001', help='Inference server URL.')
# 注意:这里的 model_name 应该是 speaker_model
parser.add_argument('--model_name', default='speaker_model', help='The name of the speaker embedding model on Triton.')
parser.add_argument('--audio_file1', type=str, required=True, help='Path to first audio file')
parser.add_argument('--audio_file2', type=str, required=True, help='Path to second audio file')
FLAGS = parser.parse_args()
client = TritonSpeakerClient(FLAGS.url, FLAGS.model_name, verbose=FLAGS.verbose)
start_time = time.time()
try:
similarity = await client.compute_similarity(FLAGS.audio_file1, FLAGS.audio_file2)
elapsed = time.time() - start_time
print(f"Similarity: {similarity:.4f}, Time: {elapsed:.3f}s")
except Exception as e:
print(f"Error computing similarity: {e}", file=sys.stderr)
sys.exit(1)
# 使用示例:
# python speaker_client.py --audio_file1=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi1.wav --audio_file2=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi2.wav
if __name__ == '__main__':
asyncio.run(main())