Feature Extraction
Transformers
Safetensors
fast_esmfold
protein
structure-prediction
esmfold
test-time-training
custom_code
Instructions to use Synthyra/FastESMFold with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/FastESMFold with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/FastESMFold", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/FastESMFold", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import torch | |
| import torch._inductor.config as inductor_config | |
| import torch._dynamo as dynamo | |
| # Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs) | |
| # Provides significant speedup with minimal precision loss | |
| torch.set_float32_matmul_precision('high') | |
| # Enable TF32 for matrix multiplications and cuDNN operations | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # Enable cuDNN autotuner - finds fastest algorithms for your hardware | |
| # Best when input sizes are consistent; may slow down first iterations | |
| torch.backends.cudnn.benchmark = True | |
| # Deterministic operations off for speed (set True if reproducibility needed) | |
| torch.backends.cudnn.deterministic = False | |
| inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM" | |
| dynamo.config.capture_scalar_outputs = True | |
| torch._dynamo.config.recompile_limit = 16 | |
| """Shared attention infrastructure for all FastPLMs models. | |
| Contains: AttentionBackend enum, backend resolution, mask creation, | |
| flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities. | |
| """ | |
| from enum import Enum | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from einops import rearrange | |
| try: | |
| from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask | |
| except ImportError: | |
| create_block_mask = None | |
| flex_attention = None | |
| BlockMask = None | |
| _compiled_flex_attention = None | |
| def _get_flex_attention_fn(): | |
| """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.""" | |
| global _compiled_flex_attention | |
| if flex_attention is None: | |
| return None | |
| flex_mod = torch.nn.attention.flex_attention | |
| if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): | |
| return flex_attention | |
| if _compiled_flex_attention is None: | |
| _compiled_flex_attention = torch.compile( | |
| flex_attention, | |
| dynamic=False, | |
| ) | |
| return _compiled_flex_attention | |
| # HuggingFace `kernels` exposes slightly different APIs for Flash Attention 2 | |
| # and 3. Detect the loaded variant once so every caller uses the same dispatch. | |
| def _infer_kernels_flash_variant(kernel) -> Optional[str]: | |
| if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): | |
| return "flash_attn2" | |
| if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): | |
| return "flash_attn3" | |
| return None | |
| def _try_get_kernels_flash(): | |
| try: | |
| from kernels import get_kernel | |
| except ImportError: | |
| return None, None | |
| flash_kernel = None | |
| flash_kernel_variant = None | |
| try: | |
| flash_kernel = get_kernel("kernels-community/flash-attn3") | |
| flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) | |
| assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." | |
| except Exception: | |
| try: | |
| flash_kernel = get_kernel("kernels-community/flash-attn2") | |
| flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) | |
| assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." | |
| except Exception: | |
| flash_kernel = None | |
| flash_kernel_variant = None | |
| return flash_kernel, flash_kernel_variant | |
| _FLASH_KERNELS_LOADED = False | |
| FLASH_KERNEL = None | |
| FLASH_KERNEL_VARIANT = None | |
| def _ensure_flash_kernels_loaded(): | |
| global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT | |
| if _FLASH_KERNELS_LOADED: | |
| return | |
| _FLASH_KERNELS_LOADED = True | |
| FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() | |
| def _kernels_flash_forward( | |
| query_states: torch.Tensor, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| causal: bool = False, | |
| softmax_scale: Optional[float] = None, | |
| ) -> torch.Tensor: | |
| """Flash-attention forward, optionally overriding the softmax scale. | |
| When `softmax_scale is None`, the flash kernel applies its default | |
| `1 / sqrt(head_dim)`. Pass `softmax_scale=1.0` if the caller has already | |
| pre-scaled Q (the convention used by ESM2, DPLM, DPLM2, E1, ESMFold). | |
| Failing to override when Q is pre-scaled applies the scale twice. On | |
| DPLM-150M, that produced pooled-embedding cosine around -0.12 and argmax | |
| agreement around 0.27 vs SDPA. | |
| """ | |
| assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." | |
| if FLASH_KERNEL_VARIANT == "flash_attn2": | |
| return FLASH_KERNEL.fwd( | |
| q=query_states, k=key_states, v=value_states, | |
| softmax_scale=softmax_scale, is_causal=causal, | |
| )[0] | |
| if FLASH_KERNEL_VARIANT == "flash_attn3": | |
| try: | |
| output = FLASH_KERNEL.flash_attn_func( | |
| q=query_states, k=key_states, v=value_states, | |
| softmax_scale=softmax_scale, causal=causal, | |
| ) | |
| except TypeError: | |
| output = FLASH_KERNEL.flash_attn_func( | |
| query_states, key_states, value_states, | |
| 0.0, softmax_scale, causal, | |
| ) | |
| if isinstance(output, tuple): | |
| return output[0] | |
| return output | |
| raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") | |
| def _kernels_flash_varlen_forward( | |
| query_states: torch.Tensor, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| cu_seqlens_q: torch.Tensor, | |
| cu_seqlens_k: torch.Tensor, | |
| max_seqlen_in_batch_q: int, | |
| max_seqlen_in_batch_k: int, | |
| causal: bool = False, | |
| softmax_scale: Optional[float] = None, | |
| ) -> torch.Tensor: | |
| """Varlen flash-attention forward, optionally overriding the softmax scale. | |
| See `_kernels_flash_forward` docstring for why `softmax_scale=1.0` must be | |
| passed when Q has been pre-scaled by the caller. | |
| """ | |
| assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." | |
| if FLASH_KERNEL_VARIANT == "flash_attn2": | |
| return FLASH_KERNEL.varlen_fwd( | |
| q=query_states, k=key_states, v=value_states, | |
| cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, | |
| softmax_scale=softmax_scale, is_causal=causal, | |
| )[0] | |
| if FLASH_KERNEL_VARIANT == "flash_attn3": | |
| try: | |
| output = FLASH_KERNEL.flash_attn_varlen_func( | |
| q=query_states, k=key_states, v=value_states, | |
| cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, | |
| softmax_scale=softmax_scale, causal=causal, | |
| ) | |
| except TypeError: | |
| output = FLASH_KERNEL.flash_attn_varlen_func( | |
| query_states, key_states, value_states, | |
| cu_seqlens_q, cu_seqlens_k, | |
| max_seqlen_in_batch_q, max_seqlen_in_batch_k, | |
| 0.0, softmax_scale, causal, | |
| ) | |
| if isinstance(output, tuple): | |
| return output[0] | |
| return output | |
| raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") | |
| # Varlen flash attention runs only on real tokens. These helpers remove padding | |
| # before the kernel call and restore the original padded batch shape afterward. | |
| class IndexFirstAxis(torch.autograd.Function): | |
| def forward(ctx, input, indices) -> torch.Tensor: | |
| ctx.save_for_backward(indices) | |
| assert input.ndim >= 2 | |
| ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] | |
| second_dim = other_shape.numel() | |
| return torch.gather( | |
| rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) | |
| ).reshape(-1, *other_shape) | |
| def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]: | |
| (indices,) = ctx.saved_tensors | |
| assert grad_output.ndim >= 2 | |
| other_shape = grad_output.shape[1:] | |
| grad_output = rearrange(grad_output, "b ... -> b (...)") | |
| grad_input = torch.zeros( | |
| [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype | |
| ) | |
| grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) | |
| return grad_input.reshape(ctx.first_axis_dim, *other_shape), None | |
| class IndexPutFirstAxis(torch.autograd.Function): | |
| def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: | |
| ctx.save_for_backward(indices) | |
| assert indices.ndim == 1 | |
| assert values.ndim >= 2 | |
| output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) | |
| output[indices] = values | |
| return output | |
| def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]: | |
| (indices,) = ctx.saved_tensors | |
| return grad_output[indices], None, None | |
| index_first_axis = IndexFirstAxis.apply | |
| index_put_first_axis = IndexPutFirstAxis.apply | |
| def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: | |
| output = index_put_first_axis(hidden_states, indices, batch * seqlen) | |
| return rearrange(output, "(b s) ... -> b s ...", b=batch) | |
| def _unpad_input( | |
| query_layer: torch.Tensor, | |
| key_layer: torch.Tensor, | |
| value_layer: torch.Tensor, | |
| attention_mask_2d: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: | |
| batch_size, seq_len, num_heads, head_dim = query_layer.shape | |
| seqlens = attention_mask_2d.sum(dim=1).int() | |
| cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) | |
| max_seqlen = int(seqlens.max().item()) | |
| indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() | |
| query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) | |
| key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) | |
| value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) | |
| return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) | |
| def kernels_flash_attention_func( | |
| query_states: torch.Tensor, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| attention_mask_2d: Optional[torch.Tensor] = None, | |
| causal: bool = False, | |
| softmax_scale: Optional[float] = None, | |
| ) -> torch.Tensor: | |
| """Public flash-attention entry point with optional padding handling. | |
| `softmax_scale`: | |
| None -> kernel applies its default `1 / sqrt(head_dim)`. | |
| float -> kernel uses the given scale (pass 1.0 when Q is pre-scaled | |
| by the caller). | |
| Caller contract: if a model family pre-scales Q by `1/sqrt(head_dim)` | |
| before calling this function (ESM2, DPLM, DPLM2, E1, and ESMFold do), pass | |
| `softmax_scale=1.0`. Otherwise the flash kernel applies its default scale | |
| again, yielding an effective `1/head_dim` scale that drifts across layers. | |
| """ | |
| assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." | |
| if not causal and attention_mask_2d is not None: | |
| batch_size, q_len = query_states.shape[:2] | |
| ( | |
| query_states, key_states, value_states, | |
| indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), | |
| ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) | |
| attn_output_unpad = _kernels_flash_varlen_forward( | |
| query_states=query_states, key_states=key_states, value_states=value_states, | |
| cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, | |
| softmax_scale=softmax_scale, | |
| ) | |
| return pad_input(attn_output_unpad, indices_q, batch_size, q_len) | |
| else: | |
| return _kernels_flash_forward( | |
| query_states=query_states, key_states=key_states, value_states=value_states, | |
| causal=causal, softmax_scale=softmax_scale, | |
| ) | |
| # User-facing backend strings resolve to this enum before attention dispatch. | |
| class AttentionBackend(Enum): | |
| AUTO = "auto" | |
| KERNELS_FLASH = "kernels_flash" | |
| FLEX = "flex" | |
| SDPA = "sdpa" | |
| VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) | |
| _BACKEND_CONFIRMED = False | |
| def resolve_attention_backend(requested_backend: str) -> AttentionBackend: | |
| global _BACKEND_CONFIRMED | |
| assert requested_backend in VALID_ATTENTION_BACKENDS, ( | |
| f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." | |
| ) | |
| if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value): | |
| _ensure_flash_kernels_loaded() | |
| if requested_backend == AttentionBackend.AUTO.value: | |
| if FLASH_KERNEL is not None: | |
| resolved = AttentionBackend.KERNELS_FLASH | |
| elif flex_attention is not None: | |
| resolved = AttentionBackend.FLEX | |
| else: | |
| resolved = AttentionBackend.SDPA | |
| elif requested_backend == AttentionBackend.KERNELS_FLASH.value: | |
| assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." | |
| resolved = AttentionBackend.KERNELS_FLASH | |
| elif requested_backend == AttentionBackend.FLEX.value: | |
| assert flex_attention is not None, "Flex Attention is not available in this environment." | |
| resolved = AttentionBackend.FLEX | |
| elif requested_backend == AttentionBackend.SDPA.value: | |
| resolved = AttentionBackend.SDPA | |
| else: | |
| raise AssertionError(f"Unsupported attention backend: {requested_backend}") | |
| if not _BACKEND_CONFIRMED: | |
| print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") | |
| _BACKEND_CONFIRMED = True | |
| return resolved | |
| def get_attention_mask( | |
| effective_backend: AttentionBackend, | |
| batch_size: int, | |
| seq_len: int, | |
| device: torch.device, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]: | |
| """Build padding masks once for all encoder layers. | |
| Returns (attention_mask_2d, attention_mask_4d, flex_block_mask). | |
| """ | |
| if attention_mask is None: | |
| return None, None, None | |
| attention_mask_2d = attention_mask.bool() | |
| if effective_backend == AttentionBackend.KERNELS_FLASH: | |
| return attention_mask_2d, None, None | |
| if effective_backend == AttentionBackend.FLEX: | |
| assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." | |
| valid_lens = attention_mask_2d.sum(dim=-1) | |
| def mask_mod(batch_idx, head_idx, q_idx, kv_idx): | |
| return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) | |
| flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) | |
| return attention_mask_2d, None, flex_block_mask | |
| # SDPA/manual masks only keys. Padding queries still attend to real keys, so | |
| # their outputs stay finite instead of softmaxing over all -inf scores. | |
| attention_mask_4d = attention_mask_2d[:, None, None, :] | |
| return attention_mask_2d, attention_mask_4d, None | |
| def bool_to_additive_mask( | |
| bool_mask: torch.Tensor, | |
| dtype: torch.dtype, | |
| ) -> torch.Tensor: | |
| """Convert a bool mask (True = valid) to a float additive mask (0.0 valid, -inf invalid). | |
| Why this exists: calling `bool_mask.masked_fill(bool_mask.logical_not(), float('-inf'))` | |
| directly on a bool tensor returns a bool tensor because `-inf` casts to `True`. | |
| That silently drops the mask. Always allocate a float tensor first, then fill it. | |
| This helper is the sanctioned way to build an SDPA additive mask from a bool validity mask. | |
| """ | |
| assert bool_mask.dtype == torch.bool, ( | |
| f"bool_to_additive_mask requires a bool tensor, got dtype={bool_mask.dtype}" | |
| ) | |
| additive = torch.zeros_like(bool_mask, dtype=dtype) | |
| additive.masked_fill_(bool_mask.logical_not(), float("-inf")) | |
| return additive | |
| """FastESMFold: Self-contained ESMFold with FastESM2 attention backends + built-in Test-Time Training. | |
| Usage: | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained("Synthyra/FastESMFold", trust_remote_code=True).cuda() | |
| # Basic folding | |
| result = model.fold_protein("MKTLLILAVVA...") | |
| print(result["plddt"], result["pdb_string"][:100]) | |
| # Folding with TTT (test-time training improves structure prediction) | |
| result = model.fold_protein("MKTLLILAVVA...", ttt=True) | |
| Dependencies: torch, transformers, einops, peft (for LoRA TTT only) | |
| No dependency on: esm (fair-esm), proteinttt, openfold | |
| """ | |
| import copy | |
| from dataclasses import dataclass, field | |
| from functools import wraps | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from einops import rearrange | |
| from transformers import EsmTokenizer, PretrainedConfig, PreTrainedModel | |
| from transformers.modeling_outputs import ModelOutput | |
| from transformers.models.esm.configuration_esm import EsmConfig | |
| from transformers.models.esm.modeling_esm import ( | |
| EsmContactPredictionHead, | |
| EsmEmbeddings, | |
| EsmIntermediate, | |
| EsmLMHead, | |
| EsmOutput, | |
| EsmSelfOutput, | |
| RotaryEmbedding, | |
| ) | |
| from transformers.models.esm.modeling_esmfold import EsmForProteinFolding | |
| # ============================================================================= | |
| # Output Dataclass | |
| # ============================================================================= | |
| class FastEsmEncoderOutput(ModelOutput): | |
| last_hidden_state: Optional[torch.Tensor] = None | |
| hidden_states: Optional[Tuple[torch.Tensor, ...]] = None | |
| attentions: Optional[Tuple[torch.Tensor, ...]] = None | |
| # ============================================================================= | |
| # FastESM2 Attention Layers (multi-backend: SDPA, Flash, Flex) | |
| # ============================================================================= | |
| class EsmSelfAttention(nn.Module): | |
| def __init__(self, config, position_embedding_type: Optional[str] = None): | |
| super().__init__() | |
| assert config.hidden_size % config.num_attention_heads == 0, ( | |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " | |
| f"heads ({config.num_attention_heads})" | |
| ) | |
| self.num_attention_heads = config.num_attention_heads | |
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) | |
| self.all_head_size = self.num_attention_heads * self.attention_head_size | |
| self.query = nn.Linear(config.hidden_size, self.all_head_size) | |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) | |
| self.value = nn.Linear(config.hidden_size, self.all_head_size) | |
| self.scale = self.attention_head_size**-0.5 | |
| self.dropout_prob = config.attention_probs_dropout_prob | |
| self.config = config | |
| self.attn_backend = resolve_attention_backend(config.attn_backend) | |
| self.position_embedding_type = position_embedding_type or config.position_embedding_type | |
| self.rotary_embeddings = None | |
| if self.position_embedding_type == "rotary": | |
| self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask_2d: Optional[torch.Tensor] = None, | |
| attention_mask_4d: Optional[torch.Tensor] = None, | |
| flex_block_mask: Optional[BlockMask] = None, | |
| output_attentions: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| batch_size, seq_length = hidden_states.shape[:-1] | |
| hidden_shape = (batch_size, seq_length, -1, self.attention_head_size) | |
| query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2) | |
| key_BHLD = self.key(hidden_states).view(hidden_shape).transpose(1, 2) | |
| value_BHLD = self.value(hidden_states).view(hidden_shape).transpose(1, 2) | |
| query_BHLD = query_BHLD * self.scale | |
| if self.position_embedding_type == "rotary": | |
| query_BHLD, key_BHLD = self.rotary_embeddings(query_BHLD, key_BHLD) | |
| attn_output, attn_weights = self._attn( | |
| query_BHLD, key_BHLD, value_BHLD, | |
| attention_mask_2d=attention_mask_2d, | |
| attention_mask_4d=attention_mask_4d, | |
| flex_block_mask=flex_block_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| return attn_output, attn_weights | |
| def _attn( | |
| self, | |
| query_BHLD: torch.Tensor, | |
| key_BHLD: torch.Tensor, | |
| value_BHLD: torch.Tensor, | |
| attention_mask_2d: Optional[torch.Tensor] = None, | |
| attention_mask_4d: Optional[torch.Tensor] = None, | |
| flex_block_mask: Optional[BlockMask] = None, | |
| output_attentions: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| if output_attentions: | |
| return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) | |
| if self.attn_backend == AttentionBackend.KERNELS_FLASH: | |
| return self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d) | |
| elif self.attn_backend == AttentionBackend.FLEX: | |
| return self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask) | |
| elif self.attn_backend == AttentionBackend.SDPA: | |
| return self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) | |
| else: | |
| raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") | |
| def _manual_attn( | |
| self, | |
| query_BHLD: torch.Tensor, | |
| key_BHLD: torch.Tensor, | |
| value_BHLD: torch.Tensor, | |
| attention_mask_4d: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2)) | |
| if attention_mask_4d is not None: | |
| attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) | |
| attn_weights = F.softmax(attn_weights, dim=-1) | |
| if self.dropout_prob > 0 and self.training: | |
| attn_weights = F.dropout(attn_weights, p=self.dropout_prob, training=self.training) | |
| context_BHLD = torch.matmul(attn_weights, value_BHLD) | |
| attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)") | |
| return attn_output, attn_weights | |
| def _kernels_flash_attn( | |
| self, | |
| query_BHLD: torch.Tensor, | |
| key_BHLD: torch.Tensor, | |
| value_BHLD: torch.Tensor, | |
| attention_mask_2d: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, None]: | |
| query_BLHD = query_BHLD.transpose(1, 2).contiguous() | |
| key_BLHD = key_BHLD.transpose(1, 2).contiguous() | |
| value_BLHD = value_BHLD.transpose(1, 2).contiguous() | |
| # Q is pre-scaled by self.scale in forward() -- pass softmax_scale=1.0 | |
| # to prevent the kernel from applying its default 1/sqrt(head_dim). | |
| attn_output = kernels_flash_attention_func( | |
| query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD, | |
| attention_mask_2d=attention_mask_2d, causal=False, | |
| softmax_scale=1.0, | |
| ) | |
| return rearrange(attn_output, "b s h d -> b s (h d)"), None | |
| def _flex_attn( | |
| self, | |
| query_BHLD: torch.Tensor, | |
| key_BHLD: torch.Tensor, | |
| value_BHLD: torch.Tensor, | |
| flex_block_mask: Optional[BlockMask] = None, | |
| ) -> Tuple[torch.Tensor, None]: | |
| assert flex_attention is not None, "Flex attention is not available in this environment." | |
| fn = _get_flex_attention_fn() | |
| context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0) | |
| return rearrange(context_BHLD, "b h s d -> b s (h d)"), None | |
| def _sdpa_attn( | |
| self, | |
| query_BHLD: torch.Tensor, | |
| key_BHLD: torch.Tensor, | |
| value_BHLD: torch.Tensor, | |
| attention_mask_4d: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, None]: | |
| context_BHLD = F.scaled_dot_product_attention( | |
| query_BHLD, key_BHLD, value_BHLD, | |
| attn_mask=attention_mask_4d, | |
| dropout_p=self.dropout_prob if self.training else 0.0, | |
| scale=1.0, | |
| ) | |
| return rearrange(context_BHLD, "b h s d -> b s (h d)"), None | |
| class EsmAttention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.self = EsmSelfAttention(config) | |
| self.output = EsmSelfOutput(config) | |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask_2d: Optional[torch.Tensor] = None, | |
| attention_mask_4d: Optional[torch.Tensor] = None, | |
| flex_block_mask: Optional[BlockMask] = None, | |
| output_attentions: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| hidden_states_ln = self.LayerNorm(hidden_states) | |
| attn_output, attn_weights = self.self( | |
| hidden_states_ln, | |
| attention_mask_2d=attention_mask_2d, | |
| attention_mask_4d=attention_mask_4d, | |
| flex_block_mask=flex_block_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| attention_output = self.output(attn_output, hidden_states) | |
| return attention_output, attn_weights | |
| class EsmLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.attention = EsmAttention(config) | |
| self.intermediate = EsmIntermediate(config) | |
| self.output = EsmOutput(config) | |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask_2d: Optional[torch.Tensor] = None, | |
| attention_mask_4d: Optional[torch.Tensor] = None, | |
| flex_block_mask: Optional[BlockMask] = None, | |
| output_attentions: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| attention_output, attn_weights = self.attention( | |
| hidden_states, | |
| attention_mask_2d=attention_mask_2d, | |
| attention_mask_4d=attention_mask_4d, | |
| flex_block_mask=flex_block_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| layer_output = self._feed_forward(attention_output) | |
| return layer_output, attn_weights | |
| def _feed_forward(self, attention_output: torch.Tensor) -> torch.Tensor: | |
| attention_output_ln = self.LayerNorm(attention_output) | |
| intermediate_output = self.intermediate(attention_output_ln) | |
| return self.output(intermediate_output, attention_output) | |
| class FastEsmEncoder(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.attention_backend = resolve_attention_backend(config.attn_backend) | |
| self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)]) | |
| self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| output_hidden_states: bool = False, | |
| output_attentions: bool = False, | |
| ) -> FastEsmEncoderOutput: | |
| all_hidden_states = () if output_hidden_states else None | |
| all_attentions = () if output_attentions else None | |
| attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask( | |
| effective_backend=self.attention_backend, | |
| batch_size=hidden_states.shape[0], | |
| seq_len=hidden_states.shape[1], | |
| device=hidden_states.device, | |
| attention_mask=attention_mask, | |
| ) | |
| for layer_module in self.layer: | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states,) | |
| hidden_states, attn_weights = layer_module( | |
| hidden_states, | |
| attention_mask_2d=attention_mask_2d, | |
| attention_mask_4d=attention_mask_4d, | |
| flex_block_mask=flex_block_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| if all_attentions is not None: | |
| all_attentions = all_attentions + (attn_weights,) | |
| if self.emb_layer_norm_after: | |
| hidden_states = self.emb_layer_norm_after(hidden_states) | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states,) | |
| return FastEsmEncoderOutput( | |
| last_hidden_state=hidden_states, | |
| hidden_states=all_hidden_states, | |
| attentions=all_attentions, | |
| ) | |
| # ============================================================================= | |
| # FastESM Backbone (replaces EsmModel inside ESMFold) | |
| # ============================================================================= | |
| class FastEsmBackbone(nn.Module): | |
| """FastESM2 backbone with multi-backend attention. Drop-in replacement for | |
| transformers.EsmModel inside EsmForProteinFolding. | |
| State dict keys match HuggingFace EsmModel exactly, so pretrained weights | |
| load without any key remapping. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.embeddings = EsmEmbeddings(config) | |
| self.encoder = FastEsmEncoder(config) | |
| self.contact_head = EsmContactPredictionHead( | |
| in_features=config.num_hidden_layers * config.num_attention_heads, bias=True | |
| ) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| **kwargs, | |
| ) -> FastEsmEncoderOutput: | |
| output_attentions = output_attentions if output_attentions is not None else False | |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else False | |
| token_embedding_output = self.embeddings( | |
| input_ids=input_ids, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| inputs_embeds=inputs_embeds, | |
| ) | |
| encoder_outputs = self.encoder( | |
| token_embedding_output, | |
| attention_mask=attention_mask, | |
| output_hidden_states=output_hidden_states, | |
| output_attentions=output_attentions, | |
| ) | |
| return FastEsmEncoderOutput( | |
| last_hidden_state=encoder_outputs.last_hidden_state, | |
| hidden_states=encoder_outputs.hidden_states, | |
| attentions=encoder_outputs.attentions, | |
| ) | |
| # ============================================================================= | |
| # TTT (Test-Time Training) Configuration and Utilities | |
| # ============================================================================= | |
| _ESM_STANDARD_AA = list("ACDEFGHIKLMNPQRSTVWY") | |
| class LoraInjectedLinear(nn.Module): | |
| """LoRA-augmented linear layer matching lora_diffusion's behavior. | |
| Replaces an existing nn.Linear with base(x) + lora_up(lora_down(x)) * scale. | |
| Initialization follows cloneofsimo/lora: down=Normal(0, 1/r), up=zeros. | |
| """ | |
| def __init__(self, original_linear: nn.Linear, r: int = 4, scale: float = 1.0): | |
| super().__init__() | |
| self.linear = original_linear | |
| in_features = original_linear.in_features | |
| out_features = original_linear.out_features | |
| assert r <= min(in_features, out_features), f"LoRA rank {r} exceeds dimensions ({in_features}, {out_features})" | |
| self.lora_down = nn.Linear(in_features, r, bias=False) | |
| self.lora_up = nn.Linear(r, out_features, bias=False) | |
| self.scale = scale | |
| nn.init.normal_(self.lora_down.weight, std=1.0 / r) | |
| nn.init.zeros_(self.lora_up.weight) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.linear(x) + self.lora_up(self.lora_down(x)) * self.scale | |
| def inject_trainable_lora( | |
| model: nn.Module, | |
| target_class_name: str, | |
| r: int, | |
| scale: float, | |
| ) -> List[nn.Parameter]: | |
| """Replace nn.Linear layers inside modules matching target_class_name with LoRA. | |
| Matches lora_diffusion's inject_trainable_lora behavior: finds all modules whose | |
| class name matches target_class_name, then replaces their nn.Linear children with | |
| LoraInjectedLinear. Returns the list of trainable LoRA parameters. | |
| """ | |
| lora_params: List[nn.Parameter] = [] | |
| for _parent_name, parent_module in model.named_modules(): | |
| if parent_module.__class__.__name__ != target_class_name: | |
| continue | |
| for child_name, child_module in list(parent_module.named_children()): | |
| if not isinstance(child_module, nn.Linear): | |
| continue | |
| lora_linear = LoraInjectedLinear(child_module, r=r, scale=scale) | |
| lora_linear = lora_linear.to( | |
| device=child_module.weight.device, | |
| dtype=child_module.weight.dtype, | |
| ) | |
| setattr(parent_module, child_name, lora_linear) | |
| lora_params.extend(lora_linear.lora_down.parameters()) | |
| lora_params.extend(lora_linear.lora_up.parameters()) | |
| return lora_params | |
| class TTTConfig: | |
| lr: float = 4e-4 | |
| ags: int = 4 | |
| steps: int = 10 | |
| batch_size: int = 4 | |
| mask_ratio: float = 0.15 | |
| crop_size: int = 1024 | |
| bert_leave_prob: float = 0.1 | |
| bert_replace_prob: float = 0.1 | |
| optimizer: str = "sgd" | |
| momentum: float = 0.0 | |
| weight_decay: float = 0.0 | |
| seed: Optional[int] = 0 | |
| initial_state_reset: bool = True | |
| freeze_embeddings: bool = True | |
| lora_rank: int = 8 | |
| lora_alpha: float = 32.0 | |
| lora_target_class: str = "EsmSelfAttention" | |
| def verify(self) -> None: | |
| assert self.lr > 0.0, "TTT learning rate must be positive." | |
| assert self.ags > 0, "TTT ags must be positive." | |
| assert self.steps >= 0, "TTT steps must be non-negative." | |
| assert self.batch_size > 0, "TTT batch_size must be positive." | |
| assert 0.0 < self.mask_ratio <= 1.0, "TTT mask_ratio must be in (0, 1]." | |
| assert self.crop_size > 0, "TTT crop_size must be positive." | |
| assert 0.0 <= self.bert_leave_prob <= 1.0 | |
| assert 0.0 <= self.bert_replace_prob <= 1.0 | |
| assert self.bert_leave_prob + self.bert_replace_prob <= 1.0 | |
| assert self.optimizer in {"sgd", "adamw"} | |
| assert self.lora_rank >= 0 | |
| assert self.lora_alpha > 0.0 | |
| def preserve_model_state(func: Callable[..., Any]) -> Callable[..., Any]: | |
| def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: | |
| was_training = self.training | |
| original_device = next(self.parameters()).device | |
| original_requires_grad = { | |
| name: parameter.requires_grad | |
| for name, parameter in self.named_parameters() | |
| } | |
| try: | |
| return func(self, *args, **kwargs) | |
| finally: | |
| self.train(was_training) | |
| self.to(original_device) | |
| for name, parameter in self.named_parameters(): | |
| if name in original_requires_grad: | |
| parameter.requires_grad = original_requires_grad[name] | |
| else: | |
| parameter.requires_grad = False | |
| return wrapper | |
| # ============================================================================= | |
| # FastEsmFoldConfig | |
| # ============================================================================= | |
| class FastEsmFoldConfig(EsmConfig): | |
| model_type = "fast_esmfold" | |
| def __init__(self, attn_backend: str = "sdpa", ttt_config: Optional[Dict[str, Any]] = None, **kwargs): | |
| super().__init__(**kwargs) | |
| self.attn_backend = attn_backend | |
| self.ttt_config = ttt_config or { | |
| "lr": 4e-4, | |
| "steps": 10, | |
| "lora_rank": 8, | |
| "lora_alpha": 32.0, | |
| } | |
| # ============================================================================= | |
| # FastEsmForProteinFolding | |
| # ============================================================================= | |
| class FastEsmForProteinFolding(EsmForProteinFolding): | |
| """ESMFold with FastESM2 attention backends + built-in Test-Time Training. | |
| Inherits all folding logic (trunk, structure module, output_to_pdb, infer) | |
| from transformers.EsmForProteinFolding. Replaces the ESM2 backbone with | |
| FastESM2 for optimized attention and adds TTT for improved structure prediction. | |
| Key API: | |
| result = model.fold_protein("MKTL...", ttt=True) | |
| # result = {"plddt": float, "ptm": float, "pdb_string": str} | |
| """ | |
| config_class = FastEsmFoldConfig | |
| def __init__(self, config: FastEsmFoldConfig): | |
| super().__init__(config) | |
| # Replace standard ESM2 backbone with FastESM2 (multi-backend attention) | |
| # unless use_standard_backbone is set (for TTT debugging/compatibility) | |
| if not config.ttt_config.get("use_standard_backbone", False): | |
| self.esm = FastEsmBackbone(config) | |
| self.esm.requires_grad_(False) | |
| if config.esmfold_config.fp16_esm: | |
| self.esm.half() | |
| # MLM head for TTT (pretrained EsmLMHead: Dense -> GELU -> LN -> Linear) | |
| self.mlm_head = EsmLMHead(config) | |
| # TTT state (lazy initialization) | |
| ttt_kwargs = {k: v for k, v in config.ttt_config.items() if k != "use_standard_backbone"} | |
| self._ttt_cfg = TTTConfig(**ttt_kwargs) | |
| self._ttt_cfg.verify() | |
| self._ttt_initialized = False | |
| self._ttt_initial_state = None | |
| self._ttt_generator = torch.Generator() | |
| if self._ttt_cfg.seed is not None: | |
| self._ttt_generator.manual_seed(self._ttt_cfg.seed) | |
| self._non_special_tokens_cache = None | |
| self._ttt_tokenizer = None | |
| def _get_ttt_tokenizer(self) -> EsmTokenizer: | |
| if self._ttt_tokenizer is None: | |
| self._ttt_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") | |
| return self._ttt_tokenizer | |
| def _ensure_ttt_ready(self) -> None: | |
| """Lazy TTT initialization. Injects LoRA adapters and saves initial state. | |
| Must be called after weights are loaded (not in __init__).""" | |
| if self._ttt_initialized: | |
| return | |
| self._ttt_initialized = True | |
| tokenizer = self._get_ttt_tokenizer() | |
| vocab = tokenizer.get_vocab() | |
| self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab] | |
| if self._ttt_cfg.lora_rank > 0: | |
| self.mlm_head.eval() | |
| for p in self.mlm_head.parameters(): | |
| p.requires_grad = False | |
| # Seed global state before LoRA init for reproducible weight initialization | |
| if self._ttt_cfg.seed is not None: | |
| torch.manual_seed(self._ttt_cfg.seed) | |
| self._inject_lora() | |
| else: | |
| # Legacy path: jointly-trained random linear projection head | |
| H = self.config.hidden_size | |
| V = self.config.vocab_size | |
| device = next(self.esm.parameters()).device | |
| self._ttt_lm_proj = nn.Linear(H, V, bias=True).to(device) | |
| if self._ttt_cfg.initial_state_reset: | |
| self._ttt_initial_state = self._ttt_get_state() | |
| def _uses_lora(self) -> bool: | |
| return self._ttt_cfg.lora_rank > 0 | |
| def _inject_lora(self) -> None: | |
| """Inject LoRA adapters into ESM2 attention layers (matching lora_diffusion behavior).""" | |
| self._lora_params = inject_trainable_lora( | |
| self.esm, | |
| target_class_name=self._ttt_cfg.lora_target_class, | |
| r=self._ttt_cfg.lora_rank, | |
| scale=self._ttt_cfg.lora_alpha, | |
| ) | |
| assert len(self._lora_params) > 0, ( | |
| f"No LoRA params injected. Check target_class_name='{self._ttt_cfg.lora_target_class}' " | |
| f"matches attention modules in the backbone." | |
| ) | |
| # ---- TTT State Management ---- | |
| def _get_lora_modules(self) -> List[LoraInjectedLinear]: | |
| """Find all LoraInjectedLinear modules in the backbone.""" | |
| return [m for m in self.esm.modules() if isinstance(m, LoraInjectedLinear)] | |
| def _ttt_get_state(self) -> Dict[str, Any]: | |
| if self._uses_lora: | |
| lora_state = [] | |
| for m in self._get_lora_modules(): | |
| lora_state.append({ | |
| "down": m.lora_down.weight.data.clone(), | |
| "up": m.lora_up.weight.data.clone(), | |
| }) | |
| return {"_lora_state": lora_state} | |
| return { | |
| "esm": copy.deepcopy(self.esm), | |
| "_ttt_lm_proj": copy.deepcopy(self._ttt_lm_proj), | |
| } | |
| def _ttt_set_state(self, state: Dict[str, Any]) -> None: | |
| if "_lora_state" in state: | |
| modules = self._get_lora_modules() | |
| assert len(modules) == len(state["_lora_state"]) | |
| for m, saved in zip(modules, state["_lora_state"]): | |
| m.lora_down.weight.data.copy_(saved["down"]) | |
| m.lora_up.weight.data.copy_(saved["up"]) | |
| return | |
| if "esm" in state: | |
| self.esm = copy.deepcopy(state["esm"]) | |
| if "_ttt_lm_proj" in state: | |
| self._ttt_lm_proj = copy.deepcopy(state["_ttt_lm_proj"]) | |
| def ttt_reset(self) -> None: | |
| """Reset model to pre-TTT state (restore initial LoRA or backbone weights).""" | |
| assert self._ttt_initial_state is not None, "TTT reset requires initial_state_reset=True." | |
| self._ttt_set_state(self._ttt_initial_state) | |
| # ---- TTT Core ---- | |
| def _ttt_tokenize(self, seq: str) -> torch.Tensor: | |
| tokenizer = self._get_ttt_tokenizer() | |
| out = tokenizer( | |
| seq, | |
| return_tensors="pt", | |
| add_special_tokens=self._uses_lora, | |
| padding=False, | |
| truncation=False, | |
| ) | |
| return out["input_ids"] | |
| def _ttt_mask_token(self) -> int: | |
| return self._get_ttt_tokenizer().mask_token_id | |
| def _ttt_get_non_special_tokens(self) -> List[int]: | |
| if self._non_special_tokens_cache is not None: | |
| return self._non_special_tokens_cache | |
| tokenizer = self._get_ttt_tokenizer() | |
| vocab = tokenizer.get_vocab() | |
| self._non_special_tokens_cache = [vocab[c] for c in _ESM_STANDARD_AA if c in vocab] | |
| return self._non_special_tokens_cache | |
| def _ttt_predict_logits(self, batch: torch.Tensor) -> torch.Tensor: | |
| """Run ESM2 backbone + LM head to get MLM logits.""" | |
| # Temporarily unfreeze backbone for gradient flow during TTT | |
| output = self.esm(input_ids=batch) | |
| hidden = output.last_hidden_state | |
| if self._uses_lora: | |
| return self.mlm_head(hidden) | |
| return self._ttt_lm_proj(hidden) | |
| def _ttt_sample_batch( | |
| self, | |
| x: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| _, seq_len = x.shape | |
| batch_size = self._ttt_cfg.batch_size | |
| crop_size = min(self._ttt_cfg.crop_size, seq_len) | |
| x_expanded = x.expand(batch_size, -1) | |
| if seq_len == crop_size: | |
| start_indices = torch.zeros(batch_size, dtype=torch.long) | |
| else: | |
| start_indices = torch.randint( | |
| 0, seq_len - crop_size + 1, (batch_size,), | |
| generator=self._ttt_generator, | |
| ).to(torch.long) | |
| batch_cropped = torch.stack([ | |
| x_expanded[index, start : start + crop_size] | |
| for index, start in enumerate(start_indices) | |
| ]) | |
| non_special_tokens = set(self._ttt_get_non_special_tokens()) | |
| mask = torch.zeros((batch_size, crop_size), dtype=torch.bool) | |
| mask_token_id = self._ttt_mask_token() | |
| for row_index in range(batch_size): | |
| non_special_positions = [ | |
| col for col in range(crop_size) | |
| if batch_cropped[row_index, col].item() in non_special_tokens | |
| ] | |
| assert len(non_special_positions) > 0, "Sequence must contain at least one non-special token." | |
| num_to_mask = max(1, int(round(len(non_special_positions) * self._ttt_cfg.mask_ratio))) | |
| sampled_indices = torch.randperm( | |
| len(non_special_positions), generator=self._ttt_generator, | |
| )[:num_to_mask] | |
| positions_to_mask = torch.tensor(non_special_positions, dtype=torch.long)[sampled_indices] | |
| mask[row_index, positions_to_mask] = True | |
| batch_masked = batch_cropped.clone() | |
| for row_index in range(batch_size): | |
| masked_positions = torch.nonzero(mask[row_index], as_tuple=True)[0] | |
| for masked_position in masked_positions: | |
| probability = float(torch.rand(1, generator=self._ttt_generator).item()) | |
| if probability < 1.0 - self._ttt_cfg.bert_leave_prob - self._ttt_cfg.bert_replace_prob: | |
| batch_masked[row_index, masked_position] = mask_token_id | |
| continue | |
| if probability < 1.0 - self._ttt_cfg.bert_leave_prob: | |
| replacement_candidates = self._ttt_get_non_special_tokens() | |
| replacement_index = int(torch.randint( | |
| 0, len(replacement_candidates), (1,), generator=self._ttt_generator, | |
| ).item()) | |
| batch_masked[row_index, masked_position] = replacement_candidates[replacement_index] | |
| return batch_masked, batch_cropped, mask, start_indices | |
| def _ttt_cross_entropy_loss( | |
| self, | |
| logits: torch.Tensor, | |
| targets: torch.Tensor, | |
| mask: torch.Tensor, | |
| ) -> torch.Tensor: | |
| assert logits.ndim == 3, "Logits must be [batch, seq, vocab]." | |
| _, _, vocab_size = logits.shape | |
| logits_flat = logits.reshape(-1, vocab_size) | |
| targets_flat = targets.reshape(-1) | |
| mask_flat = mask.reshape(-1) | |
| assert int(mask_flat.sum().item()) > 0, "TTT mask must select at least one token." | |
| loss = F.cross_entropy( | |
| logits_flat[mask_flat], | |
| targets_flat[mask_flat], | |
| reduction="none", | |
| ) | |
| masked_tokens_per_seq = mask.sum(dim=1).tolist() | |
| per_sequence_losses = torch.split(loss, masked_tokens_per_seq) | |
| return torch.stack([sl.mean() for sl in per_sequence_losses]).mean() | |
| def _ttt_get_optimizer(self, parameters) -> torch.optim.Optimizer: | |
| if self._ttt_cfg.optimizer == "sgd": | |
| return torch.optim.SGD( | |
| parameters, | |
| lr=self._ttt_cfg.lr, | |
| momentum=self._ttt_cfg.momentum, | |
| weight_decay=self._ttt_cfg.weight_decay, | |
| ) | |
| return torch.optim.AdamW( | |
| parameters, | |
| lr=self._ttt_cfg.lr, | |
| weight_decay=self._ttt_cfg.weight_decay, | |
| ) | |
| def _lora_ttt(self, seq: str) -> Dict[str, List[float]]: | |
| """LoRA TTT: only LoRA adapter weights are trained, mlm_head is frozen.""" | |
| x = self._ttt_tokenize(seq) | |
| device = next(self.parameters()).device | |
| non_blocking = device.type == "cuda" | |
| losses = [] | |
| if self._ttt_cfg.steps == 0: | |
| return {"losses": losses} | |
| for parameter in self.parameters(): | |
| parameter.requires_grad = False | |
| for p in self._lora_params: | |
| p.requires_grad = True | |
| optimizer = self._ttt_get_optimizer(self._lora_params) | |
| optimizer.zero_grad(set_to_none=True) | |
| self.eval() | |
| for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): | |
| batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x) | |
| batch_masked = batch_masked.to(device, non_blocking=non_blocking) | |
| targets = targets.to(device, non_blocking=non_blocking) | |
| mask = mask.to(device, non_blocking=non_blocking) | |
| self.train() | |
| logits = self._ttt_predict_logits(batch_masked) | |
| loss = self._ttt_cross_entropy_loss(logits, targets, mask) | |
| loss.backward() | |
| losses.append(float(loss.detach().cpu().item())) | |
| if (step + 1) % self._ttt_cfg.ags == 0: | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| self.eval() | |
| return {"losses": losses} | |
| def _legacy_ttt(self, seq: str) -> Dict[str, List[float]]: | |
| """Legacy TTT: full fine-tuning of ESM2 backbone with random linear projection head.""" | |
| x = self._ttt_tokenize(seq) | |
| device = next(self.parameters()).device | |
| non_blocking = device.type == "cuda" | |
| losses = [] | |
| if self._ttt_cfg.steps == 0: | |
| return {"losses": losses} | |
| # Full fine-tune: all backbone params trainable | |
| for parameter in self.parameters(): | |
| parameter.requires_grad = False | |
| for parameter in self.esm.parameters(): | |
| parameter.requires_grad = True | |
| if self._ttt_cfg.freeze_embeddings: | |
| for parameter in self.esm.embeddings.parameters(): | |
| parameter.requires_grad = False | |
| for parameter in self._ttt_lm_proj.parameters(): | |
| parameter.requires_grad = True | |
| trainable_params = filter(lambda p: p.requires_grad, self.parameters()) | |
| optimizer = self._ttt_get_optimizer(trainable_params) | |
| optimizer.zero_grad(set_to_none=True) | |
| self.eval() | |
| for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): | |
| batch_masked, targets, mask, start_indices = self._ttt_sample_batch(x) | |
| batch_masked = batch_masked.to(device, non_blocking=non_blocking) | |
| targets = targets.to(device, non_blocking=non_blocking) | |
| mask = mask.to(device, non_blocking=non_blocking) | |
| self.train() | |
| logits = self._ttt_predict_logits(batch_masked) | |
| loss = self._ttt_cross_entropy_loss(logits, targets, mask) | |
| loss.backward() | |
| losses.append(float(loss.detach().cpu().item())) | |
| if (step + 1) % self._ttt_cfg.ags == 0: | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| self.eval() | |
| return {"losses": losses} | |
| def ttt(self, seq: str) -> Dict[str, List[float]]: | |
| """Run test-time training on a single sequence using masked language modeling. | |
| Adapts the ESM2 backbone (via LoRA or full fine-tuning) to the input sequence | |
| before structure prediction. Call fold_protein(seq, ttt=True) for the full pipeline. | |
| Args: | |
| seq: Protein sequence (single-letter amino acid codes) | |
| Returns: | |
| Dict with "losses" key containing per-step MLM loss values | |
| """ | |
| self._ensure_ttt_ready() | |
| # TTT requires fp32 for stable gradient computation. ESMFold typically | |
| # runs the backbone in fp16, but small LoRA updates vanish in half precision. | |
| esm_dtype = next(self.esm.parameters()).dtype | |
| if esm_dtype != torch.float32: | |
| self.esm.float() | |
| self.mlm_head.float() | |
| if self._uses_lora: | |
| result = self._lora_ttt(seq) | |
| else: | |
| result = self._legacy_ttt(seq) | |
| # Restore original dtype (backbone back to fp16 for inference) | |
| if esm_dtype != torch.float32: | |
| self.esm.to(esm_dtype) | |
| self.mlm_head.to(esm_dtype) | |
| return result | |
| # ---- High-Level API ---- | |
| def _fold_single(self, sequence: str, return_pdb_string: bool = True) -> Dict[str, Any]: | |
| """Fold a sequence once and return pLDDT, ptm, and optionally PDB string.""" | |
| with torch.no_grad(): | |
| output = self.infer(sequence) | |
| plddt = output["plddt"] | |
| # plddt shape is (batch, L, 37) - per-atom across atom37 types. | |
| # Use CA atom (index 1) only, matching PDB B-factor output. | |
| if plddt.dim() == 3: | |
| mean_plddt = float(plddt[:, :, 1].mean().item()) | |
| elif plddt.dim() == 2: | |
| mean_plddt = float(plddt[:, 1].mean().item()) | |
| else: | |
| mean_plddt = float(plddt.mean().item()) | |
| result = { | |
| "plddt": mean_plddt, | |
| "ptm": float(output["ptm"].item()) if "ptm" in output else None, | |
| } | |
| if return_pdb_string: | |
| pdb_strings = self.output_to_pdb(output) | |
| result["pdb_string"] = pdb_strings[0] if isinstance(pdb_strings, list) else pdb_strings | |
| return result | |
| def fold_protein( | |
| self, | |
| sequence: str, | |
| return_pdb_string: bool = True, | |
| ) -> Dict[str, Any]: | |
| """Fold a protein sequence with test-time training. | |
| Runs TTT (masked language model adaptation via LoRA) for the configured | |
| number of steps, folding after each optimizer step to track pLDDT. Returns | |
| the structure with the highest pLDDT across all steps (including baseline). | |
| Args: | |
| sequence: Protein sequence (single-letter amino acid codes) | |
| return_pdb_string: If True, include PDB string in output | |
| Returns: | |
| Dict with keys: | |
| - plddt: float, best mean pLDDT across all TTT steps | |
| - ptm: float, predicted TM-score from best step | |
| - pdb_string: str (if return_pdb_string=True), PDB from best step | |
| - step_plddts: list[float], pLDDT at each step [baseline, s1, ..., s10] | |
| - best_step: int, which step produced best structure (0=baseline) | |
| """ | |
| self._ensure_ttt_ready() | |
| # Cast to fp32 for TTT stability | |
| esm_dtype = next(self.esm.parameters()).dtype | |
| if esm_dtype != torch.float32: | |
| self.esm.float() | |
| self.mlm_head.float() | |
| device = next(self.parameters()).device | |
| non_blocking = device.type == "cuda" | |
| # Step 0: baseline fold (no TTT adaptation) | |
| best = self._fold_single(sequence, return_pdb_string=return_pdb_string) | |
| step_plddts = [best["plddt"]] | |
| if self._ttt_cfg.steps > 0: | |
| # Tokenize for masked LM training | |
| x = self._ttt_tokenize(sequence) | |
| # Freeze all, unfreeze LoRA | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| if self._uses_lora: | |
| for p in self._lora_params: | |
| p.requires_grad = True | |
| optimizer = self._ttt_get_optimizer(self._lora_params) | |
| else: | |
| for p in self.esm.parameters(): | |
| p.requires_grad = True | |
| if self._ttt_cfg.freeze_embeddings: | |
| for p in self.esm.embeddings.parameters(): | |
| p.requires_grad = False | |
| for p in self._ttt_lm_proj.parameters(): | |
| p.requires_grad = True | |
| trainable = [p for p in self.parameters() if p.requires_grad] | |
| optimizer = self._ttt_get_optimizer(trainable) | |
| optimizer.zero_grad(set_to_none=True) | |
| self.eval() | |
| for step in range(self._ttt_cfg.steps * self._ttt_cfg.ags): | |
| batch_masked, targets, mask, _start = self._ttt_sample_batch(x) | |
| batch_masked = batch_masked.to(device, non_blocking=non_blocking) | |
| targets = targets.to(device, non_blocking=non_blocking) | |
| mask = mask.to(device, non_blocking=non_blocking) | |
| self.train() | |
| logits = self._ttt_predict_logits(batch_masked) | |
| loss = self._ttt_cross_entropy_loss(logits, targets, mask) | |
| loss.backward() | |
| if (step + 1) % self._ttt_cfg.ags == 0: | |
| optimizer.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| # Fold after this optimizer step | |
| self.eval() | |
| current = self._fold_single(sequence, return_pdb_string=return_pdb_string) | |
| step_plddts.append(current["plddt"]) | |
| if current["plddt"] > best["plddt"]: | |
| best = current | |
| self.eval() | |
| # Restore requires_grad | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| # Reset LoRA weights for next sequence | |
| self.ttt_reset() | |
| # Restore dtype | |
| if esm_dtype != torch.float32: | |
| self.esm.to(esm_dtype) | |
| self.mlm_head.to(esm_dtype) | |
| return { | |
| "plddt": best["plddt"], | |
| "ptm": best["ptm"], | |
| "pdb_string": best.get("pdb_string"), | |
| "step_plddts": step_plddts, | |
| "best_step": step_plddts.index(max(step_plddts)), | |
| } | |