Skip to main content
The nn module provides thin, zero-overhead wrappers over the low-level Graph API. Each struct holds NodeIds for its parameters and exposes a forward() method that appends operations to a Graph. There is no trait hierarchy and no dynamic dispatch.
use meganeura::{Graph, nn};

let mut g = Graph::new();
let x = g.input("x", &[batch, 784]);
let labels = g.input("labels", &[batch, 10]);

let fc1 = nn::Linear::new(&mut g, "fc1", 784, 128);
let fc2 = nn::Linear::new(&mut g, "fc2", 128, 10);

let h = fc1.forward(&mut g, x);
let h = g.relu(h);
let logits = fc2.forward(&mut g, h);
let loss = g.cross_entropy_loss(logits, labels);

Linear layers

nn::Linear

Fully connected layer: y = x @ weight + bias.
pub struct Linear {
    pub weight: NodeId,
    pub bias: Option<NodeId>,
}
Constructors
// With bias — registers `name.weight` [in, out] and `name.bias` [out]
let fc = nn::Linear::new(&mut g, "fc1", 784, 128);

// Without bias — registers only `name.weight` [in, out]
let proj = nn::Linear::no_bias(&mut g, "q_proj", 512, 512);
Forward
let y = fc.forward(&mut g, x);  // x: [M, in] → y: [M, out]
When bias is present, forward calls g.bias_add after the matrix multiply. When absent, it returns the raw matmul result.

nn::Embedding

Token embedding lookup: maps integer indices to rows of a weight table.
pub struct Embedding {
    pub weight: NodeId,
}
// Registers `name` as a [vocab_size, embed_dim] parameter
let embed = nn::Embedding::new(&mut g, "model.embed_tokens.weight", 32000, 512);

// indices must be a U32 input: g.input_u32(...)
let tokens = g.input_u32("tokens", &[seq_len]);
let x = embed.forward(&mut g, tokens);  // → [seq_len, embed_dim]
The embedding weight name is not suffixed with .weight automatically. Pass the full parameter name (e.g. "model.embed_tokens.weight") to match HuggingFace checkpoint keys.

Normalization

nn::RmsNorm

Root-mean-square layer normalization: x / sqrt(mean(x²) + eps) * weight.
pub struct RmsNorm {
    pub weight: NodeId,
    pub eps: f32,
}
let norm = nn::RmsNorm::new(&mut g, "model.layers.0.input_layernorm.weight", 512, 1e-5);
let y = norm.forward(&mut g, x);  // x: [M, dim] → y: [M, dim]
RmsNorm::new registers a single [dim] scale parameter under the given name.

nn::LayerNorm

Standard layer normalization with weight and bias: (x - mean) / sqrt(var + eps) * weight + bias.
pub struct LayerNorm {
    pub weight: NodeId,
    pub bias: NodeId,
    pub eps: f32,
}
let ln = nn::LayerNorm::new(&mut g, "ln1", 512, 1e-5);
let y = ln.forward(&mut g, x);  // x: [M, dim] → y: [M, dim]
LayerNorm::new registers name.weight and name.bias, both [dim].

Activations and FFN

nn::Mlp

Two-layer MLP with a configurable activation function: fc2(act(fc1(x))).
pub struct Mlp {
    pub fc1: Linear,
    pub fc2: Linear,
    pub activation: Activation,
}

pub enum Activation {
    Relu,
    Gelu,
    Silu,
    Sigmoid,
}
let mlp = nn::Mlp::new(
    &mut g,
    "mlp",
    /*in_dim=*/   512,
    /*hidden_dim=*/ 2048,
    /*out_dim=*/  512,
    nn::Activation::Gelu,
);
let y = mlp.forward(&mut g, x);
Mlp::new registers name.fc1.weight, name.fc1.bias, name.fc2.weight, and name.fc2.bias.

nn::SwiGluFfn

SwiGLU feed-forward network used in LLaMA-style transformers: down(silu(gate(x)) * up(x)).
pub struct SwiGluFfn {
    pub gate: Linear,
    pub up: Linear,
    pub down: Linear,
}
let ffn = nn::SwiGluFfn::new(&mut g, "model.layers.0.mlp", 512, 1024);
let y = ffn.forward(&mut g, x);  // x: [M, hidden] → y: [M, hidden]
SwiGluFfn::new registers:
ParameterShape
name.gate_proj.weight[hidden, intermediate]
name.up_proj.weight[hidden, intermediate]
name.down_proj.weight[intermediate, hidden]
All three projections have no bias (Linear::no_bias). The fused SwiGLU op emitted by the optimizer combines gate and up into a single kernel.

Attention

nn::CausalSelfAttention

Grouped-query causal self-attention with RoPE positional encodings.
pub struct CausalSelfAttention {
    pub q_proj: Linear,
    pub k_proj: Linear,
    pub v_proj: Linear,
    pub o_proj: Linear,
    pub num_heads: u32,
    pub num_kv_heads: u32,
    pub head_dim: u32,
    pub rope_theta: f32,
}
Configure it with AttentionConfig:
let cfg = nn::AttentionConfig {
    hidden: 512,
    kv_dim: 256,      // num_kv_heads * head_dim
    num_heads: 8,
    num_kv_heads: 4,
    head_dim: 64,
    rope_theta: 10_000.0,
};
let attn = nn::CausalSelfAttention::new(&mut g, "model.layers.0.self_attn", &cfg);
let y = attn.forward(&mut g, x);  // x: [seq, hidden] → y: [seq, hidden]
forward projects Q, K, V, applies RoPE, runs fused causal attention, then applies the output projection.
hidden
usize
required
Total hidden dimension. Q and O projections are [hidden, hidden].
kv_dim
usize
required
KV hidden dimension: num_kv_heads * head_dim. K and V projections are [hidden, kv_dim].
num_heads
u32
required
Number of query heads.
num_kv_heads
u32
required
Number of key/value heads. Set equal to num_heads for standard MHA; set lower for GQA.
head_dim
u32
required
Dimension per head. Must satisfy num_heads * head_dim == hidden.
rope_theta
f32
required
Base frequency for RoPE. Common values: 10_000.0 (LLaMA), 500_000.0 (LLaMA 3).

nn::TransformerBlock

A complete transformer decoder block: RmsNorm → attention → residual → RmsNorm → SwiGLU FFN → residual.
pub struct TransformerBlock {
    pub attn_norm: RmsNorm,
    pub attn: CausalSelfAttention,
    pub ffn_norm: RmsNorm,
    pub ffn: SwiGluFfn,
}
Configure it with TransformerBlockConfig:
let cfg = nn::TransformerBlockConfig {
    hidden: 512,
    intermediate: 1024,
    kv_dim: 256,
    num_heads: 8,
    num_kv_heads: 4,
    head_dim: 64,
    rms_eps: 1e-5,
    rope_theta: 10_000.0,
};
let block = nn::TransformerBlock::new(&mut g, "model.layers.0", &cfg);
let y = block.forward(&mut g, x);  // x: [seq, hidden] → y: [seq, hidden]
The parameter names emitted follow HuggingFace’s LLaMA naming, e.g.:
  • model.layers.0.input_layernorm.weight
  • model.layers.0.self_attn.q_proj.weight
  • model.layers.0.post_attention_layernorm.weight
  • model.layers.0.mlp.gate_proj.weight
Stack multiple TransformerBlocks by chaining their forward outputs:
let mut x = embed.forward(&mut g, tokens);
for (i, block) in blocks.iter().enumerate() {
    x = block.forward(&mut g, x);
}
let logits = lm_head.forward(&mut g, x);

Convolutional

nn::Conv2d

2D convolution over NCHW-layout tensors. Input and output are stored as flat 1D slices.
pub struct Conv2d {
    pub weight: NodeId,
    pub in_channels: u32,
    pub in_h: u32,
    pub in_w: u32,
    pub out_channels: u32,
    pub kernel_h: u32,
    pub kernel_w: u32,
    pub stride: u32,
    pub padding: u32,
}
let conv = nn::Conv2d::new(
    &mut g,
    "encoder.conv1",
    /*in_channels=*/  4,
    /*out_channels=*/ 32,
    /*kernel_size=*/  3,
    /*in_h=*/         32,
    /*in_w=*/         32,
    /*stride=*/       1,
    /*padding=*/      1,
);
// x: flat [N * in_channels * in_h * in_w]
let y = conv.forward(&mut g, x, batch);
The weight parameter is registered as a flat [out_channels * in_channels * kH * kW] 1D tensor named name.weight. There is no bias parameter — add one manually with g.bias_add if needed.
Conv2d tensors are stored in NCHW order as flat 1D arrays. The graph tracks shape metadata in the op, not as a multi-dimensional tensor type. Pass the spatial dimensions explicitly to forward.

Primitive operations

All nn:: structs are wrappers over these low-level Graph methods. You can call them directly when you need more control:

Arithmetic

g.matmul(a, b), g.add(a, b), g.mul(a, b), g.bias_add(x, bias), g.neg(x), g.recip(x), g.div(a, b)

Activations

g.relu(x), g.gelu(x), g.silu(x), g.sigmoid(x), g.softmax(x), g.log_softmax(x)

Normalization

g.rms_norm(x, weight, eps), g.layer_norm(x, weight, bias, eps), g.group_norm(...)

Reductions

g.sum_all(x), g.mean_all(x), g.transpose(x)

Build docs developers (and LLMs) love