train module provides a high-level training loop that handles the epoch → batch → step → parameter update cycle. You configure it with TrainConfig, hand it a DataLoader, and call trainer.train(loader, epochs).
Building a session
Before creating aTrainer, you need a compiled Session. Three build functions are available:
build_session
The standard path. Runs autodiff, e-graph optimization, and GPU compilation. Logs the optimization report.
build_session_with_report
Same as
build_session but returns the OptimizeReport alongside the session for programmatic inspection.build_session_cached
Loads a compiled
ExecutionPlan from disk if the graph hash matches. Falls back to the full pipeline and saves the result.build_inference_session
Skips autodiff entirely. For inference-only workloads.
build_session_cached when you want to avoid re-compiling the same graph on repeated runs:
Optimizer
Optimizer selects which parameter update rule to apply each step:
Optimizer (when you use TrainConfig::default()) is Sgd { learning_rate: 0.01 }.
TrainConfig
TrainConfig groups all training hyperparameters:
Which optimizer to use. Overrides
learning_rate when set to Adam.Backward-compatible alias. Sets the SGD learning rate when
optimizer is not explicitly set to Adam.Log loss every N steps using
log::info!. Set to 0 to disable step-level logging.Name of the graph input node that receives sample data. Must match the name used in
g.input(...).Name of the graph input node that receives labels. Must match the name used in
g.input(...).examples/mnist.rs:
Trainer
Trainer owns the Session and drives the training loop:
Creating a trainer
Running training
train runs epochs full passes over the data and returns a TrainHistory:
train_epoch:
train_epoch shuffles the loader with the epoch number as seed, then iterates batches until the epoch is exhausted.
Accessing the session
After training you can extract the session to read parameters or run inference:TrainHistory and EpochStats
train returns a TrainHistory that accumulates per-epoch statistics:
TrainHistory::final_loss() returns the avg_loss of the last epoch, or None if no epochs ran:
MetricCallback and LossHistory
ImplementMetricCallback to receive per-step and per-epoch events:
LossHistory is a built-in implementation that collects every step’s loss into a Vec<f32>:
MetricCallback hooks are separate from Trainer. Instrument your training loop by wrapping the train_epoch call if you need per-step callbacks alongside the built-in logging.Full training loop example
The following is the complete training setup fromexamples/mnist.rs: