| """Transformer class.""" |
|
|
| import logging |
| import math |
| from collections import OrderedDict |
| from pathlib import Path |
| from typing import Literal, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| import yaml |
| from torch import nn |
|
|
| import sys, os |
|
|
| from .utils import blockwise_causal_norm |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _pos_embed_fourier1d_init( |
| cutoff: float = 256, n: int = 32, cutoff_start: float = 1 |
| ): |
| return ( |
| torch.exp(torch.linspace(-math.log(cutoff_start), -math.log(cutoff), n)) |
| .unsqueeze(0) |
| .unsqueeze(0) |
| ) |
|
|
|
|
| def _rope_pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32): |
| |
| return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """Rotate pairs of scalars as 2d vectors by pi/2.""" |
| x = x.unflatten(-1, (-1, 2)) |
| x1, x2 = x.unbind(dim=-1) |
| return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) |
|
|
|
|
| class RotaryPositionalEncoding(nn.Module): |
| def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)): |
| super().__init__() |
| assert len(cutoffs) == len(n_pos) |
| if not all(n % 2 == 0 for n in n_pos): |
| raise ValueError("n_pos must be even") |
|
|
| self._n_dim = len(cutoffs) |
| self.freqs = nn.ParameterList([ |
| nn.Parameter(_rope_pos_embed_fourier1d_init(cutoff, n // 2)) |
| for cutoff, n in zip(cutoffs, n_pos) |
| ]) |
|
|
| def get_co_si(self, coords: torch.Tensor): |
| _B, _N, D = coords.shape |
| assert D == len(self.freqs) |
| co = torch.cat( |
| tuple( |
| torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) |
| for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) |
| ), |
| axis=-1, |
| ) |
| si = torch.cat( |
| tuple( |
| torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) |
| for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) |
| ), |
| axis=-1, |
| ) |
| return co, si |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor): |
| _B, _N, D = coords.shape |
| _B, _H, _N, _C = q.shape |
|
|
| if D != self._n_dim: |
| raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}") |
|
|
| co, si = self.get_co_si(coords) |
| co = co.unsqueeze(1).repeat_interleave(2, dim=-1) |
| si = si.unsqueeze(1).repeat_interleave(2, dim=-1) |
| q2 = q * co + _rotate_half(q) * si |
| k2 = k * co + _rotate_half(k) * si |
| return q2, k2 |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, d_model, expand: float = 2, bias: bool = True): |
| super().__init__() |
| self.fc1 = nn.Linear(d_model, int(d_model * expand)) |
| self.fc2 = nn.Linear(int(d_model * expand), d_model, bias=bias) |
| self.act = nn.GELU() |
|
|
| def forward(self, x): |
| return self.fc2(self.act(self.fc1(x))) |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| def __init__( |
| self, |
| cutoffs: Tuple[float] = (256,), |
| n_pos: Tuple[int] = (32,), |
| cutoffs_start=None, |
| ): |
| super().__init__() |
| if cutoffs_start is None: |
| cutoffs_start = (1,) * len(cutoffs) |
|
|
| assert len(cutoffs) == len(n_pos) |
| self.freqs = nn.ParameterList([ |
| nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2)) |
| for cutoff, n, cutoff_start in zip(cutoffs, n_pos, cutoffs_start) |
| ]) |
|
|
| def forward(self, coords: torch.Tensor): |
| _B, _N, D = coords.shape |
| assert D == len(self.freqs) |
| embed = torch.cat( |
| tuple( |
| torch.cat( |
| ( |
| torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq), |
| torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq), |
| ), |
| axis=-1, |
| ) |
| / math.sqrt(len(freq)) |
| for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) |
| ), |
| axis=-1, |
| ) |
| return embed |
|
|
|
|
| def _bin_init_exp(cutoff: float, n: int): |
| return torch.exp(torch.linspace(0, math.log(cutoff + 1), n)) |
|
|
|
|
| def _bin_init_linear(cutoff: float, n: int): |
| return torch.linspace(-cutoff, cutoff, n) |
|
|
|
|
| class RelativePositionalBias(nn.Module): |
| def __init__( |
| self, |
| n_head: int, |
| cutoff_spatial: float, |
| cutoff_temporal: float, |
| n_spatial: int = 32, |
| n_temporal: int = 16, |
| ): |
| super().__init__() |
| self._spatial_bins = _bin_init_exp(cutoff_spatial, n_spatial) |
| self._temporal_bins = _bin_init_linear(cutoff_temporal, 2 * n_temporal + 1) |
| self.register_buffer("spatial_bins", self._spatial_bins) |
| self.register_buffer("temporal_bins", self._temporal_bins) |
| self.n_spatial = n_spatial |
| self.n_head = n_head |
| self.bias = nn.Parameter( |
| -0.5 + torch.rand((2 * n_temporal + 1) * n_spatial, n_head) |
| ) |
|
|
| def forward(self, coords: torch.Tensor): |
| _B, _N, _D = coords.shape |
| t = coords[..., 0] |
| yx = coords[..., 1:] |
| temporal_dist = t.unsqueeze(-1) - t.unsqueeze(-2) |
| spatial_dist = torch.cdist(yx, yx) |
|
|
| spatial_idx = torch.bucketize(spatial_dist, self.spatial_bins) |
| torch.clamp_(spatial_idx, max=len(self.spatial_bins) - 1) |
| temporal_idx = torch.bucketize(temporal_dist, self.temporal_bins) |
| torch.clamp_(temporal_idx, max=len(self.temporal_bins) - 1) |
|
|
| idx = spatial_idx.flatten() + temporal_idx.flatten() * self.n_spatial |
| bias = self.bias.index_select(0, idx).view((*spatial_idx.shape, self.n_head)) |
| bias = bias.transpose(-1, 1) |
| return bias |
|
|
|
|
| class RelativePositionalAttention(nn.Module): |
| def __init__( |
| self, |
| coord_dim: int, |
| embed_dim: int, |
| n_head: int, |
| cutoff_spatial: float = 256, |
| cutoff_temporal: float = 16, |
| n_spatial: int = 32, |
| n_temporal: int = 16, |
| dropout: float = 0.0, |
| mode: Literal["bias", "rope", "none"] = "bias", |
| attn_dist_mode: str = "v0", |
| ): |
| super().__init__() |
|
|
| if not embed_dim % (2 * n_head) == 0: |
| raise ValueError( |
| f"embed_dim {embed_dim} must be divisible by 2 times n_head {2 * n_head}" |
| ) |
|
|
| self.q_pro = nn.Linear(embed_dim, embed_dim, bias=True) |
| self.k_pro = nn.Linear(embed_dim, embed_dim, bias=True) |
| self.v_pro = nn.Linear(embed_dim, embed_dim, bias=True) |
| self.proj = nn.Linear(embed_dim, embed_dim) |
| self.dropout = dropout |
| self.n_head = n_head |
| self.embed_dim = embed_dim |
| self.cutoff_spatial = cutoff_spatial |
| self.attn_dist_mode = attn_dist_mode |
|
|
| if mode == "bias" or mode is True: |
| self.pos_bias = RelativePositionalBias( |
| n_head=n_head, |
| cutoff_spatial=cutoff_spatial, |
| cutoff_temporal=cutoff_temporal, |
| n_spatial=n_spatial, |
| n_temporal=n_temporal, |
| ) |
| elif mode == "rope": |
| n_split = 2 * (embed_dim // (2 * (coord_dim + 1) * n_head)) |
| self.rot_pos_enc = RotaryPositionalEncoding( |
| cutoffs=((cutoff_temporal,) + (cutoff_spatial,) * coord_dim), |
| n_pos=(embed_dim // n_head - coord_dim * n_split,) |
| + (n_split,) * coord_dim, |
| ) |
| elif mode == "none": |
| pass |
| elif mode is None or mode is False: |
| logger.warning( |
| "attn_positional_bias is not set (None or False), no positional bias." |
| ) |
| else: |
| raise ValueError(f"Unknown mode {mode}") |
|
|
| self._mode = mode |
|
|
| def forward( |
| self, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| coords: torch.Tensor, |
| padding_mask: torch.Tensor = None, |
| ): |
| B, N, D = query.size() |
| q = self.q_pro(query) |
| k = self.k_pro(key) |
| v = self.v_pro(value) |
| k = k.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) |
| q = q.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) |
| v = v.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) |
|
|
| attn_mask = torch.zeros( |
| (B, self.n_head, N, N), device=query.device, dtype=q.dtype |
| ) |
| attn_ignore_val = -1e3 |
|
|
| yx = coords[..., 1:] |
| spatial_dist = torch.cdist(yx, yx) |
| spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1) |
| attn_mask.masked_fill_(spatial_mask, attn_ignore_val) |
|
|
| if coords is not None: |
| if self._mode == "bias": |
| attn_mask = attn_mask + self.pos_bias(coords) |
| elif self._mode == "rope": |
| q, k = self.rot_pos_enc(q, k, coords) |
|
|
| if self.attn_dist_mode == "v0": |
| dist = torch.cdist(coords, coords, p=2) |
| attn_mask += torch.exp(-0.1 * dist.unsqueeze(1)) |
| elif self.attn_dist_mode == "v1": |
| attn_mask += torch.exp( |
| -5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial |
| ) |
| else: |
| raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}") |
|
|
| if padding_mask is not None: |
| ignore_mask = torch.logical_or( |
| padding_mask.unsqueeze(1), padding_mask.unsqueeze(2) |
| ).unsqueeze(1) |
| attn_mask.masked_fill_(ignore_mask, attn_ignore_val) |
|
|
| y = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0 |
| ) |
| y = y.transpose(1, 2).contiguous().view(B, N, D) |
| y = self.proj(y) |
| return y |
|
|
|
|
| class EncoderLayer(nn.Module): |
| def __init__( |
| self, |
| coord_dim: int = 2, |
| d_model=256, |
| num_heads=4, |
| dropout=0.1, |
| cutoff_spatial: int = 256, |
| window: int = 16, |
| positional_bias: Literal["bias", "rope", "none"] = "bias", |
| positional_bias_n_spatial: int = 32, |
| attn_dist_mode: str = "v0", |
| ): |
| super().__init__() |
| self.positional_bias = positional_bias |
| self.attn = RelativePositionalAttention( |
| coord_dim, |
| d_model, |
| num_heads, |
| cutoff_spatial=cutoff_spatial, |
| n_spatial=positional_bias_n_spatial, |
| cutoff_temporal=window, |
| n_temporal=window, |
| dropout=dropout, |
| mode=positional_bias, |
| attn_dist_mode=attn_dist_mode, |
| ) |
| self.mlp = FeedForward(d_model) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| coords: torch.Tensor, |
| padding_mask: torch.Tensor = None, |
| ): |
| x = self.norm1(x) |
|
|
| |
| a = self.attn( |
| x, |
| x, |
| x, |
| coords=coords if self.positional_bias else None, |
| padding_mask=padding_mask, |
| ) |
|
|
| x = x + a |
| x = x + self.mlp(self.norm2(x)) |
|
|
| return x |
|
|
|
|
| class DecoderLayer(nn.Module): |
| def __init__( |
| self, |
| coord_dim: int = 2, |
| d_model=256, |
| num_heads=4, |
| dropout=0.1, |
| window: int = 16, |
| cutoff_spatial: int = 256, |
| positional_bias: Literal["bias", "rope", "none"] = "bias", |
| positional_bias_n_spatial: int = 32, |
| attn_dist_mode: str = "v0", |
| ): |
| super().__init__() |
| self.positional_bias = positional_bias |
| self.attn = RelativePositionalAttention( |
| coord_dim, |
| d_model, |
| num_heads, |
| cutoff_spatial=cutoff_spatial, |
| n_spatial=positional_bias_n_spatial, |
| cutoff_temporal=window, |
| n_temporal=window, |
| dropout=dropout, |
| mode=positional_bias, |
| attn_dist_mode=attn_dist_mode, |
| ) |
|
|
| self.mlp = FeedForward(d_model) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.norm3 = nn.LayerNorm(d_model) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| y: torch.Tensor, |
| coords: torch.Tensor, |
| padding_mask: torch.Tensor = None, |
| ): |
| x = self.norm1(x) |
| y = self.norm2(y) |
| |
| |
| a = self.attn( |
| x, |
| y, |
| y, |
| coords=coords if self.positional_bias else None, |
| padding_mask=padding_mask, |
| ) |
|
|
| x = x + a |
| x = x + self.mlp(self.norm3(x)) |
|
|
| return x |
|
|
|
|
|
|
| class TrackingTransformer(torch.nn.Module): |
| def __init__( |
| self, |
| coord_dim: int = 3, |
| feat_dim: int = 0, |
| d_model: int = 128, |
| nhead: int = 4, |
| num_encoder_layers: int = 4, |
| num_decoder_layers: int = 4, |
| dropout: float = 0.1, |
| pos_embed_per_dim: int = 32, |
| feat_embed_per_dim: int = 1, |
| window: int = 6, |
| spatial_pos_cutoff: int = 256, |
| attn_positional_bias: Literal["bias", "rope", "none"] = "rope", |
| attn_positional_bias_n_spatial: int = 16, |
| causal_norm: Literal[ |
| "none", "linear", "softmax", "quiet_softmax" |
| ] = "quiet_softmax", |
| attn_dist_mode: str = "v0", |
| ): |
| super().__init__() |
|
|
| self.config = dict( |
| coord_dim=coord_dim, |
| feat_dim=feat_dim, |
| pos_embed_per_dim=pos_embed_per_dim, |
| d_model=d_model, |
| nhead=nhead, |
| num_encoder_layers=num_encoder_layers, |
| num_decoder_layers=num_decoder_layers, |
| window=window, |
| dropout=dropout, |
| attn_positional_bias=attn_positional_bias, |
| attn_positional_bias_n_spatial=attn_positional_bias_n_spatial, |
| spatial_pos_cutoff=spatial_pos_cutoff, |
| feat_embed_per_dim=feat_embed_per_dim, |
| causal_norm=causal_norm, |
| attn_dist_mode=attn_dist_mode, |
| ) |
|
|
| |
| |
| |
| |
|
|
| self.proj = nn.Linear( |
| (1 + coord_dim) * pos_embed_per_dim + feat_dim * feat_embed_per_dim, d_model |
| ) |
| self.norm = nn.LayerNorm(d_model) |
|
|
| self.encoder = nn.ModuleList([ |
| EncoderLayer( |
| coord_dim, |
| d_model, |
| nhead, |
| dropout, |
| window=window, |
| cutoff_spatial=spatial_pos_cutoff, |
| positional_bias=attn_positional_bias, |
| positional_bias_n_spatial=attn_positional_bias_n_spatial, |
| attn_dist_mode=attn_dist_mode, |
| ) |
| for _ in range(num_encoder_layers) |
| ]) |
| self.decoder = nn.ModuleList([ |
| DecoderLayer( |
| coord_dim, |
| d_model, |
| nhead, |
| dropout, |
| window=window, |
| cutoff_spatial=spatial_pos_cutoff, |
| positional_bias=attn_positional_bias, |
| positional_bias_n_spatial=attn_positional_bias_n_spatial, |
| attn_dist_mode=attn_dist_mode, |
| ) |
| for _ in range(num_decoder_layers) |
| ]) |
|
|
| self.head_x = FeedForward(d_model) |
| self.head_y = FeedForward(d_model) |
|
|
| if feat_embed_per_dim > 1: |
| self.feat_embed = PositionalEncoding( |
| cutoffs=(1000,) * feat_dim, |
| n_pos=(feat_embed_per_dim,) * feat_dim, |
| cutoffs_start=(0.01,) * feat_dim, |
| ) |
| else: |
| self.feat_embed = nn.Identity() |
|
|
| self.pos_embed = PositionalEncoding( |
| cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim, |
| n_pos=(pos_embed_per_dim,) * (1 + coord_dim), |
| ) |
|
|
| |
|
|
| |
| def forward(self, coords, features=None, padding_mask=None, attn_feat=None): |
| assert coords.ndim == 3 and coords.shape[-1] in (3, 4) |
| _B, _N, _D = coords.shape |
|
|
| |
| if padding_mask is not None: |
| coords = coords.clone() |
| coords[padding_mask] = coords.max() |
|
|
| |
| min_time = coords[:, :, :1].min(dim=1, keepdims=True).values |
| coords = coords - min_time |
|
|
| pos = self.pos_embed(coords) |
|
|
| if features is None or features.numel() == 0: |
| features = pos |
| else: |
| features = self.feat_embed(features) |
| features = torch.cat((pos, features), axis=-1) |
|
|
| features = self.proj(features) |
| if attn_feat is not None: |
| |
| features = features + attn_feat |
|
|
| features = self.norm(features) |
|
|
| x = features |
|
|
| |
| for enc in self.encoder: |
| x = enc(x, coords=coords, padding_mask=padding_mask) |
|
|
| y = features |
| |
| for dec in self.decoder: |
| y = dec(y, x, coords=coords, padding_mask=padding_mask) |
| |
|
|
| x = self.head_x(x) |
| y = self.head_y(y) |
|
|
| |
| A = torch.einsum("bnd,bmd->bnm", x, y) |
|
|
| return A |
|
|
| def normalize_output( |
| self, |
| A: torch.FloatTensor, |
| timepoints: torch.LongTensor, |
| coords: torch.FloatTensor, |
| ) -> torch.FloatTensor: |
| """Apply (parental) softmax, or elementwise sigmoid. |
| |
| Args: |
| A: Tensor of shape B, N, N |
| timepoints: Tensor of shape B, N |
| coords: Tensor of shape B, N, (time + n_spatial) |
| """ |
| assert A.ndim == 3 |
| assert timepoints.ndim == 2 |
| assert coords.ndim == 3 |
| assert coords.shape[2] == 1 + self.config["coord_dim"] |
|
|
| |
| dist = torch.cdist(coords[:, :, 1:], coords[:, :, 1:]) |
| invalid = dist > self.config["spatial_pos_cutoff"] |
|
|
| if self.config["causal_norm"] == "none": |
| |
| A = torch.sigmoid(A) |
| A[invalid] = 0 |
| else: |
| return torch.stack([ |
| blockwise_causal_norm( |
| _A, _t, mode=self.config["causal_norm"], mask_invalid=_m |
| ) |
| for _A, _t, _m in zip(A, timepoints, invalid) |
| ]) |
| return A |
|
|
| def save(self, folder): |
| folder = Path(folder) |
| folder.mkdir(parents=True, exist_ok=True) |
| yaml.safe_dump(self.config, open(folder / "config.yaml", "w")) |
| torch.save(self.state_dict(), folder / "model.pt") |
|
|
| @classmethod |
| def from_folder( |
| cls, folder, map_location=None, checkpoint_path: str = "model.pt" |
| ): |
| folder = Path(folder) |
|
|
| config = yaml.load(open(folder / "config.yaml"), Loader=yaml.FullLoader) |
|
|
| model = cls(**config) |
|
|
| fpath = folder / checkpoint_path |
| logger.info(f"Loading model state from {fpath}") |
|
|
| state = torch.load(fpath, map_location=map_location, weights_only=True) |
| |
| if "state_dict" in state: |
| state = state["state_dict"] |
| state = OrderedDict( |
| (k[6:], v) for k, v in state.items() if k.startswith("model.") |
| ) |
| model.load_state_dict(state) |
|
|
| return model |
| |
| @classmethod |
| def from_cfg( |
| cls, cfg_path |
| ): |
|
|
| cfg_path = Path(cfg_path) |
|
|
| config = yaml.load(open(cfg_path), Loader=yaml.FullLoader) |
|
|
| model = cls(**config) |
|
|
| return model |
|
|