Skip to main content
The dataset module provides classes for loading, sharding, and mixing multi-modal robotics datasets.

ShardedDataset

from gr00t.data.interfaces import ShardedDataset
Abstract base class for sharded datasets.

Constructor

dataset_path
str
required
Path to the dataset directory.

Methods

__len__

def __len__(self) -> int
Return the number of shards in the dataset.

get_shard_length

def get_shard_length(self, idx: int) -> int
Get the number of samples in a specific shard.
idx
int
required
Shard index.

get_shard

def get_shard(self, idx: int) -> list
Load and return all samples from a specific shard.
idx
int
required
Shard index to load.

set_processor

def set_processor(self, processor: BaseProcessor) -> None
Set the data processor for this dataset.
processor
BaseProcessor
required
Processor instance to use for data processing.

get_dataset_statistics

def get_dataset_statistics(self) -> dict[str, Any]
Get dataset statistics for normalization.
return
dict[str, Any]
Dictionary containing statistics for each modality and joint group.

ShardedSingleStepDataset

from gr00t.data.dataset import ShardedSingleStepDataset
Single-step dataset that creates shards from individual timesteps across episodes. This dataset provides step-level data access for VLA training by:
  1. Loading episodes using LeRobotEpisodeLoader
  2. Splitting episodes into individual timesteps
  3. Organizing timesteps into balanced shards for efficient loading
  4. Supporting episode subsampling for data efficiency
The sharding strategy ensures balanced shard sizes while maintaining randomization across episodes and timesteps within episodes.

Constructor

def __init__(
    self,
    dataset_path: str | Path,
    embodiment_tag: EmbodimentTag,
    modality_configs: dict[str, ModalityConfig],
    video_backend: str = "torchcodec",
    video_backend_kwargs: dict[str, Any] | None = None,
    shard_size: int = 1024,
    episode_sampling_rate: float = 0.1,
    seed: int = 42,
    allow_padding: bool = False,
)
dataset_path
str | Path
required
Path to LeRobot format dataset directory.
embodiment_tag
EmbodimentTag
required
Embodiment identifier for cross-embodiment training.
modality_configs
dict[str, ModalityConfig]
required
Configuration for each modality (sampling, keys). Should include “video”, “state”, “action”, and optionally “language”.
video_backend
str
default:"torchcodec"
Video decoding backend (‘torchcodec’, ‘decord’, etc.).
video_backend_kwargs
dict[str, Any] | None
default:"None"
Additional arguments for video backend.
shard_size
int
default:"1024"
Target number of timesteps per shard.
episode_sampling_rate
float
default:"0.1"
Fraction of episode timesteps to use (for efficiency). Value of 0.1 means 10% of timesteps are sampled.
seed
int
default:"42"
Random seed for reproducible sharding and sampling.
allow_padding
bool
default:"False"
Whether to allow padding of indices to valid range [0, max_length - 1].

Usage example

from gr00t.data.dataset import ShardedSingleStepDataset
from gr00t.data.embodiment_tags import EmbodimentTag
from gr00t.data.types import ModalityConfig

dataset = ShardedSingleStepDataset(
    dataset_path="/path/to/lerobot_dataset",
    embodiment_tag=EmbodimentTag.UNITREE_G1,
    modality_configs={
        "video": ModalityConfig(
            delta_indices=[0],
            modality_keys=["front_cam"],
        ),
        "state": ModalityConfig(
            delta_indices=[0],
            modality_keys=["joint_positions"],
        ),
        "action": ModalityConfig(
            delta_indices=list(range(16)),
            modality_keys=["joint_velocities"],
        ),
    },
    shard_size=1024,
    episode_sampling_rate=0.1,
)

# Get first shard of processed timesteps
shard_data = dataset.get_shard(0)

ShardedMixtureDataset

from gr00t.data.dataset import ShardedMixtureDataset
Iterable dataset that combines multiple sharded datasets with configurable mixing ratios. This dataset provides the core functionality for multi-dataset training:
  1. Combines multiple ShardedDataset instances with specified mixing weights
  2. Implements intelligent shard sampling that accounts for dataset sizes
  3. Provides efficient background shard caching for continuous data loading
  4. Handles distributed training across multiple workers and processes
  5. Merges dataset statistics for consistent normalization
The sampling strategy ensures that datasets are sampled proportionally to their weights while accounting for differences in shard sizes.

Constructor

def __init__(
    self,
    datasets: list[ShardedDataset],
    weights: list[float],
    processor: BaseProcessor,
    seed: int = 42,
    training: bool = True,
    num_shards_per_epoch: int = 100000,
    override_pretraining_statistics: bool = False,
)
datasets
list[ShardedDataset]
required
List of ShardedDataset instances to combine.
weights
list[float]
required
Mixing weights for each dataset (will be normalized to sum to 1.0).
processor
BaseProcessor
required
Data processor to apply to all datasets.
seed
int
default:"42"
Random seed for reproducible sampling.
training
bool
default:"True"
Whether in training mode (affects sampling strategy). In training mode, samples shards randomly. In eval mode, samples every shard once.
num_shards_per_epoch
int
default:"100000"
Number of shards to sample per epoch during training.
override_pretraining_statistics
bool
default:"False"
Whether to override pretraining statistics with merged statistics.

Methods

merge_statistics

def merge_statistics(self) -> None
Merge dataset statistics across all datasets, grouped by embodiment. Combines statistics from datasets with the same embodiment tag using weighted averaging, then configures the processor with merged statistics.

get_dataset_statistics

def get_dataset_statistics(self) -> dict[str, dict[str, dict[str, list[float]]]]
Get the merged dataset statistics.
return
dict[str, dict[str, dict[str, list[float]]]]
Nested dictionary: {embodiment_tag: {modality: {joint_group: {stat_type: values}}}}

reset_seed

def reset_seed(self, seed: int) -> None
Reset the random seed and regenerate sampling schedules.
seed
int
required
New random seed to use.
def print_dataset_statistics(self) -> None
Print formatted dataset statistics for debugging and monitoring.

get_initial_actions

def get_initial_actions(self) -> list
Collect initial actions from all datasets.
return
list
Combined list of initial actions from all constituent datasets.

Usage example

from gr00t.data.dataset import (
    ShardedSingleStepDataset,
    ShardedMixtureDataset,
)
from gr00t.data.embodiment_tags import EmbodimentTag
from gr00t.data.types import ModalityConfig

# Create individual datasets
dataset1 = ShardedSingleStepDataset(
    dataset_path="/path/to/dataset1",
    embodiment_tag=EmbodimentTag.UNITREE_G1,
    modality_configs=modality_configs,
    shard_size=1024,
)

dataset2 = ShardedSingleStepDataset(
    dataset_path="/path/to/dataset2",
    embodiment_tag=EmbodimentTag.UNITREE_G1,
    modality_configs=modality_configs,
    shard_size=1024,
)

dataset3 = ShardedSingleStepDataset(
    dataset_path="/path/to/dataset3",
    embodiment_tag=EmbodimentTag.GR1,
    modality_configs=modality_configs,
    shard_size=1024,
)

# Combine with mixture dataset
mixture = ShardedMixtureDataset(
    datasets=[dataset1, dataset2, dataset3],
    weights=[0.5, 0.3, 0.2],  # 50% dataset1, 30% dataset2, 20% dataset3
    processor=my_processor,
    num_shards_per_epoch=10000,
)

# Print dataset info
mixture.print_dataset_statistics()

# Iterate over batches
for batch in mixture:
    # batch contains processed data from mixed datasets
    model_output = model(**batch)

LeRobotEpisodeLoader

from gr00t.data.dataset import LeRobotEpisodeLoader
Episode-level data loader for LeRobot format datasets. This class handles the loading and preprocessing of individual episodes from LeRobot datasets. It manages metadata parsing, video decoding, and data extraction across multiple modalities (video, state, action, language).

Constructor

def __init__(
    self,
    dataset_path: str | Path,
    modality_configs: dict[str, ModalityConfig],
    video_backend: str = "torchcodec",
    video_backend_kwargs: dict[str, Any] | None = None,
)
dataset_path
str | Path
required
Path to dataset root directory containing meta/ and data files.
modality_configs
dict[str, ModalityConfig]
required
Dictionary mapping modality names to ModalityConfig objects that specify temporal sampling and data keys to load.
video_backend
str
default:"torchcodec"
Video decoding backend (‘torchcodec’, ‘decord’, etc.).
video_backend_kwargs
dict[str, Any] | None
default:"None"
Additional arguments for the video backend.

Methods

__getitem__

def __getitem__(self, idx: int) -> pd.DataFrame
Load complete episode data as a processed DataFrame.
idx
int
required
Episode index to load.
return
pd.DataFrame
DataFrame with columns for all modalities and timestamps, with video frames as PIL Images.

get_dataset_statistics

def get_dataset_statistics(self) -> dict[str, Any]
Extract dataset statistics for normalization from loaded metadata.
return
dict[str, Any]
Nested dictionary: {modality: {joint_group: {stat_type: values}}}

get_initial_actions

def get_initial_actions(self) -> list
Load initial actions for policy initialization if available.
return
list
List containing initial action dictionaries, or empty list if not available.

Usage example

from gr00t.data.dataset import LeRobotEpisodeLoader
from gr00t.data.types import ModalityConfig

loader = LeRobotEpisodeLoader(
    dataset_path="/path/to/lerobot_dataset",
    modality_configs={
        "video": ModalityConfig(
            delta_indices=[0],
            modality_keys=["front_cam"],
        ),
        "state": ModalityConfig(
            delta_indices=[0],
            modality_keys=["joint_positions"],
        ),
        "action": ModalityConfig(
            delta_indices=list(range(16)),
            modality_keys=["joint_velocities"],
        ),
    },
)

# Load first episode as DataFrame
episode_data = loader[0]
print(episode_data.columns)
# ['state.joint_positions', 'action.joint_velocities', 'video.front_cam', ...]

# Get dataset statistics
stats = loader.get_dataset_statistics()

Build docs developers (and LLMs) love