Instructions to use WhaletechAI/W1-4B-dLLM-Base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use WhaletechAI/W1-4B-dLLM-Base with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("WhaletechAI/W1-4B-dLLM-Base", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| Model loading and wrapping. | |
| Provides: | |
| - load_checkpoint(ckpt_path, config, device, dtype, use_ema, strict) | |
| -> ModelWrapper | |
| - ModelWrapper.__call__(x [1,L], t [1]) -> logits [1,L,V] | |
| with autocast handled internally | |
| """ | |
| from __future__ import annotations | |
| import re | |
| from contextlib import nullcontext | |
| from pathlib import Path | |
| from typing import Optional | |
| import torch | |
| from .model import LangDiT, create_model # noqa: F401 | |
| STEP_CHECKPOINT_RE = re.compile(r"step_(\d+)(?:\.pt|\.safetensors)$") | |
| IGNORED_KEY_SUFFIXES = ("._extra_state",) | |
| IGNORED_EXACT_KEYS = {"rope.rope.inv_freq"} | |
| # ββ checkpoint helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def resolve_checkpoint(path: str) -> str: | |
| """If *path* is a directory, find a supported checkpoint file inside it.""" | |
| p = Path(path) | |
| if p.is_file(): | |
| return str(p) | |
| if p.is_dir(): | |
| candidates = sorted( | |
| p.glob("step_*.pt"), | |
| key=lambda f: int(STEP_CHECKPOINT_RE.match(f.name).group(1)) | |
| if STEP_CHECKPOINT_RE.match(f.name) else -1, | |
| ) | |
| if not candidates: | |
| candidates = sorted( | |
| p.glob("step_*.safetensors"), | |
| key=lambda f: int(STEP_CHECKPOINT_RE.match(f.name).group(1)) | |
| if STEP_CHECKPOINT_RE.match(f.name) else -1, | |
| ) | |
| if candidates: | |
| return str(candidates[-1]) | |
| named = [p / "model.safetensors", p / "checkpoint.safetensors"] | |
| for candidate in named: | |
| if candidate.is_file(): | |
| return str(candidate) | |
| safetensors_files = sorted(p.glob("*.safetensors")) | |
| if len(safetensors_files) == 1: | |
| return str(safetensors_files[0]) | |
| if (p / "model.safetensors.index.json").is_file(): | |
| raise FileNotFoundError( | |
| "Sharded safetensors are not supported by whale4b yet. " | |
| "Pass a single .safetensors file instead." | |
| ) | |
| raise FileNotFoundError(f"No checkpoint found at: {path}") | |
| def load_state_dict(ckpt_path: str, use_ema: bool = True): | |
| """Load raw state dict from ``.pt`` or ``.safetensors``, preferring EMA.""" | |
| if ckpt_path.endswith(".safetensors"): | |
| from safetensors.torch import load_file | |
| return load_file(ckpt_path, device="cpu"), "safetensors" | |
| load_kwargs = {"map_location": "cpu", "weights_only": False} | |
| try: | |
| ckpt = torch.load(ckpt_path, mmap=True, **load_kwargs) | |
| except TypeError: | |
| ckpt = torch.load(ckpt_path, **load_kwargs) | |
| if not isinstance(ckpt, dict): | |
| return ckpt, "raw" | |
| if use_ema and isinstance(ckpt.get("ema"), dict): | |
| return ckpt["ema"], "ema" | |
| if isinstance(ckpt.get("model"), dict): | |
| return ckpt["model"], "model" | |
| if isinstance(ckpt.get("state_dict"), dict): | |
| return ckpt["state_dict"], "state_dict" | |
| return ckpt, "root" | |
| def _strip_prefix(sd: dict, prefix: str) -> dict: | |
| if not any(k.startswith(prefix) for k in sd): | |
| return sd | |
| out = {} | |
| for key, value in sd.items(): | |
| out[key[len(prefix):] if key.startswith(prefix) else key] = value | |
| return out | |
| def sanitize_state_dict(state_dict: dict) -> tuple[dict, list[str]]: | |
| """Strip wrapper prefixes and drop non-inference metadata keys.""" | |
| for prefix in ("module.", "model.", "_orig_mod."): | |
| state_dict = _strip_prefix(state_dict, prefix) | |
| dropped: list[str] = [] | |
| cleaned: dict = {} | |
| for key, value in state_dict.items(): | |
| if key in IGNORED_EXACT_KEYS or any( | |
| key.endswith(suffix) for suffix in IGNORED_KEY_SUFFIXES | |
| ): | |
| dropped.append(key) | |
| continue | |
| cleaned[key] = value | |
| return cleaned, dropped | |
| def resolve_dtype(dtype_name: str, device: torch.device): | |
| """Returns ``(amp_dtype, use_amp, model_dtype)``.""" | |
| dtype_map = { | |
| "bf16": torch.bfloat16, | |
| "fp16": torch.float16, | |
| "fp32": torch.float32, | |
| } | |
| amp_dtype = dtype_map.get(dtype_name, torch.bfloat16) | |
| if dtype_name == "fp32": | |
| return amp_dtype, False, torch.float32 | |
| if device.type == "cuda": | |
| return amp_dtype, True, amp_dtype | |
| if device.type == "mps" and dtype_name == "fp16": | |
| return amp_dtype, False, torch.float16 | |
| return amp_dtype, False, torch.float32 | |
| # ββ ModelWrapper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ModelWrapper: | |
| """ | |
| Wraps LangDiT into a standard ``(x [1,L], t [1]) -> logits [1,L,V]`` | |
| callable. Handles autocast internally β callers never deal with AMP. | |
| """ | |
| def __init__( | |
| self, | |
| model: LangDiT, | |
| vocab_size: int, | |
| mask_token_id: int, | |
| device: torch.device, | |
| use_amp: bool, | |
| amp_dtype: torch.dtype, | |
| ): | |
| self.model = model | |
| self.vocab_size = vocab_size | |
| self.mask_token_id = mask_token_id | |
| self.device = device | |
| self.use_amp = use_amp | |
| self.amp_dtype = amp_dtype | |
| def __call__(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| x: [1, L] int64 | |
| t: [1] float | |
| Returns: [1, L, V] float32 logits (raw β no softmax) | |
| """ | |
| x = x.to(self.device) | |
| t = t.to(self.device) | |
| amp_ctx = ( | |
| torch.autocast(device_type="cuda", dtype=self.amp_dtype) | |
| if self.use_amp and self.device.type == "cuda" | |
| else nullcontext() | |
| ) | |
| with amp_ctx: | |
| logits = self.model(x, t) | |
| return logits | |
| def load_checkpoint( | |
| ckpt_path: str, | |
| config: dict, | |
| device: Optional[torch.device] = None, | |
| dtype: str = "bf16", | |
| use_ema: bool = True, | |
| strict: bool = False, | |
| ) -> ModelWrapper: | |
| """ | |
| Full pipeline: resolve path -> load state dict -> build model -> wrap. | |
| Returns a ready-to-call ModelWrapper. | |
| """ | |
| device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| amp_dtype, use_amp, model_dtype = resolve_dtype(dtype, device) | |
| resolved = resolve_checkpoint(ckpt_path) | |
| state_dict, source = load_state_dict(resolved, use_ema=use_ema) | |
| state_dict, dropped = sanitize_state_dict(state_dict) | |
| model = create_model(config).to(device=device, dtype=model_dtype) | |
| model.eval() | |
| missing, unexpected = model.load_state_dict(state_dict, strict=strict) | |
| del state_dict | |
| if missing: | |
| print(f"[loader] missing keys: {len(missing)} β sample: {missing[:3]}") | |
| if unexpected: | |
| print(f"[loader] unexpected keys: {len(unexpected)} β sample: {unexpected[:3]}") | |
| if dropped: | |
| print(f"[loader] dropped non-inference keys: {len(dropped)} β sample: {dropped[:3]}") | |
| print(f"[loader] loaded {resolved!r} (source={source}, dtype={model_dtype})") | |
| diff_cfg = config.get("diffusion", {}) | |
| vocab_size = int(config["model"]["vocab_size"]) | |
| mask_token_id = int(diff_cfg.get("mask_token_id", 14)) | |
| return ModelWrapper( | |
| model=model, | |
| vocab_size=vocab_size, | |
| mask_token_id=mask_token_id, | |
| device=device, | |
| use_amp=use_amp, | |
| amp_dtype=amp_dtype, | |
| ) | |