Skip to main content
Meganeura includes a U-Net graph definition modelled on the Stable Diffusion 1.5 architecture. It is the primary example for training with convolutional ops, group normalization, skip connections, and upsampling — all with full backward support.

Configuration

batch_size
u32
Number of images in a batch.
in_channels
u32
Number of input/output channels. Use 4 for latent-space diffusion models.
base_channels
u32
Base channel width. Doubles at each downsampling level.
num_levels
usize
Number of encoder/decoder levels (downsampling stages).
resolution
u32
Spatial resolution of the input (square: H = W = resolution).
num_groups
u32
Number of groups for GroupNorm.
gn_eps
f32
GroupNorm epsilon. Typically 1e-5.
Two presets are provided:
use meganeura::models::sd_unet::SDUNetConfig;

// ~120K parameters, 32×32 latent, batch 4 — good for benchmarking
let config = SDUNetConfig::tiny();

// ~2M parameters, 32×32 latent, batch 2 — closer to real SD dimensions
let config = SDUNetConfig::small();

Architecture

The U-Net follows the standard encoder-bottleneck-decoder layout, all in NCHW flat tensor format:
Input conv (in_channels → base_channels)

Encoder: [ResBlock → Downsample (2×)] × num_levels

Bottleneck ResBlock

Decoder: [ResBlock → Upsample (2×) → Concat with skip] × num_levels

Output conv (base_channels → in_channels)
Each ResBlock applies:
  1. GroupNorm → SiLU → Conv 3×3
  2. GroupNorm → SiLU → Conv 3×3
  3. Residual projection (1×1 Conv if channel count changes) + residual add
At inference time, the optimizer automatically fuses GroupNorm + SiLU into a single GroupNormSilu kernel.

Building and training

use meganeura::{Graph, build_session};
use meganeura::models::sd_unet::{SDUNetConfig, build_training_graph};

let config = SDUNetConfig::tiny();

let mut g = Graph::new();
let loss = build_training_graph(&mut g, &config);
g.set_outputs(vec![loss]);

let mut session = build_session(&g);
build_training_graph constructs the full forward pass ending with an MSE loss. It returns the single loss NodeId. The graph expects two inputs:
  • "x" — F32 flat tensor of shape [batch * in_channels * H * W] (NCHW layout)
  • "target" — F32 flat tensor of the same shape (the noise target to regress against)

Running the benchmark

The repository includes a full training benchmark:
cargo run --example sd_unet_train
You can also run the dedicated benchmark:
cargo run --example bench_sd_unet_train
Set MEGANEURA_TRACE=trace.pftrace to capture a Perfetto profile of the training run. See Profiling for details.

Key operations used

The SD UNet exercises the full set of convolutional and normalization ops:
OperationUsed for
g.conv2d()3×3 and 1×1 convolutions in ResBlocks
g.group_norm()Normalization within each ResBlock
g.silu()Activation after each GroupNorm
g.upsample_2x()2× nearest-neighbor upsampling in decoder
g.concat()Skip connection merging in decoder
g.split_a() / g.split_b()Backward of concat
All tensors are stored as flat 1D arrays in NCHW order. Spatial metadata (batch, channels, height, width) is encoded in the op parameters rather than the tensor shape.

Build docs developers (and LLMs) love