Run batch inference across an entire galaxy catalog in a single call.Documentation Index
Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt
Use this file to discover all available pages before exploring further.
predict wraps PyTorch Lightning’s Trainer.predict loop, handles the CatalogDataModule setup automatically, and returns a tidy pd.DataFrame with one row per galaxy and one column per requested label.
Import
predict
Parameters
Catalog of galaxies to make predictions on. Must include a
file_loc column (absolute path to each image file) and an id_str column (unique string identifier per galaxy). Both are passed through to the output DataFrame.Any trained Zoobot Lightning module. Common choices include
ZoobotTree, FinetuneableZoobotClassifier, FinetuneableZoobotRegressor, and FinetuneableZoobotTree. The model’s predict_step is called by the Trainer.Column names used to label the prediction output. For a classifier trained to predict
['ring'], pass ['ring']. For a full decision tree, pass schema.label_cols. These names do not affect which columns are loaded from catalog — they only name the columns in the returned DataFrame.Transform pipeline applied to each image before it is passed to the model. Must produce a tensor of the shape the model expects (typically
(C, H, W)). Passed as the test_transform argument of CatalogDataModule.If provided, the prediction DataFrame is written to this path as a CSV file. The function returns the same DataFrame regardless of whether
save_loc is set.Extra keyword arguments forwarded to
CatalogDataModule (from the galaxy-datasets package). Common uses include setting batch_size and num_workers.Extra keyword arguments forwarded to
L.Trainer. Use to configure the accelerator, number of devices, precision, etc.Returns
pd.DataFrame — one row per galaxy. Contains one column for each entry in label_cols (the model’s numeric predictions) plus an id_str column that echoes the identifiers from the input catalog.
Example
| ring | id_str |
|---|---|
| 0.82 | galaxy_00001 |
| 0.11 | galaxy_00002 |
get_trainer
L.Trainer pre-configured for Zoobot finetuning. Call trainer.fit(model, datamodule) on the returned object.
get_trainer can also be useful when running predictions: if you already have a trainer created via this function you can reuse it with trainer.predict(model, datamodule).
Default Callbacks
| Callback | Configuration |
|---|---|
ModelCheckpoint | Monitors finetuning/val_loss; saves to <save_dir>/checkpoints/; keeps only the top save_top_k checkpoints |
EarlyStopping | Monitors finetuning/val_loss in min mode; stops training if no improvement after patience epochs |
LearningRateMonitor | Logs the learning rate every epoch; useful when using a scheduler |
Parameters
Directory where checkpoints and logs are written. Checkpoints are placed in
<save_dir>/checkpoints/.Filename template for saved checkpoints. Accepts Lightning format strings. Defaults to
"{epoch}".Keep only the top-k best checkpoints (ranked by
finetuning/val_loss).Maximum number of training epochs. Training may stop earlier if
EarlyStopping triggers.Number of epochs with no improvement in
finetuning/val_loss before training is stopped.Number of devices to use (typically number of GPUs). Passed directly to
L.Trainer.Which device type to target — typically
'gpu' or 'cpu'. Passed directly to L.Trainer.Optional Lightning logger. Pass a
WandbLogger to track training on Weights & Biases.Any remaining keyword arguments are forwarded directly to
L.Trainer. See the Lightning Trainer docs for the full list of options.Returns
L.Trainer — a configured Lightning Trainer ready to call .fit(model, datamodule) or .predict(model, datamodule).