Zoobot ships several loss functions to cover the distinct supervision signals it supports. The flagship loss is the Dirichlet-Multinomial loss, designed specifically for the imbalanced vote-count structure of Galaxy Zoo citizen science data. Simpler wrappers cover standard classification and regression finetuning tasks. All wrappers useDocumentation Index
Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt
Use this file to discover all available pages before exploring further.
reduction='none' so that per-example losses can be accumulated via torchmetrics.MeanMetric in both single-device and distributed-training settings.
Dirichlet-Multinomial Loss
Background
Galaxy Zoo asks volunteers a branching decision tree of questions. Because volunteers only answer a question if they chose the relevant answer to the previous question, different questions receive wildly different numbers of votes. A model trained on raw vote fractions would be dominated by well-voted root questions and would largely ignore leaf questions. The Dirichlet-Multinomial loss addresses this by:- Having the model predict Dirichlet concentration parameters — one positive scalar per answer — rather than raw fractions.
- Evaluating each question separately using the negative log-likelihood of the observed vote counts under the Dirichlet-Multinomial distribution.
- Handling questions with zero votes gracefully (they contribute ~zero loss regardless of the model’s prediction).
ZoobotTree and FinetuneableZoobotTree apply this loss automatically — you do not need to instantiate it yourself for standard workflows.
get_dirichlet_loss_func
CustomMultiQuestionLoss.forward callable configured for the Dirichlet-Multinomial loss.
The same
question_answer_pairs dict used to build a Schema. Maps question text to a list of answer suffixes.callable with signature (inputs: torch.Tensor, targets: dict) -> torch.Tensor, where:
inputshas shape(batch, num_answers)— the model’s raw Dirichlet concentrations.targetsis a dict with answer-column keys and(batch,)integer count tensors as values.- Returns shape
(batch, num_questions)— the per-question loss for each galaxy.
CustomMultiQuestionLoss
torch.nn.Module that iterates over every question, slices the prediction and target tensors for that question’s answers, applies question_functional_loss to each, and stacks the per-question losses.
Maps question text to a list of answer suffix strings. Determines how the flat prediction vector is sliced.
Loss function applied to each question independently. For the Dirichlet-Multinomial case this is
get_dirichlet_neg_log_prob. Must accept (predictions_for_q, targets_for_q) and return a (batch,) tensor.Deprecated. Formerly masked NaN/Inf predictions. No longer needed. Raises
AssertionError if set to True.If
True, sums losses across questions and returns shape (batch,). If False (default), returns shape (batch, num_questions). ZoobotTree uses False and sums manually after per-question metric logging.get_dirichlet_neg_log_prob
concentrations_for_q. This is a manual implementation that removes the Pyro dependency.
Cross-Entropy Loss
torch.nn.functional.cross_entropy with reduction='none'. Used automatically by FinetuneableZoobotClassifier.
Logit predictions of shape
(batch, num_classes).Integer class targets of shape
(batch,). Internally cast to .long().Label-smoothing factor. See
torch.nn.functional.cross_entropy docs for details.Per-class loss weights. See
torch.nn.functional.cross_entropy docs. If provided, it is moved to the same device as y automatically.torch.Tensor of shape (batch,) — unreduced per-example cross-entropy loss.
MSE Loss
torch.nn.functional.mse_loss with reduction='none'. Used by FinetuneableZoobotRegressor when loss='mse' (the default). y is cast to .float() internally.
Returns — torch.Tensor of shape (batch,) — unreduced per-example squared error.
L1 / MAE Loss
torch.nn.functional.l1_loss with reduction='none'. Used by FinetuneableZoobotRegressor when loss='mae'.
Returns — torch.Tensor of shape (batch,) — unreduced per-example absolute error.
LinearHead
torch.nn.Module used as the prediction head for all FinetuneableZoobot classes. Applies dropout, then a linear layer, then an optional activation.
Parameters
Dimensionality of the encoder output — the input feature size. Typically determined by
get_encoder_dim(encoder) or encoder.num_features.Number of outputs. Use the number of classes for classification,
1 for regression.Probability of zeroing each element in the dropout layer.
Optional activation applied after the linear layer. For example, pass
torch.nn.functional.sigmoid to constrain regression outputs to [0, 1] (used by FinetuneableZoobotRegressor with unit_interval=True). When None, the raw linear output (logits) is returned — recommended for classification to maintain gradient stability with cross_entropy_loss.Forward pass
The forward pass runs:output_dim == 1, the output tensor is squeezed to shape (batch,) for compatibility with scalar regression targets.
Example
Design Note: reduction='none'
All loss wrappers (cross_entropy_loss, mse_loss, l1_loss) return unreduced per-example losses. This is intentional: Zoobot aggregates losses using torchmetrics.MeanMetric, which correctly handles distributed training by accumulating numerators and denominators separately across devices before computing the final mean. Using reduction='mean' inside the loss function would make it incompatible with this aggregation pattern.