Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/ml-explore/mlx/llms.txt

Use this file to discover all available pages before exploring further.

The mlx.optimizers module provides optimizers for training neural networks. All optimizers work with both mlx.nn modules and pure mlx.core functions.

Quick Start

Here’s a typical training loop with an optimizer:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# Create model and optimizer
model = MLP(num_layers, input_dims, hidden_dim, output_dims)
mx.eval(model.parameters())

# Create gradient function and optimizer
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
optimizer = optim.SGD(learning_rate=0.01)

for epoch in range(num_epochs):
    for X, y in batch_iterate(batch_size, train_images, train_labels):
        loss, grads = loss_and_grad_fn(model, X, y)
        
        # Update model with gradients
        optimizer.update(model, grads)
        
        # Evaluate parameters and optimizer state
        mx.eval(model.parameters(), optimizer.state)

Base Optimizer

optim.Optimizer
class
Base class for all optimizers.Allows implementing optimizers on a per-parameter basis and applying them to parameter trees.Key Methods:
  • update(model, gradients): Apply gradients to model parameters
  • init(parameters): Initialize optimizer state
  • apply_gradients(gradients, parameters): Apply gradients and return updated parameters

Optimizer Methods

update(model, gradients)
method
Apply the gradients to the parameters of the model and update the model.Parameters:
  • model (nn.Module): An MLX module to be updated
  • gradients (dict): Python tree of gradients, typically from nn.value_and_grad
loss, grads = loss_and_grad_fn(model, x, y)
optimizer.update(model, grads)
init(parameters)
method
Initialize the optimizer’s state.Optional - the optimizer will initialize itself on first update if not called explicitly.Parameters:
  • parameters (dict): Python tree of parameters
optimizer = optim.SGD(learning_rate=0.1, momentum=0.9)
model = nn.Linear(2, 2)
optimizer.init(model.trainable_parameters())
print(optimizer.state.keys())
# dict_keys(['step', 'learning_rate', 'weight', 'bias'])
state
property
The optimizer’s state dictionary.Contains step count, learning rate, and optimizer-specific state (e.g., momentum).
print(optimizer.state)
print(f"Step: {optimizer.step}")
print(f"Learning rate: {optimizer.learning_rate}")

Common Optimizers

optim.SGD
class
Stochastic Gradient Descent optimizer.Updates: v_t+1 = μv_t + (1 - τ)g_t and w_t+1 = w_t - λv_t+1Parameters:
  • learning_rate (float or callable): The learning rate λ
  • momentum (float): The momentum strength μ. Default: 0
  • weight_decay (float): The weight decay (L2 penalty). Default: 0
  • dampening (float): Dampening for momentum τ. Default: 0
  • nesterov (bool): Enables Nesterov momentum. Default: False
# Basic SGD
optimizer = optim.SGD(learning_rate=0.01)

# SGD with momentum
optimizer = optim.SGD(learning_rate=0.01, momentum=0.9)

# SGD with Nesterov momentum
optimizer = optim.SGD(learning_rate=0.01, momentum=0.9, nesterov=True)
optim.Adam
class
Adam optimizer.Parameters:
  • learning_rate (float or callable): The learning rate λ
  • betas (Tuple[float, float]): Coefficients (β₁, β₂) for running averages. Default: (0.9, 0.999)
  • eps (float): Term ε added to denominator for numerical stability. Default: 1e-8
  • bias_correction (bool): If True, apply bias correction. Default: False
optimizer = optim.Adam(learning_rate=1e-3)

# With custom betas
optimizer = optim.Adam(learning_rate=1e-3, betas=(0.9, 0.999))
optim.AdamW
class
AdamW optimizer with decoupled weight decay.Parameters:
  • learning_rate (float or callable): The learning rate α
  • betas (Tuple[float, float]): Coefficients (β₁, β₂) for running averages. Default: (0.9, 0.999)
  • eps (float): Term ε added to denominator for numerical stability. Default: 1e-8
  • weight_decay (float): The weight decay λ. Default: 0.01
  • bias_correction (bool): If True, apply bias correction. Default: False
optimizer = optim.AdamW(learning_rate=1e-3, weight_decay=0.01)
optim.Adamax
class
Adamax optimizer, a variant of Adam based on the infinity norm.Parameters:
  • learning_rate (float or callable): The learning rate λ
  • betas (Tuple[float, float]): Coefficients (β₁, β₂). Default: (0.9, 0.999)
  • eps (float): Term ε added to denominator. Default: 1e-8
optimizer = optim.Adamax(learning_rate=1e-3)
optim.Lion
class
Lion optimizer.Recommended to use a learning rate 3-10x smaller than AdamW and weight decay 3-10x larger.Parameters:
  • learning_rate (float or callable): The learning rate η
  • betas (Tuple[float, float]): Coefficients (β₁, β₂). Default: (0.9, 0.99)
  • weight_decay (float): The weight decay λ. Default: 0.0
# Lion typically needs smaller learning rate than AdamW
optimizer = optim.Lion(learning_rate=1e-4, weight_decay=0.1)
optim.Adagrad
class
Adagrad optimizer.Parameters:
  • learning_rate (float or callable): The learning rate λ
  • eps (float): Term ε added to denominator for numerical stability. Default: 1e-8
optimizer = optim.Adagrad(learning_rate=0.01)
optim.AdaDelta
class
AdaDelta optimizer with a learning rate.Parameters:
  • learning_rate (float or callable): The learning rate λ
  • rho (float): Coefficient ρ for computing running average of squared gradients. Default: 0.9
  • eps (float): Term ε added to denominator for numerical stability. Default: 1e-6
optimizer = optim.AdaDelta(learning_rate=1.0, rho=0.9)
optim.RMSprop
class
RMSprop optimizer.Parameters:
  • learning_rate (float or callable): The learning rate λ
  • alpha (float): The smoothing constant α. Default: 0.99
  • eps (float): Term ε added to denominator for numerical stability. Default: 1e-8
optimizer = optim.RMSprop(learning_rate=0.01, alpha=0.99)
optim.Adafactor
class
Adafactor optimizer with adaptive learning rates and sublinear memory cost.Parameters:
  • learning_rate (float or callable): The learning rate. Default: None
  • eps (tuple): (ε₁, ε₂) for numerical stability and parameter scaling. Default: (1e-30, 1e-3)
  • clip_threshold (float): Clips unscaled update at this threshold. Default: 1.0
  • decay_rate (float): Coefficient for running average of squared gradient. Default: -0.8
  • beta_1 (float): If set, use first moment. Default: None
  • weight_decay (float): The weight decay λ. Default: 0.0
  • scale_parameter (bool): If True, scale learning rate by RMS of parameters. Default: True
  • relative_step (bool): If True, use relative step size. Default: True
  • warmup_init (bool): If True, calculate step size by current step. Default: False
optimizer = optim.Adafactor(learning_rate=None, relative_step=True)
optim.Muon
class
Muon (MomentUm Orthogonalized by Newton-schulz) optimizer.Note: Muon may be sub-optimal for embedding layers, final fully connected layers, or 0D/1D parameters. Use a different optimizer (e.g., AdamW) for those.Parameters:
  • learning_rate (float or callable): The learning rate
  • momentum (float): The momentum strength. Default: 0.95
  • weight_decay (float): The weight decay (L2 penalty). Default: 0.01
  • nesterov (bool): Enables Nesterov momentum. Default: True
  • ns_steps (int): Number of Newton-Schulz iteration steps. Default: 5
optimizer = optim.Muon(learning_rate=0.01, momentum=0.95, weight_decay=0.01)

Multi-Optimizer

optim.MultiOptimizer
class
Wraps multiple optimizers with weight predicates to use different optimizers for different parameters.Parameters:
  • optimizers (list[Optimizer]): List of optimizers to delegate to
  • filters (list[Callable]): List of predicates (one less than optimizers). Last optimizer is fallback.
# Use AdamW for most parameters, but SGD for biases
optimizer = optim.MultiOptimizer(
    optimizers=[
        optim.SGD(learning_rate=0.01),
        optim.AdamW(learning_rate=1e-3)
    ],
    filters=[
        lambda k, v: "bias" in k  # Use SGD for biases
        # AdamW is fallback for everything else
    ]
)

Learning Rate Schedulers

Learning rate schedulers can be passed directly to optimizers:
optim.exponential_decay
function
Make an exponential decay scheduler.Parameters:
  • init (float): Initial value
  • decay_rate (float): Multiplicative factor to decay by
lr_schedule = optim.exponential_decay(1e-1, 0.9)
optimizer = optim.SGD(learning_rate=lr_schedule)

# Learning rate decays exponentially with each step
print(optimizer.learning_rate)  # 0.1
for _ in range(5):
    optimizer.update({}, {})
print(optimizer.learning_rate)  # 0.06561
optim.step_decay
function
Make a step decay scheduler.Parameters:
  • init (float): Initial value
  • decay_rate (float): Multiplicative factor to decay by
  • step_size (int): Decay every step_size steps
lr_schedule = optim.step_decay(1e-1, 0.9, step_size=10)
optimizer = optim.SGD(learning_rate=lr_schedule)

# Learning rate stays constant for 10 steps, then decays
for _ in range(21):
    optimizer.update({}, {})
print(optimizer.learning_rate)  # 0.081
optim.cosine_decay
function
Make a cosine decay scheduler.Parameters:
  • init (float): Initial value
  • decay_steps (int): Number of steps to decay over
  • end (float): Final value to decay to. Default: 0.0
lr_schedule = optim.cosine_decay(1e-1, decay_steps=1000)
optimizer = optim.SGD(learning_rate=lr_schedule)

# Learning rate follows cosine curve from init to end
optim.linear_schedule
function
Make a linear scheduler.Parameters:
  • init (float): Initial value
  • end (float): Final value
  • steps (int): Number of steps to apply schedule over
lr_schedule = optim.linear_schedule(0, 1e-1, steps=100)
optimizer = optim.Adam(learning_rate=lr_schedule)

# Learning rate linearly increases from 0 to 0.1 over 100 steps
print(optimizer.learning_rate)  # 0.0
for _ in range(101):
    optimizer.update({}, {})
print(optimizer.learning_rate)  # 0.1
optim.join_schedules
function
Join multiple schedules to create a new schedule.Parameters:
  • schedules (list[Callable]): List of schedules
  • boundaries (list[int]): Boundaries indicating when to transition between schedules
# Warmup with linear schedule, then cosine decay
linear = optim.linear_schedule(0, 1e-1, steps=10)
cosine = optim.cosine_decay(1e-1, decay_steps=200)
lr_schedule = optim.join_schedules([linear, cosine], boundaries=[10])

optimizer = optim.Adam(learning_rate=lr_schedule)
print(optimizer.learning_rate)  # 0.0 (linear warmup)

for _ in range(12):
    optimizer.update({}, {})
print(optimizer.learning_rate)  # ~0.0999 (cosine decay)

Gradient Clipping

optim.clip_grad_norm
function
Clips the global norm of the gradients.Ensures that the global norm of gradients does not exceed max_norm. Scales down gradients proportionally if needed.Parameters:
  • grads (dict): Dictionary containing gradient arrays
  • max_norm (float): Maximum allowed global norm of gradients
Returns:
  • (dict, float): Clipped gradients and original gradient norm
loss, grads = loss_and_grad_fn(model, x, y)

# Clip gradients to max norm of 1.0
clipped_grads, total_norm = optim.clip_grad_norm(grads, max_norm=1.0)

optimizer.update(model, clipped_grads)
mx.eval(model.parameters(), optimizer.state)

print(f"Gradient norm: {total_norm}")

Saving and Loading

To serialize an optimizer, save its state. To load an optimizer, load and set the saved state.
import mlx.core as mx
from mlx.utils import tree_flatten, tree_unflatten
import mlx.optimizers as optim

# Create and use optimizer
optimizer = optim.Adam(learning_rate=1e-2)
model = {"w": mx.zeros((5, 5))}
grads = {"w": mx.ones((5, 5))}
optimizer.update(model, grads)

# Save the state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer.safetensors", state)

# Later: recreate optimizer and load state
optimizer = optim.Adam(learning_rate=1e-2)
state = tree_unflatten(mx.load("optimizer.safetensors"))
optimizer.state = state
Note: Not every optimizer configuration parameter is saved in the state. For example, for Adam the learning rate is saved but betas and eps are not. As a rule of thumb, if a parameter can be scheduled, it will be included in the optimizer state.

Complete Training Example

Here’s a complete example showing optimizer usage in a training loop:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

# Define model
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.layers = [
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        ]
    
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Initialize model
model = MLP(784, 128, 10)
mx.eval(model.parameters())

# Define loss function
def loss_fn(model, x, y):
    logits = model(x)
    return nn.losses.cross_entropy(logits, y)

# Create optimizer with learning rate schedule
lr_schedule = optim.join_schedules(
    [
        optim.linear_schedule(0, 1e-3, steps=100),  # Warmup
        optim.cosine_decay(1e-3, decay_steps=1000)  # Decay
    ],
    boundaries=[100]
)
optimizer = optim.AdamW(learning_rate=lr_schedule, weight_decay=0.01)

# Create gradient function
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    epoch_loss = 0.0
    for x_batch, y_batch in dataloader:
        # Compute loss and gradients
        loss, grads = loss_and_grad_fn(model, x_batch, y_batch)
        
        # Clip gradients
        grads, grad_norm = optim.clip_grad_norm(grads, max_norm=1.0)
        
        # Update model
        optimizer.update(model, grads)
        
        # Evaluate
        mx.eval(model.parameters(), optimizer.state, loss)
        
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch}: Loss = {epoch_loss / len(dataloader):.4f}, "
          f"LR = {optimizer.learning_rate.item():.6f}")

# Save optimizer state
state = tree_flatten(optimizer.state, destination={})
mx.save_safetensors("optimizer_checkpoint.safetensors", state)

Using Multiple Optimizers

For advanced use cases, you can use different optimizers for different parts of your model:
# Use different optimizers for different parameter groups
def is_embedding(key, value):
    return "embedding" in key

def is_output_layer(key, value):
    return "output" in key

optimizer = optim.MultiOptimizer(
    optimizers=[
        optim.SGD(learning_rate=0.001),      # For embeddings
        optim.SGD(learning_rate=0.01),       # For output layer  
        optim.AdamW(learning_rate=1e-3)      # For everything else (fallback)
    ],
    filters=[
        is_embedding,
        is_output_layer
    ]
)

Build docs developers (and LLMs) love