| # Muon Optimizer: Implementation Guide |
|
|
| This document explains the internal architecture of the Muon optimizer for reviewers and new contributors. It covers the execution paths, the parallel pipeline design, and the distributed sharding utilities. |
|
|
| ## Table of Contents |
|
|
| 1. [Overview](#overview) |
| 2. [Entry Point and Parameter Routing](#entry-point-and-parameter-routing) |
| 3. [Execution Paths](#execution-paths) |
| 4. [Parallel Pipeline (the core feature)](#parallel-pipeline) |
| 5. [MoE Expert Weight Support](#moe-expert-weight-support-expert_keys) |
| 6. [Distributed Utilities](#distributed-utilities) |
| 7. [Newton-Schulz Orthogonalization](#newton-schulz-orthogonalization) |
| 8. [QK Clipping](#qk-clipping) |
| 9. [AdamW for Non-Muon Parameters](#adamw-for-non-muon-parameters) |
| 10. [Source File Map](#source-file-map) |
|
|
| --- |
|
|
| ## Overview |
|
|
| Muon (MomentUm Orthogonalized by Newton-schulz) applies standard SGD-momentum and then replaces each 2D parameter's update with the nearest orthogonal matrix via a Newton-Schulz iteration. The iteration runs stably in bfloat16 on GPU. |
|
|
| The optimizer supports arbitrary N-D sharding configurations: FSDP2, TP, or hybrid setups like `2 TP x 2 DP-Replicate x 2 DP-Shard`. This generality is what drives most of the code complexity. |
|
|
| ## Entry Point and Parameter Routing |
|
|
| **File:** `muon.py` — `Muon.step()` / `Muon._step_muon()` |
|
|
| Users must provide parameter groups with `use_muon=True/False` flags (via `get_default_muon_param_groups()`). At each step: |
|
|
| 1. **Non-Muon groups** → `step_adamw()` (fused AdamW). |
| 2. **Muon groups** → `_step_muon()`, which further classifies each parameter: |
|
|
| ``` |
| _step_muon(group) |
| | |
| +-- momentum update (batched _foreach_* ops) |
| +-- _expand_expert_params() -- 3D expert params → per-expert 2D views (cached) |
| | |
| +-- DTensor, all Replicate placements --> base() (no sharding) |
| +-- DTensor, sharded --> parallel() (pipelined all-to-all) |
| +-- plain Tensor --> base() (single device) |
| ``` |
|
|
| Parameters are classified by their DTensor placements: |
| - **Fully replicated** DTensors and plain tensors use `base()` — standard single-device Muon. |
| - **Sharded** DTensors use `parallel()` — the pipelined all-to-all approach described below. |
| - `distributed_muon()` exists as a **test-only reference implementation** for correctness verification. |
|
|
| ## Execution Paths |
|
|
| ### base() — Single Device |
|
|
| Straightforward per-parameter loop: momentum update → Newton-Schulz orthogonalization → parameter update → optional QK clipping. |
|
|
| ### distributed_muon() — Full Gather (test-only) |
| |
| Reference implementation for correctness verification. Uses batched all-gather to reconstruct full tensors, computes Newton-Schulz on the full grad, then slices back to local shards. Simple but communication-heavy — not used in production. |
| |
| ### parallel() — Pipelined All-to-All |
| |
| This is the main advanced feature. Instead of all-gathering the full parameter, it uses **all-to-all** to distribute work: each rank "owns" a subset of parameters and is responsible for their Newton-Schulz computation. |
| |
| ## Parallel Pipeline |
| |
| ### Design Motivation |
| |
| Newton-Schulz is compute-intensive. The key insight is that each rank only needs to orthogonalize the parameters it "owns" — not all parameters. So the flow is: |
| |
| 1. **Gather**: Each rank sends its local gradient shard to the owning rank via all-to-all. |
| 2. **Compute**: The owning rank runs Newton-Schulz on the full (gathered) gradient. |
| 3. **Scatter**: The owning rank sends the orthogonalized update back to all ranks via all-to-all. |
| 4. **Update**: Each rank applies weight decay and the update to its local shard. |
| |
| To overlap communication and computation, parameters are split into **chunks**, and multiple chunks are processed concurrently. |
| |
| ### Architecture |
| |
| ``` |
| muon.py: parallel() |
| | |
| +-- init_state_and_assign_params() -- assigns ownership, precomputes indices |
| | |
| +-- pipelines() generator -- yields muon_chunk_pipeline() per chunk |
| | |
| +-- run_pipeline(pipelines, max_concurrent=warmup_step+1) |
| | |
| +-- interleaves chunks at yield boundaries |
| ``` |
| |
| ### The Chunk Pipeline Generator |
|
|
| **File:** `pipeline.py` — `muon_chunk_pipeline()` |
|
|
| Each chunk is a generator that yields **2 times**, creating stages separated by async communication: |
|
|
| ``` |
| YIELD 1 YIELD 2 |
| | | |
| [Build bufs + async gather a2a] --> [wait + NS compute + async scatter a2a] --> [wait + Update params] |
| ``` |
|
|
| - **Async communication**: `dist.all_to_all_single(..., async_op=True)` launches non-blocking communication. The generator yields immediately after, allowing other chunks to run. `work.wait()` completes the operation after the yield. |
| - **Chunk-level overlap**: `run_pipeline()` interleaves multiple chunks at yield boundaries, so while chunk N waits for its communication, chunk N+1 can launch its own. |
|
|
| ### The Pipeline Scheduler |
|
|
| **File:** `async_utils.py` — `run_pipeline()` |
|
|
| A simple round-robin scheduler: |
|
|
| ```python |
| while have_new or previous_tasks: |
| # Admit one new pipeline if below concurrency limit |
| if have_new and len(previous_tasks) < max_concurrent: |
| task = next(pipelines) # runs to first yield |
| # Advance all existing tasks by one yield |
| for task in previous_tasks: |
| task.step() # runs to next yield |
| ``` |
|
|
| `max_concurrent = warmup_step + 1` controls how many chunks can be in-flight simultaneously. Higher values increase memory usage but improve communication/computation overlap. |
|
|
| ### Ownership Assignment |
|
|
| **File:** `muon.py` — `init_state_and_assign_params()` |
|
|
| Parameters are sorted by FLOP cost (descending) and assigned to ranks in round-robin order across the shard mesh. This balances compute load across ranks. |
|
|
| ### Precomputed Shard Indices |
|
|
| Instead of computing per-rank shard indices on every step, they are precomputed once during `init_state_and_assign_params()` and stored in `_muon_state`: |
|
|
| ```python |
| @dataclass |
| class _muon_state: |
| worker_rank: int # which rank owns this param's computation |
| process_group: ProcessGroup # the all-to-all communication group |
| rank_indices: dict[int, tuple] # rank -> per-dim indices into full tensor |
| rank_numels: dict[int, int] # rank -> number of elements in shard |
| name: str |
| qk_clip_state: QKClipInfo | None |
| ``` |
|
|
| `rank_indices[r]` is a tuple of `slice` or `torch.Tensor` per dimension, describing which elements of the full tensor rank `r` owns. `rank_numels[r]` is the total number of elements in that shard. These are used directly in the pipeline's gather and scatter stages. |
|
|
| ### Pipeline Stages in Detail |
|
|
| #### Stages 1-2: Gather |
|
|
| 1. **Allocate** receive buffers for gathered gradients (only on owning ranks). |
| 2. **Build send buffer**: Each rank flattens its local gradient shard for each destination rank. |
| 3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches gather. |
| 4. **Yield 1**: Other chunks can launch their gather while this one waits. |
| 5. **`work.wait()`**: Complete the gather. |
| 6. **Reconstruct**: The owning rank places received shards into the full gradient using `rank_indices`. |
|
|
| #### Stage 3: Compute |
|
|
| The owning rank runs `_zeropower_via_newtonschulz5()` on the full gathered gradient. This is the most compute-intensive stage. Runs inline (no yield) since it is synchronous GPU work. |
|
|
| #### Stages 4-5: Scatter |
|
|
| Inverse of gather: |
| 1. **Allocate** receive buffers for the orthogonalized update `U`. |
| 2. **Build send buffer**: The owning rank slices `U` using `rank_indices` for each destination rank. |
| 3. **Async all-to-all**: `dist.all_to_all_single(..., async_op=True)` launches scatter. |
| 4. **Yield 2**: Other chunks can launch their scatter while this one waits. |
| 5. **`work.wait()`**: Complete the scatter. |
| 6. **Copy** received shards into local update buffers. |
|
|
| #### Stage 6: Update |
|
|
| Each rank applies weight decay and the Muon update to its local parameter shard. Also applies QK clipping if configured. |
|
|
| ## MoE Expert Weight Support (`expert_keys`) |
| |
| **File:** `muon.py` — `_expand_expert_params()` |
|
|
| MoE models have 3D expert weights with shape `(num_experts, out_dim, in_dim)`. Since Muon operates on 2D matrices, expert params need special handling. |
|
|
| ### Configuration |
|
|
| Pass `expert_keys` to both `get_default_muon_param_groups()` and `Muon()`: |
|
|
| ```python |
| params = get_default_muon_param_groups(model, expert_keys=["experts"]) |
| optim = Muon(params, expert_keys=["experts"], ...) |
| ``` |
|
|
| Any parameter whose name contains a string in `expert_keys` is treated as an expert-parallel parameter. Non-matching 3D+ parameters raise `AssertionError` to catch misconfiguration. |
|
|
| ### How It Works |
|
|
| `_expand_expert_params()` runs after momentum and before routing to `base()`/`parallel()`/`distributed_muon()`: |
|
|
| 1. **Split on dim 0**: A 3D `(E, out, in)` tensor becomes `E` separate 2D `(out, in)` `nn.Parameter` views. Views share storage with the original, so in-place updates propagate back. |
| 2. **Placement remapping**: When the original is a DTensor, `Shard(k)` on dim `k > 0` becomes `Shard(k-1)` on the 2D slice (since dim 0 is consumed by the split). |
| 3. **Submesh wrapping**: Non-dim-0 shard placements are preserved by wrapping each 2D slice as a DTensor on the corresponding submesh. This is **placement-agnostic** — the same logic handles TP `Shard(1/2)`, EFSDP `Shard(1)`, or any other non-dim-0 sharding. |
|
|
| ### Placement-Agnostic Design |
|
|
| The expansion logic does not care *why* a dimension is sharded — only whether it's on dim 0 (consumed by split) or not (preserved on submesh): |
|
|
| | Original Placement | After Expansion | |
| |-------------------|-----------------| |
| | `Shard(0)` (EP) | Consumed by split → plain tensor | |
| | `Shard(1)` (TP or EFSDP) | `Shard(0)` on submesh → 2D DTensor | |
| | `Shard(2)` (TP row-wise) | `Shard(1)` on submesh → 2D DTensor | |
| | `Replicate` | Ignored (not a shard) | |
| | `_StridedShard(0)` (EFSDP) | Consumed by split → plain tensor | |
|
|
| After expansion, the 2D params flow through the standard routing: DTensors with shard placements go to `parallel()`, plain tensors go to `base()`. |
|
|
| For EP/EFSDP background and torchtitan integration details, see [`docs/expert_parallel.md`](expert_parallel.md). |
|
|
| ## Distributed Utilities |
|
|
| **File:** `distributed/utils.py` |
|
|
| These utilities solve the problem of mapping from a DTensor's arbitrary sharding configuration to the concrete indices each rank owns. |
|
|
| ### `construct_shard_mesh(placements, mesh)` |
|
|
| Given a DTensor's placements and device mesh, this function: |
|
|
| 1. **Sorts** placements: Replicate dims first, then Shard dims by dimension (with `_StridedShard` before regular `Shard` on the same dim, so the outer sharding is applied first). |
| 2. **Permutes** the mesh accordingly. |
| 3. **Separates** replicate dims from shard dims — each replicate group gets its own shard sub-mesh. |
| 4. **Creates** a ProcessGroup for the current rank's shard mesh. |
|
|
| Returns `(shard_mesh, process_group, shard_placements)` — used for all-to-all communication. |
|
|
| **Why this is needed:** A model might use `[Replicate, Shard(0), _StridedShard(0)]` across a 3D mesh. The optimizer needs to identify which ranks participate in the same shard group (share the same data) and create a ProcessGroup for them. |
|
|
| ### `get_slices_of_dtensor(target, local_rank, shard_mesh, shard_placements)` |
|
|
| Computes the exact indices that a given rank owns in the full tensor. Handles both contiguous (`Shard`) and strided (`_StridedShard`) sharding, including composed multi-level sharding on the same dimension. |
|
|
| Returns a tuple of `slice` (contiguous) or `torch.LongTensor` (strided) per dimension. |
|
|
| **Example:** With `[Shard(0), _StridedShard(0)]` on a (16, 2048) tensor across 4 ranks: |
| - Rank 0 might own rows `[0, 4, 8, 12]` (strided) |
| - Rank 1 might own rows `[1, 5, 9, 13]` |
| - etc. |
|
|
| ### PyTorch 2.10 Compatibility |
|
|
| In PyTorch 2.10, `_StridedShard` no longer inherits from `Shard`. The helper `_is_shard()` handles both old and new hierarchies: |
|
|
| ```python |
| def _is_shard(placement): |
| return isinstance(placement, (Shard, _StridedShard)) |
| ``` |
|
|
| ## Newton-Schulz Orthogonalization |
|
|
| **File:** `newton_schulz.py` |
|
|
| `_zeropower_via_newtonschulz5()` computes the polar factor of a matrix using the Polar Express method — quintic Newton-Schulz iterations with analytically optimal (minimax/Remez) coefficients precomputed by `_optimal_composition()`. The default configuration uses 10 iterations with `l=1e-3`, converging all singular values to 1 to produce the exact polar factor `UV^T`. Wrapped by `zeropower_via_newtonschulz5()` which adds per-shape `torch.compile` caching with CUDA graph support. |
|
|
| Each iteration uses `matmul_transpose_assign()` (a Triton kernel for `X @ X^T`) for efficiency. |
|
|
| **File:** `matmul_transpose_triton.py` |
|
|
| The `matmul_transpose_assign(d_in, d_out)` kernel computes `d_out = d_in @ d_in^T` in-place. It exploits symmetry by computing only upper-triangle blocks and mirroring. |
|
|
| ## QK Clipping |
|
|
| **File:** `qk_clip.py` |
|
|
| Optional dynamic clipping for attention head projections (Q and K weight matrices). When the maximum QK logit for a head exceeds a threshold, the corresponding rows of the weight matrix are scaled down by `sqrt(threshold / logit)`. |
|
|
| **In the parallel pipeline:** QK clipping is applied per-row using each row's global head index. This correctly handles strided sharding where local rows may be interleaved across multiple heads: |
|
|
| ```python |
| # pipeline.py: _update_params() |
| ratio = p.shape[0] // scales_full.shape[0] # rows per head |
| idx0 = state.rank_indices[rank][0] # which global rows this rank owns |
| row_scales = scales_full[idx0 // ratio] # map each row to its head's scale |
| p._local_tensor.mul_(row_scales.view(-1, 1)) |
| ``` |
|
|
| ## AdamW for Non-Muon Parameters |
|
|
| **File:** `adamw.py` |
|
|
| Parameters not eligible for Muon (1D parameters, embeddings, LM head) are optimized with fused AdamW via `torch._fused_adamw_`. Parameters are grouped by device/dtype and DTensor placement before the fused call. |
|
|
| ## Source File Map |
|
|
| | File | Lines | Purpose | |
| |------|-------|---------| |
| | `muon.py` | ~815 | Optimizer class, parameter routing, 3 execution paths, MoE expert expansion + caching | |
| | `pipeline.py` | ~400 | Generator-based parallel pipeline (gather/compute/scatter/update) | |
| | `async_utils.py` | ~75 | Pipeline scheduler with bounded concurrency | |
| | `core.py` | ~175 | `_muon_state` dataclass, batched momentum/update helpers, param grouping | |
| | `distributed/utils.py` | ~230 | Shard mesh construction, DTensor index computation | |
| | `newton_schulz.py` | ~190 | Polar Express coefficients, Newton-Schulz iteration + compile/CUDA graph | |
| | `matmul_transpose_triton.py` | ~130 | Triton kernel for symmetric matmul | |
| | `qk_clip.py` | ~135 | QK logit clipping | |
| | `adamw.py` | ~170 | Fused AdamW for non-Muon params | |
|
|
| ### Dependency Graph |
|
|
| ``` |
| matmul_transpose_triton.py (leaf) |
| | |
| newton_schulz.py (leaf + triton) |
| | |
| core.py ---- qk_clip.py (leaf, distributed/utils) |
| | | | |
| | pipeline.py --- async_utils.py |
| | | |
| | adamw.py |
| | | |
| muon.py (all above) |
| | |
| __init__.py |
| ``` |
|
|