Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/terrafloww/rasteret/llms.txt

Use this file to discover all available pages before exploring further.

Method Signature

Collection.to_torchgeo_dataset(
    *,
    bands,
    chip_size=None,
    is_image=True,
    allow_resample=False,
    split=None,
    split_column="split",
    label_field=None,
    geometries=None,
    geometries_crs=4326,
    transforms=None,
    max_concurrent=50,
    cloud_config=None,
    backend=None,
    time_series=False,
    target_crs=None
)
Create a TorchGeo GeoDataset backed by this Collection for PyTorch training workflows.
This integration requires torchgeo and its dependencies.

Parameters

bands
list of str
required
Band codes to load (e.g. ["B04", "B03", "B02"]).
chip_size
int
Spatial extent of each chip in pixels.
is_image
bool
default:"True"
If True, return chips as sample["image"]. If False, return chips as sample["mask"] (single-band data will have its channel dimension squeezed to match TorchGeo RasterDataset behavior).
allow_resample
bool
default:"False"
If True, Rasteret will resample bands to the dataset grid when requested bands have different resolutions. This is opt-in because it may change pixel values and can be slow.
split
str or sequence of str
Filter to the given split(s) before creating the dataset (e.g. "train", ["train", "val"]).
split_column
str
default:"split"
Column holding split labels.
label_field
str
Column name to include as sample["label"].
geometries
bbox tuple, pa.Array, Shapely, WKB bytes, or GeoJSON dict
Spatial extent for the dataset. Accepts (minx, miny, maxx, maxy) bbox tuples, Arrow arrays (e.g. from GeoParquet), Shapely objects, raw WKB bytes, or GeoJSON dicts.
geometries_crs
int
default:"4326"
EPSG code for the geometries parameter.
transforms
callable
TorchGeo-compatible transforms applied to each sample.
max_concurrent
int
default:"50"
Maximum concurrent HTTP requests.
cloud_config
CloudConfig
Cloud configuration for URL rewriting.
backend
StorageBackend
Pluggable I/O backend (e.g. ObstoreBackend). See create_backend().
time_series
bool
default:"False"
When True, stack all timesteps as [T, C, H, W].
target_crs
int
Reproject all records to this EPSG code at read time.

Returns

dataset
RasteretGeoDataset
A standard TorchGeo GeoDataset. Pixel data is in the native COG dtype (e.g. uint16 for Sentinel-2).

Examples

Basic Training Dataset

import rasteret
from torch.utils.data import DataLoader

collection = rasteret.load("my_collection")

# Create dataset for training
dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
    split="train"
)

loader = DataLoader(dataset, batch_size=16, shuffle=True)

for batch in loader:
    images = batch["image"]  # [16, 3, 256, 256]
    # Training loop here

With Data Augmentation

from torchgeo.transforms import AugmentationSequential
from kornia.augmentation import RandomHorizontalFlip, RandomVerticalFlip

transforms = AugmentationSequential(
    RandomHorizontalFlip(p=0.5),
    RandomVerticalFlip(p=0.5),
    data_keys=["image"]
)

dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
    split="train",
    transforms=transforms
)

Semantic Segmentation

# Image dataset
image_ds = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=512,
    split="train"
)

# Label dataset
label_ds = labels_collection.to_torchgeo_dataset(
    bands=["mask"],
    chip_size=512,
    is_image=False,  # Load as mask
    split="train"
)

# Combine datasets
from torchgeo.datasets import IntersectionDataset
dataset = image_ds & label_ds

With Label Field

# Collection has a "crop_type" column
dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
    split="train",
    label_field="crop_type"
)

for sample in dataset:
    image = sample["image"]
    label = sample["label"]  # From crop_type column

Time Series Mode

# Stack temporal dimension
dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02"],
    chip_size=256,
    time_series=True
)

for sample in dataset:
    images = sample["image"]  # [T, 3, 256, 256]

Multi-Resolution Bands

# Allow resampling for mixed-resolution bands
dataset = collection.to_torchgeo_dataset(
    bands=["B04", "B08", "B11"],  # 10m, 10m, 20m
    chip_size=256,
    allow_resample=True,  # Resample B11 to 10m
    split="train"
)

Sample Structure

Each sample is a dictionary with:
{
    "image": torch.Tensor,      # [C, H, W] or [T, C, H, W] if time_series=True
    "crs": rasterio.CRS,        # Coordinate reference system
    "bbox": torchgeo.BoundingBox # Spatial extent
}
If label_field is specified:
{
    "image": torch.Tensor,
    "label": Any,               # Value from label_field column
    "crs": rasterio.CRS,
    "bbox": torchgeo.BoundingBox
}

Notes

  • Data is in native COG dtype (typically uint16 for Sentinel-2)
  • Compatible with all TorchGeo samplers and datasets
  • Use is_image=False for mask/label datasets to match TorchGeo conventions
  • Set allow_resample=True when working with multi-resolution sensors
  • Supports train/val/test splits via the split parameter

Build docs developers (and LLMs) love