| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import contextlib |
| | import math |
| | from collections import defaultdict |
| | from typing import Dict, List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| |
|
| | class SamePad(nn.Module): |
| | def __init__(self, kernel_size, causal=False): |
| | super().__init__() |
| | if causal: |
| | self.remove = kernel_size - 1 |
| | else: |
| | self.remove = 1 if kernel_size % 2 == 0 else 0 |
| |
|
| | def forward(self, x): |
| | if self.remove > 0: |
| | x = x[:, :, : -self.remove] |
| | return x |
| |
|
| |
|
| | class TransposeLast(nn.Module): |
| | def __init__(self, deconstruct_idx=None, tranpose_dim=-2): |
| | super().__init__() |
| | self.deconstruct_idx = deconstruct_idx |
| | self.tranpose_dim = tranpose_dim |
| |
|
| | def forward(self, x): |
| | if self.deconstruct_idx is not None: |
| | x = x[self.deconstruct_idx] |
| | return x.transpose(self.tranpose_dim, -1) |
| |
|
| |
|
| | class Swish(nn.Module): |
| | def __init__(self): |
| | super(Swish, self).__init__() |
| |
|
| | def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| | return inputs * inputs.sigmoid() |
| |
|
| |
|
| | class GLU(nn.Module): |
| | def __init__(self, dim: int) -> None: |
| | super(GLU, self).__init__() |
| | self.dim = dim |
| |
|
| | def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| | outputs, gate = inputs.chunk(2, dim=self.dim) |
| | return outputs * gate.sigmoid() |
| |
|
| |
|
| | class ResidualConnectionModule(nn.Module): |
| | def __init__( |
| | self, |
| | module: nn.Module, |
| | module_factor: float = 1.0, |
| | input_factor: float = 1.0, |
| | ): |
| | super(ResidualConnectionModule, self).__init__() |
| | self.module = module |
| | self.module_factor = module_factor |
| | self.input_factor = input_factor |
| |
|
| | def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| | return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) |
| |
|
| |
|
| | class Linear(nn.Module): |
| | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: |
| | super(Linear, self).__init__() |
| | self.linear = nn.Linear(in_features, out_features, bias=bias) |
| | nn.init.xavier_uniform_(self.linear.weight) |
| | if bias: |
| | nn.init.zeros_(self.linear.bias) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.linear(x) |
| |
|
| |
|
| | class View(nn.Module): |
| | def __init__(self, shape: tuple, contiguous: bool = False): |
| | super(View, self).__init__() |
| | self.shape = shape |
| | self.contiguous = contiguous |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | if self.contiguous: |
| | x = x.contiguous() |
| |
|
| | return x.view(*self.shape) |
| |
|
| |
|
| | class Transpose(nn.Module): |
| | def __init__(self, shape: tuple): |
| | super(Transpose, self).__init__() |
| | self.shape = shape |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return x.transpose(*self.shape) |
| |
|
| |
|
| | class FeedForwardModule(nn.Module): |
| | def __init__( |
| | self, |
| | encoder_dim: int = 512, |
| | expansion_factor: int = 4, |
| | dropout_p: float = 0.1, |
| | ) -> None: |
| | super(FeedForwardModule, self).__init__() |
| | self.sequential = nn.Sequential( |
| | nn.LayerNorm(encoder_dim), |
| | Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), |
| | Swish(), |
| | nn.Dropout(p=dropout_p), |
| | Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), |
| | nn.Dropout(p=dropout_p), |
| | ) |
| |
|
| | def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| | return self.sequential(inputs) |
| |
|
| |
|
| | class DepthwiseConv1d(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | kernel_size: int, |
| | stride: int = 1, |
| | padding: int = 0, |
| | bias: bool = False, |
| | ) -> None: |
| | super(DepthwiseConv1d, self).__init__() |
| | assert ( |
| | out_channels % in_channels == 0 |
| | ), "out_channels should be constant multiple of in_channels" |
| | self.conv = nn.Conv1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | groups=in_channels, |
| | stride=stride, |
| | padding=padding, |
| | bias=bias, |
| | ) |
| |
|
| | def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| | return self.conv(inputs) |
| |
|
| |
|
| | class PointwiseConv1d(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | stride: int = 1, |
| | padding: int = 0, |
| | bias: bool = True, |
| | ) -> None: |
| | super(PointwiseConv1d, self).__init__() |
| | self.conv = nn.Conv1d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=1, |
| | stride=stride, |
| | padding=padding, |
| | bias=bias, |
| | ) |
| |
|
| | def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| | return self.conv(inputs) |
| |
|
| |
|
| | class ConformerConvModule(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | kernel_size: int = 31, |
| | expansion_factor: int = 2, |
| | dropout_p: float = 0.1, |
| | ) -> None: |
| | super(ConformerConvModule, self).__init__() |
| | assert ( |
| | kernel_size - 1 |
| | ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" |
| | assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" |
| |
|
| | self.sequential = nn.Sequential( |
| | nn.LayerNorm(in_channels), |
| | Transpose(shape=(1, 2)), |
| | PointwiseConv1d( |
| | in_channels, |
| | in_channels * expansion_factor, |
| | stride=1, |
| | padding=0, |
| | bias=True, |
| | ), |
| | GLU(dim=1), |
| | DepthwiseConv1d( |
| | in_channels, |
| | in_channels, |
| | kernel_size, |
| | stride=1, |
| | padding=(kernel_size - 1) // 2, |
| | ), |
| | nn.BatchNorm1d(in_channels), |
| | Swish(), |
| | PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), |
| | nn.Dropout(p=dropout_p), |
| | ) |
| |
|
| | def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| | return self.sequential(inputs).transpose(1, 2) |
| |
|
| |
|
| | class FramewiseConv2dSubampling(nn.Module): |
| | def __init__(self, out_channels: int, subsample_rate: int = 2) -> None: |
| | super(FramewiseConv2dSubampling, self).__init__() |
| | assert subsample_rate in {2, 4}, "subsample_rate should be 2 or 4" |
| | self.subsample_rate = subsample_rate |
| | self.cnn = nn.Sequential( |
| | nn.Conv2d(1, out_channels, kernel_size=3, stride=2), |
| | nn.ReLU(), |
| | nn.Conv2d( |
| | out_channels, |
| | out_channels, |
| | kernel_size=3, |
| | stride=(2 if subsample_rate == 4 else 1, 2), |
| | padding=(0 if subsample_rate == 4 else 1, 0), |
| | ), |
| | nn.ReLU(), |
| | ) |
| |
|
| | def forward( |
| | self, inputs: torch.Tensor, input_lengths: torch.LongTensor |
| | ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| | |
| | if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0: |
| | inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0) |
| | outputs = self.cnn(inputs.unsqueeze(1)) |
| | batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size() |
| |
|
| | outputs = outputs.permute(0, 2, 1, 3) |
| | outputs = outputs.contiguous().view( |
| | batch_size, subsampled_lengths, channels * sumsampled_dim |
| | ) |
| |
|
| | if self.subsample_rate == 4: |
| | output_lengths = (((input_lengths - 1) >> 1) - 1) >> 1 |
| | else: |
| | output_lengths = input_lengths >> 1 |
| |
|
| | return outputs, output_lengths |
| |
|
| |
|
| | class PatchwiseConv2dSubampling(nn.Module): |
| | def __init__( |
| | self, |
| | mel_dim: int, |
| | out_channels: int, |
| | patch_size_time: int = 16, |
| | patch_size_freq: int = 16, |
| | ) -> None: |
| | super(PatchwiseConv2dSubampling, self).__init__() |
| |
|
| | self.mel_dim = mel_dim |
| | self.patch_size_time = patch_size_time |
| | self.patch_size_freq = patch_size_freq |
| |
|
| | self.proj = nn.Conv2d( |
| | 1, |
| | out_channels, |
| | kernel_size=(patch_size_time, patch_size_freq), |
| | stride=(patch_size_time, patch_size_freq), |
| | padding=0, |
| | ) |
| | self.cnn = nn.Sequential( |
| | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| | nn.ReLU(), |
| | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
| | nn.ReLU(), |
| | ) |
| |
|
| | @property |
| | def subsample_rate(self) -> int: |
| | return self.patch_size_time * self.patch_size_freq // self.mel_dim |
| |
|
| | def forward( |
| | self, inputs: torch.Tensor, input_lengths: torch.LongTensor |
| | ) -> Tuple[torch.Tensor, torch.LongTensor]: |
| | assert ( |
| | inputs.shape[2] == self.mel_dim |
| | ), "inputs.shape[2] should be equal to mel_dim" |
| |
|
| | |
| | outputs = self.proj(inputs.unsqueeze(1)) |
| | outputs = self.cnn(outputs) |
| | |
| | outputs = outputs.flatten(2, 3).transpose(1, 2) |
| | |
| |
|
| | output_lengths = ( |
| | input_lengths |
| | // self.patch_size_time |
| | * (self.mel_dim // self.patch_size_freq) |
| | ) |
| |
|
| | return outputs, output_lengths |
| |
|
| |
|
| | class RelPositionalEncoding(nn.Module): |
| | def __init__(self, d_model: int, max_len: int = 10000) -> None: |
| | super(RelPositionalEncoding, self).__init__() |
| | self.d_model = d_model |
| | self.pe = None |
| | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) |
| |
|
| | def extend_pe(self, x: torch.Tensor) -> None: |
| | if self.pe is not None: |
| | if self.pe.size(1) >= x.size(1) * 2 - 1: |
| | if self.pe.dtype != x.dtype or self.pe.device != x.device: |
| | self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
| | return |
| |
|
| | pe_positive = torch.zeros(x.size(1), self.d_model) |
| | pe_negative = torch.zeros(x.size(1), self.d_model) |
| | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | * -(math.log(10000.0) / self.d_model) |
| | ) |
| | pe_positive[:, 0::2] = torch.sin(position * div_term) |
| | pe_positive[:, 1::2] = torch.cos(position * div_term) |
| | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
| | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
| |
|
| | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| | pe_negative = pe_negative[1:].unsqueeze(0) |
| | pe = torch.cat([pe_positive, pe_negative], dim=1) |
| | self.pe = pe.to(device=x.device, dtype=x.dtype) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | self.extend_pe(x) |
| | pos_emb = self.pe[ |
| | :, |
| | self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1), |
| | ] |
| | return pos_emb |
| |
|
| |
|
| | class RelativeMultiHeadAttention(nn.Module): |
| | def __init__( |
| | self, |
| | d_model: int = 512, |
| | num_heads: int = 16, |
| | dropout_p: float = 0.1, |
| | ): |
| | super(RelativeMultiHeadAttention, self).__init__() |
| | assert d_model % num_heads == 0, "d_model % num_heads should be zero." |
| | self.d_model = d_model |
| | self.d_head = int(d_model / num_heads) |
| | self.num_heads = num_heads |
| | self.sqrt_dim = math.sqrt(self.d_head) |
| |
|
| | self.query_proj = Linear(d_model, d_model) |
| | self.key_proj = Linear(d_model, d_model) |
| | self.value_proj = Linear(d_model, d_model) |
| | self.pos_proj = Linear(d_model, d_model, bias=False) |
| |
|
| | self.dropout = nn.Dropout(p=dropout_p) |
| | self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) |
| | self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) |
| | torch.nn.init.xavier_uniform_(self.u_bias) |
| | torch.nn.init.xavier_uniform_(self.v_bias) |
| |
|
| | self.out_proj = Linear(d_model, d_model) |
| |
|
| | def forward( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | pos_embedding: torch.Tensor, |
| | mask: Optional[torch.Tensor] = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | batch_size = value.size(0) |
| |
|
| | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) |
| | key = ( |
| | self.key_proj(key) |
| | .view(batch_size, -1, self.num_heads, self.d_head) |
| | .permute(0, 2, 1, 3) |
| | ) |
| | value = ( |
| | self.value_proj(value) |
| | .view(batch_size, -1, self.num_heads, self.d_head) |
| | .permute(0, 2, 1, 3) |
| | ) |
| | pos_embedding = self.pos_proj(pos_embedding).view( |
| | batch_size, -1, self.num_heads, self.d_head |
| | ) |
| |
|
| | content_score = torch.matmul( |
| | (query + self.u_bias).transpose(1, 2), key.transpose(2, 3) |
| | ) |
| | pos_score = torch.matmul( |
| | (query + self.v_bias).transpose(1, 2), |
| | pos_embedding.permute(0, 2, 3, 1), |
| | ) |
| | pos_score = self._relative_shift(pos_score) |
| |
|
| | score = (content_score + pos_score) / self.sqrt_dim |
| |
|
| | if mask is not None: |
| | mask = mask.unsqueeze(1) |
| | score.masked_fill_(mask, -1e9) |
| |
|
| | attn = F.softmax(score, -1) |
| | attn = self.dropout(attn) |
| |
|
| | context = torch.matmul(attn, value).transpose(1, 2) |
| | context = context.contiguous().view(batch_size, -1, self.d_model) |
| |
|
| | return self.out_proj(context), attn |
| |
|
| | def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor: |
| | batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() |
| | zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) |
| | padded_pos_score = torch.cat([zeros, pos_score], dim=-1) |
| |
|
| | padded_pos_score = padded_pos_score.view( |
| | batch_size, num_heads, seq_length2 + 1, seq_length1 |
| | ) |
| | pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)[ |
| | :, :, :, : seq_length2 // 2 + 1 |
| | ] |
| |
|
| | return pos_score |
| |
|
| |
|
| | class MultiHeadedSelfAttentionModule(nn.Module): |
| | def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1): |
| | super(MultiHeadedSelfAttentionModule, self).__init__() |
| | self.positional_encoding = RelPositionalEncoding(d_model) |
| | self.layer_norm = nn.LayerNorm(d_model) |
| | self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) |
| | self.dropout = nn.Dropout(p=dropout_p) |
| |
|
| | def forward( |
| | self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | batch_size = inputs.size(0) |
| | pos_embedding = self.positional_encoding(inputs) |
| | pos_embedding = pos_embedding.repeat(batch_size, 1, 1) |
| |
|
| | inputs = self.layer_norm(inputs) |
| | outputs, attn = self.attention( |
| | inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask |
| | ) |
| |
|
| | return self.dropout(outputs), attn |
| |
|
| |
|
| | class ConformerBlock(nn.Module): |
| | def __init__( |
| | self, |
| | encoder_dim: int = 512, |
| | attention_type: str = "mhsa", |
| | num_attention_heads: int = 8, |
| | mamba_d_state: int = 16, |
| | mamba_d_conv: int = 4, |
| | mamba_expand: int = 2, |
| | mamba_bidirectional: bool = True, |
| | feed_forward_expansion_factor: int = 4, |
| | conv_expansion_factor: int = 2, |
| | feed_forward_dropout_p: float = 0.1, |
| | attention_dropout_p: float = 0.1, |
| | conv_dropout_p: float = 0.1, |
| | conv_kernel_size: int = 31, |
| | half_step_residual: bool = True, |
| | transformer_style: bool = False, |
| | ): |
| | super(ConformerBlock, self).__init__() |
| |
|
| | self.transformer_style = transformer_style |
| | self.attention_type = attention_type |
| |
|
| | if half_step_residual and not transformer_style: |
| | self.feed_forward_residual_factor = 0.5 |
| | else: |
| | self.feed_forward_residual_factor = 1 |
| |
|
| | assert attention_type in ["mhsa", "mamba"] |
| | if attention_type == "mhsa": |
| | attention = MultiHeadedSelfAttentionModule( |
| | d_model=encoder_dim, |
| | num_heads=num_attention_heads, |
| | dropout_p=attention_dropout_p, |
| | ) |
| |
|
| | self.ffn_1 = FeedForwardModule( |
| | encoder_dim=encoder_dim, |
| | expansion_factor=feed_forward_expansion_factor, |
| | dropout_p=feed_forward_dropout_p, |
| | ) |
| | self.attention = attention |
| | if not transformer_style: |
| | self.conv = ConformerConvModule( |
| | in_channels=encoder_dim, |
| | kernel_size=conv_kernel_size, |
| | expansion_factor=conv_expansion_factor, |
| | dropout_p=conv_dropout_p, |
| | ) |
| | self.ffn_2 = FeedForwardModule( |
| | encoder_dim=encoder_dim, |
| | expansion_factor=feed_forward_expansion_factor, |
| | dropout_p=feed_forward_dropout_p, |
| | ) |
| | self.layernorm = nn.LayerNorm(encoder_dim) |
| |
|
| | def forward( |
| | self, x: torch.Tensor |
| | ) -> Tuple[torch.Tensor, Dict[str, Union[torch.Tensor, None]]]: |
| | |
| | ffn_1_out = self.ffn_1(x) |
| | x = ffn_1_out * self.feed_forward_residual_factor + x |
| |
|
| | |
| | if not isinstance(self.attention, MultiHeadedSelfAttentionModule): |
| | |
| | attn_out = self.attention(x) |
| | attn = None |
| | else: |
| | attn_out, attn = self.attention(x) |
| | x = attn_out + x |
| |
|
| | if self.transformer_style: |
| | x = self.layernorm(x) |
| | return x, { |
| | "ffn_1": ffn_1_out, |
| | "attn": attn, |
| | "conv": None, |
| | "ffn_2": None, |
| | } |
| |
|
| | |
| | conv_out = self.conv(x) |
| | x = conv_out + x |
| |
|
| | |
| | ffn_2_out = self.ffn_2(x) |
| | x = ffn_2_out * self.feed_forward_residual_factor + x |
| | x = self.layernorm(x) |
| |
|
| | other = { |
| | "ffn_1": ffn_1_out, |
| | "attn": attn, |
| | "conv": conv_out, |
| | "ffn_2": ffn_2_out, |
| | } |
| |
|
| | return x, other |
| |
|
| |
|
| | class ConformerEncoder(nn.Module): |
| | def __init__(self, cfg): |
| | super(ConformerEncoder, self).__init__() |
| |
|
| | self.cfg = cfg |
| | self.framewise_subsample = None |
| | self.patchwise_subsample = None |
| | self.framewise_in_proj = None |
| | self.patchwise_in_proj = None |
| | assert ( |
| | cfg.use_framewise_subsample or cfg.use_patchwise_subsample |
| | ), "At least one subsampling method should be used" |
| | if cfg.use_framewise_subsample: |
| | self.framewise_subsample = FramewiseConv2dSubampling( |
| | out_channels=cfg.conv_subsample_channels, |
| | subsample_rate=cfg.conv_subsample_rate, |
| | ) |
| | self.framewise_in_proj = nn.Sequential( |
| | Linear( |
| | cfg.conv_subsample_channels * (((cfg.input_dim - 1) // 2 - 1) // 2), |
| | cfg.encoder_dim, |
| | ), |
| | nn.Dropout(p=cfg.input_dropout_p), |
| | ) |
| | if cfg.use_patchwise_subsample: |
| | self.patchwise_subsample = PatchwiseConv2dSubampling( |
| | mel_dim=cfg.input_dim, |
| | out_channels=cfg.conv_subsample_channels, |
| | patch_size_time=cfg.patch_size_time, |
| | patch_size_freq=cfg.patch_size_freq, |
| | ) |
| | self.patchwise_in_proj = nn.Sequential( |
| | Linear( |
| | cfg.conv_subsample_channels, |
| | cfg.encoder_dim, |
| | ), |
| | nn.Dropout(p=cfg.input_dropout_p), |
| | ) |
| | assert not cfg.use_framewise_subsample or ( |
| | cfg.conv_subsample_rate == self.patchwise_subsample.subsample_rate |
| | ), ( |
| | f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate" |
| | f"({self.patchwise_subsample.subsample_rate})" |
| | ) |
| |
|
| | self.framewise_norm, self.patchwise_norm = None, None |
| | if getattr(cfg, "subsample_normalization", False): |
| | if cfg.use_framewise_subsample: |
| | self.framewise_norm = nn.LayerNorm(cfg.encoder_dim) |
| | if cfg.use_patchwise_subsample: |
| | self.patchwise_norm = nn.LayerNorm(cfg.encoder_dim) |
| |
|
| | self.conv_pos = None |
| | if getattr(cfg, "conv_pos", False): |
| | num_pos_layers = cfg.conv_pos_depth |
| | k = max(3, cfg.conv_pos_width // num_pos_layers) |
| | self.conv_pos = nn.Sequential( |
| | TransposeLast(), |
| | *[ |
| | nn.Sequential( |
| | nn.Conv1d( |
| | cfg.encoder_dim, |
| | cfg.encoder_dim, |
| | kernel_size=k, |
| | padding=k // 2, |
| | groups=cfg.conv_pos_groups, |
| | ), |
| | SamePad(k), |
| | TransposeLast(), |
| | nn.LayerNorm(cfg.encoder_dim, elementwise_affine=False), |
| | TransposeLast(), |
| | nn.GELU(), |
| | ) |
| | for _ in range(num_pos_layers) |
| | ], |
| | TransposeLast(), |
| | ) |
| | self.conv_pos_post_ln = nn.LayerNorm(cfg.encoder_dim) |
| |
|
| | self.layers = nn.ModuleList( |
| | [ |
| | ConformerBlock( |
| | encoder_dim=cfg.encoder_dim, |
| | attention_type=cfg.attention_type, |
| | num_attention_heads=cfg.num_attention_heads, |
| | mamba_d_state=cfg.mamba_d_state, |
| | mamba_d_conv=cfg.mamba_d_conv, |
| | mamba_expand=cfg.mamba_expand, |
| | mamba_bidirectional=cfg.mamba_bidirectional, |
| | feed_forward_expansion_factor=cfg.feed_forward_expansion_factor, |
| | conv_expansion_factor=cfg.conv_expansion_factor, |
| | feed_forward_dropout_p=cfg.feed_forward_dropout_p, |
| | attention_dropout_p=cfg.attention_dropout_p, |
| | conv_dropout_p=cfg.conv_dropout_p, |
| | conv_kernel_size=cfg.conv_kernel_size, |
| | half_step_residual=cfg.half_step_residual, |
| | transformer_style=getattr(cfg, "transformer_style", False), |
| | ) |
| | for _ in range(cfg.num_layers) |
| | ] |
| | ) |
| |
|
| | def count_parameters(self) -> int: |
| | """Count parameters of encoder""" |
| | return sum([p.numel() for p in self.parameters() if p.requires_grad]) |
| |
|
| | def update_dropout(self, dropout_p: float) -> None: |
| | """Update dropout probability of encoder""" |
| | for name, child in self.named_children(): |
| | if isinstance(child, nn.Dropout): |
| | child.p = dropout_p |
| |
|
| | def forward( |
| | self, |
| | inputs: torch.Tensor, |
| | input_lengths: Optional[torch.Tensor] = None, |
| | return_hidden: bool = False, |
| | freeze_input_layers: bool = False, |
| | target_layer: Optional[int] = None, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, List[torch.Tensor]]]: |
| | if input_lengths is None: |
| | input_lengths = torch.full( |
| | (inputs.size(0),), |
| | inputs.size(1), |
| | dtype=torch.long, |
| | device=inputs.device, |
| | ) |
| |
|
| | with torch.no_grad() if freeze_input_layers else contextlib.ExitStack(): |
| | frame_feat, patch_feat = None, None |
| | if self.framewise_subsample is not None: |
| | frame_feat, frame_lengths = self.framewise_subsample( |
| | inputs, input_lengths |
| | ) |
| | frame_feat = self.framewise_in_proj(frame_feat) |
| | if self.framewise_norm is not None: |
| | frame_feat = self.framewise_norm(frame_feat) |
| |
|
| | if self.patchwise_subsample is not None: |
| | patch_feat, patch_lengths = self.patchwise_subsample( |
| | inputs, input_lengths |
| | ) |
| | patch_feat = self.patchwise_in_proj(patch_feat) |
| | if self.patchwise_norm is not None: |
| | patch_feat = self.patchwise_norm(patch_feat) |
| |
|
| | if frame_feat is not None and patch_feat is not None: |
| | min_len = min(frame_feat.size(1), patch_feat.size(1)) |
| | frame_feat = frame_feat[:, :min_len] |
| | patch_feat = patch_feat[:, :min_len] |
| |
|
| | features = frame_feat + patch_feat |
| | output_lengths = ( |
| | frame_lengths |
| | if frame_lengths.max().item() < patch_lengths.max().item() |
| | else patch_lengths |
| | ) |
| | elif frame_feat is not None: |
| | features = frame_feat |
| | output_lengths = frame_lengths |
| | else: |
| | features = patch_feat |
| | output_lengths = patch_lengths |
| |
|
| | if self.conv_pos is not None: |
| | features = features + self.conv_pos(features) |
| | features = self.conv_pos_post_ln(features) |
| |
|
| | layer_results = defaultdict(list) |
| |
|
| | outputs = features |
| | for i, layer in enumerate(self.layers): |
| | outputs, other = layer(outputs) |
| | if return_hidden: |
| | layer_results["hidden_states"].append(outputs) |
| | for k, v in other.items(): |
| | layer_results[k].append(v) |
| |
|
| | if target_layer is not None and i == target_layer: |
| | break |
| |
|
| | return outputs, output_lengths, layer_results |
| |
|