Text Generation
Transformers
Safetensors
English
cloverlm
causal-lm
quartet-ii
nvfp4
low-precision-training
pretrained
custom_code
Instructions to use daslab-testing/CloverLM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use daslab-testing/CloverLM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="daslab-testing/CloverLM", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("daslab-testing/CloverLM", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use daslab-testing/CloverLM with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "daslab-testing/CloverLM" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "daslab-testing/CloverLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/daslab-testing/CloverLM
- SGLang
How to use daslab-testing/CloverLM with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "daslab-testing/CloverLM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "daslab-testing/CloverLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "daslab-testing/CloverLM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "daslab-testing/CloverLM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use daslab-testing/CloverLM with Docker Model Runner:
docker model run hf.co/daslab-testing/CloverLM
| import torch | |
| from . import exp_mlp as mlp | |
| from math import sqrt | |
| import math | |
| SCALE_TYPES = ["1/sqrt(d)", "1/d"] | |
| POS_TYPES = ["learned", "sinusoidal", "rope", "alibi"] | |
| BACKENDS = ["pytorch", "flash2", "flash3", "flash4", "flex", "cudnn"] | |
| NORM_TYPES = ["layer", "rms_learned", "rms_const", "sphere"] | |
| def get_causal(context): | |
| causal = torch.full((context,context), True) | |
| causal = causal.tril() | |
| return causal | |
| def get_sinusoidal(context, d, base=1024): | |
| # [pos=0, pos=1, ...] | |
| poss = torch.arange(0., context) | |
| # [i=0, i=1, ...] | |
| js = torch.arange(0., d//2) | |
| # [ω0, ω1, ...] | |
| ωs = 1/base**(2*js/d) | |
| # [pos=0*ω0, pos=0*ω1, ...] | |
| # [pos=1*ω0, pos=1*ω1, ...] | |
| φs = poss[...,None] @ ωs[None,...] | |
| # context*d | |
| sinusoidal = torch.empty((context, d)) | |
| sinusoidal[:,0::2] = torch.sin(φs) | |
| sinusoidal[:,1::2] = torch.cos(φs) | |
| return sinusoidal | |
| def get_rope(context, d, *, device, base=1024): | |
| # [m=0, m=1, ...] | |
| ms = torch.arange(0., context, device=device, dtype=torch.float32) | |
| # [i=0, i=1, ...] | |
| js = torch.arange(0., d//2, device=device, dtype=torch.float32) | |
| # [θ0, θ1, ...] | |
| θs = 1/base**(2*js/d) | |
| # [m=0*θ0, m=0*θ1, ...] | |
| # [m=1*θ0, m=1*θ1, ...] | |
| φs = ms[...,None] @ θs[None,...] | |
| # context*d/2 | |
| cos = torch.cos(φs) | |
| sin = torch.sin(φs) | |
| # context*d | |
| cos = cos.repeat_interleave(repeats=2, dim=1) | |
| sin = sin.repeat_interleave(repeats=2, dim=1) | |
| # 2*context*d | |
| rope = torch.stack((cos,sin)) | |
| return rope | |
| # (batches*)context*d | |
| def apply_rope(X, rope): | |
| X_ = torch.empty_like(X) | |
| X_[...,0::2] = -X[...,1::2] | |
| X_[...,1::2] = X[...,0::2] | |
| # context*d | |
| cos = rope[0] | |
| sin = rope[1] | |
| Y = X*cos + X_*sin | |
| return Y.to(X.dtype) | |
| def get_m(heads, base=2, exp=8): | |
| m = base**( (-exp/heads)*torch.arange(1,heads+1) ) | |
| return m | |
| def get_alibi(heads, context): | |
| # 1*context*1 | |
| i = torch.arange(0, context)[None,:,None] | |
| # 1*1*context | |
| j = i.mT | |
| # heads*1*1 | |
| m = get_m(heads)[:,None,None] | |
| alibi = -torch.abs(i - j)*m | |
| return alibi | |
| def get_swa(context, window): | |
| # context*1 | |
| i = torch.arange(0, context).unsqueeze(-1) | |
| # 1*context | |
| j = i.T | |
| swa = torch.abs(i - j) <= window | |
| return swa | |
| # (batches*)heads/groups*context*d_head | |
| def sdpa_pytorch(Q, K, V, causal=None, alibi=None, swa=None, scale=None, return_A=False): | |
| if scale is None: | |
| d_head = Q.shape[-1] | |
| scale = 1/sqrt(d_head) | |
| # (batches*)heads*context*d_head | |
| heads = Q.shape[-3] | |
| groups = K.shape[-3] | |
| ratio = heads//groups | |
| # PyTorch only broadcasts when the operation is not defined otherwise. MM does not involve the batch dimensions, and hence PyTorch does not broadcast them. | |
| K = K.repeat_interleave(repeats=ratio, dim=-3) | |
| V = V.repeat_interleave(repeats=ratio, dim=-3) | |
| # (batches*)heads*context*context | |
| A__ = Q @ K.mT | |
| # batches*heads*context*context | |
| A_ = scale*A__ | |
| # (batches*)heads*context*context | |
| A_ = A_.reshape(A__.shape) | |
| if alibi is not None: | |
| A_ = A_ + alibi | |
| if causal is not None: | |
| A_.masked_fill_(~causal, -float("inf")) | |
| if swa is not None: | |
| A_.masked_fill_(~swa, -float("inf")) | |
| A = torch.softmax(A_, dim=-1) | |
| # (batches*)heads*context*d_head | |
| Y = A @ V | |
| if not return_A: | |
| return Y | |
| else: | |
| return Y, A__, A_, A | |
| # (batches*)heads/groups*context*d_head | |
| def sdpa_flash(Q, K, V, causal=False, alibi=None, swa=None, scale=None, backend="flash2"): | |
| if (alibi is not None) and backend != "flash2": | |
| print("\x1b[93;3m[WARNING]: backend={backend} does not support ALiBi. Hence, we force backend=flash2.\x1b[0m") | |
| backend = "flash2" | |
| # FlashAttention only supports float scale | |
| if isinstance(scale, torch.Tensor): | |
| Q_shape = Q.shape | |
| # batches*heads*context*d_head | |
| Q = scale*Q | |
| # (batches*)heads*context*d_head | |
| Q = Q.reshape(Q_shape) | |
| scale = 1 | |
| # FlashAttention2 only supports BF16 and FP16 | |
| if Q.dtype in [torch.bfloat16, torch.float16]: | |
| dtype = Q.dtype | |
| else: | |
| dtype = torch.bfloat16 | |
| heads = Q.shape[-3] | |
| groups = K.shape[-3] | |
| context = Q.shape[-2] | |
| d_head = Q.shape[-1] | |
| # CAUTION: FlashAttention expects batches*context*heads/groups*d_head | |
| Q = Q.movedim(-3,-2).reshape(-1,context,heads,d_head) | |
| K = K.movedim(-3,-2).reshape(-1,context,groups,d_head) | |
| V = V.movedim(-3,-2).reshape(-1,context,groups,d_head) | |
| if swa is None: | |
| swa = (-1,-1) | |
| if backend=="flash2": | |
| import flash_attn | |
| Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, alibi_slopes=alibi, window_size=swa, softmax_scale=scale) | |
| elif backend=="flash3": | |
| import flash_attn_interface | |
| Y = flash_attn_interface.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, window_size=swa, softmax_scale=scale) | |
| elif backend=="flash4": | |
| import flash_attn.cute | |
| # FlashAttention4 returns (out, lse) | |
| Y = flash_attn.cute.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, window_size=swa, softmax_scale=scale)[0] | |
| Y = Y.to(Q.dtype) | |
| # Restore the shape to: (batches*)heads*context*d_head | |
| Y = Y.movedim(-3,-2).squeeze(0) | |
| return Y | |
| # (batches*)heads/groups*context*d_head | |
| def sdpa_flex(): | |
| return None | |
| # (batches*)heads/groups*context*d_head | |
| def sdpa_cudnn(): | |
| return None | |
| def sdpa_wrapper(Q, K, V, causal=None, alibi=None, swa=None, scale=None, return_A=False, backend="flash2"): | |
| if backend=="pytorch": | |
| return sdpa_pytorch(Q, K, V, causal, alibi, swa, scale, return_A) | |
| elif backend in {"flash2", "flash3", "flash4"}: | |
| return sdpa_flash(Q, K, V, causal, alibi, swa, scale, backend) | |
| elif backend=="flex": | |
| return sdpa_flex() | |
| elif backend=="cudnn": | |
| return sdpa_cudnn() | |
| def test_sdpa(): | |
| batches = 32 | |
| heads = 12 | |
| context = 1024 | |
| d_head = 64 | |
| window = 256 | |
| groups = 4 | |
| dtype = torch.bfloat16 | |
| print("\x1b[1mbfloat16\x1b[0m",end="") | |
| Q = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype) | |
| K = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype) | |
| V = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype) | |
| pytorch = sdpa_wrapper(Q, K, V, backend="pytorch") | |
| flash2 = sdpa_wrapper(Q, K, V, backend="flash2") | |
| torch.testing.assert_close(flash2, pytorch, check_dtype=False) | |
| flash3 = sdpa_wrapper(Q, K, V, backend="flash3") | |
| torch.testing.assert_close(flash3, pytorch, check_dtype=False) | |
| flash4 = sdpa_wrapper(Q, K, V, backend="flash4") | |
| torch.testing.assert_close(flash4, pytorch, check_dtype=False) | |
| print("\x1b[32m ✔\x1b[0m") | |
| print("\x1b[1mcausal\x1b[0m",end="") | |
| pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), backend="pytorch") | |
| flash2 = sdpa_wrapper(Q, K, V, causal=True, backend="flash2") | |
| torch.testing.assert_close(flash2, pytorch, check_dtype=False) | |
| flash3 = sdpa_wrapper(Q, K, V, causal=True, backend="flash3") | |
| torch.testing.assert_close(flash3, pytorch, check_dtype=False) | |
| flash4 = sdpa_wrapper(Q, K, V, causal=True, backend="flash4") | |
| torch.testing.assert_close(flash4, pytorch, check_dtype=False) | |
| print("\x1b[32m ✔\x1b[0m") | |
| print("\x1b[1malibi\x1b[0m",end="") | |
| pytorch = sdpa_wrapper(Q, K, V, alibi=get_alibi(heads,context).to("cuda:0",dtype), backend="pytorch") | |
| flash2 = sdpa_wrapper(Q, K, V, alibi=get_m(heads).to("cuda:0"), backend="flash2") | |
| torch.testing.assert_close(flash2, pytorch, check_dtype=False) | |
| # ALiBi not supported on FlashAttention3/4 | |
| print("\x1b[32m ✔\x1b[0m") | |
| print("\x1b[1mswa\x1b[0m",end="") | |
| pytorch = sdpa_wrapper(Q, K, V, swa=get_swa(context,window).to("cuda:0"), backend="pytorch") | |
| flash2 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash2") | |
| torch.testing.assert_close(flash2, pytorch, check_dtype=False) | |
| flash3 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash3") | |
| torch.testing.assert_close(flash3, pytorch, check_dtype=False) | |
| flash4 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash4") | |
| torch.testing.assert_close(flash4, pytorch, check_dtype=False) | |
| print("\x1b[32m ✔\x1b[0m") | |
| print("\x1b[1mcausal+alibi\x1b[0m",end="") | |
| pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), alibi=get_alibi(heads,context).to("cuda:0",dtype), backend="pytorch") | |
| flash2 = sdpa_wrapper(Q, K, V, causal=True, alibi=get_m(heads).to("cuda:0"), backend="flash2") | |
| torch.testing.assert_close(flash2, pytorch, check_dtype=False) | |
| # ALiBi not supported on FlashAttention3/4 | |
| print("\x1b[32m ✔\x1b[0m") | |
| print("\x1b[1mcausal+swa\x1b[0m",end="") | |
| pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), swa=get_swa(context,window).to("cuda:0"), backend="pytorch") | |
| flash2 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash2") | |
| torch.testing.assert_close(flash2, pytorch, check_dtype=False) | |
| flash3 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash3") | |
| torch.testing.assert_close(flash3, pytorch, check_dtype=False) | |
| flash4 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash4") | |
| torch.testing.assert_close(flash4, pytorch, check_dtype=False) | |
| print("\x1b[32m ✔\x1b[0m") | |
| print("\x1b[1mGQA\x1b[0m",end="") | |
| Q = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype) | |
| K = torch.rand((batches, groups, context, d_head)).to("cuda:0", dtype) | |
| V = torch.rand((batches, groups, context, d_head)).to("cuda:0", dtype) | |
| pytorch = sdpa_wrapper(Q, K, V, backend="pytorch") | |
| flash2 = sdpa_wrapper(Q, K, V, backend="flash2") | |
| torch.testing.assert_close(flash2, pytorch, check_dtype=False) | |
| flash3 = sdpa_wrapper(Q, K, V, backend="flash3") | |
| torch.testing.assert_close(flash3, pytorch, check_dtype=False) | |
| flash4 = sdpa_wrapper(Q, K, V, backend="flash4") | |
| torch.testing.assert_close(flash4, pytorch, check_dtype=False) | |
| print("\x1b[32m ✔\x1b[0m") | |
| class MHSA(torch.nn.Module): | |
| def __init__(self, heads, d_head, scale_type="1/sqrt(d)", ratio=1, qk_norm=True, quartet=True, fake_quartet=False): | |
| super().__init__() | |
| self.heads = heads | |
| self.d_head = d_head | |
| self.d = heads * d_head | |
| self.scale_type = scale_type | |
| self.ratio = ratio | |
| self.groups = heads//ratio | |
| self.d_KV = self.groups * d_head | |
| self.qk_norm = qk_norm | |
| if qk_norm: | |
| # (batches*)heads*context*d_head | |
| scale = torch.full((1,heads,1,1), sqrt(d_head)) | |
| self.scale = torch.nn.Parameter(scale) | |
| else: | |
| if scale_type=="1/sqrt(d)": | |
| self.scale = 1/sqrt(d_head) | |
| elif scale_type=="1/d": | |
| self.scale = 1/d_head | |
| self.quartet = quartet | |
| self.fake_quartet = fake_quartet | |
| # Packing QKV gives negligible speed gains, while not allowing GQA, hurting code clarity and having side effects with μP | |
| if quartet: | |
| pass # quartet2 not available in HF mode | |
| self.lq = quartet2.linear.Quartet_II_linear(self.d, self.d, bias=False) | |
| self.lk = quartet2.linear.Quartet_II_linear(self.d, self.d_KV, bias=False) | |
| self.lv = quartet2.linear.Quartet_II_linear(self.d, self.d_KV, bias=False) | |
| self.lo = quartet2.linear.Quartet_II_linear(self.d, self.d, bias=False) | |
| elif fake_quartet: | |
| from . import fake_quartet as fq | |
| self.lq = fq.FakeQuartetLinear(self.d, self.d, bias=False) | |
| self.lk = fq.FakeQuartetLinear(self.d, self.d_KV, bias=False) | |
| self.lv = fq.FakeQuartetLinear(self.d, self.d_KV, bias=False) | |
| self.lo = fq.FakeQuartetLinear(self.d, self.d, bias=False) | |
| else: | |
| self.lq = torch.nn.Linear(self.d, self.d, bias=False) | |
| self.lk = torch.nn.Linear(self.d, self.d_KV, bias=False) | |
| self.lv = torch.nn.Linear(self.d, self.d_KV, bias=False) | |
| self.lo = torch.nn.Linear(self.d, self.d, bias=False) | |
| # (batches*)context*d | |
| def forward(self, X, causal=None, rope=None, alibi=None, swa=None, return_A=False, backend="flash2"): | |
| # (batches*)context*d | |
| Q = self.lq(X) | |
| # (batches*)context*d_KV | |
| K = self.lk(X) | |
| V = self.lv(X) | |
| # (batches*)context*heads*d_head | |
| Q = Q.unflatten(dim=-1, sizes=(self.heads, self.d_head)) | |
| # (batches*)context*groups*d_head | |
| K = K.unflatten(dim=-1, sizes=(self.groups, self.d_head)) | |
| V = V.unflatten(dim=-1, sizes=(self.groups, self.d_head)) | |
| # (batches*)heads*context*d_head | |
| Q = Q.movedim(-3,-2) | |
| # (batches*)groups*context*d_head | |
| K = K.movedim(-3,-2) | |
| V = V.movedim(-3,-2) | |
| if rope is not None: | |
| Q = apply_rope(Q,rope) | |
| K = apply_rope(K,rope) | |
| # After RoPE | |
| if self.qk_norm: | |
| Q = mlp.sphere_norm(Q) | |
| K = mlp.sphere_norm(K) | |
| # (batches*)heads*context*d_head | |
| if not return_A: | |
| Y = sdpa_wrapper(Q, K, V, causal, alibi, swa, self.scale, return_A, backend) | |
| else: | |
| Y, A__, A_, A = sdpa_wrapper(Q, K, V, causal, alibi, swa, self.scale, return_A, backend) | |
| # (batches*)context*heads*d_head | |
| Y = Y.movedim(-3,-2) | |
| # (batches*)context*d | |
| Y = Y.flatten(-2,-1) | |
| Y = self.lo(Y) | |
| if not return_A: | |
| return Y | |
| else: | |
| return Y, A__, A_, A | |
| class Block(torch.nn.Module): | |
| def __init__(self, heads, d_head, scale_type="1/sqrt(d)", ratio=1, exp_factor=4, dropout=0, norm_type="rms_learned", bias=False, act=mlp.ReLU2(), l1_type="linear", pre_att_norm=False, qk_norm=True, out_att_norm=True, pre_mlp_norm=False, act_norm=False, out_mlp_norm=True, quartet=True, fake_quartet=False): | |
| super().__init__() | |
| self.heads = heads | |
| self.d_head = d_head | |
| self.d = heads * d_head | |
| self.scale_type = scale_type | |
| self.ratio = ratio | |
| self.groups = heads//ratio | |
| self.exp_factor = exp_factor | |
| self.d_hidden = int(exp_factor*self.d) | |
| self.dropout = dropout | |
| self.norm_type = norm_type | |
| self.bias = bias | |
| self.act = act | |
| self.l1_type = l1_type | |
| self.mhsa = MHSA(heads, d_head, scale_type, ratio, qk_norm, quartet, fake_quartet) | |
| self.pre_att_norm = mlp.get_norm(pre_att_norm, norm_type, self.d, bias) | |
| self.out_att_norm = mlp.get_norm(out_att_norm, norm_type, self.d, bias) | |
| self.mlp = mlp.MLP2L(self.d, self.d_hidden, self.d, bias, act, dropout, l1_type, norm_type, act_norm, quartet, fake_quartet) | |
| self.pre_mlp_norm = mlp.get_norm(pre_mlp_norm, norm_type, self.d, bias) | |
| self.out_mlp_norm = mlp.get_norm(out_mlp_norm, norm_type, self.d, bias) | |
| self.quartet = quartet | |
| self.fake_quartet = fake_quartet | |
| def forward(self, X, causal=None, rope=None, alibi=None, swa=None, return_res=False, return_A=False, backend="flash2"): | |
| mhsa = self.mhsa(self.pre_att_norm(X) if self.pre_att_norm else X, causal, rope, alibi, swa, return_A, backend) | |
| if not return_A: | |
| Y = mhsa | |
| else: | |
| Y, A__, A_, A = mhsa | |
| if self.out_att_norm: Y = self.out_att_norm(Y) | |
| Y_ = torch.nn.functional.dropout(Y, p=self.dropout, training=self.training) | |
| Y__ = X + Y_ | |
| Z = self.mlp(self.pre_mlp_norm(Y__) if self.pre_mlp_norm else Y__) | |
| if self.out_mlp_norm: Z = self.out_mlp_norm(Z) | |
| Z_ = torch.nn.functional.dropout(Z, p=self.dropout, training=self.training) | |
| Z__ = Y__ + Z_ | |
| if not return_res: | |
| if not return_A: | |
| return Z__ | |
| else: | |
| return Z__, A__, A_, A | |
| else: | |
| if not return_A: | |
| return Z__, Y__ | |
| else: | |
| return Z__, Y__, A__, A_, A | |
| class Transformer(torch.nn.Module): | |
| def __init__(self, vocab_size=50304, num_blocks=12, heads=12, d_head=64, scale_type="1/sqrt(d)", ratio=1, is_causal=True, window=None, backend="flash2", exp_factor=4, dropout=0, pos_type="rope", max_context=128, norm_type="rms_learned", bias=False, act=mlp.ReLU2(), l1_type="linear", std=0.02, test=False, weight_tying=True, emb_norm=False, pre_att_norm=False, qk_norm=True, out_att_norm=True, pre_mlp_norm=False, act_norm=False, out_mlp_norm=True, out_norm=True, fix_norm=False, quartet=True, fake_quartet=False): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.num_blocks = num_blocks | |
| self.heads = heads | |
| self.d_head = d_head | |
| self.d = heads * d_head | |
| self.scale_type = scale_type | |
| self.ratio = ratio | |
| self.groups = heads//ratio | |
| self.is_causal = is_causal | |
| self.window = window | |
| self.backend = backend | |
| self.exp_factor = exp_factor | |
| self.dropout = dropout | |
| self.pos_type = pos_type | |
| self.max_context = max_context | |
| self.norm_type = norm_type | |
| self.bias = bias | |
| self.act = act | |
| self.l1_type = l1_type | |
| self.weight_tying = weight_tying | |
| self.fix_norm = fix_norm | |
| self.quartet = quartet | |
| self.fake_quartet = fake_quartet | |
| self.emb = torch.nn.Embedding(vocab_size, self.d) | |
| self.emb_norm = mlp.get_norm(emb_norm, norm_type, self.d, bias) | |
| if pos_type == "learned": | |
| pos = torch.rand((max_context, self.d)) | |
| self.pos = torch.nn.Parameter(pos) | |
| self.blocks = torch.nn.Sequential(*[Block(heads, d_head, scale_type, ratio, exp_factor, dropout, norm_type, bias, act, l1_type, pre_att_norm, qk_norm, out_att_norm, pre_mlp_norm, act_norm, out_mlp_norm, quartet, fake_quartet) for _ in range(num_blocks)]) | |
| self.out_norm = mlp.get_norm(out_norm, norm_type, self.d, bias) | |
| self.linear = torch.nn.Linear(self.d, vocab_size, bias=False) | |
| if weight_tying: self.emb.weight = self.linear.weight | |
| self.init(std, test) | |
| if fake_quartet: | |
| for m in self.modules(): | |
| if isinstance(m, (torch.nn.LayerNorm, torch.nn.RMSNorm, torch.nn.Embedding)): | |
| m.to(torch.bfloat16) | |
| def init(self, std=0.02, test=False): | |
| if test: print("\x1b[1m%36.36s %8.8s %8.8s %8.8s\x1b[0m" % ("parameter_name", "suffix", "mean", "std")) | |
| for parameter_name, parameter in self.named_parameters(): | |
| parent_name, _, suffix = parameter_name.rpartition(".") | |
| parent = self.get_submodule(parent_name) | |
| if isinstance(parent, (torch.nn.Linear, torch.nn.Embedding)) and suffix=="weight": | |
| torch.nn.init.normal_(parameter, 0, std) | |
| elif isinstance(parent, (torch.nn.Linear, torch.nn.LayerNorm)) and suffix=="bias": | |
| torch.nn.init.zeros_(parameter) | |
| elif isinstance(parent, (torch.nn.LayerNorm, torch.nn.RMSNorm)) and suffix=="weight": | |
| torch.nn.init.ones_(parameter) | |
| else: | |
| # pos | |
| if parameter.ndim == 2: | |
| torch.nn.init.zeros_(parameter) | |
| # scale | |
| elif parameter.ndim == 4: | |
| torch.nn.init.constant_(parameter, sqrt(self.d_head)) | |
| if test: | |
| print("%36.36s %8.8s %8.8s %8.8s\x1b[0m" % (parameter_name, suffix, "%f" % parameter.mean(), "%f" % parameter.std())) | |
| # (batches*)context | |
| def forward(self, ids, return_res=False, return_A=False): | |
| context = ids.shape[-1] | |
| if return_A: | |
| # (batches*)num_blocks*heads*context*context | |
| A__ = torch.empty(*ids.shape[:-1], self.num_blocks, self.heads, context, context) | |
| A_ = torch.empty_like(A__) | |
| A = torch.empty_like(A__) | |
| # (batches*)context*d | |
| X = self.emb(ids) | |
| if return_res: | |
| res_in = X | |
| # (batches*)num_blocks*context*d | |
| res_att = torch.empty(*ids.shape[:-1], self.num_blocks, context, self.d) | |
| res_mlp = torch.empty(*ids.shape[:-1], self.num_blocks, context, self.d) | |
| # Recompute in every batch in case context changes | |
| if self.is_causal: | |
| if self.backend=="pytorch": | |
| causal = get_causal(context).to(ids.device) | |
| elif self.backend in {"flash2", "flash3", "flash4"}: | |
| causal = True | |
| elif self.backend=="flex": | |
| causal = causal_mod | |
| elif self.backend=="cudnn": | |
| # right_bound | |
| causal = 0 | |
| else: causal = None | |
| if self.pos_type == "sinusoidal": | |
| pos = get_sinusoidal(context, self.d).to(ids.device) | |
| X = X + pos | |
| if self.pos_type == "learned": | |
| X = X + self.pos[:context,:] | |
| if self.pos_type == "rope": | |
| rope = get_rope(context, self.d_head, device=ids.device) | |
| else: rope = None | |
| if self.pos_type == "alibi": | |
| if self.backend=="pytorch": | |
| alibi = get_alibi(self.heads, context).to(ids.device) | |
| elif self.backend in {"flash2", "flash3", "flash4"}: | |
| alibi = get_m(self.heads).to(ids.device) | |
| elif self.backend=="flex": | |
| alibi = alibi_mod | |
| elif self.backend=="cudnn": | |
| alibi = True | |
| else: alibi = None | |
| if self.window is not None: | |
| if self.backend=="pytorch": | |
| swa = get_swa(context, self.window).to(ids.device) | |
| elif self.backend in {"flash2", "flash3", "flash4"}: | |
| swa = (self.window, self.window) | |
| elif self.backend=="flex": | |
| swa = swa_mod | |
| elif self.backend=="cudnn": | |
| # left_bound | |
| swa = self.window | |
| else: swa = None | |
| # After positional encoding | |
| if self.emb_norm: X = self.emb_norm(X) | |
| X_ = torch.nn.functional.dropout(X, p=self.dropout, training=self.training) | |
| Y = X_ | |
| for i, block in enumerate(self.blocks): | |
| if not return_res: | |
| if not return_A: | |
| Y = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend) | |
| else: | |
| Y, A__[...,i,:,:,:], A_[...,i,:,:,:], A[...,i,:,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend) | |
| else: | |
| if not return_A: | |
| Y, res_att[...,i,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend) | |
| res_mlp[...,i,:,:]= Y | |
| else: | |
| Y, res_att[...,i,:,:], A__[...,i,:,:,:], A_[...,i,:,:,:], A[...,i,:,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend) | |
| res_mlp[...,i,:,:]= Y | |
| if self.out_norm: Y = self.out_norm(Y) | |
| # (batches*)context*vocab_size | |
| if self.fix_norm: | |
| Z = torch.nn.functional.linear(Y, mlp.sphere_norm(self.linear.weight)) | |
| else: | |
| Z = self.linear(Y) | |
| if not return_res: | |
| if not return_A: | |
| return Z | |
| else: | |
| return Z, A__, A_, A | |
| else: | |
| if not return_A: | |
| return Z, res_in, res_att, res_mlp | |
| else: | |
| return Z, res_in, res_att, res_mlp, A__, A_, A | |
| def get_attention_header(transformer): | |
| attention_header = "" | |
| for block in range(transformer.num_blocks): | |
| for head in range(transformer.heads): | |
| attention_header += f"block{block}.head{head} " | |
| # Remove last space | |
| attention_header = attention_header[:-1] | |
| return attention_header | |
| def get_attention(W): | |
| attention = "" | |
| for block in range(W.shape[0]): | |
| for head in range(W.shape[1]): | |
| # rows->y, columns->x | |
| attention += "%.2f " % W[block, head] | |
| # Remove last space | |
| attention = attention[:-1] | |
| return attention | |
| def get_similarity_header(transformer): | |
| similarity_header = "embedding " | |
| for block in range(transformer.num_blocks): | |
| similarity_header += f"block{block} " | |
| # Remove last space | |
| similarity_header = similarity_header[:-1] | |
| return similarity_header | |
| def get_similarity(embeddings_x, embeddings_y): | |
| similarity = "" | |
| for block in range(embeddings_x.shape[0]): | |
| similarity += "%.2f " % torch.nn.functional.cosine_similarity(embeddings_x[block,:], embeddings_y[block,:], dim=0) | |
| # Remove last space | |
| similarity = similarity[:-1] | |
| return similarity | |
| def get_clustering_header(transformer): | |
| clustering_header = "embedding.random.x embedding.random.y "\ | |
| "embedding.pca.x embedding.pca.y "\ | |
| "embedding.mds.x embedding.mds.y "\ | |
| "embedding.tsne.x embedding.tsne.y "\ | |
| "embedding.umap.x embedding.umap.y " | |
| for block in range(transformer.num_blocks): | |
| clustering_header += f"block{block}.random.x block{block}.random.y "\ | |
| f"block{block}.pca.x block{block}.pca.y "\ | |
| f"block{block}.mds.x block{block}.mds.y "\ | |
| f"block{block}.tsne.x block{block}.tsne.y "\ | |
| f"block{block}.umap.x block{block}.umap.y " | |
| # Remove last space | |
| clustering_header = clustering_header[:-1] | |
| return clustering_header | |
| def get_clustering(random, pca, mds, tsne, umap): | |
| clustering = "" | |
| for block in range(random.shape[0]): | |
| clustering += "%f %f %f %f %f %f %f %f %f %f " % (random[block,0], random[block,1], pca[block,0], pca[block,1], mds[block,0], mds[block,1], tsne[block,0], tsne[block,1], umap[block,0], umap[block,1]) | |
| # Remove last space | |
| clustering = clustering[:-1] | |
| return clustering | |