Skip to main content

Overview

BaseDiffusion is an abstract base class that defines the core interface for diffusion models in Alpamayo R1. All diffusion implementations inherit from this class and implement the sample() method.

Class Definition

from alpamayo_r1.diffusion.base import BaseDiffusion

Constructor

x_dims
list[int] | tuple[int] | int
required
The dimension of the input tensor. Can be a single integer or a sequence of integers defining the shape.

Example

from alpamayo_r1.diffusion.flow_matching import FlowMatching

# Single dimension
diffusion = FlowMatching(x_dims=128)

# Multiple dimensions
diffusion = FlowMatching(x_dims=[64, 64, 3])

Methods

sample

@torch.no_grad()
def sample(
    batch_size: int,
    step_fn: StepFn,
    device: torch.device = torch.device("cpu"),
    return_all_steps: bool = False,
    *args,
    **kwargs,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]
Sample from the diffusion model using an iterative denoising process.
batch_size
int
required
The number of samples to generate in parallel.
step_fn
StepFn
required
The denoising step function that takes a noisy tensor x and a timestep t and returns either a denoised tensor, a vector field, or noise depending on the prediction type of the diffusion model.The step function should have the signature:
def step_fn(*, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor
device
torch.device
default:"torch.device('cpu')"
The PyTorch device to use for sampling (e.g., “cpu”, “cuda”).
return_all_steps
bool
default:"False"
Whether to return the outputs from all intermediate sampling steps.

Returns

output
torch.Tensor | tuple[torch.Tensor, torch.Tensor]
If return_all_steps=False: Returns the final sampled tensor with shape [B, *x_dims].If return_all_steps=True: Returns a tuple of:
  • All sampled tensors with shape [B, T, *x_dims] where T is the number of steps
  • The time steps with shape [T]

StepFn Protocol

The StepFn protocol defines the interface for denoising step functions:
from alpamayo_r1.diffusion.base import StepFn

class StepFn(Protocol):
    def __call__(
        self,
        *,
        x: torch.Tensor,
        t: torch.Tensor,
    ) -> torch.Tensor:
        ...
x
torch.Tensor
required
The input tensor to denoise.
t
torch.Tensor
required
The timestep tensor indicating the current noise level.

Diffusion Sampling Process

The diffusion sampling process works as follows:
  1. Initialization: Start with random noise sampled from a standard normal distribution
  2. Iterative Denoising: Apply the step function repeatedly at different timesteps to gradually denoise the sample
  3. Output: Return the final denoised sample (and optionally all intermediate steps)
The specific denoising schedule and integration method depend on the concrete implementation (e.g., Flow Matching).

Usage Example

import torch
from alpamayo_r1.diffusion.flow_matching import FlowMatching

# Initialize diffusion model
diffusion = FlowMatching(
    x_dims=[64, 64, 3],  # Image dimensions
    num_inference_steps=20
)

# Define a step function (typically wraps your neural network)
def step_fn(*, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    # Your model predicts the vector field
    return model(x, t)

# Sample new data
samples = diffusion.sample(
    batch_size=4,
    step_fn=step_fn,
    device=torch.device("cuda")
)

print(samples.shape)  # [4, 64, 64, 3]

# Sample with all intermediate steps
all_samples, timesteps = diffusion.sample(
    batch_size=4,
    step_fn=step_fn,
    device=torch.device("cuda"),
    return_all_steps=True
)

print(all_samples.shape)  # [4, 21, 64, 64, 3]
print(timesteps.shape)    # [21]

See Also

Build docs developers (and LLMs) love