| | import math |
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PretrainedConfig, PreTrainedModel |
| |
|
| |
|
| | class GeLU(nn.Module): |
| | def __init__(self) -> None: |
| | """ |
| | This is the gelu implementation from the original ESM repo. |
| | Using F.gelu yields subtly wrong results. |
| | """ |
| | super().__init__() |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
| |
|
| |
|
| | @dataclass |
| | class RotaryEmbeddingConfig: |
| | """ |
| | Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows |
| | to adapt the rotary embeddings to larger lengths than what was used for training. |
| | One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa |
| | Args: |
| | """ |
| |
|
| | rescaling_factor: Optional[float] |
| |
|
| |
|
| | class RotaryEmbedding(torch.nn.Module): |
| | """ |
| | Rotary position embeddings based on those in |
| | [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). |
| | Query and keys are transformed by rotation |
| | matrices which depend on their relative positions. |
| | """ |
| |
|
| | def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfig): |
| | super().__init__() |
| |
|
| | |
| | self.rescaling_factor = rotary_embedding_config.rescaling_factor |
| | self.upper_freq = 10000 |
| | self.dim = dim |
| |
|
| | self._seq_len_cached = None |
| | self._cos_cached = None |
| | self._sin_cached = None |
| |
|
| | def _apply_rotary_pos_emb( |
| | self, |
| | heads: torch.Tensor, |
| | cos: torch.Tensor, |
| | sin: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ """ |
| | x_first, x_second = ( |
| | heads[..., : heads.shape[-1] // 2], |
| | heads[..., heads.shape[-1] // 2 :], |
| | ) |
| |
|
| | first_part = x_first * cos - x_second * sin |
| | second_part = x_second * cos + x_first * sin |
| |
|
| | return torch.cat((first_part, second_part), dim=-1) |
| |
|
| | def _compute_cos_sin_tables( |
| | self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | seq_len = x.shape[seq_dimension] |
| | |
| | |
| | self._seq_len_cached = seq_len |
| | t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) |
| | |
| | freqs = torch.einsum("i, j -> ij", t, inv_freq) |
| |
|
| | self._cos_cached = torch.cos(freqs)[None, :, None, :] |
| | self._sin_cached = torch.sin(freqs)[None, :, None, :] |
| | |
| |
|
| | |
| | |
| |
|
| | return self._cos_cached, self._sin_cached |
| |
|
| | def forward( |
| | self, q: torch.Tensor, k: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | if self.rescaling_factor is None: |
| | inv_freq = 1.0 / ( |
| | self.upper_freq ** (torch.arange(0, self.dim, 2).float() / self.dim) |
| | ) |
| | else: |
| | updated_base = self.upper_freq * ( |
| | self.rescaling_factor ** (self.dim / (self.dim - 2)) |
| | ) |
| | inv_freq = 1.0 / ( |
| | updated_base ** (torch.arange(0, self.dim, 2).float() / self.dim) |
| | ) |
| |
|
| | self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( |
| | q, |
| | inv_freq, |
| | seq_dimension=-3, |
| | ) |
| |
|
| | return ( |
| | self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), |
| | self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), |
| | ) |
| |
|
| |
|
| | class ResidualConvBlock(nn.Module): |
| | """ |
| | Conv Block with Residual connection. |
| | """ |
| |
|
| | def __init__(self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int = 1): |
| | super().__init__() |
| | self.conv_block = ConvBlock( |
| | dim_in=dim_in, dim_out=dim_out, seq_len=seq_len, kernel_size=kernel_size |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | y = self.conv_block(x) |
| | return x.reshape(y.shape) + y |
| |
|
| |
|
| | class ConvBlock(nn.Module): |
| | """ |
| | Conv Block. |
| | """ |
| |
|
| | def __init__(self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int = 1): |
| | super().__init__() |
| | self.conv = nn.Conv1d( |
| | in_channels=dim_in, |
| | out_channels=dim_out, |
| | kernel_size=kernel_size, |
| | padding="same", |
| | ) |
| | self.layer_norm = nn.LayerNorm(seq_len, eps=1e-5) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.layer_norm(x) |
| | x = x.reshape(x.shape[0], x.shape[1], -1) |
| | x = self.conv(x) |
| | x = F.gelu(x, approximate="tanh") |
| | return x |
| |
|
| |
|
| | class ResidualDeConvBlock(nn.Module): |
| | """ |
| | Conv Block with Residual connection. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim_in: int, |
| | dim_out: int, |
| | seq_len: int, |
| | kernel_size: int = 1, |
| | stride: int = 1, |
| | ): |
| | super().__init__() |
| | self.deconv_block = DeConvBlock( |
| | dim_in=dim_in, |
| | dim_out=dim_out, |
| | seq_len=seq_len, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | y = self.deconv_block(x) |
| | return x.reshape(y.shape) + y |
| |
|
| |
|
| | class DeConvBlock(nn.Module): |
| | """ |
| | DeConv Block. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim_in: int, |
| | dim_out: int, |
| | seq_len: int, |
| | kernel_size: int = 1, |
| | stride: int = 1, |
| | ): |
| | super().__init__() |
| | self.deconv = nn.ConvTranspose1d( |
| | in_channels=dim_in, |
| | out_channels=dim_out, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=0, |
| | ) |
| | self.layer_norm = nn.LayerNorm(seq_len) |
| | self.kernel_size = kernel_size |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.layer_norm(x) |
| | x = x.reshape(x.shape[0], x.shape[1], -1) |
| | x = self.deconv(x) |
| | if self.kernel_size == 5: |
| | |
| | |
| | x = x[:, :, 1:-2] |
| | x = F.gelu(x, approximate="tanh") |
| | return x |
| |
|
| |
|
| | class SpatialEncoding(nn.Module): |
| | """ |
| | Spatial coordinates encoding module |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | embed_dim: int, |
| | num_scales: int = 10, |
| | sigma_min: float = 1.0, |
| | sigma_max: float = 10.0, |
| | ): |
| | super().__init__() |
| | self.num_scales = num_scales |
| | self.sigma_min = sigma_min |
| | self.sigma_max = sigma_max |
| | self.g = sigma_max / sigma_min |
| | self.scales = torch.linspace(sigma_min, sigma_max, num_scales) |
| | self.fc_layer = nn.Linear(embed_dim, embed_dim) |
| |
|
| | def scale_specific_encoder( |
| | self, coordinates: torch.Tensor, scale: float |
| | ) -> torch.Tensor: |
| | x, y = coordinates[..., 0], coordinates[..., 1] |
| | constant = self.sigma_min * (self.g ** (scale / (self.num_scales - 1))) |
| | x_transform = torch.cos(x / constant) |
| | y_transform = torch.sin(y / constant) |
| | transformed_coordinates = torch.stack([x_transform, y_transform], dim=-1) |
| | return transformed_coordinates |
| |
|
| | def forward(self, coordinates: torch.Tensor) -> torch.Tensor: |
| | transformed_coordinates = [ |
| | self.scale_specific_encoder(coordinates, scale) for scale in self.scales |
| | ] |
| | transformed_coordinates = torch.cat(transformed_coordinates, dim=-1) |
| | return self.fc_layer(transformed_coordinates) |
| |
|
| |
|
| | class ConvTowerBlock(nn.Module): |
| | def __init__( |
| | self, dim_in: int, dim_out: int, seq_len: int, kernel_size: int, num_cells: int |
| | ) -> None: |
| | super().__init__() |
| | self.conv_layer = ConvBlock( |
| | dim_in=dim_in, dim_out=dim_out, seq_len=seq_len, kernel_size=kernel_size |
| | ) |
| | self.res_conv = ResidualConvBlock( |
| | dim_in=dim_out, dim_out=dim_out, seq_len=seq_len, kernel_size=1 |
| | ) |
| | self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2) |
| | self.num_cells = num_cells |
| |
|
| | def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| | residual = x |
| | x = x.reshape(x.shape[0], x.shape[1], self.num_cells, -1) |
| | x = self.conv_layer(x) |
| | x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1)) |
| | x = self.res_conv(x) |
| | x = self.avg_pool(x) |
| | return x, residual |
| |
|
| |
|
| | class DeConvTowerBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim_in: int, |
| | dim_out: int, |
| | kernel_size: int, |
| | seq_len: int, |
| | stride: int = 2, |
| | num_cells: int = 1, |
| | ): |
| | super().__init__() |
| | self.deconv_block = DeConvBlock( |
| | dim_in=dim_in, |
| | dim_out=dim_out, |
| | seq_len=seq_len, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | ) |
| | self.res_deconv_block = ResidualDeConvBlock( |
| | dim_in=dim_out, dim_out=dim_out, seq_len=seq_len * 2, kernel_size=1 |
| | ) |
| | self.num_cells = num_cells |
| |
|
| | def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor: |
| | x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1)) |
| | x = self.deconv_block(x) |
| | x = x.reshape((x.shape[0], x.shape[1], self.num_cells, -1)) |
| | x = self.res_deconv_block(x) |
| |
|
| | x = x + res |
| | return x |
| |
|
| |
|
| | class MultiHeadAttention(nn.Module): |
| | def __init__( |
| | self, |
| | num_heads: int, |
| | key_size: int, |
| | rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, |
| | add_bias_kv: bool = False, |
| | value_size: Optional[int] = None, |
| | model_size: Optional[int] = None, |
| | name: Optional[str] = None, |
| | ): |
| | super().__init__() |
| | if not model_size: |
| | model_size = key_size |
| | if not value_size: |
| | value_size = key_size |
| | self.model_size = model_size |
| | self.key_size = key_size |
| | self.value_size = value_size |
| | self.add_bias_kv = add_bias_kv |
| | self.name = name |
| | self.num_heads = num_heads |
| | self._rotary_embedding_config = rotary_embedding_config |
| |
|
| | self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) |
| | self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) |
| | self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) |
| | self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) |
| | if self._rotary_embedding_config: |
| | self._rotary_embedding = RotaryEmbedding( |
| | self.key_size, self._rotary_embedding_config |
| | ) |
| |
|
| | def apply_rotary_embeddings( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ """ |
| | query, key = self._rotary_embedding(query, key) |
| | return query, key |
| |
|
| | def forward( |
| | self, |
| | query: torch.Tensor, |
| | key: torch.Tensor, |
| | value: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | attention_weight_bias: Optional[torch.Tensor] = None, |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Returns: |
| | dictionary containing attention weights |
| | and outputs. |
| | """ |
| | key_heads = self.w_k(key).reshape( |
| | (*key.shape[:-1], self.num_heads, self.key_size) |
| | ) |
| | query_heads = self.w_q(query).reshape( |
| | (*query.shape[:-1], self.num_heads, self.key_size) |
| | ) |
| | value_heads = self.w_v(value).reshape( |
| | (*value.shape[:-1], self.num_heads, self.value_size) |
| | ) |
| | if self._rotary_embedding_config: |
| | query_heads, key_heads = self.apply_rotary_embeddings( |
| | query_heads, key_heads |
| | ) |
| | attention_weights = torch.einsum( |
| | "...thd, ...Thd -> ...htT", query_heads, key_heads |
| | ) |
| | sqrt_key_size = np.sqrt(self.key_size) |
| | attention_weights = attention_weights / sqrt_key_size |
| | if attention_mask: |
| | attention_weights = torch.where(attention_mask, attention_weights, -1e30) |
| | if attention_weight_bias: |
| | attention_weights = F.softmax( |
| | attention_weights + attention_weight_bias, dim=-1 |
| | ) |
| | else: |
| | attention_weights = F.softmax(attention_weights, dim=-1) |
| | value_out = torch.einsum( |
| | "...htT, ...Thd->...thd", attention_weights, value_heads |
| | ) |
| | value_out = value_out.reshape((*value_out.shape[:-2], -1)) |
| | embeddings = self.output(value_out) |
| |
|
| | return {"attention_weights": attention_weights, "embeddings": embeddings} |
| |
|
| |
|
| | class SelfAttentionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | num_heads: int, |
| | embed_dim: int, |
| | ffn_embed_dim: int, |
| | key_size: Optional[int] = None, |
| | add_bias_kv: bool = False, |
| | add_bias_fnn: bool = True, |
| | ffn_activation_name: str = "gelu-no-approx", |
| | use_glu_in_ffn: bool = False, |
| | layer_norm_eps: float = 1e-5, |
| | pre_layer_norm: bool = True, |
| | name: Optional[str] = None, |
| | rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, |
| | ): |
| | super().__init__() |
| | if key_size is None: |
| | if embed_dim % num_heads != 0: |
| | raise ValueError( |
| | f"The embedding dimension should be divisible by the number of " |
| | f"heads, however provided embedding dimension is {embed_dim} and " |
| | f"the number of heads is {num_heads}." |
| | ) |
| | else: |
| | key_size = embed_dim // num_heads |
| |
|
| | |
| | self._pre_layer_norm = pre_layer_norm |
| | self._use_glu_in_fnn = use_glu_in_ffn |
| | |
| | if use_glu_in_ffn: |
| | |
| | |
| | |
| | |
| | self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) |
| | else: |
| | self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) |
| |
|
| | self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) |
| |
|
| | self.layer_norm_self_attention = nn.LayerNorm( |
| | embed_dim, |
| | ) |
| | self.layer_norm_mlp = nn.LayerNorm(embed_dim) |
| | if ffn_activation_name == "swish": |
| | self._ffn_activation_fn = nn.SiLU() |
| | elif ffn_activation_name == "gelu-no-approx": |
| | self._ffn_activation_fn = nn.GeLU(approximate="tanh") |
| | else: |
| | self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) |
| |
|
| | self.mha = MultiHeadAttention( |
| | num_heads=num_heads, |
| | key_size=key_size, |
| | add_bias_kv=add_bias_kv, |
| | model_size=embed_dim, |
| | name="self_attention", |
| | rotary_embedding_config=rotary_embedding_config, |
| | ) |
| |
|
| | def mlp(self, embed: torch.Tensor) -> torch.Tensor: |
| |
|
| | if self._pre_layer_norm: |
| | x = self.layer_norm_mlp(embed) |
| | else: |
| | x = embed |
| |
|
| | if self._use_glu_in_fnn: |
| | x = self.fc1(x) |
| | x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) |
| | x = self._ffn_activation_fn(x1) * x2 |
| | else: |
| | x = self._ffn_activation_fn(self.fc1(x)) |
| | x = self.fc2(x) |
| |
|
| | if not self._pre_layer_norm: |
| | x = self.layer_norm_mlp(x + embed) |
| | return x |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | attention_weight_bias: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| |
|
| | res = x |
| | if self._pre_layer_norm: |
| | x = self.layer_norm_self_attention(x) |
| |
|
| | output = self.mha( |
| | x, |
| | x, |
| | x, |
| | attention_mask=attention_mask, |
| | attention_weight_bias=attention_weight_bias, |
| | ) |
| |
|
| | if not self._pre_layer_norm: |
| | output["embeddings"] = self.layer_norm_self_attention( |
| | output["embeddings"] + res |
| | ) |
| |
|
| | x = output["embeddings"] |
| | else: |
| | x = output["embeddings"] |
| | x = res + x |
| |
|
| | |
| | if not self._pre_layer_norm: |
| | x = self.mlp(x) |
| | else: |
| | x = x + self.mlp(x) |
| |
|
| | output["embeddings"] = x |
| | return output |
| |
|
| |
|
| | class LMHead(nn.Module): |
| | def __init__( |
| | self, dim_in: int, embed_dim: int, dim_out: int, num_hidden_layers: int |
| | ) -> None: |
| | """ """ |
| | super().__init__() |
| | self.num_hidden_layers = num_hidden_layers |
| | self.linear_layers = nn.ModuleList([nn.Linear(dim_in, embed_dim)]) |
| | self.linear_layers.extend( |
| | nn.ModuleList( |
| | [nn.Linear(embed_dim, embed_dim)] for _ in range(num_hidden_layers - 1) |
| | ) |
| | ) |
| | self.linear_out = nn.Linear(embed_dim, dim_out) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | res = x |
| | x = F.gelu(x, approximate="tanh") |
| | for layer in self.linear_layers: |
| | x = layer(x) |
| | x = F.gelu(x, approximate="tanh") |
| | out = self.linear_out(x) |
| | return out |
| |
|
| |
|
| | @dataclass |
| | class sCTConfig(PretrainedConfig): |
| | model_type = "sCT" |
| |
|
| | def __init__(self, **kwargs): |
| | self.alphabet_size = kwargs.get("alphabet_size", 7) |
| | self.pad_token_id = kwargs.get("pad_token_id", 5) |
| | self.mask_token_id = kwargs.get("mask_token_id", 6) |
| | self.cell_len = kwargs.get("cell_len", 19968) |
| |
|
| | self.num_downsamples = kwargs.get("num_downsamples", 8) |
| | self.attention_heads = kwargs.get("attention_heads", 16) |
| | self.key_size = kwargs.get("key_size", None) |
| | self.token_embed_dim = kwargs.get("token_embed_dim", 16) |
| |
|
| | self.embed_dim = kwargs.get("embed_dim", 1024) |
| | self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 2048) |
| | self.num_layers = kwargs.get("num_layers", 4) |
| | self.layer_norm_eps = kwargs.get("layer_norm_eps", 1e-5) |
| | self.interpolation_method = kwargs.get("interpolation_method", "nearest") |
| |
|
| | |
| | self.max_positions: int = kwargs.get("max_positions", 20480) |
| | self.num_cells: int = kwargs.get("num_cells", 50) |
| | self.num_hidden_layers_head: int = kwargs.get("num_hidden_layers_head", 1) |
| |
|
| | self.use_skip_connection: bool = kwargs.get("use_skip_connection", True) |
| |
|
| | |
| | self.use_gradient_checkpointing: bool = False |
| |
|
| | |
| | self.embeddings_layers_to_save: Tuple[int, ...] = kwargs.get( |
| | "embeddings_layers_to_save", () |
| | ) |
| | self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get( |
| | "attention_maps_to_save", [] |
| | ) |
| |
|
| | |
| | self.use_spatial_information: bool = kwargs.get( |
| | "use_spatial_information", False |
| | ) |
| | self.num_scales: int = kwargs.get("num_scales", 10) |
| | self.sigma_min: float = kwargs.get("sigma_min", 1.0) |
| | self.sigma_max: float = kwargs.get("sigma_max", 10.0) |
| |
|
| | super().__init__(**kwargs) |
| |
|
| | def __post_init__(self) -> None: |
| | """ |
| | Checks that the given values are compatible. |
| | """ |
| | if self.key_size is None: |
| | if not self.embed_dim % self.attention_heads == 0: |
| | raise ValueError( |
| | f"When no key size is provided, the embedding dimension" |
| | f"should be divisible by the number of heads, however " |
| | f"provided embedding dimension is {self.embed_dim} and " |
| | f"the number of heads is {self.attention_heads}." |
| | ) |
| | self.key_size = self.embed_dim // self.attention_heads |
| |
|
| |
|
| | class sCT(PreTrainedModel): |
| | config_class = sCTConfig |
| |
|
| | def __init__(self, config: sCTConfig): |
| | |
| | super().__init__(config=config) |
| | if config.use_spatial_information: |
| | self.spatial_embed_layer = SpatialEncoding( |
| | embed_dim=config.token_embed_dim, |
| | num_scales=config.num_scales, |
| | sigma_min=config.sigma_min, |
| | sigma_max=config.sigma_max, |
| | ) |
| | self.cell_len = config.cell_len |
| |
|
| | self.token_embed = nn.Embedding(config.alphabet_size, config.token_embed_dim) |
| |
|
| | attention_maps_to_save = config.attention_maps_to_save |
| | self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save}) |
| |
|
| | self._attention_maps_per_layer_to_save = { |
| | layer: [t[1] for t in attention_maps_to_save if t[0] == layer] |
| | for layer in self._attention_layers_to_save |
| | } |
| |
|
| | max_layer = max(self._attention_layers_to_save + [0]) |
| | if max_layer > config.num_layers: |
| | raise ValueError( |
| | f"You are requiring attention maps for layer {max_layer}, " |
| | f"while the model has {config.num_layers} layers only." |
| | ) |
| |
|
| | filter_list = np.linspace( |
| | config.token_embed_dim, |
| | config.embed_dim, |
| | config.num_downsamples + 1, |
| | ) |
| |
|
| | filter_list = np.ceil(filter_list / 32) * 32 |
| | filter_list = filter_list.astype(int).tolist() |
| |
|
| | self._filter_list = filter_list |
| | self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None) |
| |
|
| | self.stem_conv = nn.Sequential( |
| | nn.Conv1d( |
| | in_channels=config.token_embed_dim, |
| | out_channels=config.token_embed_dim, |
| | kernel_size=15, |
| | padding="same", |
| | ), |
| | nn.GELU(approximate="tanh"), |
| | ) |
| | downsampled_seq_lens = [ |
| | self.cell_len // (2**i) for i in range(len(filter_list) - 1) |
| | ] |
| |
|
| | self.conv_tower = nn.ModuleList( |
| | [ |
| | ConvTowerBlock( |
| | dim_in=self._filter_list[i], |
| | dim_out=self._filter_list[i + 1], |
| | kernel_size=5, |
| | seq_len=seq_len, |
| | num_cells=config.num_cells, |
| | ) |
| | for i, seq_len in zip(range(len(filter_list) - 1), downsampled_seq_lens) |
| | ] |
| | ) |
| |
|
| | self.deconv_tower = nn.ModuleList( |
| | [ |
| | DeConvTowerBlock( |
| | dim_in=filter_list[-1 - i], |
| | dim_out=filter_list[-1 - i - 1], |
| | kernel_size=5, |
| | stride=2, |
| | seq_len=seq_len // 2, |
| | num_cells=config.num_cells, |
| | ) |
| | for i, seq_len in zip( |
| | range(len(filter_list) - 1), downsampled_seq_lens[::-1] |
| | ) |
| | ] |
| | ) |
| | self.transformer_layers = nn.ModuleList( |
| | [ |
| | SelfAttentionBlock( |
| | num_heads=config.attention_heads, |
| | embed_dim=config.embed_dim, |
| | ffn_embed_dim=config.ffn_embed_dim, |
| | key_size=config.key_size, |
| | add_bias_kv=False, |
| | add_bias_fnn=False, |
| | ffn_activation_name="swish", |
| | use_glu_in_ffn=True, |
| | layer_norm_eps=1e-5, |
| | pre_layer_norm=True, |
| | name=f"attention_layer_{layer_idx}", |
| | rotary_embedding_config=self._rotary_embedding_config, |
| | ) |
| | for layer_idx in range(config.num_layers) |
| | ] |
| | ) |
| |
|
| | self.lm_head = LMHead( |
| | dim_in=config.token_embed_dim, |
| | embed_dim=config.embed_dim, |
| | dim_out=config.alphabet_size, |
| | num_hidden_layers=config.num_hidden_layers_head, |
| | ) |
| |
|
| | def forward(self, input_ids: torch.Tensor) -> dict[str, torch.Tensor]: |
| | outs = {} |
| | embeddings = self.token_embed(input_ids) |
| | x = embeddings.permute(0, 2, 1) |
| | x = self.stem_conv(x) |
| | residuals = [] |
| | for _idx, conv_block in enumerate(self.conv_tower): |
| | x, res = conv_block(x) |
| | residuals.append(res) |
| | residuals = residuals[::-1] |
| | x = x.permute(0, 2, 1) |
| |
|
| | for layer_idx, transformer in enumerate(self.transformer_layers): |
| | output = transformer(x) |
| | x = output["embeddings"] |
| | if (layer_idx + 1) in self.config.embeddings_layers_to_save: |
| | outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"] |
| | if (layer_idx + 1) in self._attention_layers_to_save: |
| | for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]: |
| | dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}" |
| | outs[dkey] = output["attention_weights"][:, map_number + 1] |
| | x = x.permute(0, 2, 1) |
| | for deconv_block, res in zip(self.deconv_tower, residuals): |
| | x = deconv_block(x, res) |
| | x = x.permute(0, 2, 1) |
| | logits = self.lm_head(x) |
| | outs["logits"] = logits |
| |
|
| | return outs |
| |
|