Skip to main content
This guide walks you through training a two-layer MLP on MNIST using the complete example from examples/mnist.rs. By the end you will have a working model that trains for ten epochs on GPU.
If you don’t have MNIST data files, Meganeura falls back to synthetic data automatically. You can complete this quickstart without downloading anything.
1

Add Meganeura to your project

Add Meganeura as a git dependency in your Cargo.toml:
Cargo.toml
[dependencies]
meganeura = { git = "https://github.com/kvark/meganeura" }
Meganeura requires Rust 1.88 or later. Check your version with:
rustc --version
2

Set up your imports and hyperparameters

Create src/main.rs and add the imports and training hyperparameters:
src/main.rs
use meganeura::{DataLoader, Graph, MnistDataset, TrainConfig, Trainer, build_session};
use std::path::Path;

fn main() {
    env_logger::init();

    let batch = 32;
    let input_dim = 784; // 28x28
    let hidden = 128;
    let classes = 10;
    let epochs = 10;
    let lr = 0.01_f32;
3

Build the computation graph

Define the model as a declarative graph of operations. Meganeura records these operations but does not execute them yet:
    // --- Build the model graph ---
    let mut g = Graph::new();

    // Inputs
    let x = g.input("x", &[batch, input_dim]);
    let labels = g.input("labels", &[batch, classes]);

    // Layer 1: linear + relu
    let w1 = g.parameter("w1", &[input_dim, hidden]);
    let b1 = g.parameter("b1", &[hidden]);
    let mm1 = g.matmul(x, w1);
    let h1 = g.bias_add(mm1, b1);
    let a1 = g.relu(h1);

    // Layer 2: linear → logits
    let w2 = g.parameter("w2", &[hidden, classes]);
    let b2 = g.parameter("b2", &[classes]);
    let mm2 = g.matmul(a1, w2);
    let logits = g.bias_add(mm2, b2);

    // Loss
    let loss = g.cross_entropy_loss(logits, labels);
    g.set_outputs(vec![loss]);
Each call to g.parameter(name, shape) creates a named trainable tensor. g.set_outputs marks which nodes produce the values you care about — in this case, the scalar loss.
4

Build the training session

build_session runs the full compilation pipeline in one call:
  1. Extends the graph with backward-pass operations (autodiff)
  2. Optimizes the combined forward+backward graph with egglog
  3. Compiles to WGSL shaders and initializes GPU buffers
    // --- Build training session ---
    // This runs: autodiff → egglog optimize → compile → GPU init
    println!("building session (autodiff + egglog + compile)...");
    let mut session = build_session(&g);
    println!(
        "session ready: {} buffers, {} dispatches",
        session.plan().buffers.len(),
        session.plan().dispatches.len()
    );
build_session compiles once. The resulting session holds a static dispatch plan — subsequent training steps execute at full GPU speed with no recompilation.
5

Initialize parameters

Set initial values for all named parameters. This example uses Xavier initialization for the weight matrices and zeros for the biases:
    // --- Initialize parameters ---
    // Xavier initialization
    let w1_data = xavier_init(input_dim, hidden);
    let b1_data = vec![0.0_f32; hidden];
    let w2_data = xavier_init(hidden, classes);
    let b2_data = vec![0.0_f32; classes];

    session.set_parameter("w1", &w1_data);
    session.set_parameter("b1", &b1_data);
    session.set_parameter("w2", &w2_data);
    session.set_parameter("b2", &b2_data);
Parameters are referenced by the same names you gave them in the graph. The set_parameter call copies the data to the GPU buffer.
6

Run the training loop

Create a Trainer with a TrainConfig and call train. The trainer handles batching, gradient updates, and logging:
    // --- Training loop ---
    println!("training...");
    let config = TrainConfig {
        learning_rate: lr,
        log_interval: 50,
        ..TrainConfig::default()
    };
    let mut trainer = Trainer::new(session, config);
    let history = trainer.train(&mut loader, epochs);

    if let Some(final_loss) = history.final_loss() {
        println!("done! final avg_loss = {:.4}", final_loss);
    } else {
        println!("done! (no epochs ran)");
    }
}
train returns a TrainHistory with per-epoch loss statistics. final_loss() gives the average loss over the last epoch.

Complete example

The full source including data loading and Xavier initialization is in examples/mnist.rs. Run it directly with:
cargo run --example mnist
To use real MNIST data, download the training files and place them in a data/ directory:
data/train-images-idx3-ubyte.gz
data/train-labels-idx1-ubyte.gz

Expected output

MNIST not found in data/, using synthetic data
3200 samples, 100 batches/epoch
forward graph:
...
building session (autodiff + egglog + compile)...
session ready: 14 buffers, 22 dispatches
training...
done! final avg_loss = 0.1842
Loss values will differ with real MNIST data versus synthetic data. With real MNIST, expect the loss to decrease from around 2.3 to below 0.5 over ten epochs.

Next steps

HuggingFace integration

Load pre-trained safetensors weights from the HuggingFace Hub for inference.

Trainers and optimizers

Configure learning rate schedules, gradient clipping, and optimizer settings.

Layers and operations

Full reference for all graph operations: attention, normalization, activations, and more.

Profiling

Capture Perfetto traces to analyze GPU execution timelines.

Build docs developers (and LLMs) love