| | from dataclasses import make_dataclass |
| |
|
| | import torch |
| | import torchaudio |
| | from torch import nn |
| |
|
| | from .usad_modules import ConformerEncoder |
| |
|
| | MAX_MEL_LENGTH = 3000 |
| |
|
| |
|
| | @torch.no_grad() |
| | def wav_to_fbank( |
| | wavs: torch.Tensor, |
| | mel_dim: int = 128, |
| | norm_mean: float = -4.268, |
| | norm_std: float = 4.569, |
| | ) -> torch.Tensor: |
| | """Convert waveform to fbank features. |
| | |
| | Args: |
| | wavs (torch.Tensor): (B, T_wav) waveform tensor. |
| | mel_dim (int, optional): mel dimension. Defaults to 128. |
| | norm_mean (float, optional): |
| | mean for normalization. Defaults to -4.268. |
| | norm_std (float, optional): |
| | std for normalization. Defaults to 4.569. |
| | |
| | Returns: |
| | torch.Tensor: (B, T_mel, mel_dim) fbank features. |
| | """ |
| | |
| | dtype = wavs.dtype |
| | wavs = wavs.to(torch.float32) |
| | wavs = wavs - wavs.mean(dim=-1, keepdim=True) |
| | feats = [ |
| | torchaudio.compliance.kaldi.fbank( |
| | wavs[i : i + 1], |
| | htk_compat=True, |
| | sample_frequency=16000, |
| | use_energy=False, |
| | window_type="hanning", |
| | num_mel_bins=mel_dim, |
| | dither=0.0, |
| | frame_shift=10, |
| | ).to(dtype=dtype) |
| | for i in range(wavs.shape[0]) |
| | ] |
| |
|
| | mels = torch.stack(feats, dim=0) |
| | mels = (mels - norm_mean) / (norm_std * 2) |
| |
|
| | return mels |
| |
|
| |
|
| | class UsadModel(nn.Module): |
| | def __init__(self, cfg) -> None: |
| | """Initialize the UsadModel. |
| | Args: |
| | cfg: Configuration object containing model parameters. |
| | """ |
| | super().__init__() |
| |
|
| | self.cfg = cfg |
| | self.encoder = ConformerEncoder(cfg) |
| | self.max_mel_length = MAX_MEL_LENGTH |
| | |
| | |
| |
|
| | @property |
| | def sample_rate(self) -> int: |
| | return 16000 |
| |
|
| | @property |
| | def encoder_frame_rate(self) -> int: |
| | return 50 |
| |
|
| | @property |
| | def mel_dim(self) -> int: |
| | return self.cfg.input_dim |
| |
|
| | @property |
| | def encoder_dim(self) -> int: |
| | return self.cfg.encoder_dim |
| |
|
| | @property |
| | def num_layers(self) -> int: |
| | return self.cfg.num_layers |
| |
|
| | @property |
| | def scene_embedding_size(self) -> int: |
| | return self.cfg.encoder_dim * self.cfg.num_layers |
| |
|
| | @property |
| | def timestamp_embedding_size(self) -> int: |
| | return self.cfg.encoder_dim * self.cfg.num_layers |
| |
|
| | @property |
| | def device(self) -> torch.device: |
| | """Get the device on which the model is located.""" |
| | return next(self.parameters()).device |
| |
|
| | def set_audio_chunk_size(self, seconds: float = 30.0) -> None: |
| | """Set the maximum chunk size for feature extraction. |
| | |
| | Args: |
| | seconds (float, optional): Chunk size in seconds. Defaults to 30.0. |
| | """ |
| | assert ( |
| | seconds >= 0.1 |
| | ), f"Chunk size must be greater than 0.1s, got {seconds} seconds." |
| | self.max_mel_length = int(seconds * 100) |
| |
|
| | def load_audio(self, audio_path: str) -> torch.Tensor: |
| | """Load audio file and return waveform tensor. |
| | Args: |
| | audio_path (str): Path to the audio file. |
| | |
| | Returns: |
| | torch.Tensor: Waveform tensor of shape (wav_len,). |
| | """ |
| |
|
| | waveform, sr = torchaudio.load(audio_path) |
| | if sr != self.sample_rate: |
| | waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate) |
| | if waveform.shape[0] > 1: |
| | |
| | waveform = waveform.mean(dim=0, keepdim=True) |
| |
|
| | waveform = waveform.squeeze(0) |
| | return waveform.to(self.device) |
| |
|
| | def forward( |
| | self, |
| | wavs: torch.Tensor, |
| | norm_mean: float = -4.268, |
| | norm_std: float = 4.569, |
| | ) -> dict: |
| | """Forward pass for the model. |
| | |
| | Args: |
| | wavs (torch.Tensor): |
| | Input waveform tensor of shape (batch_size, wav_len). |
| | norm_mean (float, optional): |
| | Mean for normalization. Defaults to -4.268. |
| | norm_std (float, optional): |
| | Standard deviation for normalization. Defaults to 4.569. |
| | |
| | Returns: |
| | dict: A dictionary containing the model's outputs. |
| | """ |
| | |
| |
|
| | mel = wav_to_fbank(wavs, norm_mean=norm_mean, norm_std=norm_std) |
| | mel = mel[:, : mel.shape[1] - mel.shape[1] % 2] |
| | if mel.shape[1] <= self.max_mel_length: |
| | x, x_len, layer_results = self.encoder(mel, return_hidden=True) |
| |
|
| | result = { |
| | "x": x, |
| | "mel": mel, |
| | "hidden_states": layer_results["hidden_states"], |
| | "ffn": layer_results["ffn_1"], |
| | } |
| | return result |
| |
|
| | result = { |
| | "x": [], |
| | "mel": mel, |
| | "hidden_states": [[] for _ in range(self.cfg.num_layers)], |
| | "ffn": [[] for _ in range(self.cfg.num_layers)], |
| | } |
| | for i in range(0, mel.shape[1], self.max_mel_length): |
| | if mel.shape[1] - i < 10: |
| | break |
| |
|
| | x, x_len, layer_results = self.encoder( |
| | mel[:, i : i + self.max_mel_length], return_hidden=True |
| | ) |
| | result["x"].append(x) |
| | for j in range(self.cfg.num_layers): |
| | result["hidden_states"][j].append(layer_results["hidden_states"][j]) |
| | result["ffn"][j].append(layer_results["ffn_1"][j]) |
| |
|
| | result["x"] = torch.cat(result["x"], dim=1) |
| | for j in range(self.cfg.num_layers): |
| | result["hidden_states"][j] = torch.cat(result["hidden_states"][j], dim=1) |
| | result["ffn"][j] = torch.cat(result["ffn"][j], dim=1) |
| |
|
| | |
| | |
| | |
| | |
| | return result |
| |
|
| | @classmethod |
| | def load_from_fairseq_ckpt(cls, ckpt_path: str): |
| | checkpoint = torch.load(ckpt_path, weights_only=False) |
| | config = checkpoint["cfg"]["model"] |
| | config = make_dataclass("Config", config.keys())(**config) |
| | model = cls(config) |
| | state_dict = checkpoint["model"] |
| | for k in list(state_dict.keys()): |
| | if not k.startswith("encoder."): |
| | del state_dict[k] |
| | model.load_state_dict(state_dict, strict=True) |
| | return model |
| |
|