File size: 6,613 Bytes
29c0409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# new_client.py
import argparse
import asyncio
import numpy as np
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import tritonclient.grpc.aio as grpcclient
import sys
import time
import math


class TritonSpeakerClient:
    def __init__(self, url, model_name="speaker_model", verbose=False):
        try:
            self.triton_client = grpcclient.InferenceServerClient(url=url, verbose=verbose)
        except Exception as e:
            print(f"Channel creation failed: {e}", file=sys.stderr)
            sys.exit(1)
            
        self.model_name = model_name
        
        # --- 从旧的 similarity_model 迁移过来的预处理参数 ---
        self.sample_rate = 16000
        self.feature_dim = 80
        self.min_duration = 0.1
        # ----------------------------------------------------

    def _preprocess_audio(self, audio_path: str):
        """
        从音频文件路径加载并预处理音频,生成Fbank特征。
        这段逻辑完全复制自旧的 similarity_model.py 中的 preprocess 方法。
        """
        try:
            waveform, sample_rate = torchaudio.load(audio_path)
            
            # 如果采样率不匹配,则重采样
            if sample_rate != self.sample_rate:
                resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
                waveform = resampler(waveform)

            # 如果音频太短,则重复填充以满足最小长度
            duration = waveform.shape[1] / self.sample_rate
            if duration < self.min_duration:
                repeat_times = math.ceil(self.min_duration / duration)
                waveform = waveform.repeat(1, repeat_times)
            
            # 计算80维Fbank特征
            # waveform 需要是 [batch, time] 格式,所以我们移除通道维度
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True) # 转为单声道

            features = kaldi.fbank(
                waveform,
                num_mel_bins=self.feature_dim,
                sample_frequency=self.sample_rate,
                frame_length=25,
                frame_shift=10
            )
            # fbank 输出 shape [1, num_frames, num_bins], 我们需要 [num_frames, 80]
            return features.squeeze(0).numpy().astype(np.float32)
            
        except Exception as e:
            raise RuntimeError(f"Failed during audio preprocessing for {audio_path}: {e}")

    def _calculate_cosine_similarity(self, emb1: np.ndarray, emb2: np.ndarray):
        """在客户端计算余弦相似度。"""
        e1 = torch.from_numpy(emb1).flatten()
        e2 = torch.from_numpy(emb2).flatten()
        
        similarity = torch.nn.functional.cosine_similarity(e1, e2, dim=0)
        
        # 将相似度从 [-1, 1] 范围归一化到 [0, 1]
        return (similarity.item() + 1) / 2

    async def compute_similarity(self, audio1_path: str, audio2_path: str):
        """
        计算两个音频文件的相似度。
        此函数现在包含完整的处理流程:预处理 -> 批处理 -> 推理 -> 后处理。
        """
        # 1. 在客户端对两个音频文件进行预处理
        feats1 = self._preprocess_audio(audio1_path)
        feats2 = self._preprocess_audio(audio2_path)
        
        # 2. 批处理:为了使用Triton的动态批处理,我们将两个特征打包成一个请求。
        #    由于它们的长度(帧数)可能不同,我们需要将它们填充到相同的长度。
        max_len = max(feats1.shape[0], feats2.shape[0])
        
        # 使用np.pad进行填充
        padded_feats1 = np.pad(feats1, ((0, max_len - feats1.shape[0]), (0, 0)), 'constant', constant_values=0)
        padded_feats2 = np.pad(feats2, ((0, max_len - feats2.shape[0]), (0, 0)), 'constant', constant_values=0)
        
        # 将填充后的特征堆叠成一个批次
        input_batch = np.stack([padded_feats1, padded_feats2]) # Shape: [2, max_len, 80]

        # 3. 创建Triton输入张量
        #    输入名称 "feats" 必须与 speaker_model 的 config.pbtxt 中的输入名匹配
        inputs = [
            grpcclient.InferInput("feats", input_batch.shape, "FP32")
        ]
        inputs[0].set_data_from_numpy(input_batch)
        
        # 4. 设置请求的输出
        #    输出名称 "embs" 必须与 speaker_model 的 config.pbtxt 中的输出名匹配
        outputs = [grpcclient.InferRequestedOutput("embs")]
        
        # 5. 发送推理请求
        response = await self.triton_client.infer(
            model_name=self.model_name,
            inputs=inputs,
            outputs=outputs
        )
        
        # 6. 解析结果
        embeddings_batch = response.as_numpy("embs") # Shape: [2, embedding_dim]
        emb1 = embeddings_batch[0]
        emb2 = embeddings_batch[1]
        
        # 7. 在客户端计算相似度
        similarity = self._calculate_cosine_similarity(emb1, emb2)
        return similarity

async def main():
    parser = argparse.ArgumentParser(description="Triton client for speaker model (direct call).")
    parser.add_argument('-v', '--verbose', action="store_true", default=False, help='Enable verbose output')
    parser.add_argument('-u', '--url', type=str, default='localhost:8001', help='Inference server URL.')
    # 注意:这里的 model_name 应该是 speaker_model
    parser.add_argument('--model_name', default='speaker_model', help='The name of the speaker embedding model on Triton.')
    parser.add_argument('--audio_file1', type=str, required=True, help='Path to first audio file')
    parser.add_argument('--audio_file2', type=str, required=True, help='Path to second audio file')
    
    FLAGS = parser.parse_args()

    client = TritonSpeakerClient(FLAGS.url, FLAGS.model_name, verbose=FLAGS.verbose)
    
    start_time = time.time()
    try:
        similarity = await client.compute_similarity(FLAGS.audio_file1, FLAGS.audio_file2)
        elapsed = time.time() - start_time
        print(f"Similarity: {similarity:.4f}, Time: {elapsed:.3f}s")
    except Exception as e:
        print(f"Error computing similarity: {e}", file=sys.stderr)
        sys.exit(1)

# 使用示例:
# python speaker_client.py --audio_file1=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi1.wav --audio_file2=/inspire/hdd/project/embodied-multimodality/public/yqzhang/infer_prompt/testset/audio/yanzi/yanzi2.wav
if __name__ == '__main__':
    asyncio.run(main())