rfx.nn module provides simple, transparent policy implementations using tinygrad tensors. Follows the tinygrad philosophy: simple to start, powerful enough for production.
Base Classes
Policy
Base policy class for neural network policies. Users can subclass this to create custom architectures.Methods
forward(obs)
Forward pass: observations → actions. Subclasses must implement this.
Observation tensor of shape
(batch, obs_dim).Tensor - Action tensor of shape (batch, act_dim).
__call__(obs)
Run inference (JIT compiled after first call).
Observation tensor.
Tensor - Action tensor.
save(path, *, robot_config=None, normalizer=None, training_info=None)
Save policy as a self-describing directory.
Directory path to save into.
Optional RobotConfig to bundle.
Optional ObservationNormalizer to bundle.
Optional training metadata dict.
Path - The directory path.
Policy.load(path) (classmethod)
Load a self-describing policy from a directory.
Directory path or legacy
.safetensors file.Policy - Policy instance with loaded weights.
Policy Networks
MLP
Multi-layer perceptron policy with tanh activations. Suitable for most locomotion tasks.Observation dimension.
Action dimension.
List of hidden layer sizes.
JitPolicy
A policy wrapper that enables TinyJit compilation for faster inference. The first call traces the computation graph, subsequent calls are fast.The policy to wrap.
ActorCritic
Actor-critic network for PPO training. Shares a backbone between actor (policy) and critic (value function).Observation dimension.
Action dimension.
Hidden layer sizes for shared backbone.
Methods
forward_actor_critic(obs)
Get both actions and values (for training).
Observation tensor.
tuple[Tensor, Tensor] - Tuple of (action_mean, value).
get_action_and_value(obs, action=None)
Sample action and compute log prob + entropy (for PPO update).
Observations.
Optional pre-sampled action (for computing log prob).
tuple[Tensor, Tensor, Tensor, Tensor] - Tuple of (action, log_prob, entropy, value).
TorchJitPolicy
Policy wrapper for PyTorch TorchScript (.pt) models. Loads a model saved with torch.jit.save() and exposes it through the standard rfx Policy interface.
Path to the
.pt TorchScript file.Observation dict keys to concatenate as model input.
Torch device string.
dict[str, torch.Tensor] observations directly — no tinygrad conversion.
Convenience Constructors
go2_mlp
Create an MLP policy sized for the Go2 robot.- Go2 observation space: 48 dimensions
- Go2 action space: 12 dimensions (joint positions)
Hidden layer sizes.
MLP - MLP policy for Go2.
go2_actor_critic
Create an ActorCritic network sized for the Go2 robot.Hidden layer sizes.
ActorCritic - ActorCritic network for Go2.
Utilities
register_policy
Register a policy class for auto-detection during load.Policy class to register.
type[Policy] - The same class (decorator).