Skip to main content
This guide explains how to use the Gr00tPolicy class to load and run inference with your trained model. After training, you’ll use this API to integrate your model with evaluation environments.

Loading the policy

Initialize a policy by providing the embodiment tag, model checkpoint path, and device:
from gr00t.policy import Gr00tPolicy
from gr00t.data.embodiment_tags import EmbodimentTag

# Load your trained model
policy = Gr00tPolicy(
    model_path="/path/to/your/checkpoint",
    embodiment_tag=EmbodimentTag.NEW_EMBODIMENT,
    device="cuda:0",  # or "cpu", or device index like 0
    strict=True  # Enable input/output validation (recommended during development)
)

Parameters

  • model_path: Path to your trained model checkpoint directory
  • embodiment_tag: The embodiment tag you used during training (e.g., EmbodimentTag.NEW_EMBODIMENT)
  • device: Device to run inference on ("cuda:0", "cpu", or integer device index)
  • strict: Whether to validate inputs/outputs (recommended during development, can disable in production)

Understanding the observation format

The policy expects observations as a nested dictionary with three modalities:
observation = {
    "video": {
        "camera_name": np.ndarray,  # Shape: (B, T, H, W, 3), dtype: uint8
        # ... one entry per camera
    },
    "state": {
        "state_name": np.ndarray,   # Shape: (B, T, D), dtype: float32
        # ... one entry per state stream
    },
    "language": {
        "task": [[str]],            # Shape: (B, 1), list of lists of strings
    }
}

Dimensions

  • B: Batch size (number of parallel environments)
  • T: Temporal horizon (number of historical observations)
  • H, W: Image height and width
  • D: State dimension
  • C: Number of channels (must be 3 for RGB)

Data type requirements

  • Videos must be np.uint8 arrays with RGB pixel values in range [0, 255]
  • States must be np.float32 arrays
  • Language instructions are lists of lists of strings
The temporal horizon T is determined by your model’s training configuration. Different modalities may have different temporal horizons (query via get_modality_config()).

Understanding the action format

The policy returns actions in a similar nested structure:
action = {
    "action_name": np.ndarray,  # Shape: (B, T, D), dtype: float32
    # ... one entry per action stream
}

Dimensions

  • B: Batch size (matches input batch size)
  • T: Action horizon (number of future action steps to predict)
  • D: Action dimension (e.g., 7 for arm joints, 1 for gripper)
Actions are returned in physical units (e.g., joint positions in radians, velocities in rad/s) and are not normalized - they’re ready to send to your robot controller.

Running inference

Use the get_action() method to compute actions from observations:
# Get action from current observation
action, info = policy.get_action(observation)

# Access the action array
arm_action = action["action_name"]  # Shape: (B, T, D)

# Extract the first action to execute
next_action = arm_action[:, 0, :]  # Shape: (B, D)
The method returns a tuple of:
  • action: Dictionary of action arrays
  • info: Dictionary of additional information (currently empty, reserved for future use)

Querying modality configurations

To understand what observations your policy expects and what actions it produces, query the modality configuration:
# Get modality configs for your embodiment
modality_configs = policy.get_modality_config()

# Check what camera keys are expected
video_keys = modality_configs["video"].modality_keys
print(f"Expected cameras: {video_keys}")

# Check video temporal horizon
video_horizon = len(modality_configs["video"].delta_indices)
print(f"Video frames needed: {video_horizon}")

# Check state keys and horizon
state_keys = modality_configs["state"].modality_keys
state_horizon = len(modality_configs["state"].delta_indices)
print(f"Expected states: {state_keys}, horizon: {state_horizon}")

# Check action keys and horizon
action_keys = modality_configs["action"].modality_keys
action_horizon = len(modality_configs["action"].delta_indices)
print(f"Action outputs: {action_keys}, horizon: {action_horizon}")
This is especially useful when:
  • You’re unsure what observations your trained model expects
  • You need to verify the temporal horizons for each modality
  • You’re debugging observation/action format mismatches

Resetting the policy

Reset the policy between episodes:
# Reset policy state (if any) between episodes
info = policy.reset()
Currently, the policy is stateless, but calling reset() is good practice for future compatibility.

Adapting the policy to your environment

Most environments use different observation/action formats than the Policy API expects. You’ll typically need to write a policy wrapper that:
  1. Transforms observations: Convert your environment’s observation format to the Policy API format
  2. Calls the policy: Use policy.get_action() to compute actions
  3. Transforms actions: Convert the policy’s actions back to your environment’s format

Example workflow

# In your environment loop
env_obs = env.reset()  # Environment-specific format

# Transform to Policy API format
policy_obs = transform_observation(env_obs)

# Get action from policy
policy_action, _ = policy.get_action(policy_obs)

# Transform back to environment format
env_action = transform_action(policy_action)

# Execute in environment
env_obs, reward, done, info = env.step(env_action)

Server-client architecture for remote inference

For many use cases, especially when working with real robots or distributed systems, you may want to run the policy on a separate machine (e.g., a GPU server) and send observations/actions over the network.

Why use server-client architecture?

  • Separate compute resources: Run policy inference on a GPU server while controlling the robot from a different machine
  • Dependency isolation: Avoid dependency issues with the client policy

Starting the policy server

python gr00t/eval/run_gr00t_server.py \
    --embodiment-tag NEW_EMBODIMENT \
    --model-path /path/to/your/checkpoint \
    --device cuda:0 \
    --host 0.0.0.0 \
    --port 5555 \
    --strict True

Parameters

  • --embodiment-tag: The embodiment tag for your robot (e.g., NEW_EMBODIMENT)
  • --model-path: Path to your trained model checkpoint directory
  • --device: Device to run inference on (cuda:0, cuda:1, cpu, etc.)
  • --host: Host address (127.0.0.1 for local only, 0.0.0.0 to accept external connections)
  • --port: Port number (default: 5555)
  • --strict: Enable input/output validation (default: True)

Using the policy client

On the client side, use PolicyClient to connect to the server:
from gr00t.policy.server_client import PolicyClient

# Connect to the policy server
policy = PolicyClient(
    host="localhost",  # or IP address of your GPU server
    port=5555,
    timeout_ms=15000,  # 15 second timeout for inference
)

# Verify connection
if not policy.ping():
    raise RuntimeError("Cannot connect to policy server!")

# Use just like a regular policy
observation = get_observation()  # Your observation in Policy API format
action, info = policy.get_action(observation)
The PolicyClient implements the same BasePolicy interface, so it’s a drop-in replacement for Gr00tPolicy.

Common patterns

Batched inference

The policy supports batched inference for efficiency:
# Run 4 environments in parallel
batch_size = 4
observation = {
    "video": {"wrist_cam": np.zeros((batch_size, T_video, H, W, 3), dtype=np.uint8)},
    "state": {"joints": np.zeros((batch_size, T_state, D_state), dtype=np.float32)},
    "language": {"task": [["pick up the cube"]] * batch_size},
}

action, _ = policy.get_action(observation)
# action["action_name"] has shape (batch_size, action_horizon, action_dim)

Single environment inference

For single environments, use batch size of 1:
# Add batch dimension (B=1)
observation = {
    "video": {"wrist_cam": video[np.newaxis, ...]},  # (1, T, H, W, 3)
    "state": {"joints": state[np.newaxis, ...]},     # (1, T, D)
    "language": {"task": [["pick up the cube"]]},    # List of length 1
}

action, _ = policy.get_action(observation)

# Remove batch dimension
single_action = action["action_name"][0]  # (action_horizon, action_dim)

Action chunking

When the action horizon T > 1, you can use action chunking:
action, _ = policy.get_action(observation)
action_chunk = action["action_name"][:, :, :]  # (B, T, D)

# Execute actions over multiple timesteps
for t in range(action_chunk.shape[1]):
    env.step(action_chunk[:, t, :])

Troubleshooting

  1. Enable strict mode during development: strict=True
  2. Print modality configs to understand expected formats
  3. Check shapes of your observations before calling get_action()
  4. Use the reference wrapper (Gr00tSimPolicyWrapper) as a template
  5. Validate incrementally: Test with dummy observations first before connecting to real environments

Build docs developers (and LLMs) love