Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt

Use this file to discover all available pages before exploring further.

Most users should finetune a pretrained model rather than train from scratch. Zoobot ships with pretrained weights trained on tens of millions of Galaxy Zoo volunteer labels, and finetuning those weights typically yields excellent results even with a few hundred labelled galaxies. Training from scratch is reserved for pretraining entirely new foundation models — it requires large, diverse galaxy datasets and cluster-grade GPU resources.

When to Train from Scratch vs Finetune

ScenarioRecommendation
You have a labelled dataset of any size✅ Finetune a pretrained model
You want to predict a new morphology task✅ Finetune a pretrained model
You need representations for a new survey✅ Finetune a pretrained model
You are building a new foundation model for a new domainTrain from scratch
You are replicating GZ DECaLS / GZ DESI trainingTrain from scratch
You need to pretrain on a proprietary dataset from scratchTrain from scratch
In nearly every science use case, finetuning is faster, cheaper, and more data-efficient than pretraining from scratch. Use train_default_zoobot_from_scratch only when you genuinely need to train new foundation model weights.

train_default_zoobot_from_scratch

train_default_zoobot_from_scratch is the single entry point for all from-scratch pretraining. It creates a ZoobotTree LightningModule, wires up a CatalogDataModule or WebDataModule, configures a Lightning Trainer with DDP, and runs the full training loop including early stopping and checkpointing.
from zoobot.pytorch.training.train_with_pytorch_lightning import train_default_zoobot_from_scratch
from zoobot.pytorch.estimators.define_model import ZoobotTree
The function returns a Tuple[ZoobotTree, L.Trainer] — the trained model and the trainer with which it was trained.

Key Arguments

Data inputs

save_dir
str
required
Folder to save training logs and trained model checkpoints. Will be created if it does not exist.
schema
Schema
required
A zoobot.shared.schemas.Schema object describing the Galaxy Zoo decision tree — the questions, answers, and their dependencies. Defines label_cols, question_answer_pairs, and dependencies.
transform_cfg
object
required
An augmentation configuration object passed to galaxy_datasets.transforms.get_galaxy_transform(). Applies the same transform for both training and inference. Mutually exclusive with train_transform_cfg/test_transform_cfg. One of transform_cfg or the pair train_transform_cfg+test_transform_cfg must be supplied.
train_transform_cfg
object
Augmentation config object for the training set. Use when you want separate transforms for training (with heavy augmentation) and inference (with minimal augmentation). Must be paired with test_transform_cfg.
test_transform_cfg
object
Augmentation config object for validation, test, and prediction. Must be paired with train_transform_cfg.
catalog
pd.DataFrame
default:"None"
A single galaxy catalog with columns id_str and file_loc. Will be automatically split into train and validation sets. Mutually exclusive with train_catalog/val_catalog.
train_catalog
pd.DataFrame
default:"None"
Pre-split training catalog with columns id_str and file_loc.
val_catalog
pd.DataFrame
default:"None"
Pre-split validation catalog with columns id_str and file_loc.
test_catalog
pd.DataFrame
default:"None"
Optional pre-split test catalog. If provided, test metrics will be logged after training.
train_urls
list
default:"None"
List of WebDataset shard URLs for training. Mutually exclusive with catalog inputs. See WebDataset support below.
val_urls
list
default:"None"
List of WebDataset shard URLs for validation.
test_urls
list
default:"None"
List of WebDataset shard URLs for testing and prediction.
cache_dir
str
default:"None"
Directory to cache downloaded WebDataset shards locally. Only applies when using WebDataset URLs.

Training schedule

epochs
int
default:"1000"
Maximum number of epochs to train for. Training will stop earlier if patience epochs pass without any improvement to the validation loss.
patience
int
default:"8"
Number of epochs to wait for any improvement in validation/supervised_loss before triggering early stopping.

Model architecture

architecture_name
str
default:"'efficientnet_b0'"
Architecture to use for the encoder. Must be a valid model name in timm.list_models(). Popular choices include efficientnet_b0, convnext_nano, convnext_small, maxvit_tiny_tf_224, and efficientnetv2_s.
timm_kwargs
dict
default:"{}"
Additional keyword arguments forwarded directly to timm.create_model(). For example, {'drop_path_rate': 0.2} for stochastic depth on EfficientNet.
dropout_rate
float
default:"0.2"
Dropout probability applied in the Dirichlet head prior to the output layer.
compile_encoder
bool
default:"False"
Compile the encoder with torch.compile (requires PyTorch ≥ 2.0). Can meaningfully speed up training throughput during pretraining, but is not recommended for finetuning.

Optimiser

learning_rate
float
default:"1e-3"
Base learning rate for the AdamW optimiser.
betas
tuple
default:"(0.9, 0.999)"
Beta coefficients (momentum parameters) for AdamW.
weight_decay
float
default:"0.01"
L2 weight decay for AdamW.
scheduler_params
dict
default:"{}"
Optional learning rate scheduler configuration. Pass {'name': 'plateau', 'patience': 5} for ReduceLROnPlateau, or {'cosine_schedule': True, 'warmup_epochs': 5, 'max_cosine_epochs': 200, 'max_learning_rate_reduction_factor': 0.01} for a cosine warmup schedule. Defaults to no scheduler.

Hardware

gpus
int
default:"2"
Number of GPUs to use per node. When gpus > 1, DDP (DDPStrategy) is used automatically. Set to 0 to train on CPU.
nodes
int
default:"1"
Number of compute nodes for multi-node distributed training. Multi-node support may require cluster-specific configuration.
mixed_precision
bool
default:"False"
Enable automatic mixed precision (16-mixed). Reduces GPU memory footprint by roughly half. May cause training instability on some architectures such as ResNet.
batch_size
int
default:"128"
Per-GPU batch size. With DDP, each GPU processes a full batch of this size (the data pool is divided, not the batch).
num_workers
int
default:"4"
Number of CPU processes per dataloader. Should be less than the number of available CPU cores.
prefetch_factor
int
default:"4"
Number of batches to pre-load into memory per dataloader worker. Increase if GPU utilisation is low due to data loading. See PyTorch DataLoader docs.
accumulate_gradients
int
default:"1"
Number of batches over which to accumulate gradients before an optimiser step. Effective batch size becomes batch_size * accumulate_gradients * gpus.
sync_batchnorm
bool
default:"False"
Synchronise batch normalisation statistics across GPUs. Useful when per-GPU batch sizes are very small.

Logging and checkpointing

wandb_logger
WandbLogger
default:"None"
A lightning.pytorch.loggers.WandbLogger instance for experiment tracking on Weights & Biases. If None, falls back to CSV logging in save_dir.
save_top_k
int
default:"3"
Number of best checkpoints to retain, ranked by validation/supervised_loss.
random_state
int
default:"42"
Global random seed passed to lightning.seed_everything.
extra_callbacks
list
default:"None"
Additional Lightning callbacks to append to the default set (which includes ModelCheckpoint and EarlyStopping). Useful for custom logging, learning rate monitoring, or profiling.

Example Usage

from zoobot.pytorch.training.train_with_pytorch_lightning import train_default_zoobot_from_scratch
from zoobot.shared.schemas import Schema

# Build or load a schema for your decision tree
# gz_decals_schema = Schema(question_answer_pairs, dependencies)

model, trainer = train_default_zoobot_from_scratch(
    save_dir='./my_zoobot_run',
    schema=gz_decals_schema,
    train_catalog=train_df,        # pd.DataFrame with id_str, file_loc columns
    val_catalog=val_df,
    architecture_name='convnext_nano',
    gpus=2,
    epochs=200,
    patience=8,
    batch_size=128,
    learning_rate=1e-3,
    dropout_rate=0.2,
    mixed_precision=False
)

# model  -> ZoobotTree (best checkpoint weights loaded)
# trainer -> L.Trainer used for training
You can also pass a single catalog instead of pre-split train_catalog/val_catalog, and Zoobot will split it automatically:
model, trainer = train_default_zoobot_from_scratch(
    save_dir='./my_zoobot_run',
    schema=gz_decals_schema,
    catalog=full_df,              # will be automatically split
    architecture_name='efficientnet_b0',
    gpus=1,
    epochs=1000,
    batch_size=128
)

WebDataset Support for Large-Scale Training

For large-scale distributed pretraining, Zoobot supports WebDataset shards instead of on-disk image catalogs. Pass a list of shard URLs to train_urls, val_urls, and test_urls:
model, trainer = train_default_zoobot_from_scratch(
    save_dir='./my_zoobot_run',
    schema=gz_decals_schema,
    train_urls=['s3://my-bucket/train/shard-{000000..000199}.tar'],
    val_urls=['s3://my-bucket/val/shard-{000000..000019}.tar'],
    test_urls=['s3://my-bucket/test/shard-{000000..000009}.tar'],
    cache_dir='/scratch/webdataset_cache',
    gpus=8,
    nodes=4,
    batch_size=256
)
When WebDatasets are used and a test_urls is provided, the function will also run inference after training and save per-galaxy predictions to save_dir/test_predictions_<rank>.csv.
WebDataset shard format is recommended for training on thousands of GPUs across multiple nodes where individual file I/O becomes a bottleneck. For single-node training on up to ~8 GPUs, a standard CatalogDataModule with file_loc paths performs well.

Multi-GPU Training with DDP

When gpus > 1, Zoobot automatically uses PyTorch Lightning’s DDPStrategy (Distributed Data Parallel). Under DDP:
  • Each GPU receives a separate slice of the overall dataset each epoch.
  • Each GPU computes forward and backward passes on a full local batch of batch_size images.
  • Gradients are synchronised across GPUs before each optimiser step.
  • Effective throughput scales nearly linearly with the number of GPUs.
# 2-GPU single-node training
model, trainer = train_default_zoobot_from_scratch(
    ...
    gpus=2,     # DDPStrategy activated automatically
    nodes=1
)

# 8-GPU, 2-node training
model, trainer = train_default_zoobot_from_scratch(
    ...
    gpus=8,
    nodes=2     # requires matching SLURM / cluster configuration
)
Set num_workers to a value less than the number of CPU cores available per node. With DDP, Lightning spawns num_workers processes per GPU, so num_workers * gpus processes total will be active.

Benchmarks and Replication Scripts

The benchmarks/ folder in the Zoobot GitHub repository contains SLURM submission scripts and Python entry points to replicate official Zoobot pretraining runs:
  • GZ DECaLS — 314,000 galaxies, EfficientNet-B0 backbone
  • GZ DESI — 8.7 million galaxies (GZD-1/2/5), larger architectures
  • GZ Evo — combined GZD-1/2/5, Hubble, Candels, and GZ2 data
These scripts are the canonical reference for reproducing the models described in the Walmsley et al. 2023 and Scaling Laws for Galaxy Images papers.
train_default_zoobot_from_scratch is a complex, research-grade function that has accumulated many arguments over time as pretraining methodology has evolved. It is being gradually migrated to a dedicated foundation model repository. If you need to train a new foundation model, please check the Zoobot GitHub issues and mailing list for the latest guidance on which repository to use.

Build docs developers (and LLMs) love