Skip to main content
The Trainer struct wraps a Session and a DataLoader and manages the training loop for you. It handles shuffling, batching, optimizer dispatch, and loss logging.
use meganeura::{Trainer, TrainConfig, Optimizer, build_session};

let session = build_session(&graph);

let config = TrainConfig {
    optimizer: Optimizer::adam(1e-3),
    log_interval: 100,
    data_input: "x".into(),
    label_input: "labels".into(),
    ..Default::default()
};

let mut trainer = Trainer::new(session, config);
let history = trainer.train(&mut loader, 10);
println!("final loss: {:.4}", history.final_loss().unwrap());

Optimizer

#[derive(Clone, Debug)]
pub enum Optimizer {
    Sgd { learning_rate: f32 },
    Adam {
        learning_rate: f32,
        beta1: f32,
        beta2: f32,
        epsilon: f32,
    },
}
Selects the optimization algorithm and its hyperparameters.
Sgd
variant
Stochastic gradient descent.
Adam
variant
Adaptive moment estimation optimizer.

Constructors

Optimizer::sgd

pub fn sgd(lr: f32) -> Self
Creates an SGD optimizer with the given learning rate.
let opt = Optimizer::sgd(0.01);

Optimizer::adam

pub fn adam(lr: f32) -> Self
Creates an Adam optimizer with standard defaults: beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8.
let opt = Optimizer::adam(1e-3);

TrainConfig

pub struct TrainConfig {
    pub optimizer: Optimizer,
    pub learning_rate: f32,
    pub log_interval: usize,
    pub data_input: String,
    pub label_input: String,
}
Configuration for a Trainer. Implements Default.
optimizer
Optimizer
The optimizer to use. Defaults to Optimizer::Sgd { learning_rate: 0.01 }.
learning_rate
f32
Backward-compatible SGD learning rate alias. Ignored when optimizer is set to Adam. Defaults to 0.01.
log_interval
usize
Print a loss line every this many steps. Set to 0 to disable step logging. Defaults to 100.
data_input
String
Name of the graph input node that receives sample data. Defaults to "x".
label_input
String
Name of the graph input node that receives labels. Defaults to "labels".

Default values

TrainConfig {
    optimizer: Optimizer::Sgd { learning_rate: 0.01 },
    learning_rate: 0.01,
    log_interval: 100,
    data_input: "x".into(),
    label_input: "labels".into(),
}

Trainer

pub struct Trainer {
    session: Session,
    config: TrainConfig,
}
Drives the training loop over a Session and DataLoader.

new

pub fn new(session: Session, config: TrainConfig) -> Self
Creates a trainer from a compiled session and a training configuration.
session
Session
required
A Session built with build_session().
config
TrainConfig
required
Training hyperparameters and input names.

train

pub fn train(&mut self, loader: &mut DataLoader, epochs: usize) -> TrainHistory
Runs epochs full passes over the data. On each epoch:
  1. Shuffles the data loader with the epoch index as the RNG seed
  2. Iterates over all batches, uploading data and labels, running step() + wait()
  3. Logs the per-epoch average loss via log::info!
Returns a TrainHistory containing per-epoch statistics.
loader
&mut DataLoader
required
Data loader to iterate over. Reset and shuffled at the start of each epoch.
epochs
usize
required
Number of full passes over the dataset.
let history = trainer.train(&mut loader, 20);

train_epoch

pub fn train_epoch(&mut self, loader: &mut DataLoader, epoch: usize) -> EpochStats
Runs a single epoch and returns its statistics. Useful when you need to interleave training with validation or other custom logic between epochs.
loader
&mut DataLoader
required
Data loader.
epoch
usize
required
Epoch index (used as the RNG seed for shuffling).
for epoch in 0..num_epochs {
    let stats = trainer.train_epoch(&mut loader, epoch);
    let val_loss = evaluate(&mut trainer.session_mut(), &val_data);
    println!("epoch {}: train={:.4} val={:.4}", epoch, stats.avg_loss, val_loss);
}

session

pub fn session(&self) -> &Session
Borrows the underlying session immutably (e.g. to read parameters).

session_mut

pub fn session_mut(&mut self) -> &mut Session
Borrows the underlying session mutably (e.g. to set parameters, save checkpoints, or read outputs).
trainer.session_mut().save_checkpoint(Path::new("ckpt.safetensors"))?;

into_session

pub fn into_session(self) -> Session
Consumes the trainer and returns the inner session. Use this when training is complete and you want to run inference without rebuilding a session.
let session = trainer.into_session();
// session is now available for inference

TrainHistory

#[derive(Clone, Debug, Default)]
pub struct TrainHistory {
    pub epochs: Vec<EpochStats>,
}
Accumulated training history returned by Trainer::train.
epochs
Vec<EpochStats>
One EpochStats entry per completed epoch.

final_loss

pub fn final_loss(&self) -> Option<f32>
Returns the average loss from the last epoch, or None if no epochs ran.
if let Some(loss) = history.final_loss() {
    println!("training complete, final loss: {:.4}", loss);
}

EpochStats

#[derive(Clone, Debug)]
pub struct EpochStats {
    pub epoch: usize,
    pub avg_loss: f32,
    pub steps: usize,
}
Statistics for a single completed epoch.
epoch
usize
Zero-based epoch index.
avg_loss
f32
Average loss across all steps in the epoch.
steps
usize
Number of batches processed.

StepMetrics

#[derive(Clone, Debug)]
pub struct StepMetrics {
    pub epoch: usize,
    pub step: usize,
    pub loss: f32,
}
Per-step training metrics passed to MetricCallback::on_step.
epoch
usize
Current epoch index.
step
usize
Step index within the epoch.
loss
f32
Loss value for this step.

MetricCallback

pub trait MetricCallback {
    fn on_step(&mut self, _metrics: &StepMetrics) {}
    fn on_epoch(&mut self, _stats: &EpochStats) {}
}
A trait for receiving training events. Both methods have default no-op implementations, so you only need to implement what you need.
on_step
fn(&mut self, metrics: &StepMetrics)
Called after each training step with the step’s loss and position.
on_epoch
fn(&mut self, stats: &EpochStats)
Called at the end of each epoch with the epoch’s aggregated statistics.

LossHistory

#[derive(Default)]
pub struct LossHistory {
    pub losses: Vec<f32>,
}
A MetricCallback implementation that records every step’s loss for later analysis or plotting.
losses
Vec<f32>
Ordered list of per-step loss values in the order they were recorded.
LossHistory implements MetricCallback by appending metrics.loss on each on_step call.
let mut history = LossHistory::default();

// Pass to a custom training loop
for step in 0..num_steps {
    session.set_input("x", &batch_data);
    session.set_learning_rate(lr);
    session.step();
    session.wait();
    let metrics = StepMetrics { epoch: 0, step, loss: session.read_loss() };
    history.on_step(&metrics);
}

println!("recorded {} loss values", history.losses.len());

Build docs developers (and LLMs) love