Instructions to use WhaletechAI/W1-4B-dLLM-Base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use WhaletechAI/W1-4B-dLLM-Base with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("WhaletechAI/W1-4B-dLLM-Base", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """CLI: run one sampler on a prompt and print the result.""" | |
| from __future__ import annotations | |
| import argparse | |
| from pathlib import Path | |
| import torch | |
| from whale4b.core.runner import RunConfig, SamplingRunner | |
| from whale4b.samplers import list_samplers | |
| def parse_args(): | |
| p = argparse.ArgumentParser(description="Whale3B diffusion LM sampler.") | |
| p.add_argument("--checkpoint", required=True, help="Path to .safetensors or .pt") | |
| p.add_argument("--config", default=str(Path(__file__).parent / "configs" / "whale3b.yaml")) | |
| p.add_argument("--tokenizer", default=str(Path(__file__).parent / "whale-tokenizer")) | |
| p.add_argument("--prompt", default="") | |
| p.add_argument("--sampler", default="standard", choices=list_samplers()) | |
| p.add_argument("--steps", type=int, default=64) | |
| p.add_argument("--max-new-tokens", type=int, default=256) | |
| p.add_argument("--temperature", type=float, default=0.0) | |
| p.add_argument("--top-k", type=int, default=0) | |
| p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") | |
| p.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) | |
| p.add_argument("--seed", type=int, default=1234) | |
| p.add_argument("--no-ema", action="store_true") | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| cfg = RunConfig( | |
| ckpt_path=args.checkpoint, | |
| config_path=args.config, | |
| tokenizer_path=args.tokenizer, | |
| sampler=args.sampler, | |
| steps=args.steps, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| device=args.device, | |
| dtype=args.dtype, | |
| seed=args.seed, | |
| use_ema=not args.no_ema, | |
| ) | |
| runner = SamplingRunner(cfg) | |
| result = runner.run(prompt=args.prompt) | |
| print(f"\n=== CONTINUATION ===\n{result.new_text}") | |
| print(f"\n=== STATS ===") | |
| print( | |
| f"sampler={result.sampler} | steps={result.steps_run} | " | |
| f"tokens={result.generated_tokens} | elapsed={result.elapsed_s:.2f}s" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |