from dataclasses import make_dataclass import torch import torchaudio from torch import nn from .usad_modules import ConformerEncoder MAX_MEL_LENGTH = 3000 # 30 seconds @torch.no_grad() def wav_to_fbank( wavs: torch.Tensor, mel_dim: int = 128, norm_mean: float = -4.268, norm_std: float = 4.569, ) -> torch.Tensor: """Convert waveform to fbank features. Args: wavs (torch.Tensor): (B, T_wav) waveform tensor. mel_dim (int, optional): mel dimension. Defaults to 128. norm_mean (float, optional): mean for normalization. Defaults to -4.268. norm_std (float, optional): std for normalization. Defaults to 4.569. Returns: torch.Tensor: (B, T_mel, mel_dim) fbank features. """ # ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract dtype = wavs.dtype wavs = wavs.to(torch.float32) wavs = wavs - wavs.mean(dim=-1, keepdim=True) feats = [ torchaudio.compliance.kaldi.fbank( wavs[i : i + 1], htk_compat=True, sample_frequency=16000, use_energy=False, window_type="hanning", num_mel_bins=mel_dim, dither=0.0, frame_shift=10, ).to(dtype=dtype) for i in range(wavs.shape[0]) ] mels = torch.stack(feats, dim=0) mels = (mels - norm_mean) / (norm_std * 2) return mels class UsadModel(nn.Module): def __init__(self, cfg) -> None: """Initialize the UsadModel. Args: cfg: Configuration object containing model parameters. """ super().__init__() self.cfg = cfg self.encoder = ConformerEncoder(cfg) self.max_mel_length = MAX_MEL_LENGTH # NOTE: The max_mel_length is set to 3000, # which corresponds to 30 seconds of audio at 100 Hz frame rate. @property def sample_rate(self) -> int: return 16000 # Hz @property def encoder_frame_rate(self) -> int: return 50 # Hz @property def mel_dim(self) -> int: return self.cfg.input_dim @property def encoder_dim(self) -> int: return self.cfg.encoder_dim @property def num_layers(self) -> int: return self.cfg.num_layers @property def scene_embedding_size(self) -> int: return self.cfg.encoder_dim * self.cfg.num_layers @property def timestamp_embedding_size(self) -> int: return self.cfg.encoder_dim * self.cfg.num_layers @property def device(self) -> torch.device: """Get the device on which the model is located.""" return next(self.parameters()).device def set_audio_chunk_size(self, seconds: float = 30.0) -> None: """Set the maximum chunk size for feature extraction. Args: seconds (float, optional): Chunk size in seconds. Defaults to 30.0. """ assert ( seconds >= 0.1 ), f"Chunk size must be greater than 0.1s, got {seconds} seconds." self.max_mel_length = int(seconds * 100) # 100 Hz frame rate def load_audio(self, audio_path: str) -> torch.Tensor: """Load audio file and return waveform tensor. Args: audio_path (str): Path to the audio file. Returns: torch.Tensor: Waveform tensor of shape (wav_len,). """ waveform, sr = torchaudio.load(audio_path) if sr != self.sample_rate: waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate) if waveform.shape[0] > 1: # If stereo, convert to mono by averaging channels waveform = waveform.mean(dim=0, keepdim=True) waveform = waveform.squeeze(0) # Remove channel dimension if mono return waveform.to(self.device) # Ensure tensor is on the same device def forward( self, wavs: torch.Tensor, norm_mean: float = -4.268, norm_std: float = 4.569, ) -> dict: """Forward pass for the model. Args: wavs (torch.Tensor): Input waveform tensor of shape (batch_size, wav_len). norm_mean (float, optional): Mean for normalization. Defaults to -4.268. norm_std (float, optional): Standard deviation for normalization. Defaults to 4.569. Returns: dict: A dictionary containing the model's outputs. """ # wavs: (batch_size, wav_len) mel = wav_to_fbank(wavs, norm_mean=norm_mean, norm_std=norm_std) mel = mel[:, : mel.shape[1] - mel.shape[1] % 2] if mel.shape[1] <= self.max_mel_length: x, x_len, layer_results = self.encoder(mel, return_hidden=True) result = { "x": x, "mel": mel, "hidden_states": layer_results["hidden_states"], "ffn": layer_results["ffn_1"], } return result result = { "x": [], "mel": mel, "hidden_states": [[] for _ in range(self.cfg.num_layers)], "ffn": [[] for _ in range(self.cfg.num_layers)], } for i in range(0, mel.shape[1], self.max_mel_length): if mel.shape[1] - i < 10: break x, x_len, layer_results = self.encoder( mel[:, i : i + self.max_mel_length], return_hidden=True ) result["x"].append(x) for j in range(self.cfg.num_layers): result["hidden_states"][j].append(layer_results["hidden_states"][j]) result["ffn"][j].append(layer_results["ffn_1"][j]) result["x"] = torch.cat(result["x"], dim=1) for j in range(self.cfg.num_layers): result["hidden_states"][j] = torch.cat(result["hidden_states"][j], dim=1) result["ffn"][j] = torch.cat(result["ffn"][j], dim=1) # result["x"]: model final output (batch_size, seq_len) # result["mel"]: mel fbank (batch_size, seq_len * 2, mel_dim) # result["hidden_states"]: List of (batch_size, seq_len, encoder_dim) # result["ffn"]: List of (batch_size, seq_len, encoder_dim) return result @classmethod def load_from_fairseq_ckpt(cls, ckpt_path: str): checkpoint = torch.load(ckpt_path, weights_only=False) config = checkpoint["cfg"]["model"] config = make_dataclass("Config", config.keys())(**config) model = cls(config) state_dict = checkpoint["model"] for k in list(state_dict.keys()): if not k.startswith("encoder."): del state_dict[k] model.load_state_dict(state_dict, strict=True) return model