Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/terrafloww/rasteret/llms.txt

Use this file to discover all available pages before exploring further.

This example demonstrates a complete production ML pipeline: building a Sentinel-2 collection from STAC, assigning train/val/test splits, persisting the split-annotated collection, and running a training loop with TorchGeo.

Overview

The workflow covers:
  1. Building a collection from STAC (cached after first run)
  2. Assigning train/val/test splits using PyArrow
  3. Saving the split-annotated collection as a shareable Parquet artifact
  4. Creating TorchGeo datasets for each split
  5. Running a standard training loop

Prerequisites

pip install rasteret[torchgeo]

Complete Example

from pathlib import Path
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import rasteret

# Configuration
workspace = Path.home() / "rasteret_workspace"
bbox = (77.55, 13.01, 77.58, 13.08)  # Bangalore, India

# Step 1: Build collection from STAC (cached after first run)
collection = rasteret.build_from_stac(
    name="bangalore",
    stac_api="https://earth-search.aws.element84.com/v1",
    collection="sentinel-2-l2a",
    bbox=bbox,
    date_range=("2024-01-01", "2024-06-30"),
    workspace_dir=workspace,
)
print(f"Collection: {collection.name}, rows={collection.dataset.count_rows()}")
Output:
Collection: bangalore, rows=142

Assign Train/Val/Test Splits

Add a split column to the collection for reproducible data partitioning:
def assign_splits(
    collection: rasteret.Collection,
    output_path: Path,
    train_ratio: float = 0.7,
    val_ratio: float = 0.15,
    seed: int = 42,
) -> rasteret.Collection:
    """Add a 'split' column to a collection and save it.
    
    Uses deterministic random assignment so the same collection always
    gets the same splits (reproducible across runs and machines).
    """
    table = collection.dataset.to_table()
    n = len(table)

    # Deterministic random assignment
    rng = np.random.default_rng(seed)
    assignments = rng.random(n)
    splits = np.where(
        assignments < train_ratio,
        "train",
        np.where(assignments < train_ratio + val_ratio, "val", "test"),
    )

    table = table.append_column("split", pa.array(splits))

    # Persist to a new Parquet dataset
    output_path.mkdir(parents=True, exist_ok=True)
    partition_cols = [c for c in ("year", "month") if c in table.schema.names]
    ds.write_dataset(
        table,
        output_path,
        format="parquet",
        partitioning=partition_cols or None,
        existing_data_behavior="overwrite_or_ignore",
    )

    return rasteret.load(output_path, name=collection.name)

# Step 2: Assign and save splits
split_path = workspace / "bangalore_with_splits"
if (
    split_path.exists()
    and "split" in rasteret.load(split_path).dataset.schema.names
):
    print("Loading existing split-annotated collection...")
    collection = rasteret.load(split_path, name="bangalore")
else:
    print("Assigning train/val/test splits...")
    collection = assign_splits(
        collection,
        output_path=split_path,
        train_ratio=0.7,
        val_ratio=0.15,
    )
Output:
Assigning train/val/test splits...

Verify Split Distribution

# Show split distribution
table = collection.dataset.to_table(columns=["split"])
for split_name in ["train", "val", "test"]:
    count = pc.sum(pc.equal(table.column("split"), split_name)).as_py()
    print(f"  {split_name}: {count} rows")
Output:
  train: 99 rows
  val: 21 rows
  test: 22 rows

Create TorchGeo DataLoaders

Use the split column to create separate datasets for training and validation:
from torch.utils.data import DataLoader
from torchgeo.datasets.utils import stack_samples
from torchgeo.samplers import RandomGeoSampler

# Step 3: Create TorchGeo dataset for training split
train_ds = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02", "B08"],  # RGB + NIR
    geometries=bbox,
    split="train",
    chip_size=256,
)

# Create random sampler and dataloader
sampler = RandomGeoSampler(
    train_ds,
    size=256,
    length=32,  # Number of samples per epoch
)

loader = DataLoader(
    train_ds,
    sampler=sampler,
    batch_size=4,
    num_workers=0,
    collate_fn=stack_samples,
)

Training Loop

# Step 4: Training loop (placeholder, replace with your model)
for i, batch in enumerate(loader):
    img = batch["image"]
    print(f"  batch {i}: image={img.shape}, dtype={img.dtype}")
    # Your model training code here:
    # output = model(img)
    # loss = criterion(output, target)
    # loss.backward()
    # optimizer.step()
    if i >= 2:
        break
Output:
  batch 0: image=torch.Size([4, 4, 256, 256]), dtype=torch.float32
  batch 1: image=torch.Size([4, 4, 256, 256]), dtype=torch.float32
  batch 2: image=torch.Size([4, 4, 256, 256]), dtype=torch.float32

Validation Set

Access the validation split for model evaluation:
# Step 5: Create validation dataset
val_collection = collection.subset(split="val")
print(f"Val split: {val_collection.dataset.count_rows()} rows")

val_ds = collection.to_torchgeo_dataset(
    bands=["B04", "B03", "B02", "B08"],
    geometries=bbox,
    split="val",
    chip_size=256,
)

# Use for validation during training
# val_sampler = RandomGeoSampler(val_ds, size=256, length=16)
# val_loader = DataLoader(val_ds, sampler=val_sampler, batch_size=4)
Output:
Val split: 21 rows
Shareable artifact saved at: /home/user/rasteret_workspace/bangalore_with_splits

Share Split Annotations

The split-annotated collection is saved as a standard Parquet dataset and can be shared:
# Save location
print(f"Shareable artifact: {split_path}")

# Colleagues can load it directly
shared_collection = rasteret.load(split_path, name="bangalore")
assert "split" in shared_collection.dataset.schema.names

CLI Alternative

You can also build the base collection via CLI:
# Build collection
rasteret build earthsearch/sentinel-2-l2a bangalore \
  --bbox 77.55,13.01,77.58,13.08 \
  --date-range 2024-01-01,2024-06-30

# Then load and assign splits in Python

Key Features

  • Cached builds: STAC queries are cached; rebuilds are instant
  • Reproducible splits: Deterministic seed ensures same splits across runs
  • Shareable artifacts: Split-annotated collections are standard Parquet
  • TorchGeo integration: Direct integration with TorchGeo samplers and dataloaders
  • Lazy loading: Only fetches imagery when sampling (no bulk downloads)

Next Steps

Complete Script

Full example: ml_training_with_splits.py

Build docs developers (and LLMs) love