Documentation Index
Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt
Use this file to discover all available pages before exploring further.
Zoobot uses the galaxy-datasets library to handle data loading. This guide covers the main ways to get your galaxy images into a format Zoobot can train on.
Loading from a Catalog of Image Paths
The most common approach is CatalogDataModule, which loads images from a table of file paths and labels.
# galaxy-datasets is a companion package, not part of Zoobot itself
# see github.com/mwalmsley/galaxy-datasets
from galaxy_datasets.pytorch.galaxy_datamodule import CatalogDataModule
datamodule = CatalogDataModule(
train_catalog=train_catalog,
val_catalog=val_catalog,
test_catalog=test_catalog,
batch_size=batch_size,
label_cols=['is_cool_galaxy']
# ...many more options, see below for augmentations
)
Required Catalog Columns
Each catalog (train, val, test, or predict) must be a pandas DataFrame with these columns:
| Column | Description |
|---|
file_loc | Absolute path to the image file (jpg, png, or FITS) |
id_str | A unique string identifier for each galaxy |
| (your label columns) | Any columns you specify in label_cols |
You may pass any combination of train_catalog, val_catalog, test_catalog, and predict_catalog. For inference-only use cases, set label_cols=None to load without labels.
CatalogDataModule dataloaders yield batches as dictionaries:
{
'image': tensor of shape (batch_size, channels, height, width),
'id_str': tensor of shape (batch_size, 1),
'is_cool_galaxy': tensor of shape (batch_size, 1)
}
PyTorch Lightning’s Trainer automatically calls .train_dataloader(), .val_dataloader(), and so on during training. See the Lightning DataModule docs for more.
Loading from HuggingFace
There is a HuggingFace-native equivalent called HuggingFaceDataModule. Many Galaxy Zoo datasets are available on HuggingFace at huggingface.co/mwalmsley.
from galaxy_datasets.pytorch.galaxy_datamodule import HuggingFaceDataModule
from datasets import load_dataset
ds_dict = load_dataset("mwalmsley/gz2") # returns dict with 'train' and 'test' keys
datamodule = HuggingFaceDataModule(
dataset_dict=ds_dict, # must have 'train' and 'test' keys
batch_size=32,
iterable=False # set True for IterableDataset (faster streaming, no indexed access)
# many more options...
)
Standard Augmentations
Both CatalogDataModule and HuggingFaceDataModule accept train_transform and test_transform arguments. These are applied to each image before it is passed to the network.
galaxy_datasets.transforms provides a standard set of augmentations:
from galaxy_datasets.transforms import default_view_config, minimal_view_config, get_galaxy_transform
# A dictionary describing which augmentations to apply
train_transform_cfg = default_view_config()
# Convert to a T.Compose object
train_transform = get_galaxy_transform(train_transform_cfg)
# Simpler augmentations for validation/test
test_transform_cfg = minimal_view_config()
test_transform = get_galaxy_transform(test_transform_cfg)
# Test your transform on a single image before plugging it in
transformed = train_transform(im)
# Pass to the datamodule
datamodule = HuggingFaceDataModule(
dataset_dict=ds_dict,
batch_size=32,
train_transform=train_transform, # applied to training batches
test_transform=test_transform # applied to val and test batches
)
Loading FITS Images
Where possible, JPG images are recommended for scale and convenience. However, FITS files are fully supported for when you need the original flux data.
from galaxy_datasets.transforms import default_view_config, get_galaxy_transform
cfg = default_view_config()
cfg.flux_to_jpg_like_dynamic_range = {
'arcsinh_q': 1.0, 'percentile_min': 0, 'percentile_max': 99.7
}
cfg.pil_to_tensor = False # FITS files already load as tensors
transform = get_galaxy_transform(cfg) # ready for datamodule
# Test the transform first
transformed = transform(im)
# Then use it in the datamodule as normal
datamodule = CatalogDataModule(
train_catalog=train_catalog,
val_catalog=val_catalog,
test_catalog=test_catalog,
batch_size=batch_size,
label_cols=['is_cool_galaxy'],
train_transform=transform,
test_transform=transform,
)
Only single-channel (greyscale) FITS images are supported. When loading FITS, make sure to also set greyscale=True on your Zoobot model so the pretrained encoder accepts single-channel input:from zoobot.pytorch.training.finetune import FinetuneableZoobotClassifier
model = FinetuneableZoobotClassifier(
name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
greyscale=True, # converts pretrained model to accept single-channel images
num_classes=2
)
Custom Augmentations
You’re not limited to the built-in transforms. Any torchvision T.Compose object works. To be compatible with Zoobot’s pretrained models, your transforms should produce:
- PyTorch tensors of shape
(channels, height, width) (the datamodule adds the batch dimension).
- Float values normalized to
[0, 1] — although in practice, Zoobot can handle other ranges when using end-to-end finetuning.
- If presenting raw flux values (e.g. from FITS), apply a dynamic range rescaling (such as
np.arcsinh) before normalizing to [0, 1].
- Galaxies should appear large and centered in the image.
Bringing Your Own DataModule
Using galaxy-datasets is entirely optional. Zoobot is designed to work with any PyTorch Lightning LightningDataModule that returns batches of the form:
{'image': images, 'some_label': labels}
Advanced users can also pass data directly to Zoobot’s encoder however they like — see Advanced Finetuning for details.