Most users should finetune a pretrained model rather than train from scratch. Zoobot ships with pretrained weights trained on tens of millions of Galaxy Zoo volunteer labels, and finetuning those weights typically yields excellent results even with a few hundred labelled galaxies. Training from scratch is reserved for pretraining entirely new foundation models — it requires large, diverse galaxy datasets and cluster-grade GPU resources.Documentation Index
Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt
Use this file to discover all available pages before exploring further.
When to Train from Scratch vs Finetune
| Scenario | Recommendation |
|---|---|
| You have a labelled dataset of any size | ✅ Finetune a pretrained model |
| You want to predict a new morphology task | ✅ Finetune a pretrained model |
| You need representations for a new survey | ✅ Finetune a pretrained model |
| You are building a new foundation model for a new domain | Train from scratch |
| You are replicating GZ DECaLS / GZ DESI training | Train from scratch |
| You need to pretrain on a proprietary dataset from scratch | Train from scratch |
train_default_zoobot_from_scratch only when you genuinely need to train new foundation model weights.
train_default_zoobot_from_scratch
train_default_zoobot_from_scratch is the single entry point for all from-scratch pretraining. It creates a ZoobotTree LightningModule, wires up a CatalogDataModule or WebDataModule, configures a Lightning Trainer with DDP, and runs the full training loop including early stopping and checkpointing.
Tuple[ZoobotTree, L.Trainer] — the trained model and the trainer with which it was trained.
Key Arguments
Data inputs
Folder to save training logs and trained model checkpoints. Will be created if it does not exist.
A
zoobot.shared.schemas.Schema object describing the Galaxy Zoo decision tree — the questions, answers, and their dependencies. Defines label_cols, question_answer_pairs, and dependencies.An augmentation configuration object passed to
galaxy_datasets.transforms.get_galaxy_transform(). Applies the same transform for both training and inference. Mutually exclusive with train_transform_cfg/test_transform_cfg. One of transform_cfg or the pair train_transform_cfg+test_transform_cfg must be supplied.Augmentation config object for the training set. Use when you want separate transforms for training (with heavy augmentation) and inference (with minimal augmentation). Must be paired with
test_transform_cfg.Augmentation config object for validation, test, and prediction. Must be paired with
train_transform_cfg.A single galaxy catalog with columns
id_str and file_loc. Will be automatically split into train and validation sets. Mutually exclusive with train_catalog/val_catalog.Pre-split training catalog with columns
id_str and file_loc.Pre-split validation catalog with columns
id_str and file_loc.Optional pre-split test catalog. If provided, test metrics will be logged after training.
List of WebDataset shard URLs for training. Mutually exclusive with catalog inputs. See WebDataset support below.
List of WebDataset shard URLs for validation.
List of WebDataset shard URLs for testing and prediction.
Directory to cache downloaded WebDataset shards locally. Only applies when using WebDataset URLs.
Training schedule
Maximum number of epochs to train for. Training will stop earlier if
patience epochs pass without any improvement to the validation loss.Number of epochs to wait for any improvement in
validation/supervised_loss before triggering early stopping.Model architecture
Architecture to use for the encoder. Must be a valid model name in
timm.list_models(). Popular choices include efficientnet_b0, convnext_nano, convnext_small, maxvit_tiny_tf_224, and efficientnetv2_s.Additional keyword arguments forwarded directly to
timm.create_model(). For example, {'drop_path_rate': 0.2} for stochastic depth on EfficientNet.Dropout probability applied in the Dirichlet head prior to the output layer.
Compile the encoder with
torch.compile (requires PyTorch ≥ 2.0). Can meaningfully speed up training throughput during pretraining, but is not recommended for finetuning.Optimiser
Base learning rate for the AdamW optimiser.
Beta coefficients (momentum parameters) for AdamW.
L2 weight decay for AdamW.
Optional learning rate scheduler configuration. Pass
{'name': 'plateau', 'patience': 5} for ReduceLROnPlateau, or {'cosine_schedule': True, 'warmup_epochs': 5, 'max_cosine_epochs': 200, 'max_learning_rate_reduction_factor': 0.01} for a cosine warmup schedule. Defaults to no scheduler.Hardware
Number of GPUs to use per node. When
gpus > 1, DDP (DDPStrategy) is used automatically. Set to 0 to train on CPU.Number of compute nodes for multi-node distributed training. Multi-node support may require cluster-specific configuration.
Enable automatic mixed precision (
16-mixed). Reduces GPU memory footprint by roughly half. May cause training instability on some architectures such as ResNet.Per-GPU batch size. With DDP, each GPU processes a full batch of this size (the data pool is divided, not the batch).
Number of CPU processes per dataloader. Should be less than the number of available CPU cores.
Number of batches to pre-load into memory per dataloader worker. Increase if GPU utilisation is low due to data loading. See PyTorch
DataLoader docs.Number of batches over which to accumulate gradients before an optimiser step. Effective batch size becomes
batch_size * accumulate_gradients * gpus.Synchronise batch normalisation statistics across GPUs. Useful when per-GPU batch sizes are very small.
Logging and checkpointing
A
lightning.pytorch.loggers.WandbLogger instance for experiment tracking on Weights & Biases. If None, falls back to CSV logging in save_dir.Number of best checkpoints to retain, ranked by
validation/supervised_loss.Global random seed passed to
lightning.seed_everything.Additional Lightning callbacks to append to the default set (which includes
ModelCheckpoint and EarlyStopping). Useful for custom logging, learning rate monitoring, or profiling.Example Usage
catalog instead of pre-split train_catalog/val_catalog, and Zoobot will split it automatically:
WebDataset Support for Large-Scale Training
For large-scale distributed pretraining, Zoobot supports WebDataset shards instead of on-disk image catalogs. Pass a list of shard URLs totrain_urls, val_urls, and test_urls:
test_urls is provided, the function will also run inference after training and save per-galaxy predictions to save_dir/test_predictions_<rank>.csv.
WebDataset shard format is recommended for training on thousands of GPUs across multiple nodes where individual file I/O becomes a bottleneck. For single-node training on up to ~8 GPUs, a standard
CatalogDataModule with file_loc paths performs well.Multi-GPU Training with DDP
Whengpus > 1, Zoobot automatically uses PyTorch Lightning’s DDPStrategy (Distributed Data Parallel). Under DDP:
- Each GPU receives a separate slice of the overall dataset each epoch.
- Each GPU computes forward and backward passes on a full local batch of
batch_sizeimages. - Gradients are synchronised across GPUs before each optimiser step.
- Effective throughput scales nearly linearly with the number of GPUs.
Benchmarks and Replication Scripts
Thebenchmarks/ folder in the Zoobot GitHub repository contains SLURM submission scripts and Python entry points to replicate official Zoobot pretraining runs:
- GZ DECaLS — 314,000 galaxies, EfficientNet-B0 backbone
- GZ DESI — 8.7 million galaxies (GZD-1/2/5), larger architectures
- GZ Evo — combined GZD-1/2/5, Hubble, Candels, and GZ2 data