Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt

Use this file to discover all available pages before exploring further.

Zoobot ships several loss functions to cover the distinct supervision signals it supports. The flagship loss is the Dirichlet-Multinomial loss, designed specifically for the imbalanced vote-count structure of Galaxy Zoo citizen science data. Simpler wrappers cover standard classification and regression finetuning tasks. All wrappers use reduction='none' so that per-example losses can be accumulated via torchmetrics.MeanMetric in both single-device and distributed-training settings.

Dirichlet-Multinomial Loss

Background

Galaxy Zoo asks volunteers a branching decision tree of questions. Because volunteers only answer a question if they chose the relevant answer to the previous question, different questions receive wildly different numbers of votes. A model trained on raw vote fractions would be dominated by well-voted root questions and would largely ignore leaf questions. The Dirichlet-Multinomial loss addresses this by:
  1. Having the model predict Dirichlet concentration parameters — one positive scalar per answer — rather than raw fractions.
  2. Evaluating each question separately using the negative log-likelihood of the observed vote counts under the Dirichlet-Multinomial distribution.
  3. Handling questions with zero votes gracefully (they contribute ~zero loss regardless of the model’s prediction).
ZoobotTree and FinetuneableZoobotTree apply this loss automatically — you do not need to instantiate it yourself for standard workflows.

get_dirichlet_loss_func

from zoobot.pytorch.estimators.define_model import get_dirichlet_loss_func

loss_func = get_dirichlet_loss_func(question_answer_pairs: dict)
Factory function that constructs a bound CustomMultiQuestionLoss.forward callable configured for the Dirichlet-Multinomial loss.
question_answer_pairs
dict
required
The same question_answer_pairs dict used to build a Schema. Maps question text to a list of answer suffixes.
Returnscallable with signature (inputs: torch.Tensor, targets: dict) -> torch.Tensor, where:
  • inputs has shape (batch, num_answers) — the model’s raw Dirichlet concentrations.
  • targets is a dict with answer-column keys and (batch,) integer count tensors as values.
  • Returns shape (batch, num_questions) — the per-question loss for each galaxy.

CustomMultiQuestionLoss

from zoobot.pytorch.training.losses import CustomMultiQuestionLoss

loss_module = CustomMultiQuestionLoss(
    question_answer_pairs,
    question_functional_loss,
    careful=False,
    sum_over_questions=False
)
torch.nn.Module that iterates over every question, slices the prediction and target tensors for that question’s answers, applies question_functional_loss to each, and stacks the per-question losses.
question_answer_pairs
dict
required
Maps question text to a list of answer suffix strings. Determines how the flat prediction vector is sliced.
question_functional_loss
callable
required
Loss function applied to each question independently. For the Dirichlet-Multinomial case this is get_dirichlet_neg_log_prob. Must accept (predictions_for_q, targets_for_q) and return a (batch,) tensor.
careful
bool
default:"False"
Deprecated. Formerly masked NaN/Inf predictions. No longer needed. Raises AssertionError if set to True.
sum_over_questions
bool
default:"False"
If True, sums losses across questions and returns shape (batch,). If False (default), returns shape (batch, num_questions). ZoobotTree uses False and sums manually after per-question metric logging.

get_dirichlet_neg_log_prob

from zoobot.pytorch.training.losses import get_dirichlet_neg_log_prob

loss = get_dirichlet_neg_log_prob(
    concentrations_for_q: torch.Tensor,   # (batch, answers)
    labels_for_q: torch.Tensor            # (batch, answers) — integer vote counts
) -> torch.Tensor                          # (batch,)
The core per-question loss. Computes the negative log-likelihood of the observed vote counts under the Dirichlet-Multinomial distribution parametrized by concentrations_for_q. This is a manual implementation that removes the Pyro dependency.

Cross-Entropy Loss

from zoobot.pytorch.training.finetune import cross_entropy_loss

loss = cross_entropy_loss(
    y_pred: torch.Tensor,
    y: torch.Tensor,
    label_smoothing: float = 0.0,
    weight=None
) -> torch.Tensor
Trivial wrapper around torch.nn.functional.cross_entropy with reduction='none'. Used automatically by FinetuneableZoobotClassifier.
y_pred
torch.Tensor
required
Logit predictions of shape (batch, num_classes).
y
torch.Tensor
required
Integer class targets of shape (batch,). Internally cast to .long().
label_smoothing
float
default:"0.0"
Label-smoothing factor. See torch.nn.functional.cross_entropy docs for details.
weight
torch.Tensor | None
default:"None"
Per-class loss weights. See torch.nn.functional.cross_entropy docs. If provided, it is moved to the same device as y automatically.
Returnstorch.Tensor of shape (batch,) — unreduced per-example cross-entropy loss.

MSE Loss

from zoobot.pytorch.training.finetune import mse_loss

loss = mse_loss(
    y_pred: torch.Tensor,
    y: torch.Tensor
) -> torch.Tensor
Trivial wrapper around torch.nn.functional.mse_loss with reduction='none'. Used by FinetuneableZoobotRegressor when loss='mse' (the default). y is cast to .float() internally. Returnstorch.Tensor of shape (batch,) — unreduced per-example squared error.

L1 / MAE Loss

from zoobot.pytorch.training.finetune import l1_loss

loss = l1_loss(
    y_pred: torch.Tensor,
    y: torch.Tensor
) -> torch.Tensor
Trivial wrapper around torch.nn.functional.l1_loss with reduction='none'. Used by FinetuneableZoobotRegressor when loss='mae'. Returnstorch.Tensor of shape (batch,) — unreduced per-example absolute error.

LinearHead

from zoobot.pytorch.training.finetune import LinearHead

head = LinearHead(
    input_dim: int,
    output_dim: int,
    head_dropout_prob: float = 0.5,
    activation=None
)
Small torch.nn.Module used as the prediction head for all FinetuneableZoobot classes. Applies dropout, then a linear layer, then an optional activation.

Parameters

input_dim
int
required
Dimensionality of the encoder output — the input feature size. Typically determined by get_encoder_dim(encoder) or encoder.num_features.
output_dim
int
required
Number of outputs. Use the number of classes for classification, 1 for regression.
head_dropout_prob
float
default:"0.5"
Probability of zeroing each element in the dropout layer.
activation
callable | None
default:"None"
Optional activation applied after the linear layer. For example, pass torch.nn.functional.sigmoid to constrain regression outputs to [0, 1] (used by FinetuneableZoobotRegressor with unit_interval=True). When None, the raw linear output (logits) is returned — recommended for classification to maintain gradient stability with cross_entropy_loss.

Forward pass

The forward pass runs:
dropout(x) → linear(x) → activation(x)  [if activation is not None]
When output_dim == 1, the output tensor is squeezed to shape (batch,) for compatibility with scalar regression targets.

Example

from zoobot.pytorch.training.finetune import LinearHead
import torch

head = LinearHead(input_dim=640, output_dim=2, head_dropout_prob=0.5)

features = torch.randn(8, 640)   # batch of 8 encoder outputs
logits = head(features)           # shape: (8, 2)

Design Note: reduction='none'

All loss wrappers (cross_entropy_loss, mse_loss, l1_loss) return unreduced per-example losses. This is intentional: Zoobot aggregates losses using torchmetrics.MeanMetric, which correctly handles distributed training by accumulating numerators and denominators separately across devices before computing the final mean. Using reduction='mean' inside the loss function would make it incompatible with this aggregation pattern.

Build docs developers (and LLMs) love