Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/mwalmsley/zoobot/llms.txt

Use this file to discover all available pages before exploring further.

Zoobot uses the galaxy-datasets library to handle data loading. This guide covers the main ways to get your galaxy images into a format Zoobot can train on.

Loading from a Catalog of Image Paths

The most common approach is CatalogDataModule, which loads images from a table of file paths and labels.
# galaxy-datasets is a companion package, not part of Zoobot itself
# see github.com/mwalmsley/galaxy-datasets
from galaxy_datasets.pytorch.galaxy_datamodule import CatalogDataModule

datamodule = CatalogDataModule(
    train_catalog=train_catalog,
    val_catalog=val_catalog,
    test_catalog=test_catalog,
    batch_size=batch_size,
    label_cols=['is_cool_galaxy']
    # ...many more options, see below for augmentations
)

Required Catalog Columns

Each catalog (train, val, test, or predict) must be a pandas DataFrame with these columns:
ColumnDescription
file_locAbsolute path to the image file (jpg, png, or FITS)
id_strA unique string identifier for each galaxy
(your label columns)Any columns you specify in label_cols
You may pass any combination of train_catalog, val_catalog, test_catalog, and predict_catalog. For inference-only use cases, set label_cols=None to load without labels.

Batch Format

CatalogDataModule dataloaders yield batches as dictionaries:
{
    'image': tensor of shape (batch_size, channels, height, width),
    'id_str': tensor of shape (batch_size, 1),
    'is_cool_galaxy': tensor of shape (batch_size, 1)
}
PyTorch Lightning’s Trainer automatically calls .train_dataloader(), .val_dataloader(), and so on during training. See the Lightning DataModule docs for more.

Loading from HuggingFace

There is a HuggingFace-native equivalent called HuggingFaceDataModule. Many Galaxy Zoo datasets are available on HuggingFace at huggingface.co/mwalmsley.
from galaxy_datasets.pytorch.galaxy_datamodule import HuggingFaceDataModule
from datasets import load_dataset

ds_dict = load_dataset("mwalmsley/gz2")  # returns dict with 'train' and 'test' keys

datamodule = HuggingFaceDataModule(
    dataset_dict=ds_dict,  # must have 'train' and 'test' keys
    batch_size=32,
    iterable=False  # set True for IterableDataset (faster streaming, no indexed access)
    # many more options...
)

Standard Augmentations

Both CatalogDataModule and HuggingFaceDataModule accept train_transform and test_transform arguments. These are applied to each image before it is passed to the network. galaxy_datasets.transforms provides a standard set of augmentations:
from galaxy_datasets.transforms import default_view_config, minimal_view_config, get_galaxy_transform

# A dictionary describing which augmentations to apply
train_transform_cfg = default_view_config()
# Convert to a T.Compose object
train_transform = get_galaxy_transform(train_transform_cfg)

# Simpler augmentations for validation/test
test_transform_cfg = minimal_view_config()
test_transform = get_galaxy_transform(test_transform_cfg)

# Test your transform on a single image before plugging it in
transformed = train_transform(im)

# Pass to the datamodule
datamodule = HuggingFaceDataModule(
    dataset_dict=ds_dict,
    batch_size=32,
    train_transform=train_transform,  # applied to training batches
    test_transform=test_transform     # applied to val and test batches
)

Loading FITS Images

Where possible, JPG images are recommended for scale and convenience. However, FITS files are fully supported for when you need the original flux data.
from galaxy_datasets.transforms import default_view_config, get_galaxy_transform

cfg = default_view_config()
cfg.flux_to_jpg_like_dynamic_range = {
    'arcsinh_q': 1.0, 'percentile_min': 0, 'percentile_max': 99.7
}
cfg.pil_to_tensor = False  # FITS files already load as tensors

transform = get_galaxy_transform(cfg)  # ready for datamodule

# Test the transform first
transformed = transform(im)

# Then use it in the datamodule as normal
datamodule = CatalogDataModule(
    train_catalog=train_catalog,
    val_catalog=val_catalog,
    test_catalog=test_catalog,
    batch_size=batch_size,
    label_cols=['is_cool_galaxy'],
    train_transform=transform,
    test_transform=transform,
)
Only single-channel (greyscale) FITS images are supported. When loading FITS, make sure to also set greyscale=True on your Zoobot model so the pretrained encoder accepts single-channel input:
from zoobot.pytorch.training.finetune import FinetuneableZoobotClassifier

model = FinetuneableZoobotClassifier(
    name='hf_hub:mwalmsley/zoobot-encoder-convnext_nano',
    greyscale=True,  # converts pretrained model to accept single-channel images
    num_classes=2
)

Custom Augmentations

You’re not limited to the built-in transforms. Any torchvision T.Compose object works. To be compatible with Zoobot’s pretrained models, your transforms should produce:
  • PyTorch tensors of shape (channels, height, width) (the datamodule adds the batch dimension).
  • Float values normalized to [0, 1] — although in practice, Zoobot can handle other ranges when using end-to-end finetuning.
  • If presenting raw flux values (e.g. from FITS), apply a dynamic range rescaling (such as np.arcsinh) before normalizing to [0, 1].
  • Galaxies should appear large and centered in the image.

Bringing Your Own DataModule

Using galaxy-datasets is entirely optional. Zoobot is designed to work with any PyTorch Lightning LightningDataModule that returns batches of the form:
{'image': images, 'some_label': labels}
Advanced users can also pass data directly to Zoobot’s encoder however they like — see Advanced Finetuning for details.

Build docs developers (and LLMs) love