Skip to main content
Graph is the central data structure in Meganeura. Every call to a builder method appends a node and returns its NodeId. Nodes are identified by a u32 handle and stored in definition order. The graph can be topologically sorted and compiled to a GPU program.
pub type NodeId = u32;
let mut g = Graph::new();
let x = g.input("x", &[4, 784]);
let w = g.parameter("w", &[784, 128]);
let y = g.matmul(x, w);
let h = g.relu(y);
g.set_outputs(vec![h]);

Tensor creation

g.input

Declares an f32 runtime input (e.g. data or activations passed at inference time).
name
&str
required
Unique name used to bind values at runtime.
shape
&[usize]
required
Tensor shape.
NodeId
NodeId
Node of type f32 with the given shape.
let x = g.input("x", &[4, 784]);

g.input_u32

Declares a U32 runtime input. Required for token indices, position counters, and other integer data.
name
&str
required
Unique name used to bind values at runtime.
shape
&[usize]
required
Tensor shape.
NodeId
NodeId
Node of type u32 with the given shape.
let tokens = g.input_u32("tokens", &[seq_len]);
let kv_pos = g.input_u32("kv_pos", &[1]);

g.parameter

Declares a learnable f32 parameter (weight or bias). Parameters are loaded from a checkpoint and updated by the optimizer.
name
&str
required
Unique name used to look up the parameter in the weight file.
shape
&[usize]
required
Parameter shape.
NodeId
NodeId
Node of type f32 with the given shape.
let w = g.parameter("fc.weight", &[784, 128]);
let b = g.parameter("fc.bias", &[128]);

g.constant

Embeds a fixed f32 tensor whose values are known at graph construction time.
data
Vec<f32>
required
Flat data buffer. Length must equal the product of all shape dimensions.
shape
&[usize]
required
Tensor shape.
NodeId
NodeId
Node holding the constant values.
let c = g.constant(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);

g.scalar

Convenience wrapper that creates a constant with shape [1].
value
f32
required
The scalar value.
NodeId
NodeId
Shape [1] constant node.
let one = g.scalar(1.0);

Matrix operations

All matrix ops require 2D tensors.

g.matmul

Standard matrix multiply: C = A @ B.
a
NodeId
required
Shape [M, K].
b
NodeId
required
Shape [K, N].
NodeId
NodeId
Shape [M, N].
let y = g.matmul(x, w);  // [M,K] @ [K,N] → [M,N]

g.matmul_at

Transposed-A matrix multiply: C = A^T @ B. A is stored as [K, M] (i.e. the transpose is implicit — no actual transpose is performed).
a
NodeId
required
Shape [K, M] (stored transposed).
b
NodeId
required
Shape [K, N].
NodeId
NodeId
Shape [M, N].
// Gradient accumulation: grad_w = x^T @ grad_y
let grad_w = g.matmul_at(x, grad_y);  // [K,M]^T @ [K,N] → [M,N]

g.matmul_bt

Transposed-B matrix multiply: C = A @ B^T. B is stored as [N, K] (transposed layout).
a
NodeId
required
Shape [M, K].
b
NodeId
required
Shape [N, K] (stored transposed).
NodeId
NodeId
Shape [M, N].
// Attention scores: Q @ K^T
let scores = g.matmul_bt(q, k);  // [M,K] @ [N,K]^T → [M,N]

Elementwise ops

g.add

Element-wise addition. Both inputs must have the same shape.
a
NodeId
required
Shape [...].
b
NodeId
required
Same shape as a.
NodeId
NodeId
Same shape as inputs.
let out = g.add(x, residual);

g.mul

Element-wise multiplication. Both inputs must have the same shape.
a
NodeId
required
Shape [...].
b
NodeId
required
Same shape as a.
NodeId
NodeId
Same shape as inputs.
let gated = g.mul(gate, up);

g.bias_add

Adds a 1D bias to each row of a 2D tensor: out[i, j] = a[i, j] + bias[j].
a
NodeId
required
2D tensor of shape [M, N].
bias
NodeId
required
1D bias of shape [N].
NodeId
NodeId
Shape [M, N].
let out = g.bias_add(mm, b);

g.broadcast_add

Adds a [1, N] tensor to every row of a [M, N] tensor. Uses the same BiasAdd shader as bias_add.
a
NodeId
required
2D tensor of shape [M, N].
b
NodeId
required
2D tensor of shape [1, N].
NodeId
NodeId
Shape [M, N].
let out = g.broadcast_add(x, pos_bias);  // pos_bias shape: [1, N]

g.greater

Element-wise greater-than comparison for use in autodiff (e.g. ReLU gradient). Both inputs must have the same shape.
a
NodeId
required
Shape [...].
b
NodeId
required
Same shape as a.
NodeId
NodeId
Same shape. Values are 1.0 where a > b, 0.0 otherwise.

g.neg

Element-wise negation: out = -x.
x
NodeId
required
Any shape.
NodeId
NodeId
Same shape.

g.abs

Element-wise absolute value: out = |x|.
x
NodeId
required
Any shape.
NodeId
NodeId
Same shape.

g.log

Element-wise natural logarithm: out = ln(x).
x
NodeId
required
Any shape.
NodeId
NodeId
Same shape.

g.recip

Element-wise reciprocal: out = 1 / x.
x
NodeId
required
Any shape.
NodeId
NodeId
Same shape.

g.div

Element-wise division: out = a / b. Implemented as a * recip(b).
a
NodeId
required
Any shape.
b
NodeId
required
Same shape as a.
NodeId
NodeId
Same shape.

Activations

g.relu

Rectified linear unit: out = max(0, x).
x
NodeId
required
Any shape.
NodeId
NodeId
Same shape.

g.sigmoid

Logistic sigmoid: out = 1 / (1 + exp(-x)).
x
NodeId
required
Any shape.
NodeId
NodeId
Same shape.

g.silu

Sigmoid linear unit: out = x * sigmoid(x).
x
NodeId
required
Any shape.
NodeId
NodeId
Same shape.

g.gelu

Gaussian error linear unit: out = x * 0.5 * (1 + erf(x / sqrt(2))).
x
NodeId
required
Any shape.
NodeId
NodeId
Same shape.

g.swiglu

Fused SwiGLU: out = silu(gate) * up. Both inputs must have the same shape.
gate
NodeId
required
Gate tensor of shape [M, N].
up
NodeId
required
Up tensor of shape [M, N].
NodeId
NodeId
Shape [M, N].
let gate = gate_proj.forward(&mut g, x);
let up = up_proj.forward(&mut g, x);
let h = g.swiglu(gate, up);

g.swiglu_concat

SwiGLU on a concatenated input of shape [M, 2*N]. Reads gate from the first half and up from the second half.
input
NodeId
required
2D tensor of shape [M, 2*N]. The last dimension must be even.
NodeId
NodeId
Shape [M, N].
let h = g.swiglu_concat(fused_proj_output);  // [M, 2*N] → [M, N]

Reductions

g.sum_all

Sums all elements to a scalar.
x
NodeId
required
Any shape.
NodeId
NodeId
Shape [1].

g.mean_all

Averages all elements to a scalar.
x
NodeId
required
Any shape.
NodeId
NodeId
Shape [1].

g.softmax

Row-wise softmax for 2D inputs.
x
NodeId
required
Any shape (row-wise for 2D).
NodeId
NodeId
Same shape.

g.log_softmax

Numerically stable log-softmax (row-wise for 2D inputs).
x
NodeId
required
Any shape.
NodeId
NodeId
Same shape.

g.transpose

Swaps the two dimensions of a 2D tensor: [M, N] → [N, M].
x
NodeId
required
2D tensor of shape [M, N].
NodeId
NodeId
Shape [N, M].
let xt = g.transpose(x);  // [4, 8] → [8, 4]

Embedding ops

g.embedding

Looks up rows from an embedding table.
indices
NodeId
required
1D U32 tensor of shape [seq_len].
table
NodeId
required
2D f32 parameter of shape [vocab_size, embed_dim].
NodeId
NodeId
Shape [seq_len, embed_dim].
let indices = g.input_u32("tokens", &[seq_len]);
let table = g.parameter("embed", &[32000, 512]);
let embeds = g.embedding(indices, table);
// embeds shape: [seq_len, 512]

g.scatter_add

Accumulates source rows into an output tensor indexed by indices. The backward of embedding.
indices
NodeId
required
1D U32 tensor of shape [seq_len].
src
NodeId
required
2D f32 tensor of shape [seq_len, embed_dim].
vocab_size
usize
required
Number of rows in the output accumulator.
NodeId
NodeId
Shape [vocab_size, embed_dim] where output[indices[i]] += src[i].

Spatial ops

All spatial ops work on tensors stored as flat 1D arrays in NCHW order.

g.conv2d

2D convolution: input[N, C_in, H, W] * kernel[C_out, C_in, kH, kW] → output[N, C_out, oH, oW].
input
NodeId
required
Flat tensor of size N * C_in * H * W.
kernel
NodeId
required
Flat kernel of size C_out * C_in * kH * kW.
batch
u32
required
Batch size N.
in_channels
u32
required
Input channels C_in.
in_h
u32
required
Input height H.
in_w
u32
required
Input width W.
out_channels
u32
required
Output channels C_out.
kernel_h
u32
required
Kernel height kH.
kernel_w
u32
required
Kernel width kW.
stride
u32
required
Convolution stride.
padding
u32
required
Zero-padding on each edge.
NodeId
NodeId
Flat tensor of size N * C_out * oH * oW, where oH = (H + 2*padding - kH) / stride + 1.
let y = g.conv2d(x, kernel, 4, 3, 224, 224, 64, 3, 3, 1, 1);

g.concat

Concatenates two flat NCHW tensors along the channel dimension: [N, Ca, H, W] ++ [N, Cb, H, W] → [N, Ca+Cb, H, W].
a
NodeId
required
Flat tensor of size N * Ca * H * W.
b
NodeId
required
Flat tensor of size N * Cb * H * W.
batch
u32
required
Batch size N.
channels_a
u32
required
Number of channels in a.
channels_b
u32
required
Number of channels in b.
spatial
u32
required
Spatial size H * W.
NodeId
NodeId
Flat tensor of size N * (Ca + Cb) * H * W.
let merged = g.concat(skip, x, batch, 32, 64, h * w);

g.split_a

Extracts the first channels_a channels from a concatenated [N, Ca+Cb, H, W] tensor.
x
NodeId
required
Flat tensor of size N * (Ca + Cb) * H * W.
batch
u32
required
Batch size N.
channels_a
u32
required
Channels to extract (first Ca).
channels_b
u32
required
Remaining channels Cb.
spatial
u32
required
Spatial size H * W.
NodeId
NodeId
Flat tensor of size N * Ca * H * W.

g.split_b

Extracts the last channels_b channels from a concatenated [N, Ca+Cb, H, W] tensor.
x
NodeId
required
Flat tensor of size N * (Ca + Cb) * H * W.
batch
u32
required
Batch size N.
channels_a
u32
required
Leading channels Ca.
channels_b
u32
required
Channels to extract (last Cb).
spatial
u32
required
Spatial size H * W.
NodeId
NodeId
Flat tensor of size N * Cb * H * W.

g.upsample_2x

Nearest-neighbor 2× upsampling: [N, C, H, W] → [N, C, 2H, 2W].
x
NodeId
required
Flat input tensor of size N * C * H * W.
batch
u32
required
Batch size N.
channels
u32
required
Number of channels C.
in_h
u32
required
Input height H.
in_w
u32
required
Input width W.
NodeId
NodeId
Flat tensor of size N * C * (2H) * (2W).
let up = g.upsample_2x(x, batch, 128, 16, 16);
// up size: N * 128 * 32 * 32

Build docs developers (and LLMs) love