Skip to main content

rfx.utils

Utility functions and helpers for multi-embodiment training, including DOF padding and observation/action transforms.

Padding Functions

DOF padding utilities for training models across robots with different numbers of joints.

PaddingConfig

Configuration for state/action padding.
from rfx.utils import PaddingConfig

config = PaddingConfig(
    state_dim=24,
    action_dim=12,
    max_state_dim=64,
    max_action_dim=64
)

Constructor

PaddingConfig(
    state_dim: int,
    action_dim: int,
    max_state_dim: int = 64,
    max_action_dim: int = 64
)
state_dim
int
required
Actual state dimension for this robot.
action_dim
int
required
Actual action dimension for this robot.
max_state_dim
int
default:"64"
Maximum state dimension across all robots in the dataset.
max_action_dim
int
default:"64"
Maximum action dimension across all robots in the dataset.

pad_state

from rfx.utils import pad_state
import torch

state = torch.randn(32, 24)  # (batch, state_dim)
padded = pad_state(state, state_dim=24, max_state_dim=64)
print(padded.shape)  # (32, 64)
Pad state tensor to max_state_dim.
pad_state(
    state: torch.Tensor,
    state_dim: int,
    max_state_dim: int
) -> torch.Tensor
state
torch.Tensor
required
State tensor of shape (batch, state_dim) or (batch, seq_len, state_dim).
state_dim
int
required
Actual state dimension.
max_state_dim
int
required
Target padded dimension.
return
torch.Tensor
Padded state tensor of shape (batch, max_state_dim) or (batch, seq_len, max_state_dim).

pad_action

from rfx.utils import pad_action
import torch

action = torch.randn(32, 12)  # (batch, action_dim)
padded = pad_action(action, action_dim=12, max_action_dim=64)
print(padded.shape)  # (32, 64)
Pad action tensor to max_action_dim.
pad_action(
    action: torch.Tensor,
    action_dim: int,
    max_action_dim: int
) -> torch.Tensor
action
torch.Tensor
required
Action tensor of shape (batch, action_dim) or (batch, seq_len, action_dim).
action_dim
int
required
Actual action dimension.
max_action_dim
int
required
Target padded dimension.
return
torch.Tensor
Padded action tensor of shape (batch, max_action_dim) or (batch, seq_len, max_action_dim).

unpad_action

from rfx.utils import unpad_action
import torch

padded_action = torch.randn(32, 64)  # (batch, max_action_dim)
action = unpad_action(padded_action, action_dim=12)
print(action.shape)  # (32, 12)
Extract actual action from padded tensor.
unpad_action(
    action: torch.Tensor,
    action_dim: int
) -> torch.Tensor
action
torch.Tensor
required
Padded action tensor.
action_dim
int
required
Actual action dimension to extract.
return
torch.Tensor
Unpadded action tensor of shape (..., action_dim).

Transform Classes

Observation and action transforms for training.

ObservationNormalizer

Running mean/std normalizer for observations using Welford’s algorithm.
from rfx.utils import ObservationNormalizer
import torch

normalizer = ObservationNormalizer(state_dim=24, clip=10.0)

# Update statistics
states = torch.randn(1000, 24)
normalizer.update(states)

# Normalize observations
obs = {"state": torch.randn(32, 24)}
normalized = normalizer.normalize(obs)

Constructor

ObservationNormalizer(
    state_dim: int,
    clip: float = 10.0,
    eps: float = 1e-8
)
state_dim
int
required
Dimension of the state to normalize.
clip
float
default:"10.0"
Maximum absolute value after normalization.
eps
float
default:"1e-8"
Small constant for numerical stability.

Methods

update
normalizer.update(state: torch.Tensor) -> None
Update running statistics with a batch of states.
state
torch.Tensor
required
Batch of states, shape (batch_size, state_dim).
normalize
normalizer.normalize(obs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]
Normalize observations using running statistics.
obs
dict[str, torch.Tensor]
required
Observation dictionary containing "state", optionally "images" and "language".
return
dict[str, torch.Tensor]
Normalized observation dictionary.
to_dict
normalizer.to_dict() -> dict[str, Any]
Serialize normalizer state to a JSON-compatible dict.
return
dict[str, Any]
Dictionary containing all normalizer state.
from_dict
@classmethod
ObservationNormalizer.from_dict(d: dict[str, Any]) -> ObservationNormalizer
Reconstruct a normalizer from a serialized dict.
d
dict[str, Any]
required
Serialized normalizer state.
return
ObservationNormalizer
Reconstructed normalizer.

ActionChunker

Action chunking for temporal abstraction. Ensembles multiple action predictions over a temporal horizon.
from rfx.utils import ActionChunker
import torch

chunker = ActionChunker(
    horizon=4,
    action_dim=12,
    ensemble_mode="exponential",
    temperature=0.5
)

# Add action chunks (e.g., from a model predicting 4 steps ahead)
for _ in range(4):
    chunk = torch.randn(1, 4, 12)  # (batch, horizon, action_dim)
    chunker.add_chunk(chunk)

# Get ensembled action for current timestep
action = chunker.get_action()  # (batch, action_dim)

Constructor

ActionChunker(
    horizon: int,
    action_dim: int,
    ensemble_mode: str = "exponential",
    temperature: float = 0.5
)
horizon
int
required
Number of timesteps to predict ahead.
action_dim
int
required
Dimension of each action.
ensemble_mode
str
default:"exponential"
How to ensemble predictions: "first", "average", or "exponential".
temperature
float
default:"0.5"
Temperature for exponential weighting (only used if ensemble_mode=“exponential”).

Methods

add_chunk
chunker.add_chunk(chunk: torch.Tensor) -> None
Add a new action chunk prediction.
chunk
torch.Tensor
required
Action chunk of shape (batch, horizon, action_dim).
get_action
chunker.get_action() -> torch.Tensor
Get ensembled action for the current timestep.
return
torch.Tensor
Action tensor of shape (batch, action_dim).
reset
chunker.reset() -> None
Clear all stored chunks.

Ensemble Modes

  • "first": Use only the first prediction from the most recent chunk
  • "average": Average all predictions for the current timestep
  • "exponential": Weighted average with exponentially decaying weights based on prediction age

Build docs developers (and LLMs) love