nn module provides thin, zero-overhead wrappers over the low-level Graph API. Each struct holds NodeIds for its parameters and exposes a forward() method that appends operations to a Graph. There is no trait hierarchy and no dynamic dispatch.
Linear layers
nn::Linear
Fully connected layer: y = x @ weight + bias.
bias is present, forward calls g.bias_add after the matrix multiply. When absent, it returns the raw matmul result.
nn::Embedding
Token embedding lookup: maps integer indices to rows of a weight table.
The embedding weight name is not suffixed with
.weight automatically. Pass the full parameter name (e.g. "model.embed_tokens.weight") to match HuggingFace checkpoint keys.Normalization
nn::RmsNorm
Root-mean-square layer normalization: x / sqrt(mean(x²) + eps) * weight.
RmsNorm::new registers a single [dim] scale parameter under the given name.
nn::LayerNorm
Standard layer normalization with weight and bias: (x - mean) / sqrt(var + eps) * weight + bias.
LayerNorm::new registers name.weight and name.bias, both [dim].
Activations and FFN
nn::Mlp
Two-layer MLP with a configurable activation function: fc2(act(fc1(x))).
Mlp::new registers name.fc1.weight, name.fc1.bias, name.fc2.weight, and name.fc2.bias.
nn::SwiGluFfn
SwiGLU feed-forward network used in LLaMA-style transformers: down(silu(gate(x)) * up(x)).
SwiGluFfn::new registers:
| Parameter | Shape |
|---|---|
name.gate_proj.weight | [hidden, intermediate] |
name.up_proj.weight | [hidden, intermediate] |
name.down_proj.weight | [intermediate, hidden] |
Linear::no_bias). The fused SwiGLU op emitted by the optimizer combines gate and up into a single kernel.
Attention
nn::CausalSelfAttention
Grouped-query causal self-attention with RoPE positional encodings.
AttentionConfig:
forward projects Q, K, V, applies RoPE, runs fused causal attention, then applies the output projection.
Total hidden dimension. Q and O projections are
[hidden, hidden].KV hidden dimension:
num_kv_heads * head_dim. K and V projections are [hidden, kv_dim].Number of query heads.
Number of key/value heads. Set equal to
num_heads for standard MHA; set lower for GQA.Dimension per head. Must satisfy
num_heads * head_dim == hidden.Base frequency for RoPE. Common values:
10_000.0 (LLaMA), 500_000.0 (LLaMA 3).nn::TransformerBlock
A complete transformer decoder block: RmsNorm → attention → residual → RmsNorm → SwiGLU FFN → residual.
TransformerBlockConfig:
model.layers.0.input_layernorm.weightmodel.layers.0.self_attn.q_proj.weightmodel.layers.0.post_attention_layernorm.weightmodel.layers.0.mlp.gate_proj.weight
TransformerBlocks by chaining their forward outputs:
Convolutional
nn::Conv2d
2D convolution over NCHW-layout tensors. Input and output are stored as flat 1D slices.
[out_channels * in_channels * kH * kW] 1D tensor named name.weight. There is no bias parameter — add one manually with g.bias_add if needed.
Primitive operations
Allnn:: structs are wrappers over these low-level Graph methods. You can call them directly when you need more control:
Arithmetic
g.matmul(a, b), g.add(a, b), g.mul(a, b), g.bias_add(x, bias), g.neg(x), g.recip(x), g.div(a, b)Activations
g.relu(x), g.gelu(x), g.silu(x), g.sigmoid(x), g.softmax(x), g.log_softmax(x)Normalization
g.rms_norm(x, weight, eps), g.layer_norm(x, weight, bias, eps), g.group_norm(...)Reductions
g.sum_all(x), g.mean_all(x), g.transpose(x)