Skip to main content
A Session is the runtime object produced after compiling a Graph. It owns the GPU context, allocates all buffers, compiles GPU pipelines, and executes the pre-built dispatch sequence on each call to step() or run(). You do not construct a Session directly. Use one of the builder functions:
use meganeura::{Graph, build_session, build_inference_session};

// Training session (forward + backward + parameter updates)
let session = build_session(&graph);

// Inference session (forward pass only)
let session = build_inference_session(&graph);

Builder functions

build_session

pub fn build_session(forward_graph: &Graph) -> Session
Builds a full training session from a forward-pass graph. The pipeline:
  1. Optimizes the forward graph (SwiGLU fusion, MatMul+Add fusion, etc.)
  2. Runs autodiff on the optimized forward graph to produce the full forward+backward graph
  3. Optimizes the full graph (fuses backward MatMul+Add, etc.)
  4. Compiles the optimized graph to an ExecutionPlan
  5. Allocates all GPU buffers and compiles GPU pipelines
let session = build_session(&g);

build_inference_session

pub fn build_inference_session(forward_graph: &Graph) -> Session
Builds an inference-only session. Skips autodiff — no gradient buffers are allocated. Applies inference-specific fusions (e.g. GroupNorm + SiLU → GroupNormSilu).
let session = build_inference_session(&g);

build_session_cached

pub fn build_session_cached(forward_graph: &Graph, cache_path: &Path) -> Session
Like build_session, but loads the compiled ExecutionPlan from a cache file if the graph hash matches. Saves the plan after compilation if no cache exists. Useful for skipping the compilation overhead on repeated runs.
forward_graph
&Graph
required
The forward-pass computation graph.
cache_path
&Path
required
Path to read from / write to (e.g. "model.plan.ron").

build_session_with_report

pub fn build_session_with_report(forward_graph: &Graph) -> (Session, OptimizeReport)
Like build_session, but also returns an OptimizeReport describing what the e-graph optimizer did.

MemorySummary

pub struct MemorySummary {
    pub total_buffer_bytes: usize,
    pub adam_state_bytes: usize,
    pub num_buffers: usize,
    pub largest_buffer_bytes: usize,
}
A snapshot of GPU memory usage for a session.
total_buffer_bytes
usize
Total bytes across all allocated GPU buffers.
adam_state_bytes
usize
Bytes used by Adam first- and second-moment buffers (0 if using SGD).
num_buffers
usize
Number of individual GPU buffers.
largest_buffer_bytes
usize
Size of the largest single buffer.
MemorySummary implements Display:
4 buffers, 12.3 MB total (8.2 MB adam state), largest 4.1 MB

Session methods

Data upload

set_input

pub fn set_input(&mut self, name: &str, data: &[f32])
Uploads f32 data into the named input buffer. Call this before each step() or run().
name
&str
required
Must match the name passed to graph.input() when building the graph.
data
&[f32]
required
Flat slice of values. Length must match the number of elements in the named input.
session.set_input("x", &batch_data);
session.set_input("labels", &batch_labels);

set_parameter

pub fn set_parameter(&mut self, name: &str, data: &[f32])
Uploads f32 data into a named parameter buffer. Use this to initialize weights before training, or to restore a checkpoint manually.
name
&str
required
Must match the name passed to graph.parameter().
data
&[f32]
required
Flat slice of initial values.
session.set_parameter("w1", &kaiming_init);
If the parameter feeds a derived (concatenated) weight created by the optimizer, set_parameter writes the source data into the correct column offset of the derived buffer automatically.

get_parameter / read_param

pub fn read_param(&self, name: &str, out: &mut [f32])
Reads the current value of a named parameter buffer back from GPU memory.
name
&str
required
Parameter name.
out
&mut [f32]
required
Destination slice. Must be large enough to hold all parameter elements.

Optimizer configuration

set_learning_rate

pub fn set_learning_rate(&mut self, lr: f32)
Sets the SGD learning rate for updates fused into the next step(). When set, step() appends all SGD parameter updates to the same GPU submission as the forward+backward pass — eliminating the overhead of a separate sgd_step() call.
lr
f32
required
Learning rate.
session.set_learning_rate(0.01);
session.step();
session.wait();

set_adam

pub fn set_adam(&mut self, lr: f32, beta1: f32, beta2: f32, epsilon: f32)
Sets Adam optimizer parameters for updates fused into the next step(). Analogous to set_learning_rate for SGD.
lr
f32
required
Learning rate.
beta1
f32
required
First moment decay (typically 0.9).
beta2
f32
required
Second moment decay (typically 0.999).
epsilon
f32
required
Numerical stability constant (typically 1e-8).
session.set_adam(1e-3, 0.9, 0.999, 1e-8);
session.step();
session.wait();

Execution

step

pub fn step(&mut self)
Executes the full GPU dispatch sequence: forward pass, backward pass, and (if an optimizer was configured) parameter updates — all in a single GPU submission. Call wait() before reading results.
session.set_input("x", &data);
session.set_learning_rate(0.01);
session.step();
session.wait();
let loss = session.read_loss();

wait

pub fn wait(&mut self)
Blocks the CPU until any pending GPU work completes. Safe to call even if no work is in flight.

run

For inference graphs, use step() as the forward execution method. The session does not have a separate run() method — step() works for both training and inference sessions.

Reading outputs

read_loss

pub fn read_loss(&self) -> f32
Reads the scalar loss value from GPU memory. Must be called after wait(). Returns 0.0 if the graph has no loss output.

read_output_by_index

pub fn read_output_by_index(&self, index: usize, out: &mut [f32])
Reads a graph output by index into a pre-allocated slice. Index 0 is the primary output. Higher indices correspond to additional outputs set via graph.set_outputs().
index
usize
required
Output index (0-based).
out
&mut [f32]
required
Destination buffer.

num_outputs

pub fn num_outputs(&self) -> usize
Returns the number of graph outputs.

Inspection

plan

pub fn plan(&self) -> &ExecutionPlan
Returns a reference to the compiled execution plan. Useful for inspecting dispatch counts and buffer layouts.

memory_summary

pub fn memory_summary(&self) -> MemorySummary
Returns a MemorySummary with GPU memory statistics for this session.
let summary = session.memory_summary();
println!("{}", summary);
// 4 buffers, 12.3 MB total (8.2 MB adam state), largest 4.1 MB

Checkpointing

save_checkpoint

pub fn save_checkpoint(&mut self, path: &std::path::Path) -> std::io::Result<()>
Saves all parameter values, Adam first/second moment buffers, and the Adam step counter to a safetensors file.
session.save_checkpoint(Path::new("checkpoint.safetensors"))?;

load_checkpoint

pub fn load_checkpoint(&mut self, path: &std::path::Path) -> std::io::Result<()>
Restores parameter values and Adam optimizer state from a safetensors file. The session must have been built from the same graph (same parameter names and sizes).
session.load_checkpoint(Path::new("checkpoint.safetensors"))?;

Build docs developers (and LLMs) love