Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt

Use this file to discover all available pages before exploring further.

ZoobotTree is the PyTorch Lightning module used to train a Zoobot foundation model from scratch. It wires together a timm-based encoder, a Dirichlet prediction head, and the CustomMultiQuestionLoss into a fully self-contained LightningModule.
Most users should not use ZoobotTree directly. If you want to apply Zoobot to your own data, use one of the finetuning classes (FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, or FinetuneableZoobotTree) which load a pretrained encoder and add a task-specific head. ZoobotTree is intended for pretraining new foundation models on large Galaxy Zoo vote-count datasets.

Import

from zoobot.pytorch.estimators.define_model import ZoobotTree

Constructor

ZoobotTree(
    output_dim,
    question_answer_pairs=None,
    dependencies=None,
    architecture_name="convnext_nano",
    channels=3,
    test_time_dropout=False,
    compile_encoder=False,
    timm_kwargs={},
    dropout_rate=0.2,
    learning_rate=1e-3,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    scheduler_params={}
)

Parameters

output_dim
int
required
Total number of output dimensions of the prediction head — equal to the number of answers across the entire decision tree (e.g. len(schema.label_cols)). For GZ DECaLS this is 34.
question_answer_pairs
dict
default:"None"
Dictionary mapping question text to a list of answer suffix strings. For example:
{
    'smooth-or-featured': ['_smooth', '_featured-or-disk', '_artifact'],
    'has-spiral-arms': ['_yes', '_no']
}
Used to construct a schemas.Schema internally and to build the CustomMultiQuestionLoss. See Schemas.
dependencies
dict
default:"None"
Dictionary mapping each question’s text to the answer text that triggers it. For example {'has-spiral-arms': 'smooth-or-featured_featured-or-disk'}. Required — an AssertionError is raised if None is passed at runtime.
architecture_name
str
default:"\"convnext_nano\""
Name of the timm architecture to use as the encoder backbone. Must be present in timm.list_models(). Passed to get_pytorch_encoder.
channels
int
default:"3"
Number of input image channels. Use 3 for RGB or 1 for greyscale. Passed to get_pytorch_encoder.
test_time_dropout
bool
default:"False"
If True, keeps dropout active at inference time (Monte Carlo dropout). Passed to get_pytorch_dirichlet_head.
compile_encoder
bool
default:"False"
If True, wraps the encoder with torch.compile for potential speed-ups on PyTorch ≥ 2.0. Logs a warning when enabled.
timm_kwargs
dict
default:"{}"
Additional keyword arguments forwarded to timm.create_model when building the encoder (e.g. drop_path_rate=0.2).
dropout_rate
float
default:"0.2"
Dropout probability applied in the Dirichlet head.
learning_rate
float
default:"1e-3"
Learning rate for the AdamW optimizer.
betas
tuple
default:"(0.9, 0.999)"
Beta coefficients for the AdamW optimizer. Matches the PyTorch default.
weight_decay
float
default:"0.01"
L2 weight-decay coefficient for the AdamW optimizer.
scheduler_params
dict
default:"{}"
Optional learning rate scheduler configuration. Pass {'name': 'plateau', 'patience': 5} for ReduceLROnPlateau, or {'cosine_schedule': True, 'warmup_epochs': 5, 'max_cosine_epochs': 100, 'max_learning_rate_reduction_factor': 0.01} for cosine annealing with warm-up. An empty dict (default) disables the scheduler.

Key Attributes

After construction, the following attributes are available:
AttributeTypeDescription
model.encodertorch.nn.ModuleThe timm encoder backbone
model.headtorch.nn.SequentialThe Dirichlet prediction head
model.schemaschemas.SchemaDecision tree schema built from question_answer_pairs and dependencies
model.encoder_dimintOutput feature dimensionality of the encoder
model.loss_funccallableThe CustomMultiQuestionLoss.forward bound method

Loading from Checkpoint

Lightning saves all constructor arguments as hyperparameters, so checkpoints can be restored without specifying them again:
from zoobot.pytorch.estimators.define_model import ZoobotTree

# Restore the full model
model = ZoobotTree.load_from_checkpoint('path/to/checkpoint.ckpt')

# Extract just the encoder for downstream use
encoder = model.encoder
To load only the encoder from a checkpoint (e.g. before finetuning), use the convenience function:
from zoobot.pytorch.training.finetune import load_pretrained_zoobot

encoder = load_pretrained_zoobot('path/to/checkpoint.ckpt')

Optimizer and Scheduler

ZoobotTree uses torch.optim.AdamW with the learning_rate, betas, and weight_decay passed to the constructor. Two scheduler strategies are supported via scheduler_params:
  • plateauReduceLROnPlateau monitoring validation/loss, with a minimum LR of 1e-6.
  • cosine_scheduleCosineWarmupScheduler from zoobot.pytorch.training.schedulers, with configurable warmup and decay.
If scheduler_params is empty (the default), no scheduler is used.

Typical Usage

ZoobotTree is normally invoked through train_default_zoobot_from_scratch rather than instantiated directly. Below is a minimal example of direct usage:
from zoobot.pytorch.estimators.define_model import ZoobotTree
import lightning as L
from galaxy_datasets.pytorch.galaxy_datamodule import CatalogDataModule

question_answer_pairs = {
    'smooth-or-featured': ['_smooth', '_featured-or-disk', '_artifact'],
    'has-spiral-arms': ['_yes', '_no']
}
dependencies = {
    'smooth-or-featured': None,                                      # first question
    'has-spiral-arms': 'smooth-or-featured_featured-or-disk'
}

model = ZoobotTree(
    output_dim=5,  # 3 + 2 answers
    question_answer_pairs=question_answer_pairs,
    dependencies=dependencies,
    architecture_name='convnext_nano',
    channels=3,
)

trainer = L.Trainer(max_epochs=100, accelerator='gpu', devices=1)
trainer.fit(model, datamodule=my_datamodule)
For training on a pre-existing Galaxy Zoo dataset, pre-built schemas are available in zoobot.shared.schemas (e.g. decals_dr5_ortho_schema, gz2_ortho_schema, desi_schema). See Schemas for the full list.
See the Training from Scratch guide for a complete worked example.

Build docs developers (and LLMs) love