Instructions to use Blackroot/SimpleDiffusion-TensorProductAttentionRope with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use Blackroot/SimpleDiffusion-TensorProductAttentionRope with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline from diffusers.utils import load_image # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("Blackroot/SimpleDiffusion-TensorProductAttentionRope", dtype=torch.bfloat16, device_map="cuda") prompt = "Turn this cat into a dog" input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png") image = pipe(image=input_image, prompt=prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # Changelog since original version: | |
| # xATGLU instead of top linear in transformer block | |
| # Added a learned residual scale to all blocks and all residuals. This allowed bfloat16 training to stabilize, prior it was just exploding. | |
| # This architecture was my attempt at the following Simple Diffusion paper with some modifications: | |
| # https://arxiv.org/pdf/2410.19324v1 | |
| # Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn. | |
| class xATGLU(nn.Module): | |
| def __init__(self, input_dim, output_dim, bias=True): | |
| super().__init__() | |
| # GATE path | VALUE path | |
| self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias) | |
| nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear') | |
| self.alpha = nn.Parameter(torch.zeros(1)) | |
| self.half_pi = torch.pi / 2 | |
| self.inv_pi = 1 / torch.pi | |
| def forward(self, x): | |
| projected = self.proj(x) | |
| gate_path, value_path = projected.chunk(2, dim=-1) | |
| # Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768 | |
| gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi | |
| expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha | |
| return expanded_gate * value_path # g(x) × y | |
| # Tensor product attention, modified. Original code from: | |
| # https://github.com/tensorgi/T6/blob/main/model/T6_ropek.py | |
| # https://arxiv.org/pdf/2501.06425 | |
| class CPLinear(nn.Module): | |
| def __init__(self, in_features, n_head, head_dim, rank: int = 1, q_rank: int = 12): | |
| super(CPLinear, self).__init__() | |
| self.in_features = in_features | |
| self.n_head = n_head | |
| self.head_dim = head_dim | |
| self.rank = rank | |
| self.q_rank = q_rank | |
| self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False) | |
| self.W_A_k = nn.Linear(in_features, n_head * rank, bias=False) | |
| self.W_A_v = nn.Linear(in_features, n_head * rank, bias=False) | |
| nn.init.xavier_normal_(self.W_A_q.weight) | |
| nn.init.xavier_normal_(self.W_A_k.weight) | |
| nn.init.xavier_normal_(self.W_A_v.weight) | |
| self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False) | |
| self.W_B_k = nn.Linear(in_features, rank * head_dim, bias=False) | |
| self.W_B_v = nn.Linear(in_features, rank * head_dim, bias=False) | |
| nn.init.xavier_normal_(self.W_B_q.weight) | |
| nn.init.xavier_normal_(self.W_B_k.weight) | |
| nn.init.xavier_normal_(self.W_B_v.weight) | |
| def forward(self, x): | |
| batch_size, seq_len, _ = x.size() | |
| # A clarification on the naming, it's somewhat standard to call the two low rank matrices A and B, so I've followed that. | |
| # Compute intermediate variables A for Q, K, and V | |
| A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank) | |
| A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.rank) | |
| A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.rank) | |
| # Compute intermediate variables B for Q, K, and V | |
| B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim) | |
| B_k = self.W_B_k(x).view(batch_size, seq_len, self.rank, self.head_dim) | |
| B_v = self.W_B_v(x).view(batch_size, seq_len, self.rank, self.head_dim) | |
| # Reshape A_q, A_k, A_v | |
| A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank) | |
| A_k = A_k.view(batch_size * seq_len, self.n_head, self.rank) | |
| A_v = A_v.view(batch_size * seq_len, self.n_head, self.rank) | |
| # Reshape B_k, B_v | |
| B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim) | |
| B_k = B_k.view(batch_size * seq_len, self.rank, self.head_dim) | |
| B_v = B_v.view(batch_size * seq_len, self.rank, self.head_dim) | |
| q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim) | |
| k = torch.bmm(A_k, B_k).div_(self.rank).view(batch_size, seq_len, self.n_head, self.head_dim) | |
| v = torch.bmm(A_v, B_v).div_(self.rank).view(batch_size, seq_len, self.n_head, self.head_dim) | |
| return q, k, v | |
| # Very possible this is not a good method for positional encoding in DiT, in fact it may be actively harmful. It does help in small datasets though. | |
| # No positional embedding should be a serious consideration for high compute resources/large data scenarios. | |
| class Rotary(torch.nn.Module): | |
| def __init__(self, dim, base=10000): | |
| super().__init__() | |
| self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.seq_len_cached = None | |
| self.cos_cached = None | |
| self.sin_cached = None | |
| def forward(self, x): | |
| seq_len = x.shape[1] | |
| if seq_len != self.seq_len_cached: | |
| self.seq_len_cached = seq_len | |
| t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) | |
| freqs = torch.outer(t, self.inv_freq).to(x.device) | |
| self.cos_cached = freqs.cos().bfloat16() | |
| self.sin_cached = freqs.sin().bfloat16() | |
| return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] | |
| def apply_rotary_emb(x, cos, sin): | |
| assert x.ndim == 4 # multihead attention | |
| d = x.shape[3] // 2 | |
| x1 = x[..., :d] | |
| x2 = x[..., d:] | |
| y1 = x1 * cos + x2 * sin | |
| y2 = x1 * (-sin) + x2 * cos | |
| return torch.cat([y1, y2], 3).type_as(x) | |
| class TensorProductAttentionWithRope(nn.Module): | |
| def __init__(self, n_head, head_dim, n_embd, kv_rank=2, q_rank=6): | |
| super().__init__() | |
| self.n_head = n_head | |
| self.head_dim = head_dim | |
| self.n_embd = n_embd | |
| self.kv_rank = kv_rank | |
| self.q_rank = q_rank | |
| self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.kv_rank, self.q_rank) | |
| # Output projection. Bias seems sensible here, each head can learn a shift. | |
| self.o_proj = xATGLU(self.n_head * self.head_dim, self.n_embd, bias=True) | |
| # Not a layer, just a helper | |
| self.rotary = Rotary(self.head_dim) | |
| def forward(self, x): | |
| B, T, C = x.size() # batch_size, seq_length (T), embedding_dim | |
| # Get Q, K, V through CPLinear factorization | |
| q, k, v = self.c_qkv(x) # Each shape: (B, T, n_head, head_dim) | |
| cos, sin = self.rotary(q) | |
| q = apply_rotary_emb(q, cos, sin) | |
| k = apply_rotary_emb(k, cos, sin) | |
| # SDPA expects (B, n_head, T, head_dim) | |
| q = q.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim | |
| k = k.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim | |
| v = v.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim | |
| # Compute attention using scaled_dot_product_attention | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=False) | |
| # Back to B T C | |
| y = y.transpose(1, 2).flatten(2) | |
| y = self.o_proj(y) | |
| return y | |
| class ResBlock(nn.Module): | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) | |
| self.norm1 = nn.GroupNorm(32, channels) | |
| self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) | |
| self.norm2 = nn.GroupNorm(32, channels) | |
| self.learned_residual_scale = nn.Parameter(torch.ones(1) * 0.1) | |
| def forward(self, x): | |
| h = self.conv1(F.silu(self.norm1(x))) | |
| h = self.conv2(F.silu(self.norm2(h))) | |
| return x + h * self.learned_residual_scale | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, channels, num_heads=8): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(channels) | |
| self.norm2 = nn.LayerNorm(channels) | |
| # Params recommended by TPA paper, seem to work fine. | |
| self.attn = TensorProductAttentionWithRope( | |
| n_head=num_heads, | |
| head_dim=channels // num_heads, | |
| n_embd=channels, | |
| kv_rank=2, | |
| q_rank=6 | |
| ) | |
| self.mlp = nn.Sequential( | |
| xATGLU(channels, 2 * channels, bias=False), | |
| nn.Linear(2 * channels, channels, bias=False) # Candidate for a bias | |
| ) | |
| self.learned_residual_scale_attn = nn.Parameter(torch.ones(1) * 0.1) | |
| self.learned_residual_scale_mlp = nn.Parameter(torch.ones(1) * 0.1) | |
| def forward(self, x): | |
| # Input shape B C H W | |
| b, c, h, w = x.shape | |
| x = x.reshape(b, h * w, c) # [B, H*W, C] | |
| # Pre-norm architecture, this was really helpful for network stability when using bf16 | |
| identity = x | |
| x = self.norm1(x) | |
| h_attn = self.attn(x) | |
| #h_attn, _ = self.attn(x, x, x) | |
| x = identity + h_attn * self.learned_residual_scale_attn | |
| identity = x | |
| x = self.norm2(x) | |
| h_mlp = self.mlp(x) | |
| x = identity + h_mlp * self.learned_residual_scale_mlp | |
| # Reshape back to B C H W | |
| x = x.permute(1, 2, 0).reshape(b, c, h, w) | |
| return x | |
| class LevelBlock(nn.Module): | |
| def __init__(self, channels, num_blocks, block_type='res'): | |
| super().__init__() | |
| self.blocks = nn.ModuleList() | |
| for _ in range(num_blocks): | |
| if block_type == 'transformer': | |
| self.blocks.append(TransformerBlock(channels)) | |
| else: | |
| self.blocks.append(ResBlock(channels)) | |
| def forward(self, x): | |
| for block in self.blocks: | |
| x = block(x) | |
| return x | |
| class AsymmetricResidualUDiT(nn.Module): | |
| def __init__(self, | |
| in_channels=3, # Input color channels | |
| base_channels=128, # Initial feature size, dramatically increases parameter size of network. | |
| patch_size=2, # Smaller patches dramatically increases flops and compute expenses. Recommend >=4 unless you have real compute. | |
| num_levels=3, # Feature downsample, essentially the unet depth -- so we down/upsample three times. Dramatically increases parameters as you increase. | |
| encoder_blocks=3, # Can be different number of blocks VS decoder_blocks | |
| decoder_blocks=7, # Can be different number of blocks VS encoder_blocks | |
| encoder_transformer_thresh=2, #When to start using transformer blocks instead of res blocks in the encoder. (>=) | |
| decoder_transformer_thresh=4, #When to stop using transformer blocks instead of res blocks in the decoder. (<=) | |
| mid_blocks=16, # Number of middle transformer blocks. Relatively cheap as this is at the bottom of the unet feature bottleneck. | |
| ): | |
| super().__init__() | |
| self.learned_middle_residual_scale = nn.Parameter(torch.ones(1) * 0.1) | |
| # Initial projection from image space | |
| self.patch_embed = nn.Conv2d(in_channels, base_channels, | |
| kernel_size=patch_size, stride=patch_size) | |
| self.encoders = nn.ModuleList() | |
| curr_channels = base_channels | |
| for level in range(num_levels): | |
| use_transformer = level >= encoder_transformer_thresh # Use transformers for latter levels | |
| # Encoder blocks -- N = encoder_blocks | |
| self.encoders.append( | |
| LevelBlock(curr_channels, encoder_blocks, use_transformer) | |
| ) | |
| # Each successive decoder halves the size of the feature space for each step, except for the last level. | |
| if level < num_levels - 1: | |
| self.encoders.append( | |
| nn.Conv2d(curr_channels, curr_channels * 2, 1) | |
| ) | |
| curr_channels *= 2 | |
| # Middle transformer blocks -- N = mid_blocks | |
| self.middle = nn.ModuleList([ | |
| TransformerBlock(curr_channels) for _ in range(mid_blocks) | |
| ]) | |
| # Create decoder levels | |
| self.decoders = nn.ModuleList() | |
| for level in range(num_levels): | |
| use_transformer = level <= decoder_transformer_thresh # Use transformers for early levels (inverse of encoder) | |
| # Decoder blocks -- N = decoder_blocks | |
| self.decoders.append( | |
| LevelBlock(curr_channels, decoder_blocks, use_transformer) | |
| ) | |
| # Each successive decoder halves the size of the feature space for each step, except for the last level. | |
| if level < num_levels - 1: | |
| self.decoders.append( | |
| nn.Conv2d(curr_channels, curr_channels // 2, 1) | |
| ) | |
| curr_channels //= 2 | |
| # Final projection back to image space | |
| self.final_proj = nn.ConvTranspose2d(base_channels, in_channels, | |
| kernel_size=patch_size, stride=patch_size) | |
| def downsample(self, x): | |
| return F.avg_pool2d(x, kernel_size=2) | |
| def upsample(self, x): | |
| return F.interpolate(x, scale_factor=2, mode='nearest') | |
| def forward(self, x, t=None): | |
| # x shape B C H W | |
| # This patchifies our input, for example given an input shape like: | |
| # From 2, 3, 256, 256 | |
| x = self.patch_embed(x) | |
| # Our shape is now more channels and with smaller W and H | |
| # To 2, 128, 64, 64 | |
| # *Per resolution e.g. per num_level resolution block more or less | |
| # f(x) = fu( U(fm(D(h)) - D(h)) + h ) where h = fd(x) | |
| # | |
| # Where | |
| # 1. h = fd(x) : Encoder path processes input | |
| # 2. D(h) : Downsample the encoded features | |
| # 3. fm(D(h)) : Middle transformer blocks process downsampled features | |
| # 4. fm(D(h))-D(h): Subtract original downsampled features (residual connection) | |
| # 5. U(...) : Upsample the processed features | |
| # 6. ... + h : Add back original encoder features (skip connection) | |
| # 7. fu(...) : Decoder path processes the combined features | |
| residuals = [] | |
| curr_res = x | |
| # Encoder path (computing h = fd(x)) | |
| h = x | |
| for i, blocks in enumerate(self.encoders): | |
| if isinstance(blocks, LevelBlock): | |
| h = blocks(h) | |
| else: | |
| # Save residual before downsampling | |
| residuals.append(curr_res) | |
| # Downsample and update current residual | |
| h = self.downsample(blocks(h)) | |
| curr_res = h | |
| # Middle blocks (fm) | |
| x = h | |
| for block in self.middle: | |
| x = block(x) | |
| # Subtract the residual at this level (D(h)) | |
| x = x - curr_res * self.learned_middle_residual_scale | |
| # Decoder path (fu) | |
| for i, blocks in enumerate(self.decoders): | |
| if isinstance(blocks, LevelBlock): | |
| x = blocks(x) | |
| else: | |
| # Channel reduction | |
| x = blocks(x) | |
| # Upsample | |
| x = self.upsample(x) | |
| # Add residual from encoder at this level, LIFO, last residual added is the first we want, since it's this u-shape. | |
| curr_res = residuals.pop() | |
| x = x + curr_res * self.learned_middle_residual_scale | |
| # Final projection | |
| x = self.final_proj(x) | |
| return x |