Instructions to use Motif-Technologies/optimizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use Motif-Technologies/optimizer with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("Motif-Technologies/optimizer") - Notebooks
- Google Colab
- Kaggle
| """CPU memory peak vs offloaded tensor size verification. | |
| Compares CPU memory usage with turn_on_cpu_offload() vs no offload to isolate | |
| the actual CPU cost of offloading, separating it from CUDA runtime, | |
| NCCL, and DTensor overhead. | |
| Run with: | |
| torchrun --nproc-per-node=8 --local-ranks-filter=0 test/test_cpu_memory_peak.py | |
| """ | |
| import gc | |
| import logging | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| from torch.distributed.tensor import DTensor, Shard, distribute_tensor | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") | |
| def _setup(): | |
| dist.init_process_group(backend="nccl") | |
| rank = dist.get_rank() | |
| torch.cuda.set_device(rank % torch.cuda.device_count()) | |
| return rank, dist.get_world_size() | |
| def _make_mesh(world_size): | |
| return dist.init_device_mesh("cuda", (world_size, ), | |
| mesh_dim_names=("dp", )) | |
| def get_cpu_rss_bytes(): | |
| """Get current process RSS in bytes from /proc/self/statm.""" | |
| with open("/proc/self/statm") as f: | |
| pages = int(f.read().split()[1]) | |
| return pages * os.sysconf("SC_PAGE_SIZE") | |
| def get_pinned_pool_bytes(pool): | |
| """Get total pinned CPU buffer size from CPUOffloadPool.""" | |
| total = 0 | |
| for grp in pool._groups.values(): | |
| cpu_flat = grp["cpu_flat"] | |
| total += cpu_flat.numel() * cpu_flat.element_size() | |
| return total | |
| def _run_muon_steps(mesh, dim0, dim1, num_params, num_steps, cpu_offload): | |
| """Run Muon optimizer steps and return final CPU RSS.""" | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| params, names = [], [] | |
| for i in range(num_params): | |
| full = torch.randn(dim0, dim1, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| params.append(p) | |
| names.append(f"layer.{i}.weight") | |
| param_groups = [{ | |
| "params": params, | |
| "names": names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| if cpu_offload: | |
| optim.turn_on_cpu_offload() | |
| for step_idx in range(num_steps): | |
| for p in params: | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| optim.step() | |
| torch.cuda.synchronize() | |
| gc.collect() | |
| cpu_rss = get_cpu_rss_bytes() | |
| pinned_bytes = 0 | |
| if cpu_offload and optim._cpu_offload_pool is not None: | |
| pool = optim._cpu_offload_pool | |
| pinned_bytes = get_pinned_pool_bytes(pool) | |
| # Cleanup. | |
| del optim, params, param_groups | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| set_ns_compile(True) | |
| return cpu_rss, pinned_bytes | |
| def test_offload_cpu_cost_isolation(rank, world_size): | |
| """A/B test: measure CPU cost of offload by comparing ON vs OFF.""" | |
| mesh = _make_mesh(world_size) | |
| dim0, dim1 = 2048, 4096 | |
| num_params = 8 | |
| num_steps = 3 | |
| if rank == 0: | |
| logger.info("=" * 70) | |
| logger.info("A/B TEST: CPU MEMORY COST OF OFFLOAD (ON vs OFF)") | |
| logger.info("=" * 70) | |
| logger.info("Config: %d params of shape (%d, %d), %d ranks, %d steps", | |
| num_params, dim0, dim1, world_size, num_steps) | |
| logger.info("Local param shape per rank: (%d, %d)", dim0 // world_size, | |
| dim1) | |
| logger.info("-" * 70) | |
| # Run WITHOUT offload first (baseline). | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| cpu_before_no_offload = get_cpu_rss_bytes() | |
| cpu_after_no_offload, _ = _run_muon_steps(mesh, | |
| dim0, | |
| dim1, | |
| num_params, | |
| num_steps, | |
| cpu_offload=False) | |
| cpu_growth_no_offload = cpu_after_no_offload - cpu_before_no_offload | |
| # Run WITH offload. | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| cpu_before_offload = get_cpu_rss_bytes() | |
| cpu_after_offload, pinned_bytes = _run_muon_steps(mesh, | |
| dim0, | |
| dim1, | |
| num_params, | |
| num_steps, | |
| cpu_offload=True) | |
| cpu_growth_offload = cpu_after_offload - cpu_before_offload | |
| # Delta = additional CPU cost from offloading. | |
| offload_delta = cpu_growth_offload - cpu_growth_no_offload | |
| if rank == 0: | |
| logger.info("CPU growth WITHOUT offload: %.2f MB", | |
| cpu_growth_no_offload / 1024**2) | |
| logger.info("CPU growth WITH offload: %.2f MB", | |
| cpu_growth_offload / 1024**2) | |
| logger.info("-" * 70) | |
| logger.info("Pinned buffer size (expected): %.2f MB", | |
| pinned_bytes / 1024**2) | |
| logger.info("Offload delta (WITH - WITHOUT): %.2f MB", | |
| offload_delta / 1024**2) | |
| if pinned_bytes > 0: | |
| ratio = offload_delta / pinned_bytes | |
| logger.info("Ratio (delta / pinned buffer): %.2fx", ratio) | |
| if ratio > 1.5: | |
| logger.warning( | |
| "Offload adds %.2f MB CPU memory but pinned buffer is " | |
| "only %.2f MB (%.1f%% overhead beyond expected)", | |
| offload_delta / 1024**2, | |
| pinned_bytes / 1024**2, | |
| (offload_delta - pinned_bytes) / pinned_bytes * 100, | |
| ) | |
| else: | |
| logger.info("Offload CPU cost is within expected range.") | |
| # Only assert on rank 0 to avoid multi-rank assertion mismatches. | |
| if rank == 0 and pinned_bytes > 0: | |
| ratio = offload_delta / pinned_bytes | |
| assert ratio < 3.0, ( | |
| f"Offload CPU cost ({offload_delta / 1024**2:.2f} MB) is " | |
| f"{ratio:.2f}x the pinned buffer ({pinned_bytes / 1024**2:.2f} MB). " | |
| f"Expected < 3.0x.") | |
| if rank == 0: | |
| logger.info("PASSED: test_offload_cpu_cost_isolation") | |
| def test_cpu_memory_peak_detailed(rank, world_size): | |
| """Detailed per-phase CPU memory tracking for offload.""" | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| mesh = _make_mesh(world_size) | |
| dim0, dim1 = 2048, 4096 | |
| num_params = 8 | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if rank == 0: | |
| logger.info("=" * 70) | |
| logger.info("DETAILED PER-PHASE CPU MEMORY TRACKING") | |
| logger.info("=" * 70) | |
| cpu_0 = get_cpu_rss_bytes() | |
| if rank == 0: | |
| logger.info("[Phase 0] Baseline RSS: %.2f MB", cpu_0 / 1024**2) | |
| # Create params. | |
| params, names = [], [] | |
| for i in range(num_params): | |
| full = torch.randn(dim0, dim1, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| params.append(p) | |
| names.append(f"layer.{i}.weight") | |
| gc.collect() | |
| cpu_1 = get_cpu_rss_bytes() | |
| if rank == 0: | |
| logger.info("[Phase 1] After param creation: %.2f MB (+%.2f MB)", | |
| cpu_1 / 1024**2, (cpu_1 - cpu_0) / 1024**2) | |
| # Create optimizer. | |
| param_groups = [{ | |
| "params": params, | |
| "names": names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| }] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| optim.turn_on_cpu_offload() | |
| gc.collect() | |
| cpu_2 = get_cpu_rss_bytes() | |
| if rank == 0: | |
| logger.info("[Phase 2] After optimizer creation: %.2f MB (+%.2f MB)", | |
| cpu_2 / 1024**2, (cpu_2 - cpu_1) / 1024**2) | |
| # Set grads. | |
| for p in params: | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| gc.collect() | |
| cpu_3 = get_cpu_rss_bytes() | |
| if rank == 0: | |
| logger.info("[Phase 3] After grad creation: %.2f MB (+%.2f MB)", | |
| cpu_3 / 1024**2, (cpu_3 - cpu_2) / 1024**2) | |
| # Step 1 (creates states + first offload). | |
| optim.step() | |
| torch.cuda.synchronize() | |
| gc.collect() | |
| cpu_4 = get_cpu_rss_bytes() | |
| pool = optim._cpu_offload_pool | |
| pinned_bytes = get_pinned_pool_bytes(pool) | |
| if rank == 0: | |
| logger.info( | |
| "[Phase 4] After step 1 (init+offload): %.2f MB (+%.2f MB)", | |
| cpu_4 / 1024**2, (cpu_4 - cpu_3) / 1024**2) | |
| logger.info(" Pinned buffer size: %.2f MB", pinned_bytes / 1024**2) | |
| logger.info(" Step 1 growth vs pinned: %.2f MB extra", | |
| (cpu_4 - cpu_3 - pinned_bytes) / 1024**2) | |
| # Step 2 (reload + compute + offload). | |
| for p in params: | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| optim.step() | |
| torch.cuda.synchronize() | |
| gc.collect() | |
| cpu_5 = get_cpu_rss_bytes() | |
| if rank == 0: | |
| logger.info("[Phase 5] After step 2: %.2f MB (+%.2f MB)", | |
| cpu_5 / 1024**2, (cpu_5 - cpu_4) / 1024**2) | |
| # Step 3. | |
| for p in params: | |
| p.grad = distribute_tensor(torch.randn(dim0, dim1, device="cuda"), | |
| mesh, [Shard(0)]) | |
| optim.step() | |
| torch.cuda.synchronize() | |
| gc.collect() | |
| cpu_6 = get_cpu_rss_bytes() | |
| if rank == 0: | |
| logger.info("[Phase 6] After step 3: %.2f MB (+%.2f MB)", | |
| cpu_6 / 1024**2, (cpu_6 - cpu_5) / 1024**2) | |
| # Summary. | |
| total_growth = cpu_6 - cpu_0 | |
| if rank == 0: | |
| logger.info("-" * 70) | |
| logger.info("SUMMARY:") | |
| logger.info(" Total CPU growth: %.2f MB", total_growth / 1024**2) | |
| logger.info(" Pinned buffer: %.2f MB", pinned_bytes / 1024**2) | |
| logger.info(" Overhead: %.2f MB", | |
| (total_growth - pinned_bytes) / 1024**2) | |
| if pinned_bytes > 0: | |
| logger.info(" Ratio: %.2fx", | |
| total_growth / pinned_bytes) | |
| logger.info("") | |
| logger.info(" NOTE: Overhead includes CUDA runtime, NCCL buffers,") | |
| logger.info(" DTensor metadata, and optimizer internals — NOT just") | |
| logger.info(" offload cost. Use A/B test for isolated measurement.") | |
| set_ns_compile(True) | |
| if rank == 0: | |
| logger.info("PASSED: test_cpu_memory_peak_detailed") | |
| def test_offload_cpu_cost_mixed(rank, world_size): | |
| """A/B test for mixed Muon + AdamW offload CPU cost.""" | |
| from optimizer.muon import Muon | |
| from optimizer.newton_schulz import set_ns_compile | |
| mesh = _make_mesh(world_size) | |
| muon_dim0, muon_dim1 = 2048, 4096 | |
| num_muon = 8 | |
| adamw_dim = 4096 | |
| num_adamw = 8 | |
| num_steps = 3 | |
| def run_mixed(cpu_offload): | |
| set_ns_compile(False) | |
| torch.manual_seed(42) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| muon_params, muon_names = [], [] | |
| for i in range(num_muon): | |
| full = torch.randn(muon_dim0, muon_dim1, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| muon_params.append(p) | |
| muon_names.append(f"layer.{i}.weight") | |
| adamw_params = [] | |
| for i in range(num_adamw): | |
| full = torch.randn(adamw_dim, device="cuda") | |
| dt = distribute_tensor(full, mesh, [Shard(0)]) | |
| p = torch.nn.Parameter(dt) | |
| adamw_params.append(p) | |
| param_groups = [ | |
| { | |
| "params": muon_params, | |
| "names": muon_names, | |
| "use_muon": True, | |
| "lr": 0.02, | |
| "weight_decay": 0.01, | |
| "momentum": 0.95, | |
| "nesterov": True, | |
| "ns_steps": 5, | |
| "none_grad": False, | |
| "adamw_betas": (0.9, 0.95), | |
| "adamw_eps": 1e-8, | |
| }, | |
| { | |
| "params": adamw_params, | |
| "use_muon": False, | |
| "lr": 1e-3, | |
| "weight_decay": 0.01, | |
| "adamw_betas": (0.9, 0.95), | |
| "adamw_eps": 1e-8, | |
| }, | |
| ] | |
| optim = Muon(params=param_groups, chunk_size=2, warmup_step=1) | |
| if cpu_offload: | |
| optim.turn_on_cpu_offload() | |
| for step_idx in range(num_steps): | |
| for p in muon_params: | |
| p.grad = distribute_tensor( | |
| torch.randn(muon_dim0, muon_dim1, device="cuda"), mesh, | |
| [Shard(0)]) | |
| for p in adamw_params: | |
| p.grad = distribute_tensor( | |
| torch.randn(adamw_dim, device="cuda"), mesh, [Shard(0)]) | |
| optim.step() | |
| torch.cuda.synchronize() | |
| gc.collect() | |
| cpu_rss = get_cpu_rss_bytes() | |
| pinned_bytes = 0 | |
| if cpu_offload and optim._cpu_offload_pool is not None: | |
| pinned_bytes = get_pinned_pool_bytes(optim._cpu_offload_pool) | |
| del optim, muon_params, adamw_params, param_groups | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| set_ns_compile(True) | |
| return cpu_rss, pinned_bytes | |
| if rank == 0: | |
| logger.info("=" * 70) | |
| logger.info("A/B TEST: CPU COST OF MIXED OFFLOAD (Muon + AdamW)") | |
| logger.info("=" * 70) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| cpu_before_no = get_cpu_rss_bytes() | |
| cpu_after_no, _ = run_mixed(False) | |
| growth_no = cpu_after_no - cpu_before_no | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| cpu_before_yes = get_cpu_rss_bytes() | |
| cpu_after_yes, pinned_bytes = run_mixed(True) | |
| growth_yes = cpu_after_yes - cpu_before_yes | |
| delta = growth_yes - growth_no | |
| if rank == 0: | |
| logger.info("CPU growth WITHOUT offload: %.2f MB", growth_no / 1024**2) | |
| logger.info("CPU growth WITH offload: %.2f MB", | |
| growth_yes / 1024**2) | |
| logger.info("Pinned buffer size: %.2f MB", | |
| pinned_bytes / 1024**2) | |
| logger.info("Offload delta: %.2f MB", delta / 1024**2) | |
| if pinned_bytes > 0: | |
| logger.info("Ratio (delta / pinned): %.2fx", | |
| delta / pinned_bytes) | |
| if rank == 0 and pinned_bytes > 0: | |
| ratio = delta / pinned_bytes | |
| assert ratio < 3.0, ( | |
| f"Mixed offload CPU cost ({delta / 1024**2:.2f} MB) is " | |
| f"{ratio:.2f}x the pinned buffer ({pinned_bytes / 1024**2:.2f} MB)." | |
| ) | |
| if rank == 0: | |
| logger.info("PASSED: test_offload_cpu_cost_mixed") | |
| def test_pinned_memory_rss_overhead(rank, world_size): | |
| """Isolate: does cudaHostAlloc itself cause 2x RSS overhead?""" | |
| sizes_mb = [8, 16, 32, 64, 128] | |
| if rank == 0: | |
| logger.info("=" * 70) | |
| logger.info("ISOLATED TEST: PINNED MEMORY RSS OVERHEAD") | |
| logger.info("=" * 70) | |
| for size_mb in sizes_mb: | |
| numel = size_mb * 1024 * 1024 // 4 # float32 | |
| # Test 1: pin_memory=True (direct allocation). | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| rss_before = get_cpu_rss_bytes() | |
| t1 = torch.empty(numel, | |
| dtype=torch.float32, | |
| device="cpu", | |
| pin_memory=True) | |
| rss_after = get_cpu_rss_bytes() | |
| rss_growth_direct = rss_after - rss_before | |
| del t1 | |
| gc.collect() | |
| # Test 2: .pin_memory() (copy-based). | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| rss_before2 = get_cpu_rss_bytes() | |
| t2 = torch.empty(numel, dtype=torch.float32, device="cpu").pin_memory() | |
| rss_after2 = get_cpu_rss_bytes() | |
| rss_growth_copy = rss_after2 - rss_before2 | |
| del t2 | |
| gc.collect() | |
| # Test 3: regular (non-pinned) CPU allocation. | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| rss_before3 = get_cpu_rss_bytes() | |
| t3 = torch.empty(numel, dtype=torch.float32, device="cpu") | |
| # Touch all pages to ensure RSS reflects actual allocation. | |
| t3.fill_(1.0) | |
| rss_after3 = get_cpu_rss_bytes() | |
| rss_growth_regular = rss_after3 - rss_before3 | |
| del t3 | |
| gc.collect() | |
| if rank == 0: | |
| logger.info( | |
| "%3d MB: pin_memory=True → RSS +%.1f MB (%.2fx) | " | |
| ".pin_memory() → RSS +%.1f MB (%.2fx) | " | |
| "regular → RSS +%.1f MB (%.2fx)", | |
| size_mb, | |
| rss_growth_direct / 1024**2, | |
| rss_growth_direct / (size_mb * 1024**2) if size_mb > 0 else 0, | |
| rss_growth_copy / 1024**2, | |
| rss_growth_copy / (size_mb * 1024**2) if size_mb > 0 else 0, | |
| rss_growth_regular / 1024**2, | |
| rss_growth_regular / (size_mb * 1024**2) if size_mb > 0 else 0, | |
| ) | |
| if rank == 0: | |
| logger.info("PASSED: test_pinned_memory_rss_overhead") | |
| def main(): | |
| rank, world_size = _setup() | |
| try: | |
| test_pinned_memory_rss_overhead(rank, world_size) | |
| test_cpu_memory_peak_detailed(rank, world_size) | |
| test_offload_cpu_cost_isolation(rank, world_size) | |
| test_offload_cpu_cost_mixed(rank, world_size) | |
| if rank == 0: | |
| logger.info("=" * 50) | |
| logger.info("ALL CPU MEMORY PEAK TESTS PASSED") | |
| logger.info("=" * 50) | |
| finally: | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() | |