Skip to main content

Overview

Exponential Moving Average (EMA) maintains a slowly-updating copy of model weights by averaging parameters over training steps. In diffusion models, EMA provides more stable and higher-quality samples compared to using the raw training weights.

Why EMA for diffusion models?

Diffusion model training can exhibit high variance in parameter updates, especially with:
  • Large batch sizes
  • High learning rates
  • Complex architectures (CIFAR-10+)
  • Long training runs
Using the raw training weights for sampling can produce:
  • Inconsistent quality between epochs
  • Artifacts from recent gradient updates
  • High sensitivity to hyperparameter choices
EMA smooths out these fluctuations by maintaining a temporally-averaged version of the weights.
EMA is commonly used in state-of-the-art diffusion models including Stable Diffusion, DALL-E 2, and Imagen.

Mathematical formulation

Given:
  • θ: Current training model parameters
  • θ_ema: EMA model parameters
  • γ: Decay rate (typically 0.999 or 0.9999)
The EMA update after each training step is:
θ_ema ← γ · θ_ema + (1 - γ) · θ
This can be rewritten as:
θ_ema ← θ_ema + (1 - γ) · (θ - θ_ema)
The EMA weights are a weighted average of all past parameter values, with exponentially decreasing influence from older steps.

Implementation in CIFAR-10

The CIFAR-10 model uses EMA as shown in src/models/diffusion_cifar.py:

Initialization

class DiffusionProcessCIFAR(DiffusionProcess):
    def __init__(self, ..., ema_decay=0.999, ...):
        # Initialize training model
        self.model = DiffusionModelCIFAR(...).to(self.device)
        
        # Create EMA copy with same architecture
        self.ema_model = DiffusionModelCIFAR(...).to(self.device)
        
        # Initialize EMA with training weights
        self.ema_model.load_state_dict(self.model.state_dict())
        
        # Set to eval mode (never trained directly)
        self.ema_model.eval()
        
        # Store decay rate
        self.ema_decay = ema_decay
The EMA model is never trained directly - it only receives updates from the training model’s weights.

Training step with EMA update

def train_step(self, x):
    # Standard training step
    x = x.to(self.device)
    t = torch.randint(0, self.noise_steps, (x.shape[0],), device=self.device)
    x_noisy, noise = self.add_noise(x, t)
    
    self.optimizer.zero_grad(set_to_none=True)
    with self.autocast_ctx():
        predicted_noise = self.model(x_noisy, t)
        loss = F.mse_loss(predicted_noise, noise)
    
    # Backprop and optimizer step
    if self.grad_scaler.is_enabled():
        self.grad_scaler.scale(loss).backward()
        self.grad_scaler.unscale_(self.optimizer)
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.grad_scaler.step(self.optimizer)
        self.grad_scaler.update()
    else:
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
    
    # EMA update - happens after every training step
    with torch.no_grad():
        for ema_param, param in zip(self.ema_model.parameters(), 
                                    self.model.parameters()):
            ema_param.data.mul_(self.ema_decay).add_(
                param.data, alpha=1 - self.ema_decay
            )
    
    return loss.item()
The EMA update:
  1. Happens after every optimizer step
  2. Uses torch.no_grad() for efficiency
  3. Updates all parameters in-place
  4. Follows the formula: θ_ema ← γ·θ_ema + (1-γ)·θ

Sampling with EMA

During sampling, the EMA model is used instead of the training model:
def sample(self, num_samples=16):
    # Use EMA model for sampling, not training model
    model = self.ema_model
    was_training = model.training
    model.eval()
    
    with torch.no_grad():
        x_t = torch.randn(
            num_samples,
            self.model.channels,
            self.model.image_size,
            self.model.image_size,
            device=self.device,
        )
        
        for t in reversed(range(self.noise_steps)):
            t_batch = torch.full(
                (num_samples,), t, device=self.device, dtype=torch.long
            )
            eps_pred = model(x_t, t_batch)  # Use EMA model
            
            # ... DDPM sampling logic ...
        
    if was_training:
        model.train()
    return x_t
The training model is never used for sampling - only the EMA model generates final images.

Choosing the decay rate

The decay rate γ controls how quickly the EMA adapts:
Decay RateEffective WindowUse Case
0.99~100 stepsFast-changing models, short training
0.999~1000 stepsStandard choice for diffusion models
0.9999~10,000 stepsVery long training, high stability
0.99999~100,000 stepsExtremely long runs, minimal adaptation
The effective window can be approximated as 1 / (1 - γ). CIFAR-10 uses γ = 0.999:
ema_decay=0.999  # Averages over ~1000 training steps
This balances:
  • Fast enough adaptation to improving training model
  • Slow enough to smooth out training noise
Higher decay rates (closer to 1.0) provide more stability but slower adaptation. If your model is still improving rapidly, use a lower decay rate like 0.99.

EMA vs. no EMA comparison

AspectWithout EMAWith EMA
Sample qualityMore variableMore consistent
Training stabilitySensitive to LRMore robust
Convergence speedSameSame
Memory usage1× model size2× model size
ComputationBaseline+1% overhead
When to use EMA:
  • Complex datasets (CIFAR-10, ImageNet)
  • Long training runs (>100k steps)
  • High learning rates
  • When sample quality matters more than speed
When to skip EMA:
  • Simple datasets (MNIST)
  • Memory-constrained settings
  • Rapid prototyping
  • When training model is already stable

Implementation tips

1. Initialize EMA early

Always initialize the EMA model with the training model’s initial weights:
self.ema_model.load_state_dict(self.model.state_dict())
Don’t start EMA from random weights - this wastes the first ~1/γ steps.

2. Update after every step

EMA should be updated after every optimizer step, not once per epoch:
# CORRECT: Update every step
for batch in dataloader:
    loss = train_step(batch)  # Includes EMA update

# INCORRECT: Update once per epoch
for batch in dataloader:
    loss = train_step(batch)
update_ema()  # Too infrequent!

3. Never train the EMA model

The EMA model should always be in eval mode and never receive gradients:
self.ema_model.eval()  # Set once during init

# EMA update always uses torch.no_grad()
with torch.no_grad():
    for ema_param, param in zip(...):
        ema_param.data.mul_(...)

4. Use EMA for validation and sampling

Always use the EMA model for generating samples or computing validation metrics:
# CORRECT: Use EMA for evaluation
samples = diffusion.sample(num_samples=64)  # Uses self.ema_model internally

# INCORRECT: Using training model
samples = diffusion.model(x, t)  # Don't do this for final samples

5. Save both models in checkpoints

Save both the training and EMA model states:
torch.save({
    'model_state_dict': self.model.state_dict(),
    'ema_model_state_dict': self.ema_model.state_dict(),
    'optimizer_state_dict': self.optimizer.state_dict(),
    'ema_decay': self.ema_decay,
}, 'checkpoint.pt')
This allows you to:
  • Resume training with correct EMA state
  • Use EMA model for inference without retraining

Warmup strategies

Some implementations use EMA warmup to adapt the decay rate early in training:
def get_ema_decay(self, step):
    # Linear warmup for first 1000 steps
    if step < 1000:
        return 0.99 + (self.ema_decay - 0.99) * (step / 1000)
    return self.ema_decay
This starts with faster adaptation (0.99) and gradually increases to the target decay rate (0.999).
The CIFAR-10 implementation uses a fixed decay rate of 0.999 without warmup, which works well in practice.

References

EMA is a standard technique in deep learning:
  • Original DDPM paper (Ho et al., 2020): Uses EMA with decay 0.9999
  • Improved DDPM (Nichol & Dhariwal, 2021): Confirms EMA improves sample quality
  • PyTorch implementation: torch.optim.swa_utils.AveragedModel provides built-in EMA

Usage example

from src.models.diffusion_cifar import DiffusionProcessCIFAR

# Initialize with EMA
diffusion = DiffusionProcessCIFAR(
    image_size=32,
    channels=3,
    hidden_dims=[128, 256, 256, 256],
    ema_decay=0.999,  # Standard decay rate
    device=torch.device('cuda')
)

# Training automatically updates EMA
for epoch in range(num_epochs):
    for batch in train_loader:
        loss = diffusion.train_step(batch)  # EMA updated here

# Sampling uses EMA model automatically
samples = diffusion.sample(num_samples=64)
The EMA is completely transparent to the user - it’s handled internally during training and sampling.

Build docs developers (and LLMs) love