Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt

Use this file to discover all available pages before exploring further.

The Zoobot package has many classes and methods that work together. This guide is a map of the key moving parts — how they relate, what each one is responsible for, and which file to look in when you need to dig deeper.

Two Roles, One Package

Zoobot has two distinct roles:
  1. Finetuningpytorch/training/finetune.py is the heart of the package. You will use these classes to load pretrained models and adapt them to your own prediction tasks. This is what most users need.
  2. Training from Scratchpytorch/estimators/define_model.py and pytorch/training/train_with_pytorch_lightning.py create and train Zoobot foundation models from scratch. These are not required for finetuning and will eventually be migrated to a separate dedicated pretraining repository.
Zoobot’s pretraining code is being actively refactored into a dedicated foundation model repository. The finetuning classes in finetune.py are stable and are the long-term home of the package.

Finetuning: The FinetuneableZoobot Classes

There are three user-facing classes for finetuning, each targeting a different type of prediction task:

FinetuneableZoobotClassifier

Classification tasks — binary, multi-class, or multi-label.Uses cross-entropy loss. Accepts num_classes and optionally label_smoothing. Best for tasks like “is this galaxy merging?” or “which Hubble type?”.

FinetuneableZoobotRegressor

Regression tasks — continuous or fractional outputs.Uses MSE or MAE loss. Best for predicting a continuous quantity such as a morphology fraction or photometric measurement from galaxy images.

FinetuneableZoobotTree

Decision tree vote counts — Galaxy Zoo–style multi-question labels.Uses Dirichlet-Multinomial loss. Best for finetuning on a new set of Galaxy Zoo vote count columns following the same branching question structure.
All three classes are subclasses of FinetuneableZoobotAbstract, which is the non-user-facing base class. FinetuneableZoobotAbstract handles the common mechanics:
  • Loading a pretrained encoder from a checkpoint or HuggingFace Hub
  • Accepting arguments that control the finetuning process (learning rate, frozen layers, scheduler, etc.)
  • Providing training_step, validation_step, and test_step implementations
  • Setting up the optimiser with optional layer-wise learning rate decay
Each user-facing subclass then adds what is specific to its task — for example, FinetuneableZoobotClassifier accepts num_classes and attaches a linear classification head with the appropriate cross-entropy loss.

How PyTorch Lightning Fits In

Zoobot is built on PyTorch Lightning, a framework that handles the boilerplate of distributed training so that model code stays clean and framework-agnostic. Every Zoobot model is a LightningModule subclass. LightningModules define what to do at each stage of training:
  • training_step — forward pass and loss on a training batch
  • validation_step — forward pass and loss on a validation batch
  • test_step — forward pass and loss on a test batch
  • predict_step — forward pass only, no loss, used for inference
  • configure_optimizers — returns the optimiser (and optional scheduler)
A Lightning Trainer then takes a LightningModule and a DataModule and handles how to run training — distributing across devices, managing epochs, running callbacks, writing logs, and saving checkpoints. The Trainer is obtained via finetune.get_trainer(...) for finetuning or constructed directly inside train_default_zoobot_from_scratch for pretraining.

The Training Flow

# --- Finetuning flow ---

# 1. Define the model: loads a pretrained timm encoder + attaches a new head
#    FinetuneableZoobotTree also specifies how to train (LightningModule)
model = FinetuneableZoobotTree(
    checkpoint_loc='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
    schema=my_schema
)

# 2. Configure the Trainer: devices, max epochs, early stopping, checkpointing
trainer = get_trainer(save_dir, max_epochs=50)

# 3. Run training: Lightning calls training_step / validation_step each epoch
trainer.fit(model, datamodule)

# 4. Run inference: Lightning calls predict_step on every batch
trainer.predict(model, datamodule)
The same Trainer.predict pattern is used by zoobot.pytorch.predictions.predict_on_catalog.predict to make predictions on large unlabelled catalogs.

ZoobotTree: Training from Scratch

ZoobotTree (in define_model.py) plays a role analogous to FinetuneableZoobotAbstract, but for training from scratch rather than finetuning. It is also a LightningModule, and it uses the same encoder + head structure internally. The key differences from the finetuning classes are:
  • InitialisationZoobotTree builds the encoder from a timm architecture name and randomly initialises weights, rather than loading a pretrained checkpoint.
  • Optimiser — Uses AdamW across all parameters with optional plateau or cosine learning rate scheduling, tuned for long pretraining runs.
  • Head — Always uses a DirichletHead (predicting Dirichlet concentration parameters for each decision tree answer).
Some generic utilities (metric logging, predict_step) are defined in define_model.py’s GenericLightningModule base class, and are inherited by both ZoobotTree and FinetuneableZoobotAbstract.

The Encoder: timm Under the Hood

Both ZoobotTree and the FinetuneableZoobot classes use timm (PyTorch Image Models) to provide the convolutional or transformer encoder backbone.
# Inside define_model.get_pytorch_encoder():
import timm

encoder = timm.create_model(
    architecture_name,   # e.g. 'convnext_nano', 'efficientnet_b0', 'maxvit_tiny_tf_224'
    in_chans=channels,   # 1 for greyscale, 3 for RGB
    num_classes=0        # num_classes=0 returns the pooled feature vector, not logits
)
Using num_classes=0 causes timm to return the pooled representation vector from the encoder (the global average-pooled feature map), which Zoobot’s head then takes as input. The dimensionality of this vector is detected automatically at init time by running a dummy forward pass. When finetuning from a pretrained Zoobot checkpoint, the encoder is extracted directly:
encoder = ZoobotTree.load_from_checkpoint(checkpoint_loc).encoder
This encoder is a plain nn.Module and can be used in any downstream PyTorch code.

Complete Code Flow

# === Training from scratch ===
#
# ZoobotTree.__init__:
#   encoder = get_pytorch_encoder(architecture_name)   # timm model, num_classes=0
#   head    = get_pytorch_dirichlet_head(encoder_dim)  # Dropout + linear → softplus
#   loss    = CustomMultiQuestionLoss(Dirichlet NLL)
#
# train_default_zoobot_from_scratch:
#   datamodule = CatalogDataModule(...) or WebDataModule(...)
#   trainer    = L.Trainer(DDPStrategy, EarlyStopping, ModelCheckpoint, ...)
#   trainer.fit(lightning_model, datamodule)
#   return lightning_model, trainer

# === Finetuning ===
#
# FinetuneableZoobotTree.__init__:
#   encoder = ZoobotTree.load_from_checkpoint(checkpoint_loc).encoder  (frozen or unfrozen)
#   head    = get_pytorch_dirichlet_head(encoder_dim)
#   loss    = CustomMultiQuestionLoss(Dirichlet NLL)
#
# get_trainer(save_dir):
#   trainer = L.Trainer(accelerator, EarlyStopping, ModelCheckpoint, ...)
#   return trainer
#
# trainer.fit(model, datamodule)    → training loop
# trainer.predict(model, datamodule) → inference loop (used by predict_on_catalog)

Module Map

pytorch/training/finetune.py

Main finetuning classes.FinetuneableZoobotAbstract, FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, FinetuneableZoobotTree, and the get_trainer() helper. Start here for all finetuning work.

pytorch/estimators/define_model.py

Model factory and from-scratch training class.ZoobotTree, GenericLightningModule, get_pytorch_encoder(), get_pytorch_dirichlet_head(), get_encoder_dim(). Shared utilities (metrics, predict_step) are also defined here.

pytorch/training/train_with_pytorch_lightning.py

From-scratch training entry point.train_default_zoobot_from_scratch() — wires together ZoobotTree, a data module, and a Lightning Trainer with DDP, checkpointing, and early stopping. Also contains get_default_callbacks().

pytorch/predictions/predict_on_catalog.py

Batch prediction utility.predict() — takes a trained model, an unlabelled galaxy catalog, and a save path, and writes a CSV of predictions using trainer.predict() under the hood.

shared/schemas.py

Decision tree schema.Schema, Question, and Answer classes that encode the Galaxy Zoo question-and-answer tree, its branch dependencies, and the corresponding label_cols used in catalogs and loss functions.

shared/losses.py

Loss functions.CustomMultiQuestionLoss and get_dirichlet_neg_log_prob — the Dirichlet-Multinomial negative log-likelihood loss used for training on volunteer vote counts across the decision tree.

Summary

The layered design has a clear separation of concerns:
  • timm provides the encoder architecture — defining complicated convnets and transformers is someone else’s job.
  • define_model.py assembles encoder + head into a LightningModule with a training step.
  • finetune.py extends that pattern to loading pretrained weights and finetuning on new tasks.
  • L.Trainer handles everything at scale: devices, epochs, logging, checkpoints.
  • predict_on_catalog wraps Trainer.predict for user-friendly batch inference.
As you can see, there are quite a few layers (pun intended) to training Zoobot models. The goal of this design is to keep each layer simple to use and easy to extend, whether you are finetuning for a new morphology task, extracting representations, or building a new foundation model.

Build docs developers (and LLMs) love