Skip to main content
An inference session is a compiled GPU session that runs only the forward pass of a model. It skips autodiff entirely — no gradient buffers, no backward-pass dispatches, and no Adam state — which makes it faster to build and cheaper to run than a training session.

Training sessions vs inference sessions

Both session types share the same Session type and the same execution model: you set inputs and parameters, call step(), wait for the GPU, and read outputs. The difference lies in how they are built.
Training sessionInference session
Builderbuild_session(&graph)build_inference_session(&graph)
Backward passYes, via autodiffNo
Gradient buffersAllocatedNot allocated
Adam state buffersAllocated per parameterNone (adam_state_bytes = 0)
Optimizer fusionsSwiGLU, MatMul+Add, backward fusionsSwiGLU, MatMul+Add, GroupNorm+SiLU
Typical useFine-tuning, training from scratchDeployment, evaluation, benchmarking
Choose build_inference_session whenever you do not need to update weights. This reduces GPU memory allocation and removes the overhead of backward-pass kernel compilation.

Building an inference session

Call build_inference_session with a forward-only graph. The function runs egglog optimization, applies inference-specific fusions (such as GroupNorm+SiLU → GroupNormSilu), compiles the graph to an execution plan, and creates a GPU session.
use meganeura::{Graph, build_inference_session};

let mut g = Graph::new();
let x  = g.input("x", &[1, 784]);
let w1 = g.parameter("w1", &[784, 256]);
let b1 = g.parameter("b1", &[256]);
let h1 = g.bias_add(g.matmul(x, w1), b1);
let h1 = g.relu(h1);
let w2 = g.parameter("w2", &[256, 10]);
let b2 = g.parameter("b2", &[10]);
let logits = g.bias_add(g.matmul(h1, w2), b2);
let probs  = g.softmax(logits);
g.set_outputs(vec![probs]);

let mut session = build_inference_session(&g);
You do not need to call g.set_outputs before passing the graph to build_inference_session, but you must set at least one output for the session to have anything to read back.

Loading weights

Use session.set_parameter to upload weight data from CPU memory to the GPU buffer for a named parameter. The name must match exactly the string passed to g.parameter(...) when building the graph.
// Upload a weight matrix and bias vector
session.set_parameter("w1", &weights_w1);  // &[f32]
session.set_parameter("b1", &weights_b1);
Call set_parameter for every parameter in the graph before the first step(). Parameters retain their values between steps — you only need to call set_parameter again if you want to change the weights.

Running a forward pass

1

Set the input

Upload the input tensor with set_input (for f32 data) or set_input_u32 (for token IDs and other integer inputs).
session.set_input("x", &image_data);      // f32 slice
session.set_input_u32("token_ids", &ids); // u32 slice
2

Dispatch to the GPU

Call step() to submit the full dispatch sequence to the GPU. This returns immediately without blocking.
session.step();
3

Wait for completion

Call wait() to block until the GPU finishes all pending work.
session.wait();
4

Read the output

Read back the primary graph output with read_output, passing the number of elements you expect.
let probs = session.read_output(10); // Vec<f32> with 10 class probabilities
If the graph has multiple outputs, use read_output_by_index to read by position, or check session.num_outputs() first.
let n = session.num_outputs();
let mut buf = vec![0.0f32; expected_len];
session.read_output_by_index(0, &mut buf);

Inspecting memory usage

Call session.memory_summary() to get a breakdown of GPU buffer allocation. The returned MemorySummary struct has four fields:
pub struct MemorySummary {
    /// Total bytes across all GPU buffers.
    pub total_buffer_bytes: usize,
    /// Bytes used by Adam first/second moment buffers.
    /// Always 0 for inference sessions.
    pub adam_state_bytes: usize,
    /// Number of GPU buffers allocated.
    pub num_buffers: usize,
    /// Size of the largest single buffer, in bytes.
    pub largest_buffer_bytes: usize,
}
MemorySummary implements Display, so you can print it directly:
let summary = session.memory_summary();
println!("{}", summary);
// 47 buffers, 312.4 MB total (0.0 MB adam state), largest 141.6 MB
For an inference session, adam_state_bytes is always 0. For a training session built with build_session, it reflects the memory occupied by the per-parameter Adam first and second moment buffers (2 × parameter_size per parameter).

Complete example

The following example mirrors examples/huggingface.rs and shows the full flow from graph construction to reading predictions.
use meganeura::{Graph, build_inference_session};
use meganeura::data::safetensors::SafeTensorsModel;

// 1. Build the graph
let mut g = Graph::new();
let x      = g.input("x", &[1, 784]);
let w1     = g.parameter("input_layer.weight", &[784, 256]);
let b1     = g.parameter("input_layer.bias",   &[256]);
let h1     = g.relu(g.bias_add(g.matmul(x, w1), b1));
let w2     = g.parameter("mid_layer.weight", &[256, 256]);
let b2     = g.parameter("mid_layer.bias",   &[256]);
let h2     = g.relu(g.bias_add(g.matmul(h1, w2), b2));
let w3     = g.parameter("output_layer.weight", &[256, 10]);
let b3     = g.parameter("output_layer.bias",   &[10]);
let logits = g.bias_add(g.matmul(h2, w3), b3);
let probs  = g.softmax(logits);
g.set_outputs(vec![probs]);

// 2. Compile
let mut session = build_inference_session(&g);
println!("{}", session.memory_summary());

// 3. Load weights from a safetensors file
let hf = SafeTensorsModel::download("dacorvo/mnist-mlp")
    .expect("failed to download model");
for name in ["input_layer.weight", "mid_layer.weight", "output_layer.weight"] {
    let data = hf.tensor_f32_transposed(name).unwrap();
    session.set_parameter(name, &data);
}
for name in ["input_layer.bias", "mid_layer.bias", "output_layer.bias"] {
    let data = hf.tensor_f32(name).unwrap();
    session.set_parameter(name, &data);
}

// 4. Run inference
let image: Vec<f32> = raw_pixels.iter()
    .map(|&v| (v - 0.1307) / 0.3081)
    .collect();
session.set_input("x", &image);
session.step();
session.wait();

let probs = session.read_output(10);
let predicted = probs.iter()
    .enumerate()
    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
    .unwrap()
    .0;
println!("predicted class: {}", predicted);
You can reuse a session across many inputs without rebuilding it. Set a new input, call step() and wait(), and read the output again. Parameters are preserved on the GPU between calls.

Build docs developers (and LLMs) love