|
from dataclasses import make_dataclass |
|
|
|
import torch |
|
import torchaudio |
|
from torch import nn |
|
|
|
from .usad_modules import ConformerEncoder |
|
|
|
MAX_MEL_LENGTH = 3000 |
|
|
|
|
|
@torch.no_grad() |
|
def wav_to_fbank( |
|
wavs: torch.Tensor, |
|
mel_dim: int = 128, |
|
norm_mean: float = -4.268, |
|
norm_std: float = 4.569, |
|
) -> torch.Tensor: |
|
"""Convert waveform to fbank features. |
|
|
|
Args: |
|
wavs (torch.Tensor): (B, T_wav) waveform tensor. |
|
mel_dim (int, optional): mel dimension. Defaults to 128. |
|
norm_mean (float, optional): |
|
mean for normalization. Defaults to -4.268. |
|
norm_std (float, optional): |
|
std for normalization. Defaults to 4.569. |
|
|
|
Returns: |
|
torch.Tensor: (B, T_mel, mel_dim) fbank features. |
|
""" |
|
|
|
dtype = wavs.dtype |
|
wavs = wavs.to(torch.float32) |
|
wavs = wavs - wavs.mean(dim=-1, keepdim=True) |
|
feats = [ |
|
torchaudio.compliance.kaldi.fbank( |
|
wavs[i : i + 1], |
|
htk_compat=True, |
|
sample_frequency=16000, |
|
use_energy=False, |
|
window_type="hanning", |
|
num_mel_bins=mel_dim, |
|
dither=0.0, |
|
frame_shift=10, |
|
).to(dtype=dtype) |
|
for i in range(wavs.shape[0]) |
|
] |
|
|
|
mels = torch.stack(feats, dim=0) |
|
mels = (mels - norm_mean) / (norm_std * 2) |
|
|
|
return mels |
|
|
|
|
|
class UsadModel(nn.Module): |
|
def __init__(self, cfg) -> None: |
|
"""Initialize the UsadModel. |
|
Args: |
|
cfg: Configuration object containing model parameters. |
|
""" |
|
super().__init__() |
|
|
|
self.cfg = cfg |
|
self.encoder = ConformerEncoder(cfg) |
|
self.max_mel_length = MAX_MEL_LENGTH |
|
|
|
|
|
|
|
@property |
|
def sample_rate(self) -> int: |
|
return 16000 |
|
|
|
@property |
|
def encoder_frame_rate(self) -> int: |
|
return 50 |
|
|
|
@property |
|
def mel_dim(self) -> int: |
|
return self.cfg.input_dim |
|
|
|
@property |
|
def encoder_dim(self) -> int: |
|
return self.cfg.encoder_dim |
|
|
|
@property |
|
def num_layers(self) -> int: |
|
return self.cfg.num_layers |
|
|
|
@property |
|
def scene_embedding_size(self) -> int: |
|
return self.cfg.encoder_dim * self.cfg.num_layers |
|
|
|
@property |
|
def timestamp_embedding_size(self) -> int: |
|
return self.cfg.encoder_dim * self.cfg.num_layers |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
"""Get the device on which the model is located.""" |
|
return next(self.parameters()).device |
|
|
|
def set_audio_chunk_size(self, seconds: float = 30.0) -> None: |
|
"""Set the maximum chunk size for feature extraction. |
|
|
|
Args: |
|
seconds (float, optional): Chunk size in seconds. Defaults to 30.0. |
|
""" |
|
assert ( |
|
seconds >= 0.1 |
|
), f"Chunk size must be greater than 0.1s, got {seconds} seconds." |
|
self.max_mel_length = int(seconds * 100) |
|
|
|
def load_audio(self, audio_path: str) -> torch.Tensor: |
|
"""Load audio file and return waveform tensor. |
|
Args: |
|
audio_path (str): Path to the audio file. |
|
|
|
Returns: |
|
torch.Tensor: Waveform tensor of shape (wav_len,). |
|
""" |
|
|
|
waveform, sr = torchaudio.load(audio_path) |
|
if sr != self.sample_rate: |
|
waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate) |
|
if waveform.shape[0] > 1: |
|
|
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
|
|
waveform = waveform.squeeze(0) |
|
return waveform.to(self.device) |
|
|
|
def forward( |
|
self, |
|
wavs: torch.Tensor, |
|
norm_mean: float = -4.268, |
|
norm_std: float = 4.569, |
|
) -> dict: |
|
"""Forward pass for the model. |
|
|
|
Args: |
|
wavs (torch.Tensor): |
|
Input waveform tensor of shape (batch_size, wav_len). |
|
norm_mean (float, optional): |
|
Mean for normalization. Defaults to -4.268. |
|
norm_std (float, optional): |
|
Standard deviation for normalization. Defaults to 4.569. |
|
|
|
Returns: |
|
dict: A dictionary containing the model's outputs. |
|
""" |
|
|
|
|
|
mel = wav_to_fbank(wavs, norm_mean=norm_mean, norm_std=norm_std) |
|
mel = mel[:, : mel.shape[1] - mel.shape[1] % 2] |
|
if mel.shape[1] <= self.max_mel_length: |
|
x, x_len, layer_results = self.encoder(mel, return_hidden=True) |
|
|
|
result = { |
|
"x": x, |
|
"mel": mel, |
|
"hidden_states": layer_results["hidden_states"], |
|
"ffn": layer_results["ffn_1"], |
|
} |
|
return result |
|
|
|
result = { |
|
"x": [], |
|
"mel": mel, |
|
"hidden_states": [[] for _ in range(self.cfg.num_layers)], |
|
"ffn": [[] for _ in range(self.cfg.num_layers)], |
|
} |
|
for i in range(0, mel.shape[1], self.max_mel_length): |
|
if mel.shape[1] - i < 10: |
|
break |
|
|
|
x, x_len, layer_results = self.encoder( |
|
mel[:, i : i + self.max_mel_length], return_hidden=True |
|
) |
|
result["x"].append(x) |
|
for j in range(self.cfg.num_layers): |
|
result["hidden_states"][j].append(layer_results["hidden_states"][j]) |
|
result["ffn"][j].append(layer_results["ffn_1"][j]) |
|
|
|
result["x"] = torch.cat(result["x"], dim=1) |
|
for j in range(self.cfg.num_layers): |
|
result["hidden_states"][j] = torch.cat(result["hidden_states"][j], dim=1) |
|
result["ffn"][j] = torch.cat(result["ffn"][j], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
return result |
|
|
|
@classmethod |
|
def load_from_fairseq_ckpt(cls, ckpt_path: str): |
|
checkpoint = torch.load(ckpt_path, weights_only=False) |
|
config = checkpoint["cfg"]["model"] |
|
config = make_dataclass("Config", config.keys())(**config) |
|
model = cls(config) |
|
state_dict = checkpoint["model"] |
|
for k in list(state_dict.keys()): |
|
if not k.startswith("encoder."): |
|
del state_dict[k] |
|
model.load_state_dict(state_dict, strict=True) |
|
return model |
|
|