Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt

Use this file to discover all available pages before exploring further.

define_model.py contains the fundamental building blocks used by all Zoobot models: a timm-based encoder factory, a Dirichlet head factory, the ZoobotTree LightningModule, and the GenericLightningModule base class. Most users interact with the higher-level FinetuneableZoobot classes rather than these functions directly, but the factories are useful when integrating Zoobot encoders into custom pipelines.

Import

from zoobot.pytorch.estimators.define_model import (
    get_pytorch_encoder,
    get_pytorch_dirichlet_head,
    get_encoder_dim,
)

get_pytorch_encoder

get_pytorch_encoder(
    architecture_name="convnext_nano",
    channels=3,
    **timm_kwargs
) -> torch.nn.Module
Create a trainable timm model to use as an image encoder. This is a thin wrapper around timm.create_model that sets num_classes=0 to return a global-pooled feature vector rather than class logits.

Parameters

architecture_name
str
default:"\"convnext_nano\""
Name of the timm architecture to instantiate. Must be present in timm.list_models(). Common choices used in Zoobot include 'convnext_nano', 'efficientnet_b0', and 'resnet50'. Passing the legacy alias 'efficientnet' is accepted but deprecated — use 'efficientnet_b0' explicitly.
channels
int
default:"3"
Number of input channels. Use 3 for RGB images and 1 for greyscale. Passed to timm as in_chans.
**timm_kwargs
dict
Additional keyword arguments forwarded directly to timm.create_model. Examples include drop_path_rate=0.2 (stochastic depth for EfficientNet), pretrained=True (ImageNet weights), or any other timm model-level argument.

Returns

torch.nn.Module — a timm model configured as a feature extractor. The forward pass returns a (batch, encoder_dim) feature tensor.

Example

from zoobot.pytorch.estimators.define_model import get_pytorch_encoder

encoder = get_pytorch_encoder(
    architecture_name='convnext_nano',
    channels=3
)

# Pass a batch of images; get back feature vectors
import torch
x = torch.randn(4, 3, 224, 224)
features = encoder(x)  # shape: (4, encoder_dim)
print(features.shape)

get_pytorch_dirichlet_head

get_pytorch_dirichlet_head(
    encoder_dim: int,
    output_dim: int,
    test_time_dropout: bool,
    dropout_rate: float
) -> torch.nn.Sequential
Build the prediction head used for Galaxy Zoo decision-tree outputs. The head predicts Dirichlet concentration parameters, one per answer in the decision tree. The returned Sequential module contains:
  1. A Dropout layer (or PermaDropout when test_time_dropout=True) with probability dropout_rate.
  2. A custom linear layer that maps the encoder output to output_dim Dirichlet concentrations (custom_top_dirichlet).
This head is also used when finetuning on a new decision tree via FinetuneableZoobotTree.

Parameters

encoder_dim
int
required
Dimensionality of the encoder output — the input size expected by this head. Obtain via get_encoder_dim(encoder).
output_dim
int
required
Number of output dimensions. Must equal the total number of answers in the decision tree (e.g. 34 for GZ DECaLS, or len(schema.label_cols) for a custom schema).
test_time_dropout
bool
required
If True, uses PermaDropout so that dropout remains active at inference time, enabling Monte Carlo dropout uncertainty estimates. Set to False for deterministic predictions.
dropout_rate
float
required
Probability of an element being zeroed during dropout. See torch.nn.Dropout docs.

Returns

torch.nn.Sequential — a PyTorch module expecting a (batch, encoder_dim) tensor and returning (batch, output_dim) Dirichlet concentrations.

get_encoder_dim

get_encoder_dim(encoder, channels=3) -> int
Utility that determines the output feature dimension of an encoder by running a small dummy forward pass. Handles the edge case where the encoder expects a single-channel input by automatically retrying with channels=1.

Parameters

encoder
torch.nn.Module
required
The encoder whose output dimension you want to discover.
channels
int
default:"3"
Number of input channels to use in the test forward pass. Falls back to 1 automatically if a channel-mismatch RuntimeError is raised.

Returns

int — the last dimension of the encoder’s output tensor (i.e. the feature vector length).

load_pretrained_zoobot

from zoobot.pytorch.training.finetune import load_pretrained_zoobot

load_pretrained_zoobot(checkpoint_loc: str) -> torch.nn.Module
Load only the encoder from a saved ZoobotTree or FinetuneableZoobotTree Lightning checkpoint. The rest of the module (head, loss, optimizer state) is discarded. Internally, the function first tries to load the checkpoint as a ZoobotTree; if that raises a TypeError it falls back to FinetuneableZoobotTree. On CPU-only machines the function automatically sets map_location so GPU-trained checkpoints load correctly.
Most users do not need to call this directly. FinetuneableZoobotAbstract.__init__ calls it automatically when you pass zoobot_checkpoint_loc=....

Parameters

checkpoint_loc
str
required
Path to the Lightning .ckpt file. The checkpoint must have been saved from a module that exposes an .encoder attribute (i.e. ZoobotTree, FinetuneableZoobotClassifier, or FinetuneableZoobotTree).

Returns

torch.nn.Module — the pretrained timm encoder extracted from the checkpoint.

Example

from zoobot.pytorch.training.finetune import load_pretrained_zoobot

encoder = load_pretrained_zoobot('path/to/zoobot_pretrained.ckpt')

# Use as a frozen feature extractor
import torch
encoder.eval()
with torch.no_grad():
    features = encoder(torch.randn(1, 3, 224, 224))

GenericLightningModule

GenericLightningModule is the abstract base class shared by all Zoobot LightningModules. It defines the common training loop structure (training_step, validation_step, test_step, predict_step) and metric-logging utilities. You should not instantiate it directly — use ZoobotTree or one of the FinetuneableZoobot subclasses.
Most users interact with FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, or FinetuneableZoobotTree rather than the factory functions on this page. The factories described here are most relevant when building custom architectures or integrating a Zoobot encoder into an external pipeline.

Build docs developers (and LLMs) love