Overview
The CIFAR-10 diffusion classes extend the base implementation with architectural improvements for color image generation:- DiffusionProcessCIFAR: Enhanced training with AdamW optimizer, gradient clipping, EMA (Exponential Moving Average), and linear beta schedule
- DiffusionModelCIFAR: Wider U-Net with dropout, configurable attention at specific resolutions, and multiple residual blocks per level
DiffusionProcessCIFAR
Constructor
Parameters
Height and width of square images (CIFAR-10 uses 32x32).
Number of image channels (3 for RGB images).
Channel dimensions for each U-Net level. CIFAR-10 uses wider networks with base channels of 128 and multipliers [1, 2, 2, 2].
Initial noise variance in the linear noise schedule.
Final noise variance. Uses 0.02 following standard DDPM CIFAR-10 implementations.
Total number of diffusion timesteps.
Dropout probability applied in residual blocks to reduce overfitting.
EMA decay rate for maintaining an exponential moving average of model weights. The EMA model is used for sampling to improve image quality.
Device for computations. Defaults to CUDA if available, otherwise CPU.
Key differences from base class
Linear beta schedule: Uses a linear scheduletorch.linspace(beta_start, beta_end, noise_steps) instead of cosine schedule. This matches standard DDPM CIFAR-10 implementations.
Posterior variance coefficients: Precomputes posterior_variance, posterior_mean_coef1, and posterior_mean_coef2 for the posterior distribution q(x_ | x_t, x_0), improving sampling accuracy.
AdamW optimizer: Uses AdamW with learning rate 2e-4, weight decay 1e-5, and betas (0.9, 0.999) instead of Adam.
EMA model: Maintains a separate EMA copy of the model that is updated after each training step and used for sampling.
Gradient clipping: Clips gradients to max norm 1.0 to stabilize training of the larger network.
Attributes
Main training model with dropout.
Exponential moving average model used for sampling. Updated with
ema_decay rate after each training step.AdamW optimizer with lr=2e-4, weight_decay=1e-5.
Precomputed posterior variance for DDPM sampling.
Coefficient for x_0 term in posterior mean computation.
Coefficient for x_t term in posterior mean computation.
Methods
add_noise
Identical to base class implementation. Adds noise according to the forward diffusion process.train_step
Enhanced training step with gradient clipping and EMA updates.Clean images tensor of shape
[batch_size, channels, height, width].MSE loss between predicted and actual noise.
- Samples random timesteps and adds noise
- Predicts noise using the main model (not EMA)
- Computes MSE loss with optional mixed precision
- Clips gradients to max norm 1.0 (unscaled for AMP)
- Updates model parameters
- Updates EMA model:
ema_param ← ema_decay * ema_param + (1 - ema_decay) * param
sample
DDPM sampling using the EMA model and posterior variance.Number of images to generate.
Generated images of shape
[num_samples, channels, image_size, image_size], values in [-1, 1].- Uses
ema_modelinstead of the training model - Reconstructs x_0 from predicted noise and x_t, then clamps to [-1, 1]
- Computes posterior mean using precomputed coefficients
posterior_mean_coef1andposterior_mean_coef2 - Uses
posterior_varianceinstead of recomputing variance from beta
sample_ddim
DDIM sampling using the EMA model.Number of images to generate.
Number of denoising steps. Must be in (0, noise_steps].
Stochasticity parameter. 0 = deterministic, 1 = DDPM-like.
Generated images of shape
[num_samples, channels, image_size, image_size].- Uses uniform grid of timesteps via
torch.linspace(0, noise_steps-1, steps=ddim_steps) - Uses
ema_modelfor all predictions - Clamps at the final step only (not at each intermediate step)
DiffusionModelCIFAR
Constructor
Parameters
Height and width of square input images.
Number of image channels (typically 3 for CIFAR-10).
Channel dimensions at each resolution level. CIFAR-10 uses 4 levels: 32×32, 16×16, 8×8, 4×4.
Dimensionality of time embeddings.
Dropout probability in residual blocks.
Architecture features
Attention at 16×16 resolution: Self-attention is applied only at index 1 (16×16 resolution) in both encoder and decoder, plus in the bottleneck. This balances computation cost with modeling long-range dependencies. Dropout regularization: All residual blocks useResBlockWithDropout, applying 2D dropout after the first convolution to reduce overfitting on CIFAR-10.
Enhanced bottleneck: BottleneckWithAttention applies ResBlock → SelfAttention → ResBlock with dropout support.
Standard U-Net forward pass: Inherits from DiffusionModel but uses the enhanced blocks with dropout and selective attention.
Methods
forward
Standard U-Net forward pass.Noisy images of shape
[batch_size, channels, height, width].Timesteps of shape
[batch_size].Predicted noise of shape
[batch_size, channels, height, width].Supporting classes
ResBlockWithDropout
ExtendsResBlock with 2D dropout regularization.
- GroupNorm → SiLU → Conv2d
- Dropout2d(p=dropout_p)
- Add time embedding
- GroupNorm → SiLU → Conv2d
- Skip connection
BottleneckWithAttention
Bottleneck block with dropout-enabled residual blocks.- ResBlockWithDropout
- SelfAttention
- ResBlockWithDropout
DownBlockWithAttention
Downsampling block with optional attention.use_attention(bool): Whether to apply self-attention after the residual block
- ResBlockWithDropout
- SelfAttention (if
use_attention=True) - Conv2d(4x4, stride=2) downsampling
UpBlockWithAttention
Upsampling block with optional attention.- ConvTranspose2d(4x4, stride=2) upsampling
- Concatenate skip connection
- ResBlockWithDropout
- SelfAttention (if
use_attention=True)
Usage example
Mathematical formulation
Despite architectural differences,DiffusionProcessCIFAR uses the same DDPM equations:
Forward process:
Related classes
- DiffusionProcess - Base class with cosine schedule and simpler architecture
- DiffusionModel - Base U-Net without dropout or configurable attention