Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt

Use this file to discover all available pages before exploring further.

FinetuneableZoobotClassifier adapts a pretrained Zoobot encoder for classification tasks — finding rings, bars, mergers, or any discrete morphological category. It attaches a linear head to the encoder output and trains with cross-entropy loss. For binary classification (num_classes=2) it automatically uses binary accuracy metrics; for more than two classes it uses micro-averaged multi-class accuracy.
from zoobot.pytorch.training.finetune import FinetuneableZoobotClassifier

Quick Example

finetune_rings.py
import pandas as pd
from galaxy_datasets.pytorch.galaxy_datamodule import CatalogDataModule
from zoobot.pytorch.training import finetune

labelled_df = pd.read_csv('/path/to/labelled_galaxies.csv')  # needs 'ring', 'file_loc', 'id_str'

datamodule = CatalogDataModule(
    label_cols=['ring'],
    catalog=labelled_df,
    batch_size=32
)

model = FinetuneableZoobotClassifier(
    name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
    num_classes=2
)

trainer = finetune.get_trainer(save_dir='./results', max_epochs=50)
trainer.fit(model, datamodule)

Constructor Parameters

All parameters from FinetuneableZoobotAbstract are accepted. The following are specific to the classifier.
num_classes
int
required
Number of target classes. Use 2 for binary classification (e.g. ring/not-ring), or more for multi-class problems (e.g. 4 Hubble types).
label_col
str
default:"'label'"
Name of the column in the batch dictionary containing integer class labels. This should match a key in your catalog and label_cols list passed to CatalogDataModule.
label_smoothing
float
default:"0.0"
Smoothing factor for cross-entropy loss. Redistributes a fraction of probability mass from the target class to other classes. Can improve calibration on noisy labels.
class_weights
arraylike
Per-class weights for cross-entropy loss. Useful for highly imbalanced datasets. Example: [1.0, 5.0] weights the positive class 5x more heavily for a 5:1 negative:positive imbalance.
run_linear_sanity_check
bool
default:"False"
If True, fits an sklearn.linear_model.LogisticRegression on frozen encoder features before finetuning begins. Logs linear evaluation accuracy to WandB/CSV. Useful to verify the pretrained features are informative for your task.

Metrics Logged

MetricDescription
finetuning/train_lossCross-entropy loss on training set (per epoch)
finetuning/val_lossCross-entropy loss on validation set (per epoch)
finetuning/train_accAccuracy on training set
finetuning/val_accAccuracy on validation set
finetuning/test_accAccuracy on test set (when trainer.test() is called)

Prediction

predict_step applies softmax to the head logits and returns class probabilities of shape (batch_size, num_classes).
import torch
import torchvision.transforms.v2 as T
from zoobot.pytorch.predictions import predict_on_catalog

# Load the saved checkpoint
best_model = FinetuneableZoobotClassifier.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path
)

inference_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True)
])

predictions = predict_on_catalog.predict(
    catalog=unlabelled_df,
    model=best_model,
    label_cols=['ring'],
    inference_transform=inference_transform,
    save_loc='./ring_predictions.csv'
)
The output DataFrame has columns ['ring', 'id_str'] where ring contains the softmax probability for class index 0 (for binary problems, you likely want column index 1 for the positive class probability).
For multi-class problems, the output DataFrame will have one column per class, named after each entry in label_cols.

Reloading After Training

# Load best checkpoint by path
model = FinetuneableZoobotClassifier.load_from_checkpoint(
    'results/checkpoints/epoch=42.ckpt'
)

# Or download a finetuned model from HuggingFace Hub
model = FinetuneableZoobotClassifier.load_from_name(
    'mwalmsley/my-ring-classifier',
    num_classes=2
)
See the Finetuning Guide for a complete end-to-end walkthrough, and Choosing Parameters for guidance on learning rate, layer decay, and training mode.

Build docs developers (and LLMs) love