Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt

Use this file to discover all available pages before exploring further.

Run batch inference across an entire galaxy catalog in a single call. predict wraps PyTorch Lightning’s Trainer.predict loop, handles the CatalogDataModule setup automatically, and returns a tidy pd.DataFrame with one row per galaxy and one column per requested label.

Import

from zoobot.pytorch.predictions.predict_on_catalog import predict

predict

predict(
    catalog,
    model,
    label_cols,
    inference_transform,
    save_loc=None,
    datamodule_kwargs={},
    trainer_kwargs={}
) -> pd.DataFrame
Use a trained model to make predictions on a catalog of galaxies.

Parameters

catalog
pd.DataFrame
required
Catalog of galaxies to make predictions on. Must include a file_loc column (absolute path to each image file) and an id_str column (unique string identifier per galaxy). Both are passed through to the output DataFrame.
model
L.LightningModule
required
Any trained Zoobot Lightning module. Common choices include ZoobotTree, FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, and FinetuneableZoobotTree. The model’s predict_step is called by the Trainer.
label_cols
List[str]
required
Column names used to label the prediction output. For a classifier trained to predict ['ring'], pass ['ring']. For a full decision tree, pass schema.label_cols. These names do not affect which columns are loaded from catalog — they only name the columns in the returned DataFrame.
inference_transform
torchvision.transforms.v2.Compose
required
Transform pipeline applied to each image before it is passed to the model. Must produce a tensor of the shape the model expects (typically (C, H, W)). Passed as the test_transform argument of CatalogDataModule.
save_loc
str
default:"None"
If provided, the prediction DataFrame is written to this path as a CSV file. The function returns the same DataFrame regardless of whether save_loc is set.
datamodule_kwargs
dict
default:"{}"
Extra keyword arguments forwarded to CatalogDataModule (from the galaxy-datasets package). Common uses include setting batch_size and num_workers.
trainer_kwargs
dict
default:"{}"
Extra keyword arguments forwarded to L.Trainer. Use to configure the accelerator, number of devices, precision, etc.

Returns

pd.DataFrame — one row per galaxy. Contains one column for each entry in label_cols (the model’s numeric predictions) plus an id_str column that echoes the identifiers from the input catalog.

Example

import pandas as pd
import torch
import torchvision.transforms.v2 as T
from zoobot.pytorch.predictions import predict_on_catalog
from zoobot.pytorch.training.finetune import FinetuneableZoobotClassifier

# Load your finetuned model
model = FinetuneableZoobotClassifier.load_from_checkpoint('results/checkpoints/epoch=10.ckpt')

# Catalog with file_loc and id_str columns
unlabelled_df = pd.read_csv('/path/to/unlabelled_galaxies.csv')

inference_transform = T.Compose([
    T.Resize((224, 224)),
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True)
])

predictions = predict_on_catalog.predict(
    catalog=unlabelled_df,
    model=model,
    label_cols=['ring'],
    inference_transform=inference_transform,
    save_loc='./predictions.csv',
    trainer_kwargs={'accelerator': 'gpu', 'devices': 1}
)
print(predictions.head())
The CSV (and the returned DataFrame) will look like:
ringid_str
0.82galaxy_00001
0.11galaxy_00002

get_trainer

from zoobot.pytorch.training.finetune import get_trainer

trainer = get_trainer(
    save_dir,
    file_template="{epoch}",
    save_top_k=1,
    max_epochs=100,
    patience=10,
    devices="auto",
    accelerator="auto",
    logger=None,
    **trainer_kwargs
) -> L.Trainer
Convenience factory that creates a L.Trainer pre-configured for Zoobot finetuning. Call trainer.fit(model, datamodule) on the returned object. get_trainer can also be useful when running predictions: if you already have a trainer created via this function you can reuse it with trainer.predict(model, datamodule).

Default Callbacks

CallbackConfiguration
ModelCheckpointMonitors finetuning/val_loss; saves to <save_dir>/checkpoints/; keeps only the top save_top_k checkpoints
EarlyStoppingMonitors finetuning/val_loss in min mode; stops training if no improvement after patience epochs
LearningRateMonitorLogs the learning rate every epoch; useful when using a scheduler

Parameters

save_dir
str
required
Directory where checkpoints and logs are written. Checkpoints are placed in <save_dir>/checkpoints/.
file_template
str
default:"\"{epoch}\""
Filename template for saved checkpoints. Accepts Lightning format strings. Defaults to "{epoch}".
save_top_k
int
default:"1"
Keep only the top-k best checkpoints (ranked by finetuning/val_loss).
max_epochs
int
default:"100"
Maximum number of training epochs. Training may stop earlier if EarlyStopping triggers.
patience
int
default:"10"
Number of epochs with no improvement in finetuning/val_loss before training is stopped.
devices
str | int
default:"\"auto\""
Number of devices to use (typically number of GPUs). Passed directly to L.Trainer.
accelerator
str
default:"\"auto\""
Which device type to target — typically 'gpu' or 'cpu'. Passed directly to L.Trainer.
logger
L.pytorch.loggers.Logger | None
default:"None"
Optional Lightning logger. Pass a WandbLogger to track training on Weights & Biases.
**trainer_kwargs
dict
Any remaining keyword arguments are forwarded directly to L.Trainer. See the Lightning Trainer docs for the full list of options.

Returns

L.Trainer — a configured Lightning Trainer ready to call .fit(model, datamodule) or .predict(model, datamodule).

Example

from zoobot.pytorch.training.finetune import get_trainer

trainer = get_trainer(
    save_dir='results/my_finetune_run',
    max_epochs=50,
    patience=5,
    accelerator='gpu',
    devices=1
)

trainer.fit(model, datamodule)

Build docs developers (and LLMs) love