Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/karpathy/nanochat/llms.txt

Use this file to discover all available pages before exploring further.

Overview

The nanochat.optim module provides combined optimizers that use Muon for 2D matrix parameters and AdamW for others. Two versions are available:
  • MuonAdamW - Single GPU version
  • DistMuonAdamW - Distributed multi-GPU version with ZeRO-2 style sharding

MuonAdamW

Combined optimizer for single GPU training.
class MuonAdamW(param_groups: list[dict])

Parameters

param_groups
list[dict]
required
List of parameter groups. Each group is a dict containing:Common fields:
  • params (list): List of parameters
  • kind (str): Either 'adamw' or 'muon'
For AdamW groups:
  • lr (float): Learning rate
  • betas (tuple): Coefficients for computing running averages
  • eps (float): Term added to denominator for numerical stability
  • weight_decay (float): Weight decay coefficient
For Muon groups:
  • lr (float): Learning rate
  • momentum (float): Momentum coefficient
  • ns_steps (int): Number of Newton-Schulz/Polar Express iterations
  • beta2 (float): Beta2 for second moment
  • weight_decay (float): Weight decay coefficient

Methods

step

@torch.no_grad()
def step()
Performs a single optimization step.

Notes

  • AdamW: Uses fused AdamW optimizer step for non-matrix parameters (embeddings, scalars, biases)
  • Muon: MomentUm Orthogonalized by Newton-schulz for 2D matrix parameters
  • The Muon optimizer should not be used for:
    • Embedding layers
    • Final fully connected layer
    • Any 0-D or 1-D parameters
  • For 4D convolutional filters, flatten the last 3 dimensions before using Muon

Algorithm Details

Muon Step:
  1. Nesterov momentum
  2. Polar Express orthogonalization (5 iterations)
  3. Variance reduction (NorMuon)
  4. Cautious weight decay + parameter update
AdamW Step:
  1. Weight decay (decoupled)
  2. Momentum update
  3. Bias correction
  4. Parameter update

DistMuonAdamW

Combined distributed optimizer for multi-GPU training.
class DistMuonAdamW(param_groups: list[dict])

Parameters

param_groups
list[dict]
required
List of parameter groups. Same format as MuonAdamW.Additional requirement for Muon groups:
  • All params in a Muon group must have the same shape

Methods

step

@torch.no_grad()
def step()
Performs a single distributed optimization step with 3-phase async communication.

Design Goals

  • Overlap communication with computation (async ops)
  • Minimize memory by sharding optimizer states across ranks (ZeRO-2 style)
  • Batch small tensors into single comm ops where possible

Communication Pattern

Phase 1: Launch all async reduce ops
  • Kick off all reduce_scatter/all_reduce operations
  • Don’t wait - let them run in background
Phase 2: Wait for reduces, compute updates, launch gathers
  • For each group: wait for its reduce, compute the update, launch gather
  • Earlier gathers run while later computes happen
Phase 3: Wait for gathers, copy back
  • Wait for all gathers to complete
  • Copy updated params back to original tensors (Muon only)

AdamW Communication (ZeRO-2)

  • Small params (<1024 elements): all_reduce gradients, update full param on each rank
  • Large params: reduce_scatter gradients so each rank gets 1/N of the grad, update only that slice, then all_gather the updated slices
  • Optimizer state is sharded for large params
  • Requires param.shape[0] divisible by world_size

Muon Communication

  • Stack all K params into a single (K, *shape) tensor
  • Divide K params across N ranks: each rank owns ceil(K/N) params
  • reduce_scatter the stacked grads so each rank gets its chunk
  • Each rank computes Muon update only for params it owns
  • all_gather the updated params back to all ranks
  • Optimizer state is sharded by chunk
  • Zero-padding if K doesn’t divide evenly

Example

import torch
from nanochat.optim import DistMuonAdamW

# Separate parameters into AdamW and Muon groups
adamw_params = []  # embeddings, scalars, 1D params
muon_params = []   # 2D matrix params (same shape)

optimizer = DistMuonAdamW([
    {
        'params': adamw_params,
        'kind': 'adamw',
        'lr': 3e-4,
        'betas': (0.9, 0.999),
        'eps': 1e-8,
        'weight_decay': 0.01
    },
    {
        'params': muon_params,
        'kind': 'muon',
        'lr': 0.02,
        'momentum': 0.95,
        'ns_steps': 5,
        'beta2': 0.95,
        'weight_decay': 0.01
    }
])

# Training loop
for batch in dataloader:
    loss = model(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Build docs developers (and LLMs) love