Documentation Index
Fetch the complete documentation index at: https://mintlify.com/terrafloww/rasteret/llms.txt
Use this file to discover all available pages before exploring further.
This example demonstrates a complete production ML pipeline: building a Sentinel-2 collection from STAC, assigning train/val/test splits, persisting the split-annotated collection, and running a training loop with TorchGeo.
Overview
The workflow covers:
- Building a collection from STAC (cached after first run)
- Assigning train/val/test splits using PyArrow
- Saving the split-annotated collection as a shareable Parquet artifact
- Creating TorchGeo datasets for each split
- Running a standard training loop
Prerequisites
pip install rasteret[torchgeo]
Complete Example
from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import rasteret
# Configuration
workspace = Path.home() / "rasteret_workspace"
bbox = (77.55, 13.01, 77.58, 13.08) # Bangalore, India
# Step 1: Build collection from STAC (cached after first run)
collection = rasteret.build_from_stac(
name="bangalore",
stac_api="https://earth-search.aws.element84.com/v1",
collection="sentinel-2-l2a",
bbox=bbox,
date_range=("2024-01-01", "2024-06-30"),
workspace_dir=workspace,
)
print(f"Collection: {collection.name}, rows={collection.dataset.count_rows()}")
Output:
Collection: bangalore, rows=142
Assign Train/Val/Test Splits
Add a split column to the collection for reproducible data partitioning:
def assign_splits(
collection: rasteret.Collection,
output_path: Path,
train_ratio: float = 0.7,
val_ratio: float = 0.15,
seed: int = 42,
) -> rasteret.Collection:
"""Add a 'split' column to a collection and save it.
Uses deterministic random assignment so the same collection always
gets the same splits (reproducible across runs and machines).
"""
table = collection.dataset.to_table()
n = len(table)
# Deterministic random assignment
rng = np.random.default_rng(seed)
assignments = rng.random(n)
splits = np.where(
assignments < train_ratio,
"train",
np.where(assignments < train_ratio + val_ratio, "val", "test"),
)
table = table.append_column("split", pa.array(splits))
# Persist to a new Parquet dataset
output_path.mkdir(parents=True, exist_ok=True)
partition_cols = [c for c in ("year", "month") if c in table.schema.names]
ds.write_dataset(
table,
output_path,
format="parquet",
partitioning=partition_cols or None,
existing_data_behavior="overwrite_or_ignore",
)
return rasteret.load(output_path, name=collection.name)
# Step 2: Assign and save splits
split_path = workspace / "bangalore_with_splits"
if (
split_path.exists()
and "split" in rasteret.load(split_path).dataset.schema.names
):
print("Loading existing split-annotated collection...")
collection = rasteret.load(split_path, name="bangalore")
else:
print("Assigning train/val/test splits...")
collection = assign_splits(
collection,
output_path=split_path,
train_ratio=0.7,
val_ratio=0.15,
)
Output:
Assigning train/val/test splits...
Verify Split Distribution
# Show split distribution
table = collection.dataset.to_table(columns=["split"])
for split_name in ["train", "val", "test"]:
count = pc.sum(pc.equal(table.column("split"), split_name)).as_py()
print(f" {split_name}: {count} rows")
Output:
train: 99 rows
val: 21 rows
test: 22 rows
Create TorchGeo DataLoaders
Use the split column to create separate datasets for training and validation:
from torch.utils.data import DataLoader
from torchgeo.datasets.utils import stack_samples
from torchgeo.samplers import RandomGeoSampler
# Step 3: Create TorchGeo dataset for training split
train_ds = collection.to_torchgeo_dataset(
bands=["B04", "B03", "B02", "B08"], # RGB + NIR
geometries=bbox,
split="train",
chip_size=256,
)
# Create random sampler and dataloader
sampler = RandomGeoSampler(
train_ds,
size=256,
length=32, # Number of samples per epoch
)
loader = DataLoader(
train_ds,
sampler=sampler,
batch_size=4,
num_workers=0,
collate_fn=stack_samples,
)
Training Loop
# Step 4: Training loop (placeholder, replace with your model)
for i, batch in enumerate(loader):
img = batch["image"]
print(f" batch {i}: image={img.shape}, dtype={img.dtype}")
# Your model training code here:
# output = model(img)
# loss = criterion(output, target)
# loss.backward()
# optimizer.step()
if i >= 2:
break
Output:
batch 0: image=torch.Size([4, 4, 256, 256]), dtype=torch.float32
batch 1: image=torch.Size([4, 4, 256, 256]), dtype=torch.float32
batch 2: image=torch.Size([4, 4, 256, 256]), dtype=torch.float32
Validation Set
Access the validation split for model evaluation:
# Step 5: Create validation dataset
val_collection = collection.subset(split="val")
print(f"Val split: {val_collection.dataset.count_rows()} rows")
val_ds = collection.to_torchgeo_dataset(
bands=["B04", "B03", "B02", "B08"],
geometries=bbox,
split="val",
chip_size=256,
)
# Use for validation during training
# val_sampler = RandomGeoSampler(val_ds, size=256, length=16)
# val_loader = DataLoader(val_ds, sampler=val_sampler, batch_size=4)
Output:
Val split: 21 rows
Shareable artifact saved at: /home/user/rasteret_workspace/bangalore_with_splits
Share Split Annotations
The split-annotated collection is saved as a standard Parquet dataset and can be shared:
# Save location
print(f"Shareable artifact: {split_path}")
# Colleagues can load it directly
shared_collection = rasteret.load(split_path, name="bangalore")
assert "split" in shared_collection.dataset.schema.names
CLI Alternative
You can also build the base collection via CLI:
# Build collection
rasteret build earthsearch/sentinel-2-l2a bangalore \
--bbox 77.55,13.01,77.58,13.08 \
--date-range 2024-01-01,2024-06-30
# Then load and assign splits in Python
Key Features
- Cached builds: STAC queries are cached; rebuilds are instant
- Reproducible splits: Deterministic seed ensures same splits across runs
- Shareable artifacts: Split-annotated collections are standard Parquet
- TorchGeo integration: Direct integration with TorchGeo samplers and dataloaders
- Lazy loading: Only fetches imagery when sampling (no bulk downloads)
Next Steps
Complete Script
Full example: ml_training_with_splits.py