Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt

Use this file to discover all available pages before exploring further.

This guide walks you through finetuning a pretrained Zoobot encoder to detect ringed galaxies — a classic binary classification task. The same pattern applies to any morphological classification or regression problem. By the end you will have a trained model and predictions saved to disk.
The fastest way to get started is the interactive Google Colab notebook, which provides a free GPU and requires no local setup: Open in Colab →
Training on a GPU is strongly recommended — finetuning on CPU is possible but significantly slower. If you don’t have a local GPU, the Google Colab notebook gives you free GPU access with no configuration required.

Step-by-Step Walkthrough

1

Install Zoobot

Install Zoobot and its PyTorch dependencies with a single pip command:
pip install zoobot[pytorch]
This installs Zoobot along with PyTorch (≥ 2.7.0), torchvision, Lightning (≥ 2.2.5), timm (≥ 1.0.15), and all other required packages.For Google Colab (where PyTorch is pre-installed), use the lighter variant instead:
pip install zoobot[pytorch_colab]
See the Installation guide for GPU / CUDA setup and source installation.
2

Prepare Your Catalog

Zoobot reads galaxy images from a pandas DataFrame (or a CSV file loaded into one). Your DataFrame must contain at least these columns:
ColumnTypeDescription
id_strstrUnique string identifier for each galaxy
file_locstrAbsolute path to the image file (.jpg, .png, or .fits)
ring (or any label name)int / floatYour label — e.g. 0 = not a ring, 1 = ring
Example CSV structure:
id_str,file_loc,ring
galaxy_00001,/data/images/galaxy_00001.jpg,1
galaxy_00002,/data/images/galaxy_00002.jpg,0
galaxy_00003,/data/images/galaxy_00003.fits,1
For regression tasks, the label column should contain continuous float values. For vote-count tasks (FinetuneableZoobotTree), you will need one column per answer in your decision tree schema.
3

Load a Pretrained Model

Load a pretrained Zoobot encoder directly from HuggingFace Hub. The encoder weights are downloaded automatically and cached locally:
from zoobot.pytorch.training import finetune

model = finetune.FinetuneableZoobotClassifier(
    name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
    num_classes=2
)
  • name — the HuggingFace Hub identifier for the pretrained encoder. See Pretrained Models for all available options.
  • num_classes=2 — binary classification (ring / not ring). Set to the number of classes in your problem.
For a regression task, use FinetuneableZoobotRegressor instead:
model = finetune.FinetuneableZoobotRegressor(
    name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
    label_col='sersic_index'
)
4

Create a Data Module

Zoobot uses CatalogDataModule from the companion galaxy-datasets package to handle image loading, augmentation, and batching:
import pandas as pd
from galaxy_datasets.pytorch.galaxy_datamodule import CatalogDataModule

labelled_df = pd.read_csv('/your/path/labelled_galaxies.csv')

datamodule = CatalogDataModule(
    label_cols=['ring'],
    catalog=labelled_df,
    batch_size=32
    # Default augmentations (random flips, crops, colour jitter) are applied automatically
)
  • label_cols — list of column names containing your labels. Must match the label column(s) expected by your model.
  • catalog — the full DataFrame; CatalogDataModule handles the train/validation split automatically.
  • batch_size — number of images per batch. Reduce if you run out of GPU memory.
5

Finetune the Model

Create a PyTorch Lightning trainer and start finetuning. The trainer saves checkpoints and stops early if validation loss stops improving:
trainer = finetune.get_trainer(save_dir='./results')
trainer.fit(model, datamodule)
get_trainer configures sensible defaults out of the box:
  • Early stopping — stops after 10 epochs with no improvement in validation loss (configurable via patience)
  • Model checkpointing — saves the best checkpoint to ./results/checkpoints/
  • Learning-rate monitoring — logs the LR per epoch
  • Auto device selection — automatically uses GPU if available
You can pass any additional Lightning Trainer arguments as keyword arguments to get_trainer:
trainer = finetune.get_trainer(
    save_dir='./results',
    max_epochs=50,
    patience=5
)
6

Make Predictions on New Galaxies

After training, run the finetuned model on an unlabelled catalog to generate predictions:
import torch
import pandas as pd
import torchvision.transforms.v2 as T
from zoobot.pytorch.predictions import predict_on_catalog

unlabelled_df = pd.read_csv('/your/path/unlabelled_galaxies.csv')

inference_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True)
])

predictions = predict_on_catalog.predict(
    unlabelled_df,
    model,
    label_cols=['ring'],
    inference_transform=inference_transform,
    save_loc='./predictions.csv'
)
  • label_cols — used to name the output columns in the saved CSV.
  • inference_transform — deterministic (no augmentation) transform pipeline applied to each image before passing it to the model.
  • save_loc — path where the predictions CSV will be written. Each row corresponds to a galaxy in unlabelled_df, with softmax probabilities for each class.

Complete Script

Here is the full end-to-end finetuning script for reference:
import torch
import pandas as pd
import torchvision.transforms.v2 as T
from galaxy_datasets.pytorch.galaxy_datamodule import CatalogDataModule
from zoobot.pytorch.training import finetune
from zoobot.pytorch.predictions import predict_on_catalog

# --- 1. Load labelled data ---
labelled_df = pd.read_csv('/your/path/labelled_galaxies.csv')

# --- 2. Build data module ---
datamodule = CatalogDataModule(
    label_cols=['ring'],
    catalog=labelled_df,
    batch_size=32
)

# --- 3. Load pretrained model ---
model = finetune.FinetuneableZoobotClassifier(
    name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
    num_classes=2
)

# --- 4. Finetune ---
trainer = finetune.get_trainer(save_dir='./results')
trainer.fit(model, datamodule)

# --- 5. Predict on unlabelled data ---
unlabelled_df = pd.read_csv('/your/path/unlabelled_galaxies.csv')

inference_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True)
])

predict_on_catalog.predict(
    unlabelled_df,
    model,
    label_cols=['ring'],
    inference_transform=inference_transform,
    save_loc='./predictions.csv'
)

Next Steps

Finetuning Guide

Learn about training modes, learning-rate schedules, class weights, and advanced finetuning strategies.

Choosing Parameters

Guidance on selecting the right encoder architecture, batch size, and learning rate for your dataset size.

Pretrained Models

Full list of available encoder architectures and their HuggingFace Hub names.

Loading Data

How to structure your catalog, handle FITS files, and use custom data augmentations.

Build docs developers (and LLMs) love