| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import numpy as np |
| | from kernels import get_kernel |
| |
|
| | flex = get_kernel("t-tech/flex-sae") |
| |
|
| | @torch.compile(fullgraph=True) |
| | def hierarchical_sae_loss( |
| | indices: torch.Tensor, |
| | weight: torch.Tensor, |
| | vals: torch.Tensor, |
| | bias: torch.Tensor, |
| | target: torch.Tensor, |
| | ) -> torch.Tensor: |
| | emb = weight[indices].to(torch.float32) |
| | recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1) |
| | diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1) |
| | loss = diff.pow(2).mean() |
| | return loss |
| |
|
| |
|
| | B = 2048 |
| | K = 256 |
| | F = 1024 * 128 |
| | D = 1024 |
| | WARMUP = 5 |
| | NUM_ITER = 100 |
| | dtype = torch.float32 |
| |
|
| | vals = None |
| | decoder = None |
| | bias = None |
| | target = None |
| | indices = None |
| |
|
| |
|
| | def init_parameters(): |
| | global vals, decoder, bias, target, indices |
| | vals = torch.randn(B, K, dtype=dtype, device="cuda").abs().requires_grad_() |
| | decoder = torch.randn(F, D, dtype=dtype, device="cuda", requires_grad=True) |
| | bias = torch.randn(D, dtype=dtype, device="cuda", requires_grad=True) |
| | target = torch.randn(B, D, dtype=dtype, device="cuda") |
| | indices = torch.randint(0, F, (B, K), dtype=torch.long, device="cuda") |
| |
|
| |
|
| | timing_kernel = [] |
| | timing_vanilla = [] |
| | torch.cuda.reset_peak_memory_stats() |
| | loss_kernel_list = torch.zeros((100,)) |
| | loss_vanilla_list = torch.zeros((100,)) |
| |
|
| |
|
| | def zero_grad(): |
| | vals.grad = None |
| | decoder.grad = None |
| | bias.grad = None |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | for i in range(NUM_ITER + WARMUP): |
| | init_parameters() |
| | start_kernel = torch.cuda.Event(enable_timing=True) |
| | end_kernel = torch.cuda.Event(enable_timing=True) |
| | start_vanilla = torch.cuda.Event(enable_timing=True) |
| | end_vanilla = torch.cuda.Event(enable_timing=True) |
| |
|
| | start_kernel.record() |
| | loss_kernel = flex.triton_hierarchical_sae_loss(indices, decoder, vals, bias, target) |
| | loss_kernel.backward() |
| | end_kernel.record() |
| |
|
| | zero_grad() |
| | start_vanilla.record() |
| | loss_vanilla = hierarchical_sae_loss(indices, decoder, vals, bias, target) |
| | loss_vanilla.backward() |
| | end_vanilla.record() |
| | if i >= WARMUP: |
| | torch.cuda.synchronize() |
| | timing_kernel.append(start_kernel.elapsed_time(end_kernel)) |
| | timing_vanilla.append(start_vanilla.elapsed_time(end_vanilla)) |
| | loss_kernel_list[i-WARMUP] = loss_kernel.detach() |
| | loss_vanilla_list[i-WARMUP] = loss_vanilla.detach() |
| | zero_grad() |
| |
|
| | if torch.allclose(loss_kernel, loss_vanilla): |
| | print("β
Outputs are close! Everything is good! π") |
| | else: |
| | print("β Outputs mismatch... β οΈπ€") |
| |
|
| |
|
| | print(f"π¦ Triton Kernel Time (Ours): {np.mean(timing_kernel):.4f} Β± {np.std(timing_kernel):.4f} ms") |
| | print(f"π₯ Torch Compile Kernel Time: {np.mean(timing_vanilla):.4f} Β± {np.std(timing_vanilla):.4f} ms") |
| | print(f"π Speedup: {np.mean(timing_vanilla) / np.mean(timing_kernel):.2f}x") |