Documentation Index
Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt
Use this file to discover all available pages before exploring further.
ZoobotTree is the PyTorch Lightning module used to train a Zoobot foundation model from scratch. It wires together a timm-based encoder, a Dirichlet prediction head, and the CustomMultiQuestionLoss into a fully self-contained LightningModule.
Most users should not use
ZoobotTree directly. If you want to apply Zoobot to your own data, use one of the finetuning classes (FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, or FinetuneableZoobotTree) which load a pretrained encoder and add a task-specific head. ZoobotTree is intended for pretraining new foundation models on large Galaxy Zoo vote-count datasets.Import
Constructor
Parameters
Total number of output dimensions of the prediction head — equal to the number of answers across the entire decision tree (e.g.
len(schema.label_cols)). For GZ DECaLS this is 34.Dictionary mapping question text to a list of answer suffix strings. For example:Used to construct a
schemas.Schema internally and to build the CustomMultiQuestionLoss. See Schemas.Dictionary mapping each question’s text to the answer text that triggers it. For example
{'has-spiral-arms': 'smooth-or-featured_featured-or-disk'}. Required — an AssertionError is raised if None is passed at runtime.Name of the timm architecture to use as the encoder backbone. Must be present in
timm.list_models(). Passed to get_pytorch_encoder.Number of input image channels. Use
3 for RGB or 1 for greyscale. Passed to get_pytorch_encoder.If
True, keeps dropout active at inference time (Monte Carlo dropout). Passed to get_pytorch_dirichlet_head.If
True, wraps the encoder with torch.compile for potential speed-ups on PyTorch ≥ 2.0. Logs a warning when enabled.Additional keyword arguments forwarded to
timm.create_model when building the encoder (e.g. drop_path_rate=0.2).Dropout probability applied in the Dirichlet head.
Learning rate for the
AdamW optimizer.Beta coefficients for the
AdamW optimizer. Matches the PyTorch default.L2 weight-decay coefficient for the
AdamW optimizer.Optional learning rate scheduler configuration. Pass
{'name': 'plateau', 'patience': 5} for ReduceLROnPlateau, or {'cosine_schedule': True, 'warmup_epochs': 5, 'max_cosine_epochs': 100, 'max_learning_rate_reduction_factor': 0.01} for cosine annealing with warm-up. An empty dict (default) disables the scheduler.Key Attributes
After construction, the following attributes are available:| Attribute | Type | Description |
|---|---|---|
model.encoder | torch.nn.Module | The timm encoder backbone |
model.head | torch.nn.Sequential | The Dirichlet prediction head |
model.schema | schemas.Schema | Decision tree schema built from question_answer_pairs and dependencies |
model.encoder_dim | int | Output feature dimensionality of the encoder |
model.loss_func | callable | The CustomMultiQuestionLoss.forward bound method |
Loading from Checkpoint
Lightning saves all constructor arguments as hyperparameters, so checkpoints can be restored without specifying them again:Optimizer and Scheduler
ZoobotTree uses torch.optim.AdamW with the learning_rate, betas, and weight_decay passed to the constructor. Two scheduler strategies are supported via scheduler_params:
plateau—ReduceLROnPlateaumonitoringvalidation/loss, with a minimum LR of1e-6.cosine_schedule—CosineWarmupSchedulerfromzoobot.pytorch.training.schedulers, with configurable warmup and decay.
scheduler_params is empty (the default), no scheduler is used.
Typical Usage
ZoobotTree is normally invoked through train_default_zoobot_from_scratch rather than instantiated directly. Below is a minimal example of direct usage:
For training on a pre-existing Galaxy Zoo dataset, pre-built schemas are available in
zoobot.shared.schemas (e.g. decals_dr5_ortho_schema, gz2_ortho_schema, desi_schema). See Schemas for the full list.