Configuration
Number of images in a batch.
Number of input/output channels. Use
4 for latent-space diffusion models.Base channel width. Doubles at each downsampling level.
Number of encoder/decoder levels (downsampling stages).
Spatial resolution of the input (square: H = W = resolution).
Number of groups for GroupNorm.
GroupNorm epsilon. Typically
1e-5.Architecture
The U-Net follows the standard encoder-bottleneck-decoder layout, all in NCHW flat tensor format:- GroupNorm → SiLU → Conv 3×3
- GroupNorm → SiLU → Conv 3×3
- Residual projection (1×1 Conv if channel count changes) + residual add
GroupNormSilu kernel.
Building and training
build_training_graph constructs the full forward pass ending with an MSE loss. It returns the single loss NodeId. The graph expects two inputs:
"x"— F32 flat tensor of shape[batch * in_channels * H * W](NCHW layout)"target"— F32 flat tensor of the same shape (the noise target to regress against)
Running the benchmark
The repository includes a full training benchmark:MEGANEURA_TRACE=trace.pftrace to capture a Perfetto profile of the training run. See Profiling for details.
Key operations used
The SD UNet exercises the full set of convolutional and normalization ops:| Operation | Used for |
|---|---|
g.conv2d() | 3×3 and 1×1 convolutions in ResBlocks |
g.group_norm() | Normalization within each ResBlock |
g.silu() | Activation after each GroupNorm |
g.upsample_2x() | 2× nearest-neighbor upsampling in decoder |
g.concat() | Skip connection merging in decoder |
g.split_a() / g.split_b() | Backward of concat |
All tensors are stored as flat 1D arrays in NCHW order. Spatial metadata (batch, channels, height, width) is encoded in the op parameters rather than the tensor shape.