Skip to main content
The train module provides a high-level training loop that handles the epoch → batch → step → parameter update cycle. You configure it with TrainConfig, hand it a DataLoader, and call trainer.train(loader, epochs).

Building a session

Before creating a Trainer, you need a compiled Session. Three build functions are available:

build_session

The standard path. Runs autodiff, e-graph optimization, and GPU compilation. Logs the optimization report.

build_session_with_report

Same as build_session but returns the OptimizeReport alongside the session for programmatic inspection.

build_session_cached

Loads a compiled ExecutionPlan from disk if the graph hash matches. Falls back to the full pipeline and saves the result.

build_inference_session

Skips autodiff entirely. For inference-only workloads.
use meganeura::build_session;

// Forward graph you built with Graph::new() + set_outputs(...)
let mut session = build_session(&g);
Use build_session_cached when you want to avoid re-compiling the same graph on repeated runs:
use meganeura::build_session_cached;
use std::path::Path;

let session = build_session_cached(&g, Path::new("cache/my_model.ron"));

Optimizer

Optimizer selects which parameter update rule to apply each step:
pub enum Optimizer {
    Sgd { learning_rate: f32 },
    Adam {
        learning_rate: f32,
        beta1: f32,
        beta2: f32,
        epsilon: f32,
    },
}
Use the convenience constructors to create optimizers with sensible defaults:
// SGD with learning rate 0.01
let opt = Optimizer::sgd(0.01);

// Adam with lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8
let opt = Optimizer::adam(1e-3);
The default Optimizer (when you use TrainConfig::default()) is Sgd { learning_rate: 0.01 }.

TrainConfig

TrainConfig groups all training hyperparameters:
pub struct TrainConfig {
    pub optimizer: Optimizer,
    pub learning_rate: f32,
    pub log_interval: usize,
    pub data_input: String,
    pub label_input: String,
}
optimizer
Optimizer
default:"Optimizer::Sgd { learning_rate: 0.01 }"
Which optimizer to use. Overrides learning_rate when set to Adam.
learning_rate
f32
default:"0.01"
Backward-compatible alias. Sets the SGD learning rate when optimizer is not explicitly set to Adam.
log_interval
usize
default:"100"
Log loss every N steps using log::info!. Set to 0 to disable step-level logging.
data_input
String
default:"\"x\""
Name of the graph input node that receives sample data. Must match the name used in g.input(...).
label_input
String
default:"\"labels\""
Name of the graph input node that receives labels. Must match the name used in g.input(...).
A typical configuration from examples/mnist.rs:
let config = TrainConfig {
    learning_rate: 0.01,
    log_interval: 50,
    ..TrainConfig::default()
};

Trainer

Trainer owns the Session and drives the training loop:
pub struct Trainer {
    session: Session,
    config: TrainConfig,
}

Creating a trainer

use meganeura::{Trainer, TrainConfig};

let mut trainer = Trainer::new(session, config);

Running training

train runs epochs full passes over the data and returns a TrainHistory:
let history = trainer.train(&mut loader, 10);

if let Some(loss) = history.final_loss() {
    println!("final avg_loss = {:.4}", loss);
}
To run a single epoch manually, use train_epoch:
let stats = trainer.train_epoch(&mut loader, /*epoch=*/ 0);
println!("epoch {}: avg_loss={:.4} ({} steps)", stats.epoch, stats.avg_loss, stats.steps);
train_epoch shuffles the loader with the epoch number as seed, then iterates batches until the epoch is exhausted.

Accessing the session

After training you can extract the session to read parameters or run inference:
// Borrow for reading
let plan = trainer.session().plan();

// Consume and return the session
let session = trainer.into_session();

TrainHistory and EpochStats

train returns a TrainHistory that accumulates per-epoch statistics:
pub struct TrainHistory {
    pub epochs: Vec<EpochStats>,
}

pub struct EpochStats {
    pub epoch: usize,
    pub avg_loss: f32,
    pub steps: usize,
}
TrainHistory::final_loss() returns the avg_loss of the last epoch, or None if no epochs ran:
assert_eq!(history.final_loss(), history.epochs.last().map(|e| e.avg_loss));

MetricCallback and LossHistory

Implement MetricCallback to receive per-step and per-epoch events:
pub trait MetricCallback {
    fn on_step(&mut self, metrics: &StepMetrics) {}
    fn on_epoch(&mut self, stats: &EpochStats) {}
}

pub struct StepMetrics {
    pub epoch: usize,
    pub step: usize,
    pub loss: f32,
}
LossHistory is a built-in implementation that collects every step’s loss into a Vec<f32>:
pub struct LossHistory {
    pub losses: Vec<f32>,
}
MetricCallback hooks are separate from Trainer. Instrument your training loop by wrapping the train_epoch call if you need per-step callbacks alongside the built-in logging.

Full training loop example

The following is the complete training setup from examples/mnist.rs:
use meganeura::{
    DataLoader, Graph, MnistDataset, TrainConfig, Trainer, build_session,
};

// 1. Build the graph (see Building computation graphs)
let mut g = Graph::new();
let x = g.input("x", &[batch, input_dim]);
let labels = g.input("labels", &[batch, classes]);
// ... add layers, loss, set_outputs ...

// 2. Compile the session
let mut session = build_session(&g);

// 3. Initialize parameters
session.set_parameter("w1", &xavier_init(input_dim, hidden));
session.set_parameter("b1", &vec![0.0_f32; hidden]);
session.set_parameter("w2", &xavier_init(hidden, classes));
session.set_parameter("b2", &vec![0.0_f32; classes]);

// 4. Load data
let mut loader = MnistDataset::load_gz(
    Path::new("data/train-images-idx3-ubyte.gz"),
    Path::new("data/train-labels-idx1-ubyte.gz"),
)
.expect("failed to load MNIST")
.loader(batch);

// 5. Configure and run training
let config = TrainConfig {
    learning_rate: 0.01,
    log_interval: 50,
    ..TrainConfig::default()
};
let mut trainer = Trainer::new(session, config);
let history = trainer.train(&mut loader, 10);

println!("done! final avg_loss = {:.4}", history.final_loss().unwrap());

Build docs developers (and LLMs) love