Skip to main content
The Graph struct is the central abstraction in Meganeura. You define a model by adding nodes to a graph, then compile the graph into a GPU session for training or inference.
use meganeura::Graph;

let mut g = Graph::new();
let x = g.input("x", &[4, 784]);
let w = g.parameter("w", &[784, 128]);
let y = g.matmul(x, w);
let h = g.relu(y);
g.set_outputs(vec![h]);

Types

NodeId

pub type NodeId = u32;
A handle to a node in the graph. All graph builder methods return a NodeId and accept NodeId values as inputs.

DType

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum DType {
    F32,
    U32,
}
The element type of a tensor.
F32
variant
32-bit floating point. Used for activations, weights, and gradients.
U32
variant
32-bit unsigned integer. Used for token indices, position offsets, and KV cache positions.

Methods

size_bytes
fn(self) -> usize
Returns the byte width of one element. Both F32 and U32 return 4.

TensorType

pub struct TensorType {
    pub shape: Vec<usize>,
    pub dtype: DType,
}
Describes the shape and element type of a tensor node.
shape
Vec<usize>
Dimension sizes, from outermost to innermost (e.g. [batch, features]).
dtype
DType
Element type of the tensor.

Methods

new
fn(shape: Vec<usize>, dtype: DType) -> TensorType
Constructs a TensorType with the given shape and dtype.
f32
fn(shape: Vec<usize>) -> TensorType
Shorthand for TensorType::new(shape, DType::F32).
num_elements
fn(&self) -> usize
Product of all shape dimensions. A [32, 784] tensor has 25088 elements.
size_bytes
fn(&self) -> usize
Total byte size: num_elements() * dtype.size_bytes().
rank
fn(&self) -> usize
Number of dimensions. TensorType::f32(vec![4, 3]).rank() returns 2.

Node

pub struct Node {
    pub id: NodeId,
    pub op: Op,
    pub inputs: Vec<NodeId>,
    pub ty: TensorType,
}
A single node in the computation graph.
id
NodeId
The node’s unique index within its graph.
op
Op
The operation this node computes.
inputs
Vec<NodeId>
Ordered list of input node IDs.
ty
TensorType
Shape and dtype of the output tensor produced by this node.

Graph

pub struct Graph { /* ... */ }

Constructor

new
fn() -> Graph
Creates an empty graph with no nodes and no outputs.
let mut g = Graph::new();

Leaf nodes

Leaf nodes have no inputs. They represent data entering the graph from outside: runtime inputs, trainable weights, or fixed constants.

input

name
&str
required
Name used to identify this input when calling session.set_input().
shape
&[usize]
required
Shape of the input tensor. Creates a DType::F32 tensor.
Returns a NodeId. The session must supply data for this name before each forward pass.
let x = g.input("x", &[4, 784]);

input_u32

Like input, but creates a DType::U32 tensor. Used for token IDs and position indices.
name
&str
required
Name used with session.set_input_u32().
shape
&[usize]
required
Shape of the input tensor.
let tokens = g.input_u32("tokens", &[32]);

parameter

name
&str
required
Name used with session.set_parameter() and session.get_parameter().
shape
&[usize]
required
Shape of the parameter tensor. Always DType::F32.
Creates a trainable parameter node. The optimizer updates this buffer after each backward pass.
let w = g.parameter("w", &[784, 128]);
let b = g.parameter("b", &[128]);

constant

data
Vec<f32>
required
Initial values. Length must equal the product of all shape dimensions.
shape
&[usize]
required
Shape of the constant tensor.
Creates a fixed-value node. The data is uploaded to GPU at session creation and never changes.
let c = g.constant(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);

scalar

value
f32
required
The scalar value.
Shorthand for constant(vec![value], &[1]).
let scale = g.scalar(0.5);

Binary ops

matmul

Matrix multiplication: C = A × B.
a
NodeId
required
Left matrix [M, K].
b
NodeId
required
Right matrix [K, N].
Returns [M, N]. Both inputs must be 2D and their inner dimensions must match.
let y = g.matmul(x, w); // x: [4, 784], w: [784, 128] → [4, 128]

matmul_at

Transposed-A matrix multiplication: C = A^T × B.
a
NodeId
required
Matrix stored as [K, M].
b
NodeId
required
Matrix stored as [K, N].
Returns [M, N]. K dimensions of A and B must match.

matmul_bt

Transposed-B matrix multiplication: C = A × B^T.
a
NodeId
required
Matrix [M, K].
b
NodeId
required
Matrix stored as [N, K].
Returns [M, N]. K dimensions of A and B must match.

add

Element-wise addition. Both inputs must have identical shapes.
a
NodeId
required
Tensor.
b
NodeId
required
Tensor with same shape as a.

mul

Element-wise multiplication. Both inputs must have identical shapes.
a
NodeId
required
Tensor.
b
NodeId
required
Tensor with same shape as a.

bias_add

Adds a 1D bias to a 2D input: [M, N] + [N] → [M, N].
a
NodeId
required
2D input tensor [M, N].
bias
NodeId
required
1D bias tensor [N].
let h = g.bias_add(mm, b); // mm: [4, 128], b: [128] → [4, 128]

broadcast_add

Adds a [1, N] tensor across a [M, N] tensor. Uses the same underlying shader as bias_add.
a
NodeId
required
2D input [M, N].
b
NodeId
required
2D addend [1, N].

greater

Element-wise greater-than comparison. Returns 1.0 where a > b, 0.0 otherwise. Both inputs must have identical shapes.
a
NodeId
required
Tensor.
b
NodeId
required
Tensor with same shape as a.

Unary ops

All unary ops preserve the shape and dtype of their input.

relu

Rectified linear unit: max(0, x).

sigmoid

Logistic sigmoid: 1 / (1 + exp(-x)).

neg

Element-wise negation: -x.

abs

Element-wise absolute value: |x|.

log

Element-wise natural logarithm: ln(x).

recip

Element-wise reciprocal: 1 / x.

div

Element-wise division: a / b. Implemented as mul(a, recip(b)).
a
NodeId
required
Numerator tensor.
b
NodeId
required
Denominator tensor. Must have the same shape as a.

Reductions

sum_all

Reduces all elements to a scalar: output shape is [1].
x
NodeId
required
Input tensor of any shape.

mean_all

Reduces all elements to their mean: output shape is [1].
x
NodeId
required
Input tensor of any shape.

softmax

Row-wise softmax. Output has the same shape as the input.
x
NodeId
required
2D input tensor [batch, features].

log_softmax

Numerically stable log-softmax. Output has the same shape as the input.
x
NodeId
required
2D input tensor [batch, features].

Loss functions

All loss functions return a scalar [1] node.

cross_entropy_loss

logits
NodeId
required
Raw (pre-softmax) scores [batch, num_classes].
labels
NodeId
required
Target distribution [batch, num_classes]. Must match the shape of logits.
let loss = g.cross_entropy_loss(logits, labels);
g.set_outputs(vec![loss]);

bce_loss

Binary cross-entropy: -mean(t * log(p) + (1 - t) * log(1 - p)).
pred
NodeId
required
Predictions in (0, 1), typically after sigmoid.
labels
NodeId
required
Binary targets. Must match the shape of pred.

mse_loss

Mean squared error: mean((pred - target)²).
pred
NodeId
required
Predictions.
target
NodeId
required
Targets. Must match the shape of pred.

l1_loss

Mean absolute error: mean(|pred - target|).
pred
NodeId
required
Predictions.
target
NodeId
required
Targets. Must match the shape of pred.

Transformer ops

silu

SiLU activation: x * sigmoid(x). Preserves shape.

swiglu

Fused SwiGLU: silu(gate) * up. Both inputs must have the same shape.
gate
NodeId
required
Gate tensor.
up
NodeId
required
Up-projection tensor. Same shape as gate.

rms_norm

RMS normalization: x / sqrt(mean(x²) + eps) * weight.
x
NodeId
required
2D input [M, N].
weight
NodeId
required
1D scale parameter [N].
eps
f32
required
Small constant for numerical stability (e.g. 1e-5).
let w_norm = g.parameter("norm.weight", &[hidden_dim]);
let normed = g.rms_norm(x, w_norm, 1e-5);

rope

Rotary position embeddings applied to a 2D tensor. Position index for row i is i.
x
NodeId
required
2D input [seq, dim]. dim must be even and divisible by head_dim.
theta
f32
required
Rotary base frequency (e.g. 10000.0).
head_dim
u32
required
Size of each attention head. RoPE rotations are applied independently per head.

rope_with_offset

RoPE with a static position offset. Position for row i is i + pos_offset.
x
NodeId
required
2D input [seq, dim].
theta
f32
required
Rotary base frequency.
pos_offset
u32
required
Static offset added to each row’s position.
head_dim
u32
required
Attention head dimension.

rope_dynamic_offset

RoPE with a dynamic position offset read from a U32 input buffer at runtime. Position for row i is i + offset_buf[0].
x
NodeId
required
2D input [seq, dim].
theta
f32
required
Rotary base frequency.
offset_input
NodeId
required
A U32 input node whose first element is added to each row’s position index.
head_dim
u32
required
Attention head dimension.

causal_attention

Fused causal multi-head attention with optional GQA. Applies a causal mask so each token attends only to itself and earlier tokens.
q
NodeId
required
[seq, num_heads * head_dim]
k
NodeId
required
[seq, num_kv_heads * head_dim]
v
NodeId
required
[seq, num_kv_heads * head_dim]
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads. Set equal to num_heads for standard MHA.
head_dim
u32
required
Dimension per head.
Returns [seq, num_heads * head_dim].

full_attention

Non-causal (bidirectional) multi-head attention. Same signature as causal_attention but without the causal mask.

cross_attention

Cross-attention where query attends to key/value from a different sequence.
q
NodeId
required
[q_seq, num_heads * head_dim]
k
NodeId
required
[kv_seq, num_kv_heads * head_dim]
v
NodeId
required
[kv_seq, num_kv_heads * head_dim]
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads.
head_dim
u32
required
Dimension per head.
Returns [q_seq, num_heads * head_dim].

multi_head_attn

Differentiable multi-head attention that saves the log-sum-exp (LSE) tensor needed for the backward pass. Supports both self-attention and cross-attention via is_cross.
q
NodeId
required
[q_seq, num_heads * head_dim]
k
NodeId
required
[kv_seq, num_kv_heads * head_dim]
v
NodeId
required
[kv_seq, num_kv_heads * head_dim]
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads.
head_dim
u32
required
Dimension per head.
is_cross
bool
required
true for cross-attention (q_seq ≠ kv_seq), false for self-attention.
Returns [q_seq, num_heads * head_dim].

embedding

Looks up rows in an embedding table by index.
indices
NodeId
required
1D U32 tensor [seq_len] containing row indices.
table
NodeId
required
2D F32 parameter [vocab_size, embed_dim].
Returns [seq_len, embed_dim].
let tokens = g.input_u32("tokens", &[seq_len]);
let embed_table = g.parameter("embed.weight", &[vocab_size, hidden]);
let embeds = g.embedding(tokens, embed_table);

scatter_add

Accumulates source rows into an output tensor indexed by indices: output[indices[i]] += src[i].
indices
NodeId
required
1D U32 index tensor.
src
NodeId
required
2D source tensor [seq_len, embed_dim].
vocab_size
usize
required
Number of rows in the output tensor.
Returns [vocab_size, embed_dim].

Normalization

layer_norm

Standard layer normalization: (x - mean) / sqrt(var + eps) * weight + bias.
x
NodeId
required
2D input [M, N].
weight
NodeId
required
1D scale parameter [N].
bias
NodeId
required
1D bias parameter [N].
eps
f32
required
Stability constant (e.g. 1e-5).
Returns [M, N].
let w = g.parameter("ln.weight", &[hidden]);
let b = g.parameter("ln.bias", &[hidden]);
let out = g.layer_norm(x, w, b, 1e-5);

Convolution

conv2d

2D convolution. Input and output tensors are flat 1D arrays in NCHW order.
input
NodeId
required
Flat input [N * C_in * H * W].
kernel
NodeId
required
Flat kernel [C_out * C_in * kH * kW].
batch
u32
required
Batch size N.
in_channels
u32
required
Input channel count.
in_h
u32
required
Input height.
in_w
u32
required
Input width.
out_channels
u32
required
Output channel count.
kernel_h
u32
required
Kernel height.
kernel_w
u32
required
Kernel width.
stride
u32
required
Convolution stride.
padding
u32
required
Zero-padding on each spatial edge.
Returns a flat [N * C_out * out_H * out_W] tensor.

group_norm

Group normalization over a flat NCHW tensor.
x
NodeId
required
Flat input [N * C * H * W].
weight
NodeId
required
Scale [C].
bias
NodeId
required
Bias [C].
_batch
u32
required
Batch size (passed for dispatch sizing).
channels
u32
required
Channel count C.
spatial
u32
required
H × W.
num_groups
u32
required
Number of groups to normalize over.
eps
f32
required
Stability constant.

concat

Concatenates two tensors along the channel dimension (NCHW layout).
a
NodeId
required
First flat tensor [N * Ca * H * W].
b
NodeId
required
Second flat tensor [N * Cb * H * W].
batch
u32
required
Batch size N.
channels_a
u32
required
Channel count of a.
channels_b
u32
required
Channel count of b.
spatial
u32
required
H × W.
Returns [N * (Ca + Cb) * H * W].

split_a

Extracts the first channels_a channels from a concatenated NCHW tensor.
x
NodeId
required
Flat tensor [N * (Ca + Cb) * H * W].
batch
u32
required
Batch size N.
channels_a
u32
required
Channels to extract.
channels_b
u32
required
Remaining channels.
spatial
u32
required
H × W.
Returns [N * Ca * H * W].

split_b

Extracts the last channels_b channels from a concatenated NCHW tensor. Same parameters as split_a.

upsample_2x

Nearest-neighbor 2× upsampling: [N, C, H, W] → [N, C, 2H, 2W].
x
NodeId
required
Flat input tensor.
batch
u32
required
Batch size N.
channels
u32
required
Channel count C.
in_h
u32
required
Input height H.
in_w
u32
required
Input width W.

KV cache

cache_write

Writes a single-row tensor new_kv [1, dim] into row kv_pos of a cache buffer cache [max_seq, dim].
new_kv
NodeId
required
New key or value [1, dim].
cache
NodeId
required
Cache buffer [max_seq, dim].
kv_pos
NodeId
required
U32 scalar input node indicating the position to write.
Returns a node referencing the updated cache buffer.

cached_attention

Attention where Q comes from the current token and K/V come from a pre-filled cache.
q
NodeId
required
[1, num_heads * head_dim] — current query.
k_cache
NodeId
required
[max_seq, kv_dim] — key cache.
v_cache
NodeId
required
[max_seq, kv_dim] — value cache.
kv_pos
NodeId
required
U32 scalar node containing the number of valid positions in the cache.
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads.
head_dim
u32
required
Dimension per head.
Returns [1, num_heads * head_dim].

Utility methods

set_outputs

outputs
Vec<NodeId>
required
Node IDs to mark as graph outputs. The session will make these buffers readable.
g.set_outputs(vec![loss]);

outputs

Returns &[NodeId] — the list of output node IDs previously set with set_outputs.

nodes

Returns &[Node] — all nodes in the graph, indexed by their NodeId.

node

id
NodeId
required
Node to retrieve.
Returns &Node.

toposort

Rebuilds the graph in topological order, removing Nop (dead) nodes and compacting IDs. Returns a new Graph where every node’s input IDs are strictly less than its own ID.
let sorted = graph.toposort();
This is called automatically by build_session before running autodiff.

Build docs developers (and LLMs) love