Graph struct is the central abstraction in Meganeura. You define a model by adding nodes to a graph, then compile the graph into a GPU session for training or inference.
Types
NodeId
NodeId and accept NodeId values as inputs.
DType
32-bit floating point. Used for activations, weights, and gradients.
32-bit unsigned integer. Used for token indices, position offsets, and KV cache positions.
Methods
Returns the byte width of one element. Both
F32 and U32 return 4.TensorType
Dimension sizes, from outermost to innermost (e.g.
[batch, features]).Element type of the tensor.
Methods
Constructs a
TensorType with the given shape and dtype.Shorthand for
TensorType::new(shape, DType::F32).Product of all shape dimensions. A
[32, 784] tensor has 25088 elements.Total byte size:
num_elements() * dtype.size_bytes().Number of dimensions.
TensorType::f32(vec![4, 3]).rank() returns 2.Node
The node’s unique index within its graph.
The operation this node computes.
Ordered list of input node IDs.
Shape and dtype of the output tensor produced by this node.
Graph
Constructor
Creates an empty graph with no nodes and no outputs.
Leaf nodes
Leaf nodes have no inputs. They represent data entering the graph from outside: runtime inputs, trainable weights, or fixed constants.input
Name used to identify this input when calling
session.set_input().Shape of the input tensor. Creates a
DType::F32 tensor.NodeId. The session must supply data for this name before each forward pass.
input_u32
Like input, but creates a DType::U32 tensor. Used for token IDs and position indices.
Name used with
session.set_input_u32().Shape of the input tensor.
parameter
Name used with
session.set_parameter() and session.get_parameter().Shape of the parameter tensor. Always
DType::F32.constant
Initial values. Length must equal the product of all shape dimensions.
Shape of the constant tensor.
scalar
The scalar value.
constant(vec![value], &[1]).
Binary ops
matmul
Matrix multiplication: C = A × B.
Left matrix
[M, K].Right matrix
[K, N].[M, N]. Both inputs must be 2D and their inner dimensions must match.
matmul_at
Transposed-A matrix multiplication: C = A^T × B.
Matrix stored as
[K, M].Matrix stored as
[K, N].[M, N]. K dimensions of A and B must match.
matmul_bt
Transposed-B matrix multiplication: C = A × B^T.
Matrix
[M, K].Matrix stored as
[N, K].[M, N]. K dimensions of A and B must match.
add
Element-wise addition. Both inputs must have identical shapes.
Tensor.
Tensor with same shape as
a.mul
Element-wise multiplication. Both inputs must have identical shapes.
Tensor.
Tensor with same shape as
a.bias_add
Adds a 1D bias to a 2D input: [M, N] + [N] → [M, N].
2D input tensor
[M, N].1D bias tensor
[N].broadcast_add
Adds a [1, N] tensor across a [M, N] tensor. Uses the same underlying shader as bias_add.
2D input
[M, N].2D addend
[1, N].greater
Element-wise greater-than comparison. Returns 1.0 where a > b, 0.0 otherwise. Both inputs must have identical shapes.
Tensor.
Tensor with same shape as
a.Unary ops
All unary ops preserve the shape and dtype of their input.relu
Rectified linear unit: max(0, x).
sigmoid
Logistic sigmoid: 1 / (1 + exp(-x)).
neg
Element-wise negation: -x.
abs
Element-wise absolute value: |x|.
log
Element-wise natural logarithm: ln(x).
recip
Element-wise reciprocal: 1 / x.
div
Element-wise division: a / b. Implemented as mul(a, recip(b)).
Numerator tensor.
Denominator tensor. Must have the same shape as
a.Reductions
sum_all
Reduces all elements to a scalar: output shape is [1].
Input tensor of any shape.
mean_all
Reduces all elements to their mean: output shape is [1].
Input tensor of any shape.
softmax
Row-wise softmax. Output has the same shape as the input.
2D input tensor
[batch, features].log_softmax
Numerically stable log-softmax. Output has the same shape as the input.
2D input tensor
[batch, features].Loss functions
All loss functions return a scalar[1] node.
cross_entropy_loss
Raw (pre-softmax) scores
[batch, num_classes].Target distribution
[batch, num_classes]. Must match the shape of logits.bce_loss
Binary cross-entropy: -mean(t * log(p) + (1 - t) * log(1 - p)).
Predictions in
(0, 1), typically after sigmoid.Binary targets. Must match the shape of
pred.mse_loss
Mean squared error: mean((pred - target)²).
Predictions.
Targets. Must match the shape of
pred.l1_loss
Mean absolute error: mean(|pred - target|).
Predictions.
Targets. Must match the shape of
pred.Transformer ops
silu
SiLU activation: x * sigmoid(x). Preserves shape.
swiglu
Fused SwiGLU: silu(gate) * up. Both inputs must have the same shape.
Gate tensor.
Up-projection tensor. Same shape as
gate.rms_norm
RMS normalization: x / sqrt(mean(x²) + eps) * weight.
2D input
[M, N].1D scale parameter
[N].Small constant for numerical stability (e.g.
1e-5).rope
Rotary position embeddings applied to a 2D tensor. Position index for row i is i.
2D input
[seq, dim]. dim must be even and divisible by head_dim.Rotary base frequency (e.g.
10000.0).Size of each attention head. RoPE rotations are applied independently per head.
rope_with_offset
RoPE with a static position offset. Position for row i is i + pos_offset.
2D input
[seq, dim].Rotary base frequency.
Static offset added to each row’s position.
Attention head dimension.
rope_dynamic_offset
RoPE with a dynamic position offset read from a U32 input buffer at runtime. Position for row i is i + offset_buf[0].
2D input
[seq, dim].Rotary base frequency.
A
U32 input node whose first element is added to each row’s position index.Attention head dimension.
causal_attention
Fused causal multi-head attention with optional GQA. Applies a causal mask so each token attends only to itself and earlier tokens.
[seq, num_heads * head_dim][seq, num_kv_heads * head_dim][seq, num_kv_heads * head_dim]Number of query heads.
Number of key/value heads. Set equal to
num_heads for standard MHA.Dimension per head.
[seq, num_heads * head_dim].
full_attention
Non-causal (bidirectional) multi-head attention. Same signature as causal_attention but without the causal mask.
cross_attention
Cross-attention where query attends to key/value from a different sequence.
[q_seq, num_heads * head_dim][kv_seq, num_kv_heads * head_dim][kv_seq, num_kv_heads * head_dim]Number of query heads.
Number of key/value heads.
Dimension per head.
[q_seq, num_heads * head_dim].
multi_head_attn
Differentiable multi-head attention that saves the log-sum-exp (LSE) tensor needed for the backward pass. Supports both self-attention and cross-attention via is_cross.
[q_seq, num_heads * head_dim][kv_seq, num_kv_heads * head_dim][kv_seq, num_kv_heads * head_dim]Number of query heads.
Number of key/value heads.
Dimension per head.
true for cross-attention (q_seq ≠ kv_seq), false for self-attention.[q_seq, num_heads * head_dim].
embedding
Looks up rows in an embedding table by index.
1D
U32 tensor [seq_len] containing row indices.2D
F32 parameter [vocab_size, embed_dim].[seq_len, embed_dim].
scatter_add
Accumulates source rows into an output tensor indexed by indices: output[indices[i]] += src[i].
1D
U32 index tensor.2D source tensor
[seq_len, embed_dim].Number of rows in the output tensor.
[vocab_size, embed_dim].
Normalization
layer_norm
Standard layer normalization: (x - mean) / sqrt(var + eps) * weight + bias.
2D input
[M, N].1D scale parameter
[N].1D bias parameter
[N].Stability constant (e.g.
1e-5).[M, N].
Convolution
conv2d
2D convolution. Input and output tensors are flat 1D arrays in NCHW order.
Flat input
[N * C_in * H * W].Flat kernel
[C_out * C_in * kH * kW].Batch size N.
Input channel count.
Input height.
Input width.
Output channel count.
Kernel height.
Kernel width.
Convolution stride.
Zero-padding on each spatial edge.
[N * C_out * out_H * out_W] tensor.
group_norm
Group normalization over a flat NCHW tensor.
Flat input
[N * C * H * W].Scale
[C].Bias
[C].Batch size (passed for dispatch sizing).
Channel count C.
H × W.
Number of groups to normalize over.
Stability constant.
concat
Concatenates two tensors along the channel dimension (NCHW layout).
First flat tensor
[N * Ca * H * W].Second flat tensor
[N * Cb * H * W].Batch size N.
Channel count of
a.Channel count of
b.H × W.
[N * (Ca + Cb) * H * W].
split_a
Extracts the first channels_a channels from a concatenated NCHW tensor.
Flat tensor
[N * (Ca + Cb) * H * W].Batch size N.
Channels to extract.
Remaining channels.
H × W.
[N * Ca * H * W].
split_b
Extracts the last channels_b channels from a concatenated NCHW tensor. Same parameters as split_a.
upsample_2x
Nearest-neighbor 2× upsampling: [N, C, H, W] → [N, C, 2H, 2W].
Flat input tensor.
Batch size N.
Channel count C.
Input height H.
Input width W.
KV cache
cache_write
Writes a single-row tensor new_kv [1, dim] into row kv_pos of a cache buffer cache [max_seq, dim].
New key or value
[1, dim].Cache buffer
[max_seq, dim].U32 scalar input node indicating the position to write.cached_attention
Attention where Q comes from the current token and K/V come from a pre-filled cache.
[1, num_heads * head_dim] — current query.[max_seq, kv_dim] — key cache.[max_seq, kv_dim] — value cache.U32 scalar node containing the number of valid positions in the cache.Number of query heads.
Number of key/value heads.
Dimension per head.
[1, num_heads * head_dim].
Utility methods
set_outputs
Node IDs to mark as graph outputs. The session will make these buffers readable.
outputs
Returns &[NodeId] — the list of output node IDs previously set with set_outputs.
nodes
Returns &[Node] — all nodes in the graph, indexed by their NodeId.
node
Node to retrieve.
&Node.
toposort
Rebuilds the graph in topological order, removing Nop (dead) nodes and compacting IDs. Returns a new Graph where every node’s input IDs are strictly less than its own ID.
build_session before running autodiff.