|
|
|
|
|
import argparse |
|
|
import asyncio |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchaudio |
|
|
import torchaudio.compliance.kaldi as kaldi |
|
|
import tritonclient.grpc.aio as grpcclient |
|
|
import sys |
|
|
import time |
|
|
import math |
|
|
|
|
|
|
|
|
class TritonSpeakerClient: |
|
|
def __init__(self, url, model_name="speaker_model", verbose=False): |
|
|
try: |
|
|
self.triton_client = grpcclient.InferenceServerClient(url=url, verbose=verbose) |
|
|
except Exception as e: |
|
|
print(f"Channel creation failed: {e}", file=sys.stderr) |
|
|
sys.exit(1) |
|
|
|
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
|
self.sample_rate = 16000 |
|
|
self.feature_dim = 80 |
|
|
self.min_duration = 0.1 |
|
|
|
|
|
|
|
|
def _preprocess_audio(self, audio_path: str): |
|
|
""" |
|
|
从音频文件路径加载并预处理音频,生成Fbank特征。 |
|
|
这段逻辑完全复制自旧的 similarity_model.py 中的 preprocess 方法。 |
|
|
""" |
|
|
try: |
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
|
|
|
|
|
|
|
if sample_rate != self.sample_rate: |
|
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate) |
|
|
waveform = resampler(waveform) |
|
|
|
|
|
|
|
|
duration = waveform.shape[1] / self.sample_rate |
|
|
if duration < self.min_duration: |
|
|
repeat_times = math.ceil(self.min_duration / duration) |
|
|
waveform = waveform.repeat(1, repeat_times) |
|
|
|
|
|
|
|
|
|
|
|
if waveform.shape[0] > 1: |
|
|
waveform = torch.mean(waveform, dim=0, keepdim=True) |
|
|
|
|
|
features = kaldi.fbank( |
|
|
waveform, |
|
|
num_mel_bins=self.feature_dim, |
|
|
sample_frequency=self.sample_rate, |
|
|
frame_length=25, |
|
|
frame_shift=10 |
|
|
) |
|
|
|
|
|
return features.squeeze(0).numpy().astype(np.float32) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed during audio preprocessing for {audio_path}: {e}") |
|
|
|
|
|
def _calculate_cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray): |
|
|
"""在客户端计算余弦相似度。""" |
|
|
e1 = torch.from_numpy(emb1).flatten() |
|
|
e2 = torch.from_numpy(emb2).flatten() |
|
|
|
|
|
similarity = torch.nn.functional.cosine_similarity(e1, e2, dim=0) |
|
|
|
|
|
|
|
|
return (similarity.item() + 1) / 2 |
|
|
|
|
|
async def compute_similarity(self, audio1_path: str, audio2_path: str): |
|
|
""" |
|
|
计算两个音频文件的相似度。 |
|
|
此函数现在包含完整的处理流程:预处理 -> 批处理 -> 推理 -> 后处理。 |
|
|
""" |
|
|
|
|
|
feats1 = self._preprocess_audio(audio1_path) |
|
|
feats2 = self._preprocess_audio(audio2_path) |
|
|
|
|
|
|
|
|
|
|
|
max_len = max(feats1.shape[0], feats2.shape[0]) |
|
|
|
|
|
|
|
|
padded_feats1 = np.pad(feats1, ((0, max_len - feats1.shape[0]), (0, 0)), 'constant', constant_values=0) |
|
|
padded_feats2 = np.pad(feats2, ((0, max_len - feats2.shape[0]), (0, 0)), 'constant', constant_values=0) |
|
|
|
|
|
|
|
|
input_batch = np.stack([padded_feats1, padded_feats2]) |
|
|
|
|
|
|
|
|
|
|
|
inputs = [ |
|
|
grpcclient.InferInput("feats", input_batch.shape, "FP32") |
|
|
] |
|
|
inputs[0].set_data_from_numpy(input_batch) |
|
|
|
|
|
|
|
|
|
|
|
outputs = [grpcclient.InferRequestedOutput("embs")] |
|
|
|
|
|
|
|
|
response = await self.triton_client.infer( |
|
|
model_name=self.model_name, |
|
|
inputs=inputs, |
|
|
outputs=outputs |
|
|
) |
|
|
|
|
|
|
|
|
embeddings_batch = response.as_numpy("embs") |
|
|
emb1 = embeddings_batch[0] |
|
|
emb2 = embeddings_batch[1] |
|
|
|
|
|
|
|
|
similarity = self._calculate_cosine_similarity(emb1, emb2) |
|
|
return similarity |
|
|
|
|
|
async def main(): |
|
|
parser = argparse.ArgumentParser(description="Triton client for speaker model (direct call).") |
|
|
parser.add_argument('-v', '--verbose', action="store_true", default=False, help='Enable verbose output') |
|
|
parser.add_argument('-u', '--url', type=str, default='localhost:8001', help='Inference server URL.') |
|
|
|
|
|
parser.add_argument('--model_name', default='speaker_model', help='The name of the speaker embedding model on Triton.') |
|
|
parser.add_argument('--audio_file1', type=str, required=True, help='Path to first audio file') |
|
|
parser.add_argument('--audio_file2', type=str, required=True, help='Path to second audio file') |
|
|
|
|
|
FLAGS = parser.parse_args() |
|
|
|
|
|
client = TritonSpeakerClient(FLAGS.url, FLAGS.model_name, verbose=FLAGS.verbose) |
|
|
|
|
|
start_time = time.time() |
|
|
try: |
|
|
similarity = await client.compute_similarity(FLAGS.audio_file1, FLAGS.audio_file2) |
|
|
elapsed = time.time() - start_time |
|
|
print(f"Similarity: {similarity:.4f}, Time: {elapsed:.3f}s") |
|
|
except Exception as e: |
|
|
print(f"Error computing similarity: {e}", file=sys.stderr) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
asyncio.run(main()) |
|
|
|