Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
| """CPU offloading tests for optimizer states. | |
| Run with: | |
| torchrun --nproc-per-node=8 --local-ranks-filter=0 test/test_cpu_offload.py | |
| Tests: | |
| 1. Correctness: turn_on_cpu_offload() produces identical results to no offload | |
| 2. Memory: GPU optimizer state storage is freed after offload | |
| 3. AdamW: moment1/moment2 offloading works correctly | |
| """ | |
| import copy | |
| import logging | |
| import pytest | |
| import torch | |
| import torch.distributed as dist | |
| from torch.distributed.tensor import DTensor, Shard, distribute_tensor | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") | |
| def _setup(): | |
| dist.init_process_group(backend="nccl") | |
| rank = dist.get_rank() | |
| torch.cuda.set_device(rank % torch.cuda.device_count()) | |
| return rank, dist.get_world_size() | |
| def _make_mesh(world_size): | |
| return dist.init_device_mesh("cuda", (world_size, ), | |
| mesh_dim_names=("dp", )) | |
| def test_correctness(rank, world_size): | |
| """Verify that turn_on_cpu_offload() produces identical parameters as no offload.""" | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| mesh = _make_mesh(world_size) | |
| dim0, dim1 = 64, 128 | |
| num_params = 4 | |
| num_steps = 3 | |
| # Pre-generate all data on all ranks (same seed → same values). | |
| full_params = [ | |
| torch.randn(dim0, dim1, device="cuda") for _ in range(num_params) | |
| ] | |
| full_grads = [[ | |
| torch.randn(dim0, dim1, device="cuda") for _ in range(num_params) | |
| ] for _ in range(num_steps)] | |
| def make_optimizer(cpu_offload): | |
| params, names = [], [] | |
| for i, fp in enumerate(full_params): | |
| dt = distribute_tensor(fp.clone(), mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| params.append(p) | |
| names.append(f"layer.{i}.weight") | |
| param_groups = [{ | |
| "params": params, | |
| "names": names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| if cpu_offload: | |
| optim.turn_on_cpu_offload() | |
| return optim, params | |
| optim_ref, params_ref = make_optimizer(False) | |
| optim_off, params_off = make_optimizer(True) | |
| for step_idx in range(num_steps): | |
| for i in range(num_params): | |
| g = full_grads[step_idx][i] | |
| params_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| params_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| optim_ref.step() | |
| optim_off.step() | |
| for i in range(num_params): | |
| ref_full = params_ref[i].data.full_tensor() | |
| off_full = params_off[i].data.full_tensor() | |
| torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0) | |
| if rank == 0: | |
| logger.info("Step %d: correctness OK", step_idx) | |
| set_ns_compile(True) | |
| if rank == 0: | |
| logger.info("PASSED: test_correctness") | |
| def test_memory(rank, world_size): | |
| """Verify that GPU storage is freed after offload.""" | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| mesh = _make_mesh(world_size) | |
| dim0, dim1 = 512, 1024 | |
| num_params = 8 | |
| params, names = [], [] | |
| for i in range(num_params): | |
| full = torch.randn(dim0, dim1, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| params.append(p) | |
| names.append(f"layer.{i}.weight") | |
| param_groups = [{ | |
| "params": params, | |
| "names": names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| optim.turn_on_cpu_offload() | |
| optim.step() | |
| torch.cuda.synchronize() | |
| # After step + offload, all momentum buffer GPU storage should be freed. | |
| for p in params: | |
| state = optim.state[p] | |
| if "momentum_buffer" not in state: | |
| continue | |
| buf = state["momentum_buffer"] | |
| local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf | |
| assert local_buf.untyped_storage().size() == 0, ( | |
| f"Expected freed GPU storage after offload, got " | |
| f"{local_buf.untyped_storage().size()} bytes") | |
| # Verify CPU pool has pinned buffers. | |
| pool = optim._cpu_offload_pool | |
| assert len(pool._managed) > 0, "No tensors tracked by CPU offload pool" | |
| for grp in pool._groups.values(): | |
| assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned memory" | |
| # Run another step to verify reload + compute + offload cycle works. | |
| for p in params: | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| optim.step() | |
| torch.cuda.synchronize() | |
| # Storage should be freed again after second step. | |
| for p in params: | |
| state = optim.state[p] | |
| if "momentum_buffer" not in state: | |
| continue | |
| buf = state["momentum_buffer"] | |
| local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf | |
| assert local_buf.untyped_storage().size() == 0 | |
| set_ns_compile(True) | |
| if rank == 0: | |
| logger.info("PASSED: test_memory") | |
| def test_adamw_offload(rank, world_size): | |
| """Verify AdamW moment1/moment2 are offloaded correctly.""" | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| mesh = _make_mesh(world_size) | |
| num_steps = 3 | |
| # Create both Muon (2D) and AdamW (1D) params. | |
| muon_params, muon_names = [], [] | |
| adamw_params, adamw_names = [], [] | |
| for i in range(4): | |
| full = torch.randn(64, 128, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| muon_params.append(p) | |
| muon_names.append(f"layer.{i}.weight") | |
| for i in range(3): | |
| full = torch.randn(128, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| adamw_params.append(p) | |
| adamw_names.append(f"layer.{i}.bias") | |
| # Pre-generate grads. | |
| muon_grads = [[torch.randn(64, 128, device="cuda") for _ in range(4)] | |
| for _ in range(num_steps)] | |
| adamw_grads = [[torch.randn(128, device="cuda") for _ in range(3)] | |
| for _ in range(num_steps)] | |
| def make_optimizer(cpu_offload): | |
| mp = [ | |
| torch.nn.Parameter( | |
| distribute_tensor(p.data.full_tensor().clone(), mesh, | |
| [Shard(0)])) for p in muon_params | |
| ] | |
| ap = [ | |
| torch.nn.Parameter( | |
| distribute_tensor(p.data.full_tensor().clone(), mesh, | |
| [Shard(0)])) for p in adamw_params | |
| ] | |
| param_groups = [ | |
| { | |
| "params": mp, | |
| "names": list(muon_names), | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| "adamw_betas": (0.9, 0.95), | |
| "adamw_eps": 1e-8, | |
| }, | |
| { | |
| "params": ap, | |
| "use_muon": False, | |
| "lr": 1e-3, | |
| "weight_decay": 0.01, | |
| "adamw_betas": (0.9, 0.95), | |
| "adamw_eps": 1e-8, | |
| }, | |
| ] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| if cpu_offload: | |
| optim.turn_on_cpu_offload() | |
| return optim, mp, ap | |
| optim_ref, mp_ref, ap_ref = make_optimizer(False) | |
| optim_off, mp_off, ap_off = make_optimizer(True) | |
| for step_idx in range(num_steps): | |
| for i in range(4): | |
| g = muon_grads[step_idx][i] | |
| mp_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| mp_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| for i in range(3): | |
| g = adamw_grads[step_idx][i] | |
| ap_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| ap_off[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| optim_ref.step() | |
| optim_off.step() | |
| # Compare Muon params. | |
| for i in range(4): | |
| ref_full = mp_ref[i].data.full_tensor() | |
| off_full = mp_off[i].data.full_tensor() | |
| torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0) | |
| # Compare AdamW params. | |
| for i in range(3): | |
| ref_full = ap_ref[i].data.full_tensor() | |
| off_full = ap_off[i].data.full_tensor() | |
| torch.testing.assert_close(ref_full, off_full, atol=0, rtol=0) | |
| if rank == 0: | |
| logger.info("Step %d: AdamW offload correctness OK", step_idx) | |
| # Verify AdamW states are offloaded. | |
| for p in ap_off: | |
| state = optim_off.state[p] | |
| for key in ("moment1", "moment2"): | |
| if key not in state: | |
| continue | |
| t = state[key] | |
| local_t = t._local_tensor if isinstance(t, DTensor) else t | |
| assert local_t.untyped_storage().size() == 0, ( | |
| f"AdamW {key} storage not freed after offload") | |
| set_ns_compile(True) | |
| if rank == 0: | |
| logger.info("PASSED: test_adamw_offload") | |
| def test_memory_savings(rank, world_size): | |
| """Measure actual GPU memory difference with and without offload.""" | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| set_ns_compile(False) | |
| mesh = _make_mesh(world_size) | |
| dim0, dim1 = 1024, 2048 | |
| num_params = 8 | |
| def run_step(cpu_offload): | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_peak_memory_stats() | |
| torch.manual_seed(42) | |
| params, names = [], [] | |
| for i in range(num_params): | |
| full = torch.randn(dim0, dim1, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| params.append(p) | |
| names.append(f"layer.{i}.weight") | |
| param_groups = [{ | |
| "params": params, | |
| "names": names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| if cpu_offload: | |
| optim.turn_on_cpu_offload() | |
| optim.step() | |
| torch.cuda.synchronize() | |
| mem = torch.cuda.memory_allocated() | |
| # Clean up to avoid interference. | |
| del optim, params, param_groups | |
| torch.cuda.empty_cache() | |
| return mem | |
| mem_no_offload = run_step(False) | |
| mem_with_offload = run_step(True) | |
| if rank == 0: | |
| logger.info("Memory without offload: %.2f MB", | |
| mem_no_offload / 1024**2) | |
| logger.info("Memory with offload: %.2f MB", | |
| mem_with_offload / 1024**2) | |
| saved = mem_no_offload - mem_with_offload | |
| logger.info("Memory saved: %.2f MB", saved / 1024**2) | |
| assert mem_with_offload < mem_no_offload, ( | |
| f"Expected memory reduction with CPU offload. " | |
| f"Without: {mem_no_offload / 1024**2:.2f} MB, " | |
| f"With: {mem_with_offload / 1024**2:.2f} MB") | |
| set_ns_compile(True) | |
| if rank == 0: | |
| logger.info("PASSED: test_memory_savings") | |
| def test_toggle_correctness(rank, world_size): | |
| """Verify toggling offload on/off between steps produces identical results.""" | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| mesh = _make_mesh(world_size) | |
| dim0, dim1 = 64, 128 | |
| num_params = 4 | |
| num_steps = 6 | |
| full_params = [ | |
| torch.randn(dim0, dim1, device="cuda") for _ in range(num_params) | |
| ] | |
| full_grads = [[ | |
| torch.randn(dim0, dim1, device="cuda") for _ in range(num_params) | |
| ] for _ in range(num_steps)] | |
| def make_optimizer(): | |
| params, names = [], [] | |
| for i, fp in enumerate(full_params): | |
| dt = distribute_tensor(fp.clone(), mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| params.append(p) | |
| names.append(f"layer.{i}.weight") | |
| param_groups = [{ | |
| "params": params, | |
| "names": names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| return optim, params | |
| # Reference: no offload at all. | |
| optim_ref, params_ref = make_optimizer() | |
| # Toggle: on → step → off → step → on → step ... | |
| optim_toggle, params_toggle = make_optimizer() | |
| for step_idx in range(num_steps): | |
| # Toggle offload every 2 steps: on for [0,1], off for [2,3], on for [4,5]. | |
| want_on = (step_idx // 2) % 2 == 0 | |
| if want_on and not optim_toggle.cpu_offload: | |
| optim_toggle.turn_on_cpu_offload() | |
| elif not want_on and optim_toggle.cpu_offload: | |
| optim_toggle.turn_off_cpu_offload() | |
| for i in range(num_params): | |
| g = full_grads[step_idx][i] | |
| params_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| params_toggle[i].grad = distribute_tensor(g.clone(), mesh, | |
| [Shard(0)]) | |
| optim_ref.step() | |
| optim_toggle.step() | |
| for i in range(num_params): | |
| ref_full = params_ref[i].data.full_tensor() | |
| tog_full = params_toggle[i].data.full_tensor() | |
| torch.testing.assert_close(ref_full, tog_full, atol=0, rtol=0) | |
| if rank == 0: | |
| logger.info( | |
| "Step %d (offload=%s): toggle correctness OK", | |
| step_idx, | |
| optim_toggle.cpu_offload, | |
| ) | |
| set_ns_compile(True) | |
| if rank == 0: | |
| logger.info("PASSED: test_toggle_correctness") | |
| def test_leak(rank, world_size): | |
| """Run many iterations and verify no CPU/GPU memory leak.""" | |
| import os | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| mesh = _make_mesh(world_size) | |
| dim0, dim1 = 512, 1024 | |
| num_params = 8 | |
| num_steps = 50 | |
| params, names = [], [] | |
| for i in range(num_params): | |
| full = torch.randn(dim0, dim1, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| params.append(p) | |
| names.append(f"layer.{i}.weight") | |
| param_groups = [{ | |
| "params": params, | |
| "names": names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| optim.turn_on_cpu_offload() | |
| def get_cpu_rss_mb(): | |
| """Get current process RSS in MB from /proc/self/statm.""" | |
| with open("/proc/self/statm") as f: | |
| pages = int(f.read().split()[1]) | |
| return pages * os.sysconf("SC_PAGE_SIZE") / (1024**2) | |
| gpu_after_warmup = None | |
| cpu_after_warmup = None | |
| for step_idx in range(num_steps): | |
| for p in params: | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| optim.step() | |
| torch.cuda.synchronize() | |
| gpu_mem = torch.cuda.memory_allocated() | |
| cpu_mem = get_cpu_rss_mb() | |
| # Record baseline after warmup (step 2 — first step creates states, | |
| # second step does first full offload/reload cycle). | |
| if step_idx == 2: | |
| gpu_after_warmup = gpu_mem | |
| cpu_after_warmup = cpu_mem | |
| if rank == 0 and step_idx % 10 == 0: | |
| logger.info( | |
| "Step %d: GPU alloc=%.2f MB, CPU RSS=%.2f MB", | |
| step_idx, | |
| gpu_mem / (1024**2), | |
| cpu_mem, | |
| ) | |
| # Final measurements. | |
| torch.cuda.synchronize() | |
| gpu_final = torch.cuda.memory_allocated() | |
| cpu_final = get_cpu_rss_mb() | |
| if rank == 0: | |
| logger.info( | |
| "After %d steps: GPU alloc=%.2f MB, CPU RSS=%.2f MB", | |
| num_steps, | |
| gpu_final / (1024**2), | |
| cpu_final, | |
| ) | |
| logger.info( | |
| "Warmup baseline: GPU alloc=%.2f MB, CPU RSS=%.2f MB", | |
| gpu_after_warmup / (1024**2), | |
| cpu_after_warmup, | |
| ) | |
| # GPU memory should not grow beyond warmup baseline. | |
| assert gpu_final <= gpu_after_warmup, ( | |
| f"GPU memory leak detected! Warmup: {gpu_after_warmup / 1024**2:.2f} MB, " | |
| f"Final: {gpu_final / 1024**2:.2f} MB") | |
| # CPU RSS should not grow more than 50 MB over warmup (allows for minor | |
| # Python/CUDA runtime overhead but catches real leaks). | |
| cpu_growth = cpu_final - cpu_after_warmup | |
| assert cpu_growth < 50, ( | |
| f"CPU memory leak detected! Growth: {cpu_growth:.2f} MB over " | |
| f"{num_steps - 2} steps (warmup={cpu_after_warmup:.2f} MB, " | |
| f"final={cpu_final:.2f} MB)") | |
| set_ns_compile(True) | |
| if rank == 0: | |
| logger.info("PASSED: test_leak (GPU stable, CPU growth=%.2f MB)", | |
| cpu_growth) | |
| def test_state_dict_save_load(rank, world_size): | |
| """Verify state_dict() works after offload and load_state_dict() resumes correctly. | |
| Uses torch.distributed.checkpoint (DCP) for serialization, matching | |
| the actual LLM training checkpoint flow. DCP handles DTensors natively | |
| so the roundtrip is bitwise exact. | |
| """ | |
| import shutil | |
| import tempfile | |
| import torch.distributed.checkpoint as dcp | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| mesh = _make_mesh(world_size) | |
| dim0, dim1 = 64, 128 | |
| num_muon = 4 | |
| num_adamw = 3 | |
| num_steps = 3 | |
| # Pre-generate all data. | |
| muon_init = [ | |
| torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon) | |
| ] | |
| adamw_init = [torch.randn(dim1, device="cuda") for _ in range(num_adamw)] | |
| all_grads_muon = [[ | |
| torch.randn(dim0, dim1, device="cuda") for _ in range(num_muon) | |
| ] for _ in range(num_steps * 2)] | |
| all_grads_adamw = [[ | |
| torch.randn(dim1, device="cuda") for _ in range(num_adamw) | |
| ] for _ in range(num_steps * 2)] | |
| def make_optimizer(cpu_offload): | |
| mp = [ | |
| torch.nn.Parameter( | |
| distribute_tensor(muon_init[i].clone(), mesh, [Shard(0)])) | |
| for i in range(num_muon) | |
| ] | |
| ap = [ | |
| torch.nn.Parameter( | |
| distribute_tensor(adamw_init[i].clone(), mesh, [Shard(0)])) | |
| for i in range(num_adamw) | |
| ] | |
| param_groups = [ | |
| { | |
| "params": mp, | |
| "names": [f"layer.{i}.weight" for i in range(num_muon)], | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| "adamw_betas": (0.9, 0.95), | |
| "adamw_eps": 1e-8, | |
| }, | |
| { | |
| "params": ap, | |
| "use_muon": False, | |
| "lr": 1e-3, | |
| "weight_decay": 0.01, | |
| "adamw_betas": (0.9, 0.95), | |
| "adamw_eps": 1e-8, | |
| }, | |
| ] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| if cpu_offload: | |
| optim.turn_on_cpu_offload() | |
| return optim, mp, ap | |
| # --- Run one optimizer for first half, save state, then create TWO | |
| # fresh optimizers: ref loads via deepcopy, resumed loads via DCP. | |
| # Both are fresh → same internal cache state → isolates DCP fidelity. | |
| optim_off, mp_off, ap_off = make_optimizer(True) | |
| for step_idx in range(num_steps): | |
| for i in range(num_muon): | |
| mp_off[i].grad = distribute_tensor( | |
| all_grads_muon[step_idx][i].clone(), mesh, [Shard(0)]) | |
| for i in range(num_adamw): | |
| ap_off[i].grad = distribute_tensor( | |
| all_grads_adamw[step_idx][i].clone(), mesh, [Shard(0)]) | |
| optim_off.step() | |
| with pytest.raises( | |
| RuntimeError, | |
| match="turn_off_cpu_offload\\(\\) before checkpoint save"): | |
| optim_off.state_dict() | |
| optim_off.turn_off_cpu_offload() | |
| sd_off = optim_off.state_dict() | |
| # Verify state tensors are NOT empty in the state_dict. | |
| for param_states in sd_off["state"].values(): | |
| for key, val in param_states.items(): | |
| if isinstance(val, torch.Tensor) and val.is_floating_point(): | |
| assert val.untyped_storage().size() > 0, ( | |
| f"state_dict() returned empty storage for key '{key}' — " | |
| f"offload reload is broken") | |
| if rank == 0: | |
| logger.info("state_dict() contains valid (non-empty) tensors") | |
| # Save state tensors via DCP (matches real LLM training checkpoint flow). | |
| # Flatten state tensors with string keys for DCP compatibility. | |
| flat_state = {} | |
| for param_idx, param_state in sd_off["state"].items(): | |
| for key, val in param_state.items(): | |
| if isinstance(val, torch.Tensor): | |
| flat_state[f"state.{param_idx}.{key}"] = val | |
| # All ranks must use the same checkpoint directory. | |
| if rank == 0: | |
| ckpt_dir = tempfile.mkdtemp(prefix="cpu_offload_test_") | |
| else: | |
| ckpt_dir = "" | |
| ckpt_dir_list = [ckpt_dir] | |
| dist.broadcast_object_list(ckpt_dir_list, src=0) | |
| ckpt_dir = ckpt_dir_list[0] | |
| try: | |
| dcp.save(flat_state, checkpoint_id=ckpt_dir) | |
| dist.barrier() | |
| if rank == 0: | |
| logger.info("DCP save completed to %s", ckpt_dir) | |
| # --- Reference: fresh optimizer, load via deepcopy (no serialization). | |
| optim_ref, mp_ref, ap_ref = make_optimizer(True) | |
| for i in range(num_muon): | |
| mp_ref[i].data = mp_off[i].data.clone() | |
| for i in range(num_adamw): | |
| ap_ref[i].data = ap_off[i].data.clone() | |
| with pytest.raises( | |
| RuntimeError, | |
| match="turn_off_cpu_offload\\(\\) before checkpoint load"): | |
| optim_ref.load_state_dict(copy.deepcopy(sd_off)) | |
| optim_ref.turn_off_cpu_offload() | |
| optim_ref.load_state_dict(copy.deepcopy(sd_off)) | |
| optim_ref.turn_on_cpu_offload() | |
| # --- Resumed: fresh optimizer, load via DCP. | |
| optim_resumed, mp_resumed, ap_resumed = make_optimizer(True) | |
| for i in range(num_muon): | |
| mp_resumed[i].data = mp_off[i].data.clone() | |
| for i in range(num_adamw): | |
| ap_resumed[i].data = ap_off[i].data.clone() | |
| flat_target = {k: torch.zeros_like(v) for k, v in flat_state.items()} | |
| dcp.load(flat_target, checkpoint_id=ckpt_dir) | |
| dist.barrier() | |
| sd_loaded = copy.deepcopy(sd_off) | |
| for param_idx, param_state in sd_loaded["state"].items(): | |
| for key in list(param_state.keys()): | |
| flat_key = f"state.{param_idx}.{key}" | |
| if flat_key in flat_target: | |
| param_state[key] = flat_target[flat_key] | |
| with pytest.raises( | |
| RuntimeError, | |
| match="turn_off_cpu_offload\\(\\) before checkpoint load"): | |
| optim_resumed.load_state_dict(copy.deepcopy(sd_loaded)) | |
| optim_resumed.turn_off_cpu_offload() | |
| optim_resumed.load_state_dict(sd_loaded) | |
| optim_resumed.turn_on_cpu_offload() | |
| if rank == 0: | |
| logger.info("Both optimizers loaded, starting comparison steps") | |
| finally: | |
| dist.barrier() | |
| if rank == 0: | |
| shutil.rmtree(ckpt_dir, ignore_errors=True) | |
| # Second half: reference continues, resumed uses loaded state. | |
| for step_idx in range(num_steps, num_steps * 2): | |
| for i in range(num_muon): | |
| g = all_grads_muon[step_idx][i] | |
| mp_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| mp_resumed[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| for i in range(num_adamw): | |
| g = all_grads_adamw[step_idx][i] | |
| ap_ref[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| ap_resumed[i].grad = distribute_tensor(g.clone(), mesh, [Shard(0)]) | |
| optim_ref.step() | |
| optim_resumed.step() | |
| # Compare final params: bitwise exact (DCP preserves DTensor identity). | |
| for i in range(num_muon): | |
| ref_full = mp_ref[i].data.full_tensor() | |
| res_full = mp_resumed[i].data.full_tensor() | |
| torch.testing.assert_close(ref_full, res_full, atol=0, rtol=0) | |
| for i in range(num_adamw): | |
| ref_full = ap_ref[i].data.full_tensor() | |
| res_full = ap_resumed[i].data.full_tensor() | |
| torch.testing.assert_close(ref_full, res_full, atol=0, rtol=0) | |
| # Verify offload is active on the resumed optimizer. | |
| for p in mp_resumed: | |
| state = optim_resumed.state[p] | |
| if "momentum_buffer" in state: | |
| buf = state["momentum_buffer"] | |
| local_buf = buf._local_tensor if isinstance(buf, DTensor) else buf | |
| assert local_buf.untyped_storage().size() == 0, ( | |
| "Resumed optimizer should have offloaded state after step()") | |
| set_ns_compile(True) | |
| if rank == 0: | |
| logger.info("PASSED: test_state_dict_save_load") | |
| def test_checkpoint_memory(rank, world_size): | |
| """Verify checkpoint APIs require offload to be disabled explicitly.""" | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| mesh = _make_mesh(world_size) | |
| dim0, dim1 = 512, 1024 | |
| num_params = 8 | |
| params, names = [], [] | |
| for i in range(num_params): | |
| full = torch.randn(dim0, dim1, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| params.append(p) | |
| names.append(f"layer.{i}.weight") | |
| param_groups = [{ | |
| "params": params, | |
| "names": names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| optim.turn_on_cpu_offload() | |
| # Step 1: run a step so offload initializes. | |
| optim.step() | |
| torch.cuda.synchronize() | |
| mem_after_step = torch.cuda.memory_allocated() | |
| # Calculate expected state size (momentum buffers, bf16). | |
| state_bytes = 0 | |
| for p in params: | |
| state = optim.state[p] | |
| if "momentum_buffer" in state: | |
| buf = state["momentum_buffer"] | |
| local = buf._local_tensor if isinstance(buf, DTensor) else buf | |
| # Storage is freed, so use the tracked size. | |
| state_bytes += optim._cpu_offload_pool._storage_nbytes[id(buf)] | |
| if rank == 0: | |
| logger.info( | |
| "After step (offloaded): GPU alloc=%.2f MB, expected state size=%.2f MB", | |
| mem_after_step / 1024**2, | |
| state_bytes / 1024**2, | |
| ) | |
| with pytest.raises( | |
| RuntimeError, | |
| match="turn_off_cpu_offload\\(\\) before checkpoint save"): | |
| optim.state_dict() | |
| optim.turn_off_cpu_offload() | |
| torch.cuda.synchronize() | |
| mem_after_turn_off = torch.cuda.memory_allocated() | |
| sd_for_load = copy.deepcopy(optim.state_dict()) | |
| if rank == 0: | |
| logger.info( | |
| "After turn_off_cpu_offload: GPU alloc=%.2f MB", | |
| mem_after_turn_off / 1024**2, | |
| ) | |
| assert mem_after_turn_off > mem_after_step, ( | |
| f"turn_off_cpu_offload() should reload states to GPU. " | |
| f"After offload: {mem_after_step / 1024**2:.2f} MB, " | |
| f"After turn_off: {mem_after_turn_off / 1024**2:.2f} MB") | |
| optim.turn_on_cpu_offload() | |
| torch.cuda.synchronize() | |
| mem_after_turn_on = torch.cuda.memory_allocated() | |
| if rank == 0: | |
| logger.info("After turn_on_cpu_offload: GPU alloc=%.2f MB", | |
| mem_after_turn_on / 1024**2) | |
| assert mem_after_turn_on <= mem_after_step + 4 * 1024 * 1024, ( | |
| f"turn_on_cpu_offload() should return memory to offloaded level. " | |
| f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " | |
| f"got {mem_after_turn_on / 1024**2:.2f} MB") | |
| for p in params: | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| optim.step() | |
| torch.cuda.synchronize() | |
| mem_after_next_step = torch.cuda.memory_allocated() | |
| if rank == 0: | |
| logger.info( | |
| "After next step (re-offloaded): GPU alloc=%.2f MB", | |
| mem_after_next_step / 1024**2, | |
| ) | |
| # Allow 4 MB tolerance for CUDA allocator fragmentation. | |
| assert mem_after_next_step <= mem_after_step + 4 * 1024 * 1024, ( | |
| f"Memory should return to offloaded level after step(). " | |
| f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " | |
| f"got {mem_after_next_step / 1024**2:.2f} MB") | |
| with pytest.raises( | |
| RuntimeError, | |
| match="turn_off_cpu_offload\\(\\) before checkpoint load"): | |
| optim.load_state_dict(copy.deepcopy(sd_for_load)) | |
| optim.turn_off_cpu_offload() | |
| optim.load_state_dict(sd_for_load) | |
| torch.cuda.synchronize() | |
| mem_after_load = torch.cuda.memory_allocated() | |
| if rank == 0: | |
| logger.info( | |
| "After load_state_dict with offload disabled: GPU alloc=%.2f MB", | |
| mem_after_load / 1024**2, | |
| ) | |
| assert mem_after_load >= mem_after_turn_off, ( | |
| "Loaded optimizer state should stay on GPU while offload is disabled") | |
| optim.turn_on_cpu_offload() | |
| torch.cuda.synchronize() | |
| pool = optim._cpu_offload_pool | |
| assert pool._initialized, ( | |
| "Offload pool should be initialized after re-enabling offload") | |
| for grp in pool._groups.values(): | |
| assert grp["cpu_flat"].is_pinned(), "CPU buffer must be pinned" | |
| # Step 5: verify the loaded optimizer can still step correctly. | |
| for p in params: | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| optim.step() | |
| torch.cuda.synchronize() | |
| mem_final = torch.cuda.memory_allocated() | |
| assert mem_final <= mem_after_step + 4 * 1024 * 1024, ( | |
| f"Final memory should be at offloaded level. " | |
| f"Expected <= {mem_after_step / 1024**2:.2f} MB (+4 MB tolerance), " | |
| f"got {mem_final / 1024**2:.2f} MB") | |
| set_ns_compile(True) | |
| if rank == 0: | |
| logger.info("PASSED: test_checkpoint_memory") | |
| def main(): | |
| rank, world_size = _setup() | |
| try: | |
| test_correctness(rank, world_size) | |
| test_memory(rank, world_size) | |
| test_adamw_offload(rank, world_size) | |
| test_memory_savings(rank, world_size) | |
| test_toggle_correctness(rank, world_size) | |
| test_leak(rank, world_size) | |
| test_state_dict_save_load(rank, world_size) | |
| test_checkpoint_memory(rank, world_size) | |
| if rank == 0: | |
| logger.info("=" * 50) | |
| logger.info("ALL CPU OFFLOAD TESTS PASSED") | |
| logger.info("=" * 50) | |
| finally: | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() | |