USAD-Large / usad_model.py
vectominist's picture
upload model and code
aab2435
from dataclasses import make_dataclass
import torch
import torchaudio
from torch import nn
from .usad_modules import ConformerEncoder
MAX_MEL_LENGTH = 3000 # 30 seconds
@torch.no_grad()
def wav_to_fbank(
wavs: torch.Tensor,
mel_dim: int = 128,
norm_mean: float = -4.268,
norm_std: float = 4.569,
) -> torch.Tensor:
"""Convert waveform to fbank features.
Args:
wavs (torch.Tensor): (B, T_wav) waveform tensor.
mel_dim (int, optional): mel dimension. Defaults to 128.
norm_mean (float, optional):
mean for normalization. Defaults to -4.268.
norm_std (float, optional):
std for normalization. Defaults to 4.569.
Returns:
torch.Tensor: (B, T_mel, mel_dim) fbank features.
"""
# ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract
dtype = wavs.dtype
wavs = wavs.to(torch.float32)
wavs = wavs - wavs.mean(dim=-1, keepdim=True)
feats = [
torchaudio.compliance.kaldi.fbank(
wavs[i : i + 1],
htk_compat=True,
sample_frequency=16000,
use_energy=False,
window_type="hanning",
num_mel_bins=mel_dim,
dither=0.0,
frame_shift=10,
).to(dtype=dtype)
for i in range(wavs.shape[0])
]
mels = torch.stack(feats, dim=0)
mels = (mels - norm_mean) / (norm_std * 2)
return mels
class UsadModel(nn.Module):
def __init__(self, cfg) -> None:
"""Initialize the UsadModel.
Args:
cfg: Configuration object containing model parameters.
"""
super().__init__()
self.cfg = cfg
self.encoder = ConformerEncoder(cfg)
self.max_mel_length = MAX_MEL_LENGTH
# NOTE: The max_mel_length is set to 3000,
# which corresponds to 30 seconds of audio at 100 Hz frame rate.
@property
def sample_rate(self) -> int:
return 16000 # Hz
@property
def encoder_frame_rate(self) -> int:
return 50 # Hz
@property
def mel_dim(self) -> int:
return self.cfg.input_dim
@property
def encoder_dim(self) -> int:
return self.cfg.encoder_dim
@property
def num_layers(self) -> int:
return self.cfg.num_layers
@property
def scene_embedding_size(self) -> int:
return self.cfg.encoder_dim * self.cfg.num_layers
@property
def timestamp_embedding_size(self) -> int:
return self.cfg.encoder_dim * self.cfg.num_layers
@property
def device(self) -> torch.device:
"""Get the device on which the model is located."""
return next(self.parameters()).device
def set_audio_chunk_size(self, seconds: float = 30.0) -> None:
"""Set the maximum chunk size for feature extraction.
Args:
seconds (float, optional): Chunk size in seconds. Defaults to 30.0.
"""
assert (
seconds >= 0.1
), f"Chunk size must be greater than 0.1s, got {seconds} seconds."
self.max_mel_length = int(seconds * 100) # 100 Hz frame rate
def load_audio(self, audio_path: str) -> torch.Tensor:
"""Load audio file and return waveform tensor.
Args:
audio_path (str): Path to the audio file.
Returns:
torch.Tensor: Waveform tensor of shape (wav_len,).
"""
waveform, sr = torchaudio.load(audio_path)
if sr != self.sample_rate:
waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
if waveform.shape[0] > 1:
# If stereo, convert to mono by averaging channels
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform.squeeze(0) # Remove channel dimension if mono
return waveform.to(self.device) # Ensure tensor is on the same device
def forward(
self,
wavs: torch.Tensor,
norm_mean: float = -4.268,
norm_std: float = 4.569,
) -> dict:
"""Forward pass for the model.
Args:
wavs (torch.Tensor):
Input waveform tensor of shape (batch_size, wav_len).
norm_mean (float, optional):
Mean for normalization. Defaults to -4.268.
norm_std (float, optional):
Standard deviation for normalization. Defaults to 4.569.
Returns:
dict: A dictionary containing the model's outputs.
"""
# wavs: (batch_size, wav_len)
mel = wav_to_fbank(wavs, norm_mean=norm_mean, norm_std=norm_std)
mel = mel[:, : mel.shape[1] - mel.shape[1] % 2]
if mel.shape[1] <= self.max_mel_length:
x, x_len, layer_results = self.encoder(mel, return_hidden=True)
result = {
"x": x,
"mel": mel,
"hidden_states": layer_results["hidden_states"],
"ffn": layer_results["ffn_1"],
}
return result
result = {
"x": [],
"mel": mel,
"hidden_states": [[] for _ in range(self.cfg.num_layers)],
"ffn": [[] for _ in range(self.cfg.num_layers)],
}
for i in range(0, mel.shape[1], self.max_mel_length):
if mel.shape[1] - i < 10:
break
x, x_len, layer_results = self.encoder(
mel[:, i : i + self.max_mel_length], return_hidden=True
)
result["x"].append(x)
for j in range(self.cfg.num_layers):
result["hidden_states"][j].append(layer_results["hidden_states"][j])
result["ffn"][j].append(layer_results["ffn_1"][j])
result["x"] = torch.cat(result["x"], dim=1)
for j in range(self.cfg.num_layers):
result["hidden_states"][j] = torch.cat(result["hidden_states"][j], dim=1)
result["ffn"][j] = torch.cat(result["ffn"][j], dim=1)
# result["x"]: model final output (batch_size, seq_len)
# result["mel"]: mel fbank (batch_size, seq_len * 2, mel_dim)
# result["hidden_states"]: List of (batch_size, seq_len, encoder_dim)
# result["ffn"]: List of (batch_size, seq_len, encoder_dim)
return result
@classmethod
def load_from_fairseq_ckpt(cls, ckpt_path: str):
checkpoint = torch.load(ckpt_path, weights_only=False)
config = checkpoint["cfg"]["model"]
config = make_dataclass("Config", config.keys())(**config)
model = cls(config)
state_dict = checkpoint["model"]
for k in list(state_dict.keys()):
if not k.startswith("encoder."):
del state_dict[k]
model.load_state_dict(state_dict, strict=True)
return model