Overview
Exponential Moving Average (EMA) maintains a slowly-updating copy of model weights by averaging parameters over training steps. In diffusion models, EMA provides more stable and higher-quality samples compared to using the raw training weights.
Why EMA for diffusion models?
Diffusion model training can exhibit high variance in parameter updates, especially with:
- Large batch sizes
- High learning rates
- Complex architectures (CIFAR-10+)
- Long training runs
Using the raw training weights for sampling can produce:
- Inconsistent quality between epochs
- Artifacts from recent gradient updates
- High sensitivity to hyperparameter choices
EMA smooths out these fluctuations by maintaining a temporally-averaged version of the weights.
EMA is commonly used in state-of-the-art diffusion models including Stable Diffusion, DALL-E 2, and Imagen.
Given:
θ: Current training model parameters
θ_ema: EMA model parameters
γ: Decay rate (typically 0.999 or 0.9999)
The EMA update after each training step is:
θ_ema ← γ · θ_ema + (1 - γ) · θ
This can be rewritten as:
θ_ema ← θ_ema + (1 - γ) · (θ - θ_ema)
The EMA weights are a weighted average of all past parameter values, with exponentially decreasing influence from older steps.
Implementation in CIFAR-10
The CIFAR-10 model uses EMA as shown in src/models/diffusion_cifar.py:
Initialization
class DiffusionProcessCIFAR(DiffusionProcess):
def __init__(self, ..., ema_decay=0.999, ...):
# Initialize training model
self.model = DiffusionModelCIFAR(...).to(self.device)
# Create EMA copy with same architecture
self.ema_model = DiffusionModelCIFAR(...).to(self.device)
# Initialize EMA with training weights
self.ema_model.load_state_dict(self.model.state_dict())
# Set to eval mode (never trained directly)
self.ema_model.eval()
# Store decay rate
self.ema_decay = ema_decay
The EMA model is never trained directly - it only receives updates from the training model’s weights.
Training step with EMA update
def train_step(self, x):
# Standard training step
x = x.to(self.device)
t = torch.randint(0, self.noise_steps, (x.shape[0],), device=self.device)
x_noisy, noise = self.add_noise(x, t)
self.optimizer.zero_grad(set_to_none=True)
with self.autocast_ctx():
predicted_noise = self.model(x_noisy, t)
loss = F.mse_loss(predicted_noise, noise)
# Backprop and optimizer step
if self.grad_scaler.is_enabled():
self.grad_scaler.scale(loss).backward()
self.grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# EMA update - happens after every training step
with torch.no_grad():
for ema_param, param in zip(self.ema_model.parameters(),
self.model.parameters()):
ema_param.data.mul_(self.ema_decay).add_(
param.data, alpha=1 - self.ema_decay
)
return loss.item()
The EMA update:
- Happens after every optimizer step
- Uses
torch.no_grad() for efficiency
- Updates all parameters in-place
- Follows the formula:
θ_ema ← γ·θ_ema + (1-γ)·θ
Sampling with EMA
During sampling, the EMA model is used instead of the training model:
def sample(self, num_samples=16):
# Use EMA model for sampling, not training model
model = self.ema_model
was_training = model.training
model.eval()
with torch.no_grad():
x_t = torch.randn(
num_samples,
self.model.channels,
self.model.image_size,
self.model.image_size,
device=self.device,
)
for t in reversed(range(self.noise_steps)):
t_batch = torch.full(
(num_samples,), t, device=self.device, dtype=torch.long
)
eps_pred = model(x_t, t_batch) # Use EMA model
# ... DDPM sampling logic ...
if was_training:
model.train()
return x_t
The training model is never used for sampling - only the EMA model generates final images.
Choosing the decay rate
The decay rate γ controls how quickly the EMA adapts:
| Decay Rate | Effective Window | Use Case |
|---|
| 0.99 | ~100 steps | Fast-changing models, short training |
| 0.999 | ~1000 steps | Standard choice for diffusion models |
| 0.9999 | ~10,000 steps | Very long training, high stability |
| 0.99999 | ~100,000 steps | Extremely long runs, minimal adaptation |
The effective window can be approximated as 1 / (1 - γ).
CIFAR-10 uses γ = 0.999:
ema_decay=0.999 # Averages over ~1000 training steps
This balances:
- Fast enough adaptation to improving training model
- Slow enough to smooth out training noise
Higher decay rates (closer to 1.0) provide more stability but slower adaptation. If your model is still improving rapidly, use a lower decay rate like 0.99.
EMA vs. no EMA comparison
| Aspect | Without EMA | With EMA |
|---|
| Sample quality | More variable | More consistent |
| Training stability | Sensitive to LR | More robust |
| Convergence speed | Same | Same |
| Memory usage | 1× model size | 2× model size |
| Computation | Baseline | +1% overhead |
When to use EMA:
- Complex datasets (CIFAR-10, ImageNet)
- Long training runs (>100k steps)
- High learning rates
- When sample quality matters more than speed
When to skip EMA:
- Simple datasets (MNIST)
- Memory-constrained settings
- Rapid prototyping
- When training model is already stable
Implementation tips
1. Initialize EMA early
Always initialize the EMA model with the training model’s initial weights:
self.ema_model.load_state_dict(self.model.state_dict())
Don’t start EMA from random weights - this wastes the first ~1/γ steps.
2. Update after every step
EMA should be updated after every optimizer step, not once per epoch:
# CORRECT: Update every step
for batch in dataloader:
loss = train_step(batch) # Includes EMA update
# INCORRECT: Update once per epoch
for batch in dataloader:
loss = train_step(batch)
update_ema() # Too infrequent!
3. Never train the EMA model
The EMA model should always be in eval mode and never receive gradients:
self.ema_model.eval() # Set once during init
# EMA update always uses torch.no_grad()
with torch.no_grad():
for ema_param, param in zip(...):
ema_param.data.mul_(...)
4. Use EMA for validation and sampling
Always use the EMA model for generating samples or computing validation metrics:
# CORRECT: Use EMA for evaluation
samples = diffusion.sample(num_samples=64) # Uses self.ema_model internally
# INCORRECT: Using training model
samples = diffusion.model(x, t) # Don't do this for final samples
5. Save both models in checkpoints
Save both the training and EMA model states:
torch.save({
'model_state_dict': self.model.state_dict(),
'ema_model_state_dict': self.ema_model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'ema_decay': self.ema_decay,
}, 'checkpoint.pt')
This allows you to:
- Resume training with correct EMA state
- Use EMA model for inference without retraining
Warmup strategies
Some implementations use EMA warmup to adapt the decay rate early in training:
def get_ema_decay(self, step):
# Linear warmup for first 1000 steps
if step < 1000:
return 0.99 + (self.ema_decay - 0.99) * (step / 1000)
return self.ema_decay
This starts with faster adaptation (0.99) and gradually increases to the target decay rate (0.999).
The CIFAR-10 implementation uses a fixed decay rate of 0.999 without warmup, which works well in practice.
References
EMA is a standard technique in deep learning:
- Original DDPM paper (Ho et al., 2020): Uses EMA with decay 0.9999
- Improved DDPM (Nichol & Dhariwal, 2021): Confirms EMA improves sample quality
- PyTorch implementation:
torch.optim.swa_utils.AveragedModel provides built-in EMA
Usage example
from src.models.diffusion_cifar import DiffusionProcessCIFAR
# Initialize with EMA
diffusion = DiffusionProcessCIFAR(
image_size=32,
channels=3,
hidden_dims=[128, 256, 256, 256],
ema_decay=0.999, # Standard decay rate
device=torch.device('cuda')
)
# Training automatically updates EMA
for epoch in range(num_epochs):
for batch in train_loader:
loss = diffusion.train_step(batch) # EMA updated here
# Sampling uses EMA model automatically
samples = diffusion.sample(num_samples=64)
The EMA is completely transparent to the user - it’s handled internally during training and sampling.