The Trainer struct wraps a Session and a DataLoader and manages the training loop for you. It handles shuffling, batching, optimizer dispatch, and loss logging.
use meganeura :: { Trainer , TrainConfig , Optimizer , build_session};
let session = build_session ( & graph );
let config = TrainConfig {
optimizer : Optimizer :: adam ( 1 e- 3 ),
log_interval : 100 ,
data_input : "x" . into (),
label_input : "labels" . into (),
.. Default :: default ()
};
let mut trainer = Trainer :: new ( session , config );
let history = trainer . train ( & mut loader , 10 );
println! ( "final loss: {:.4}" , history . final_loss () . unwrap ());
Optimizer
#[derive( Clone , Debug )]
pub enum Optimizer {
Sgd { learning_rate : f32 },
Adam {
learning_rate : f32 ,
beta1 : f32 ,
beta2 : f32 ,
epsilon : f32 ,
},
}
Selects the optimization algorithm and its hyperparameters.
Stochastic gradient descent. Step size applied to each gradient update.
Adaptive moment estimation optimizer. First moment (mean) decay coefficient.
Second moment (variance) decay coefficient.
Small constant for numerical stability.
Constructors
Optimizer::sgd
pub fn sgd ( lr : f32 ) -> Self
Creates an SGD optimizer with the given learning rate.
let opt = Optimizer :: sgd ( 0.01 );
Optimizer::adam
pub fn adam ( lr : f32 ) -> Self
Creates an Adam optimizer with standard defaults: beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8.
let opt = Optimizer :: adam ( 1 e- 3 );
TrainConfig
pub struct TrainConfig {
pub optimizer : Optimizer ,
pub learning_rate : f32 ,
pub log_interval : usize ,
pub data_input : String ,
pub label_input : String ,
}
Configuration for a Trainer. Implements Default.
The optimizer to use. Defaults to Optimizer::Sgd { learning_rate: 0.01 }.
Backward-compatible SGD learning rate alias. Ignored when optimizer is set to Adam. Defaults to 0.01.
Print a loss line every this many steps. Set to 0 to disable step logging. Defaults to 100.
Name of the graph input node that receives sample data. Defaults to "x".
Name of the graph input node that receives labels. Defaults to "labels".
Default values
TrainConfig {
optimizer : Optimizer :: Sgd { learning_rate : 0.01 },
learning_rate : 0.01 ,
log_interval : 100 ,
data_input : "x" . into (),
label_input : "labels" . into (),
}
Trainer
pub struct Trainer {
session : Session ,
config : TrainConfig ,
}
Drives the training loop over a Session and DataLoader.
new
pub fn new ( session : Session , config : TrainConfig ) -> Self
Creates a trainer from a compiled session and a training configuration.
A Session built with build_session().
Training hyperparameters and input names.
train
pub fn train ( & mut self , loader : & mut DataLoader , epochs : usize ) -> TrainHistory
Runs epochs full passes over the data. On each epoch:
Shuffles the data loader with the epoch index as the RNG seed
Iterates over all batches, uploading data and labels, running step() + wait()
Logs the per-epoch average loss via log::info!
Returns a TrainHistory containing per-epoch statistics.
Data loader to iterate over. Reset and shuffled at the start of each epoch.
Number of full passes over the dataset.
let history = trainer . train ( & mut loader , 20 );
train_epoch
pub fn train_epoch ( & mut self , loader : & mut DataLoader , epoch : usize ) -> EpochStats
Runs a single epoch and returns its statistics. Useful when you need to interleave training with validation or other custom logic between epochs.
Epoch index (used as the RNG seed for shuffling).
for epoch in 0 .. num_epochs {
let stats = trainer . train_epoch ( & mut loader , epoch );
let val_loss = evaluate ( & mut trainer . session_mut (), & val_data );
println! ( "epoch {}: train={:.4} val={:.4}" , epoch , stats . avg_loss, val_loss );
}
session
pub fn session ( & self ) -> & Session
Borrows the underlying session immutably (e.g. to read parameters).
session_mut
pub fn session_mut ( & mut self ) -> & mut Session
Borrows the underlying session mutably (e.g. to set parameters, save checkpoints, or read outputs).
trainer . session_mut () . save_checkpoint ( Path :: new ( "ckpt.safetensors" )) ? ;
into_session
pub fn into_session ( self ) -> Session
Consumes the trainer and returns the inner session. Use this when training is complete and you want to run inference without rebuilding a session.
let session = trainer . into_session ();
// session is now available for inference
TrainHistory
#[derive( Clone , Debug , Default )]
pub struct TrainHistory {
pub epochs : Vec < EpochStats >,
}
Accumulated training history returned by Trainer::train.
One EpochStats entry per completed epoch.
final_loss
pub fn final_loss ( & self ) -> Option < f32 >
Returns the average loss from the last epoch, or None if no epochs ran.
if let Some ( loss ) = history . final_loss () {
println! ( "training complete, final loss: {:.4}" , loss );
}
EpochStats
#[derive( Clone , Debug )]
pub struct EpochStats {
pub epoch : usize ,
pub avg_loss : f32 ,
pub steps : usize ,
}
Statistics for a single completed epoch.
Average loss across all steps in the epoch.
Number of batches processed.
StepMetrics
#[derive( Clone , Debug )]
pub struct StepMetrics {
pub epoch : usize ,
pub step : usize ,
pub loss : f32 ,
}
Per-step training metrics passed to MetricCallback::on_step.
Step index within the epoch.
Loss value for this step.
MetricCallback
pub trait MetricCallback {
fn on_step ( & mut self , _metrics : & StepMetrics ) {}
fn on_epoch ( & mut self , _stats : & EpochStats ) {}
}
A trait for receiving training events. Both methods have default no-op implementations, so you only need to implement what you need.
on_step
fn(&mut self, metrics: &StepMetrics)
Called after each training step with the step’s loss and position.
on_epoch
fn(&mut self, stats: &EpochStats)
Called at the end of each epoch with the epoch’s aggregated statistics.
LossHistory
#[derive( Default )]
pub struct LossHistory {
pub losses : Vec < f32 >,
}
A MetricCallback implementation that records every step’s loss for later analysis or plotting.
Ordered list of per-step loss values in the order they were recorded.
LossHistory implements MetricCallback by appending metrics.loss on each on_step call.
let mut history = LossHistory :: default ();
// Pass to a custom training loop
for step in 0 .. num_steps {
session . set_input ( "x" , & batch_data );
session . set_learning_rate ( lr );
session . step ();
session . wait ();
let metrics = StepMetrics { epoch : 0 , step , loss : session . read_loss () };
history . on_step ( & metrics );
}
println! ( "recorded {} loss values" , history . losses . len ());