The Zoobot package has many classes and methods that work together. This guide is a map of the key moving parts — how they relate, what each one is responsible for, and which file to look in when you need to dig deeper.Documentation Index
Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt
Use this file to discover all available pages before exploring further.
Two Roles, One Package
Zoobot has two distinct roles:- Finetuning —
pytorch/training/finetune.pyis the heart of the package. You will use these classes to load pretrained models and adapt them to your own prediction tasks. This is what most users need. - Training from Scratch —
pytorch/estimators/define_model.pyandpytorch/training/train_with_pytorch_lightning.pycreate and train Zoobot foundation models from scratch. These are not required for finetuning and will eventually be migrated to a separate dedicated pretraining repository.
Finetuning: The FinetuneableZoobot Classes
There are three user-facing classes for finetuning, each targeting a different type of prediction task:
FinetuneableZoobotClassifier
Classification tasks — binary, multi-class, or multi-label.Uses cross-entropy loss. Accepts
num_classes and optionally label_smoothing. Best for tasks like “is this galaxy merging?” or “which Hubble type?”.FinetuneableZoobotRegressor
Regression tasks — continuous or fractional outputs.Uses MSE or MAE loss. Best for predicting a continuous quantity such as a morphology fraction or photometric measurement from galaxy images.
FinetuneableZoobotTree
Decision tree vote counts — Galaxy Zoo–style multi-question labels.Uses Dirichlet-Multinomial loss. Best for finetuning on a new set of Galaxy Zoo vote count columns following the same branching question structure.
FinetuneableZoobotAbstract, which is the non-user-facing base class. FinetuneableZoobotAbstract handles the common mechanics:
- Loading a pretrained encoder from a checkpoint or HuggingFace Hub
- Accepting arguments that control the finetuning process (learning rate, frozen layers, scheduler, etc.)
- Providing
training_step,validation_step, andtest_stepimplementations - Setting up the optimiser with optional layer-wise learning rate decay
FinetuneableZoobotClassifier accepts num_classes and attaches a linear classification head with the appropriate cross-entropy loss.
How PyTorch Lightning Fits In
Zoobot is built on PyTorch Lightning, a framework that handles the boilerplate of distributed training so that model code stays clean and framework-agnostic. Every Zoobot model is aLightningModule subclass. LightningModules define what to do at each stage of training:
training_step— forward pass and loss on a training batchvalidation_step— forward pass and loss on a validation batchtest_step— forward pass and loss on a test batchpredict_step— forward pass only, no loss, used for inferenceconfigure_optimizers— returns the optimiser (and optional scheduler)
Trainer then takes a LightningModule and a DataModule and handles how to run training — distributing across devices, managing epochs, running callbacks, writing logs, and saving checkpoints. The Trainer is obtained via finetune.get_trainer(...) for finetuning or constructed directly inside train_default_zoobot_from_scratch for pretraining.
The Training Flow
Trainer.predict pattern is used by zoobot.pytorch.predictions.predict_on_catalog.predict to make predictions on large unlabelled catalogs.
ZoobotTree: Training from Scratch
ZoobotTree (in define_model.py) plays a role analogous to FinetuneableZoobotAbstract, but for training from scratch rather than finetuning. It is also a LightningModule, and it uses the same encoder + head structure internally.
The key differences from the finetuning classes are:
- Initialisation —
ZoobotTreebuilds the encoder from atimmarchitecture name and randomly initialises weights, rather than loading a pretrained checkpoint. - Optimiser — Uses
AdamWacross all parameters with optional plateau or cosine learning rate scheduling, tuned for long pretraining runs. - Head — Always uses a
DirichletHead(predicting Dirichlet concentration parameters for each decision tree answer).
predict_step) are defined in define_model.py’s GenericLightningModule base class, and are inherited by both ZoobotTree and FinetuneableZoobotAbstract.
The Encoder: timm Under the Hood
Both ZoobotTree and the FinetuneableZoobot classes use timm (PyTorch Image Models) to provide the convolutional or transformer encoder backbone.
num_classes=0 causes timm to return the pooled representation vector from the encoder (the global average-pooled feature map), which Zoobot’s head then takes as input. The dimensionality of this vector is detected automatically at init time by running a dummy forward pass.
When finetuning from a pretrained Zoobot checkpoint, the encoder is extracted directly:
encoder is a plain nn.Module and can be used in any downstream PyTorch code.
Complete Code Flow
Module Map
pytorch/training/finetune.py
Main finetuning classes.
FinetuneableZoobotAbstract, FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, FinetuneableZoobotTree, and the get_trainer() helper. Start here for all finetuning work.pytorch/estimators/define_model.py
Model factory and from-scratch training class.
ZoobotTree, GenericLightningModule, get_pytorch_encoder(), get_pytorch_dirichlet_head(), get_encoder_dim(). Shared utilities (metrics, predict_step) are also defined here.pytorch/training/train_with_pytorch_lightning.py
From-scratch training entry point.
train_default_zoobot_from_scratch() — wires together ZoobotTree, a data module, and a Lightning Trainer with DDP, checkpointing, and early stopping. Also contains get_default_callbacks().pytorch/predictions/predict_on_catalog.py
Batch prediction utility.
predict() — takes a trained model, an unlabelled galaxy catalog, and a save path, and writes a CSV of predictions using trainer.predict() under the hood.shared/schemas.py
Decision tree schema.
Schema, Question, and Answer classes that encode the Galaxy Zoo question-and-answer tree, its branch dependencies, and the corresponding label_cols used in catalogs and loss functions.shared/losses.py
Loss functions.
CustomMultiQuestionLoss and get_dirichlet_neg_log_prob — the Dirichlet-Multinomial negative log-likelihood loss used for training on volunteer vote counts across the decision tree.Summary
The layered design has a clear separation of concerns:timmprovides the encoder architecture — defining complicated convnets and transformers is someone else’s job.define_model.pyassembles encoder + head into aLightningModulewith a training step.finetune.pyextends that pattern to loading pretrained weights and finetuning on new tasks.L.Trainerhandles everything at scale: devices, epochs, logging, checkpoints.predict_on_catalogwrapsTrainer.predictfor user-friendly batch inference.