Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/AI-Hypercomputer/maxdiffusion/llms.txt

Use this file to discover all available pages before exploring further.

MaxDiffusion models use Flax’s ConfigMixin for configuration management. All models inherit configuration capabilities for loading and saving.

ConfigMixin

Base configuration class that provides:
  • Configuration storage and access via config attribute
  • Saving configurations to JSON
  • Loading configurations from pretrained models

Key methods

from_config

Instantiates a model from a configuration dictionary. Parameters:
config
dict
Configuration dictionary with model parameters
**kwargs
any
Additional parameters to override config values

save_config

Saves the model configuration to a directory. Parameters:
save_directory
str
Directory to save the config.json file

Common model parameters

These parameters are common across most MaxDiffusion models:

Data types

dtype
jnp.dtype
default:"jnp.float32"
The dtype for activations and intermediate computations
weights_dtype
jnp.dtype
default:"jnp.float32"
The dtype for model weights/parameters
precision
jax.lax.Precision
JAX precision for matmul operations. Options: None, jax.lax.Precision.DEFAULT, jax.lax.Precision.HIGH, jax.lax.Precision.HIGHEST

Architecture

in_channels
int
Number of input channels
out_channels
int
Number of output channels
sample_size
int
Size of input samples

Attention

attention_kernel
str
default:"dot_product"
Attention mechanism to use. Options:
  • dot_product: Standard scaled dot-product attention
  • flash: Flash attention for improved efficiency
flash_min_seq_length
int
default:"4096"
Minimum sequence length required to apply flash attention
flash_block_sizes
BlockSizes
Block sizes for flash attention. Overrides default block sizes
use_memory_efficient_attention
bool
default:"False"
Enable memory efficient attention (alternative to flash attention)

Normalization

norm_num_groups
int
default:"32"
Number of groups for group normalization layers

Regularization

dropout
float
default:"0.0"
Dropout probability for regularization

Model-specific configurations

UNet configuration

See UNet model reference for UNet-specific parameters including:
  • down_block_types, up_block_types
  • block_out_channels
  • layers_per_block
  • cross_attention_dim

VAE configuration

See VAE model reference for VAE-specific parameters including:
  • latent_channels
  • scaling_factor
  • block_out_channels

Transformer configuration

See Transformer model reference for transformer-specific parameters including:
  • num_layers, num_single_layers
  • num_attention_heads
  • attention_head_dim
  • joint_attention_dim

Configuration decorators

@flax_register_to_config

Decorator that registers a Flax model class to use ConfigMixin.
from maxdiffusion.configuration_utils import flax_register_to_config
import flax.linen as nn

@flax_register_to_config
class MyModel(nn.Module, ConfigMixin):
    in_channels: int = 3
    out_channels: int = 3
    
    def setup(self):
        # Access config via self.config
        channels = self.config.in_channels

Example: Loading a model config

from maxdiffusion.models import FlaxUNet2DConditionModel

# Load from pretrained
model = FlaxUNet2DConditionModel.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="unet"
)

# Access configuration
print(model.config.in_channels)  # 4
print(model.config.attention_head_dim)  # 8

# Override config values
model = FlaxUNet2DConditionModel.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="unet",
    attention_kernel="flash",  # Override to use flash attention
    dtype=jnp.bfloat16  # Override dtype
)

Build docs developers (and LLMs) love