In this example, we’ll learn to useDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/ml-explore/mlx/llms.txt
Use this file to discover all available pages before exploring further.
mlx.nn by implementing a simple multi-layer perceptron (MLP) to classify MNIST digits. This demonstrates how to create custom modules, define training loops, and use MLX’s built-in optimizers.
Setup
First, import the MLX packages we need:Define the Model
The model is defined as theMLP class which inherits from mlx.nn.Module. We follow the standard idiom for creating a new module:
Define __init__
Set up parameters and submodules. The
mlx.nn.Module base class automatically registers parameters.- Multiple hidden layers with ReLU activation (using
mx.maximum) - A final output layer without activation
Loss and Evaluation Functions
Define the loss function which takes the mean of the per-example cross-entropy loss:mlx.nn.losses subpackage provides implementations of commonly used loss functions.
We also need a function to compute accuracy on the validation set:
Setup and Data Loading
Configure the problem parameters and load the MNIST data:You’ll need the MNIST data loader from the mlx-examples repository.
Batch Iterator
Since we’re using SGD, we need an iterator that shuffles and constructs minibatches:Training Loop
Put it all together by instantiating the model, optimizer, and running the training loop:The
mlx.nn.value_and_grad() function is a convenience function to get the gradient of a loss with respect to the trainable parameters of a model. This should not be confused with mlx.core.value_and_grad().