Skip to main content
After the graph is optimized and compiled, Meganeura hands the ExecutionPlan to a Session. The session owns all GPU resources and replays the pre-compiled dispatch sequence on every training step.

blade-graphics

Meganeura uses blade-graphics as its GPU abstraction layer. blade-graphics wraps Vulkan (Linux / Windows) and Metal (macOS) behind a single API. All buffer creation, shader compilation, and command encoding go through blade-graphics — you never write Vulkan or Metal code directly. The GPU context is created inside Session::new:
let gpu = unsafe {
    blade_graphics::Context::init(blade_graphics::ContextDesc {
        validation: cfg!(debug_assertions),
        timing: true,
        ..Default::default()
    })
}
.expect("failed to initialize blade GPU context");
Set the MEGANEURA_DEVICE_ID environment variable to select a specific GPU on multi-adapter systems.

ExecutionPlan

compile::compile() produces an ExecutionPlan — a fully static description of everything the GPU needs to do:
pub struct ExecutionPlan {
    /// Buffer sizes in bytes, indexed by BufferRef.
    pub buffers: Vec<usize>,
    /// Which buffers hold parameters (need initialization).
    pub param_buffers: Vec<(String, BufferRef)>,
    /// Which buffers hold inputs (filled each step).
    pub input_buffers: Vec<(String, BufferRef)>,
    /// Constant buffers with their initial data (uploaded once).
    pub constant_buffers: Vec<(BufferRef, Vec<f32>)>,
    /// The dispatch sequence.
    pub dispatches: Vec<Dispatch>,
    /// Index of the loss buffer (first graph output).
    pub loss_buffer: Option<BufferRef>,
    /// All graph output buffers.
    pub output_buffers: Vec<BufferRef>,
    /// Parameter buffer → gradient buffer mapping (for SGD/Adam).
    pub param_grad_pairs: Vec<(BufferRef, BufferRef)>,
    /// LSE buffers for MultiHeadAttn forward nodes.
    pub lse_buffers: Vec<(NodeId, BufferRef)>,
    /// Derived parameters (e.g. fused gate+up projections).
    pub derived_params: Vec<(BufferRef, Vec<(String, usize)>)>,
}
BufferRef is a typed u32 index into buffers. Every node in the graph gets one buffer; leaf nodes (Input, Parameter, Constant) have their buffers filled before dispatch execution begins.

Session

Session::new(plan) allocates all GPU buffers, compiles the required shaders, and reorders dispatches for optimal barrier grouping:
pub struct Session {
    // GPU context and buffers (blade_graphics types)
    // Compiled compute pipelines (scalar + cooperative + small-tile variants)
    // Pre-computed barrier groups
    // Adam state buffers (m, v per parameter)
    // ...
}
The primary interface:
MethodDescription
set_parameter(name, data)Upload f32 data to a named parameter buffer
set_input(name, data)Upload f32 data to a named input buffer
set_input_u32(name, data)Upload u32 data (e.g. token IDs)
set_learning_rate(lr)Schedule an SGD update on the next step()
set_adam(lr, b1, b2, eps)Schedule an Adam update on the next step()
step()Submit all dispatches (forward + backward + optimizer update)
wait()Block until the GPU submission completes
read_loss()Read back the scalar loss value
read_output(len)Read back the primary output tensor

Cooperative matrix operations

Modern GPUs expose hardware matrix-multiply units that operate on small tiles (e.g. 16×16 or 32×32 f16 matrices) directly in shader registers. On Vulkan these are surfaced via the VK_KHR_cooperative_matrix extension; blade-graphics exposes them through CooperativeMatrix capabilities. Meganeura queries the GPU at session creation time and selects the best tile configuration:
fn select_coop_config(
    caps: &blade_graphics::CooperativeMatrix,
) -> Option<crate::codegen::CoopConfig> {
    if caps.f16_tile > 0 {
        Some(CoopConfig { tile_size: caps.f16_tile, use_f16_input: true })
    } else if caps.f32_tile > 0 {
        Some(CoopConfig { tile_size: caps.f32_tile, use_f16_input: false })
    } else {
        None
    }
}
f16 input is preferred when available because it provides higher throughput. A correctness self-test (test_coop_matmul) is run immediately after selection — some drivers (e.g. AMD RADV) advertise the extension but reject specific matrix shapes at shader creation time.

MIN_COOP_WORKGROUPS: scalar vs. cooperative path

The cooperative path is not unconditionally faster. It uses a 2×2 tile layout, so each wave covers a 2*tile_size × 2*tile_size output region. For small matrices the workgroup count is too low to saturate the GPU, and the scalar tiled kernel runs faster. Meganeura evaluates each dispatch individually:
/// Minimum workgroup count below which the cooperative-matrix path is skipped.
const MIN_COOP_WORKGROUPS: u32 = 128;
/// Lower threshold when K ≥ 1024 (high arithmetic intensity amortises coop overhead).
const MIN_COOP_WORKGROUPS_HIGH_K: u32 = 32;
For a dispatch with shape [M, N]:
coop_wgs = ceil(M / output_tile) * ceil(N / output_tile)
If coop_wgs >= MIN_COOP_WORKGROUPS (or MIN_COOP_WORKGROUPS_HIGH_K when K ≥ 1024), the dispatch is marked use_coop = true and its workgroup count is recomputed for the coop tile size. Otherwise the scalar 64×64 tiled kernel is used.
For SmolVLA with chunk_size=50, the largest matmuls have m=50, n=2048, producing only ceil(50/32)*ceil(2048/32) = 2 * 64 = 128 workgroups — exactly at the threshold. Enabling the coop path for these shapes yields about 50% throughput improvement on discrete GPUs with K ≥ 1024 (high-K backward gradient matmuls).

Dispatch barrier groups

Within a single training step, many dispatches are independent (e.g. the Q, K, and V projection matmuls can all execute concurrently). Meganeura groups dispatches into barrier groups to maximise concurrency:
  1. Dispatches are sorted by dependency level (Kahn’s algorithm over buffer write→read edges).
  2. All dispatches at the same level form one compute pass with no internal barriers.
  3. A pass boundary (ALL_COMMANDS barrier in blade-graphics) separates levels.
This means independent parallel branches automatically cluster together without any manual annotation.

MemorySummary

Session::memory_summary() returns a breakdown of GPU memory usage:
pub struct MemorySummary {
    pub total_buffer_bytes: usize,
    pub adam_state_bytes: usize,
    pub num_buffers: usize,
    pub largest_buffer_bytes: usize,
}
Printing it gives a human-readable summary:
42 buffers, 128.4 MB total (64.2 MB adam state), largest 32.0 MB
adam_state_bytes is the combined size of all first- and second-moment buffers (m and v in the Adam optimizer) — each parameter gets two copies of its own size for Adam state.

Perfetto trace output

Session integrates with the meganeura::profiler module to produce Perfetto binary traces (.pftrace). CPU-side work appears as nested spans on the CPU track. GPU pass durations come from blade-graphics hardware timestamp queries and appear on a separate GPU track. Enable profiling before calling step():
meganeura::profiler::init();   // install tracing subscriber

session.set_profiling(true);
session.step();
session.wait();
session.step();  // timings from the previous step are collected here
session.wait();

session.dump_gpu_timings();    // prints per-pass aggregated timings to stderr

meganeura::profiler::save("trace.pftrace").unwrap();
Open trace.pftrace in the Perfetto UI to see the full timeline.

Next steps

Computation graph

How to build and inspect the graph that the session executes.

Profiling

How to capture and interpret Perfetto traces for performance analysis.

Build docs developers (and LLMs) love