Skip to main content
Each struct holds parameter NodeIds and provides a forward() method that appends operations to the Graph. These are thin wrappers over the low-level graph API — no trait hierarchy, no dynamic dispatch.
let mut g = Graph::new();
let x = g.input("x", &[batch, 784]);
let labels = g.input("labels", &[batch, 10]);

let fc1 = nn::Linear::new(&mut g, "fc1", 784, 128);
let fc2 = nn::Linear::new(&mut g, "fc2", 128, 10);

let h = fc1.forward(&mut g, x);
let h = g.relu(h);
let logits = fc2.forward(&mut g, h);
let loss = g.cross_entropy_loss(logits, labels);

nn::Linear

Fully connected linear layer: y = x @ weight + bias. Fields
weight
NodeId
required
Weight parameter of shape [in_features, out_features].
bias
Option<NodeId>
Optional bias parameter of shape [out_features]. None when constructed with no_bias.

Linear::new

Creates a linear layer with a bias term.
g
&mut Graph
required
The computation graph to register parameters into.
name
&str
required
Name prefix for the parameters. Registers {name}.weight and {name}.bias.
in_features
usize
required
Number of input features.
out_features
usize
required
Number of output features.
let fc = nn::Linear::new(&mut g, "fc", 784, 128);

Linear::no_bias

Creates a linear layer without a bias term.
g
&mut Graph
required
The computation graph to register parameters into.
name
&str
required
Name prefix. Registers only {name}.weight.
in_features
usize
required
Number of input features.
out_features
usize
required
Number of output features.
let proj = nn::Linear::no_bias(&mut g, "q_proj", 512, 512);

forward

Appends a matrix multiply and optional bias add to the graph.
g
&mut Graph
required
The computation graph to append ops to.
x
NodeId
required
Input tensor of shape [batch, in_features].
NodeId
NodeId
Output tensor of shape [batch, out_features].
let mut g = Graph::new();
let x = g.input("x", &[4, 8]);
let fc = nn::Linear::new(&mut g, "fc", 8, 3);
let y = fc.forward(&mut g, x);
// y shape: [4, 3]

nn::Embedding

Token embedding lookup table. Maps integer token indices to dense vectors. Fields
weight
NodeId
required
Embedding table parameter of shape [vocab_size, embed_dim].

Embedding::new

g
&mut Graph
required
The computation graph to register parameters into.
name
&str
required
Name for the embedding weight parameter.
vocab_size
usize
required
Number of tokens in the vocabulary.
embed_dim
usize
required
Dimensionality of each token embedding.
let embed = nn::Embedding::new(&mut g, "token_embeddings", 32000, 512);

forward

g
&mut Graph
required
The computation graph to append ops to.
indices
NodeId
required
1D U32 tensor of shape [seq_len] containing token indices.
NodeId
NodeId
Output tensor of shape [seq_len, embed_dim].
let indices = g.input_u32("tokens", &[seq_len]);
let embed = nn::Embedding::new(&mut g, "tok_emb", 32000, 512);
let h = embed.forward(&mut g, indices);
// h shape: [seq_len, 512]

nn::SwiGluFfn

SwiGLU feed-forward network: silu(gate(x)) * up(x) then down-projected back to the hidden dimension. Internally uses three bias-free linear projections registered as {name}.gate_proj, {name}.up_proj, and {name}.down_proj. Fields
gate
Linear
required
Gate projection: hidden → intermediate.
up
Linear
required
Up projection: hidden → intermediate.
down
Linear
required
Down projection: intermediate → hidden.

SwiGluFfn::new

g
&mut Graph
required
The computation graph to register parameters into.
name
&str
required
Name prefix. Registers {name}.gate_proj, {name}.up_proj, {name}.down_proj.
hidden
usize
required
Hidden (input and output) dimension.
intermediate
usize
required
Intermediate (expanded) dimension.
let ffn = nn::SwiGluFfn::new(&mut g, "model.layers.0.mlp", 512, 1365);

forward

g
&mut Graph
required
The computation graph to append ops to.
x
NodeId
required
Input tensor of shape [seq, hidden].
NodeId
NodeId
Output tensor of shape [seq, hidden].
let ffn = nn::SwiGluFfn::new(&mut g, "ffn", 512, 1365);
let out = ffn.forward(&mut g, x);
// out shape: [seq, 512]

nn::Mlp

Standard two-layer MLP: fc2(activation(fc1(x))). Fields
fc1
Linear
required
First linear layer: in_dim → hidden_dim (with bias).
fc2
Linear
required
Second linear layer: hidden_dim → out_dim (with bias).
activation
Activation
required
Activation function applied between the two layers.

Activation

pub enum Activation {
    Relu,
    Gelu,
    Silu,
    Sigmoid,
}

Mlp::new

g
&mut Graph
required
The computation graph to register parameters into.
name
&str
required
Name prefix. Registers {name}.fc1.weight, {name}.fc1.bias, {name}.fc2.weight, {name}.fc2.bias.
in_dim
usize
required
Input feature dimension.
hidden_dim
usize
required
Hidden layer dimension.
out_dim
usize
required
Output feature dimension.
activation
Activation
required
Activation function to use between layers.
let mlp = nn::Mlp::new(&mut g, "mlp", 8, 16, 3, nn::Activation::Relu);

forward

g
&mut Graph
required
The computation graph to append ops to.
x
NodeId
required
Input tensor of shape [batch, in_dim].
NodeId
NodeId
Output tensor of shape [batch, out_dim].
let mut g = Graph::new();
let x = g.input("x", &[4, 8]);
let mlp = nn::Mlp::new(&mut g, "mlp", 8, 16, 3, nn::Activation::Relu);
let y = mlp.forward(&mut g, x);
// y shape: [4, 3]

nn::Conv2d

2D convolution layer: y = conv2d(x, weight) + bias. Input and output tensors are flat 1D arrays in NCHW layout. Fields
weight
NodeId
required
Kernel parameter of shape [out_channels * in_channels * kernel_h * kernel_w] (flat 1D).
bias
Option<NodeId>
Optional bias. Currently None after construction — set manually if needed.
in_channels
u32
required
Number of input channels.
in_h
u32
required
Input spatial height.
in_w
u32
required
Input spatial width.
out_channels
u32
required
Number of output channels.
kernel_h
u32
required
Kernel height (equal to kernel_size passed to new).
kernel_w
u32
required
Kernel width (equal to kernel_size passed to new).
stride
u32
required
Convolution stride.
padding
u32
required
Zero-padding added to each spatial edge.

Conv2d::new

g
&mut Graph
required
The computation graph to register parameters into.
name
&str
required
Name prefix. Registers {name}.weight as a flat 1D parameter.
in_channels
u32
required
Number of input channels.
out_channels
u32
required
Number of output channels.
kernel_size
u32
required
Square kernel side length (sets both kernel_h and kernel_w).
in_h
u32
required
Input spatial height.
in_w
u32
required
Input spatial width.
stride
u32
required
Convolution stride.
padding
u32
required
Zero-padding on each edge.
let conv = nn::Conv2d::new(&mut g, "conv1", 3, 64, 3, 224, 224, 1, 1);

forward

g
&mut Graph
required
The computation graph to append ops to.
x
NodeId
required
Flat input tensor of shape [N * in_channels * in_h * in_w] in NCHW order.
batch
u32
required
Batch size N.
NodeId
NodeId
Flat output tensor of shape [N * out_channels * out_h * out_w] in NCHW order, where out_h = (in_h + 2*padding - kernel_h) / stride + 1.
let conv = nn::Conv2d::new(&mut g, "conv1", 3, 64, 3, 224, 224, 1, 1);
let y = conv.forward(&mut g, x, batch);
// y shape: [N * 64 * 224 * 224] (with padding=1, stride=1, kernel=3)
All tensors for Conv2d are stored as flat 1D arrays. Spatial metadata (channels, height, width, kernel dimensions, stride, padding) is encoded in the op and used by the GPU kernel. There is no explicit reshape needed.

nn::TransformerBlock

A single transformer decoder block combining pre-norm attention, a residual connection, and a SwiGLU feed-forward network with a second residual connection. Forward pass:
x = x + attn(attn_norm(x))
x = x + ffn(ffn_norm(x))
Fields
attn_norm
RmsNorm
required
RMS normalization applied before attention. Parameter name: {name}.input_layernorm.weight.
attn
CausalSelfAttention
required
Causal self-attention module. Parameter names prefixed with {name}.self_attn.
ffn_norm
RmsNorm
required
RMS normalization applied before the feed-forward network. Parameter name: {name}.post_attention_layernorm.weight.
ffn
SwiGluFfn
required
SwiGLU feed-forward network. Parameter names prefixed with {name}.mlp.

TransformerBlockConfig

pub struct TransformerBlockConfig {
    pub hidden: usize,
    pub intermediate: usize,
    pub kv_dim: usize,
    pub num_heads: u32,
    pub num_kv_heads: u32,
    pub head_dim: u32,
    pub rms_eps: f32,
    pub rope_theta: f32,
}
hidden
usize
required
Hidden dimension (model width).
intermediate
usize
required
Intermediate dimension for the SwiGLU FFN.
kv_dim
usize
required
Key/value projection dimension (num_kv_heads * head_dim).
num_heads
u32
required
Number of query attention heads.
num_kv_heads
u32
required
Number of key/value heads (for grouped-query attention).
head_dim
u32
required
Dimension per attention head.
rms_eps
f32
required
Epsilon for RMS normalization numerical stability.
rope_theta
f32
required
Base frequency for rotary position embeddings.

TransformerBlock::new

g
&mut Graph
required
The computation graph to register parameters into.
name
&str
required
Name prefix, typically "model.layers.{i}".
cfg
&TransformerBlockConfig
required
Block configuration.

forward

g
&mut Graph
required
The computation graph to append ops to.
x
NodeId
required
Input tensor of shape [seq, hidden].
NodeId
NodeId
Output tensor of shape [seq, hidden].
let mut g = Graph::new();
let x = g.input("x", &[16, 64]);
let cfg = nn::TransformerBlockConfig {
    hidden: 64,
    intermediate: 128,
    kv_dim: 32,
    num_heads: 4,
    num_kv_heads: 2,
    head_dim: 16,
    rms_eps: 1e-5,
    rope_theta: 10000.0,
};
let block = nn::TransformerBlock::new(&mut g, "model.layers.0", &cfg);
let y = block.forward(&mut g, x);
// y shape: [16, 64]

Build docs developers (and LLMs) love