An inference session is a compiled GPU session that runs only the forward pass of a model. It skips autodiff entirely — no gradient buffers, no backward-pass dispatches, and no Adam state — which makes it faster to build and cheaper to run than a training session.
Training sessions vs inference sessions
Both session types share the same Session type and the same execution model: you set inputs and parameters, call step(), wait for the GPU, and read outputs. The difference lies in how they are built.
| Training session | Inference session |
|---|
| Builder | build_session(&graph) | build_inference_session(&graph) |
| Backward pass | Yes, via autodiff | No |
| Gradient buffers | Allocated | Not allocated |
| Adam state buffers | Allocated per parameter | None (adam_state_bytes = 0) |
| Optimizer fusions | SwiGLU, MatMul+Add, backward fusions | SwiGLU, MatMul+Add, GroupNorm+SiLU |
| Typical use | Fine-tuning, training from scratch | Deployment, evaluation, benchmarking |
Choose build_inference_session whenever you do not need to update weights. This reduces GPU memory allocation and removes the overhead of backward-pass kernel compilation.
Building an inference session
Call build_inference_session with a forward-only graph. The function runs egglog optimization, applies inference-specific fusions (such as GroupNorm+SiLU → GroupNormSilu), compiles the graph to an execution plan, and creates a GPU session.
use meganeura::{Graph, build_inference_session};
let mut g = Graph::new();
let x = g.input("x", &[1, 784]);
let w1 = g.parameter("w1", &[784, 256]);
let b1 = g.parameter("b1", &[256]);
let h1 = g.bias_add(g.matmul(x, w1), b1);
let h1 = g.relu(h1);
let w2 = g.parameter("w2", &[256, 10]);
let b2 = g.parameter("b2", &[10]);
let logits = g.bias_add(g.matmul(h1, w2), b2);
let probs = g.softmax(logits);
g.set_outputs(vec![probs]);
let mut session = build_inference_session(&g);
You do not need to call g.set_outputs before passing the graph to build_inference_session, but you must set at least one output for the session to have anything to read back.
Loading weights
Use session.set_parameter to upload weight data from CPU memory to the GPU buffer for a named parameter. The name must match exactly the string passed to g.parameter(...) when building the graph.
// Upload a weight matrix and bias vector
session.set_parameter("w1", &weights_w1); // &[f32]
session.set_parameter("b1", &weights_b1);
Call set_parameter for every parameter in the graph before the first step(). Parameters retain their values between steps — you only need to call set_parameter again if you want to change the weights.
Running a forward pass
Set the input
Upload the input tensor with set_input (for f32 data) or set_input_u32 (for token IDs and other integer inputs).session.set_input("x", &image_data); // f32 slice
session.set_input_u32("token_ids", &ids); // u32 slice
Dispatch to the GPU
Call step() to submit the full dispatch sequence to the GPU. This returns immediately without blocking. Wait for completion
Call wait() to block until the GPU finishes all pending work. Read the output
Read back the primary graph output with read_output, passing the number of elements you expect.let probs = session.read_output(10); // Vec<f32> with 10 class probabilities
If the graph has multiple outputs, use read_output_by_index to read by position, or check session.num_outputs() first.let n = session.num_outputs();
let mut buf = vec![0.0f32; expected_len];
session.read_output_by_index(0, &mut buf);
Inspecting memory usage
Call session.memory_summary() to get a breakdown of GPU buffer allocation. The returned MemorySummary struct has four fields:
pub struct MemorySummary {
/// Total bytes across all GPU buffers.
pub total_buffer_bytes: usize,
/// Bytes used by Adam first/second moment buffers.
/// Always 0 for inference sessions.
pub adam_state_bytes: usize,
/// Number of GPU buffers allocated.
pub num_buffers: usize,
/// Size of the largest single buffer, in bytes.
pub largest_buffer_bytes: usize,
}
MemorySummary implements Display, so you can print it directly:
let summary = session.memory_summary();
println!("{}", summary);
// 47 buffers, 312.4 MB total (0.0 MB adam state), largest 141.6 MB
For an inference session, adam_state_bytes is always 0. For a training session built with build_session, it reflects the memory occupied by the per-parameter Adam first and second moment buffers (2 × parameter_size per parameter).
Complete example
The following example mirrors examples/huggingface.rs and shows the full flow from graph construction to reading predictions.
use meganeura::{Graph, build_inference_session};
use meganeura::data::safetensors::SafeTensorsModel;
// 1. Build the graph
let mut g = Graph::new();
let x = g.input("x", &[1, 784]);
let w1 = g.parameter("input_layer.weight", &[784, 256]);
let b1 = g.parameter("input_layer.bias", &[256]);
let h1 = g.relu(g.bias_add(g.matmul(x, w1), b1));
let w2 = g.parameter("mid_layer.weight", &[256, 256]);
let b2 = g.parameter("mid_layer.bias", &[256]);
let h2 = g.relu(g.bias_add(g.matmul(h1, w2), b2));
let w3 = g.parameter("output_layer.weight", &[256, 10]);
let b3 = g.parameter("output_layer.bias", &[10]);
let logits = g.bias_add(g.matmul(h2, w3), b3);
let probs = g.softmax(logits);
g.set_outputs(vec![probs]);
// 2. Compile
let mut session = build_inference_session(&g);
println!("{}", session.memory_summary());
// 3. Load weights from a safetensors file
let hf = SafeTensorsModel::download("dacorvo/mnist-mlp")
.expect("failed to download model");
for name in ["input_layer.weight", "mid_layer.weight", "output_layer.weight"] {
let data = hf.tensor_f32_transposed(name).unwrap();
session.set_parameter(name, &data);
}
for name in ["input_layer.bias", "mid_layer.bias", "output_layer.bias"] {
let data = hf.tensor_f32(name).unwrap();
session.set_parameter(name, &data);
}
// 4. Run inference
let image: Vec<f32> = raw_pixels.iter()
.map(|&v| (v - 0.1307) / 0.3081)
.collect();
session.set_input("x", &image);
session.step();
session.wait();
let probs = session.read_output(10);
let predicted = probs.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;
println!("predicted class: {}", predicted);
You can reuse a session across many inputs without rebuilding it. Set a new input, call step() and wait(), and read the output again. Parameters are preserved on the GPU between calls.