Skip to main content
The rfx.nn module provides simple, transparent policy implementations using tinygrad tensors. Follows the tinygrad philosophy: simple to start, powerful enough for production.
from rfx.nn import MLP, go2_mlp

policy = go2_mlp()
obs = Tensor.randn(1, 48)
actions = policy(obs)  # JIT compiled on second call

Base Classes

Policy

Base policy class for neural network policies. Users can subclass this to create custom architectures.
from rfx.nn import Policy
from tinygrad import Tensor
from tinygrad.nn import Linear

class CustomPolicy(Policy):
    def __init__(self):
        self.l1 = Linear(48, 256)
        self.l2 = Linear(256, 12)

    def forward(self, obs: Tensor) -> Tensor:
        x = self.l1(obs).tanh()
        return self.l2(x).tanh()

Methods

forward(obs) Forward pass: observations → actions. Subclasses must implement this.
obs
Tensor
Observation tensor of shape (batch, obs_dim).
Returns: Tensor - Action tensor of shape (batch, act_dim).
__call__(obs) Run inference (JIT compiled after first call).
obs
Tensor
Observation tensor.
Returns: Tensor - Action tensor.
save(path, *, robot_config=None, normalizer=None, training_info=None) Save policy as a self-describing directory.
path
str | Path
required
Directory path to save into.
robot_config
RobotConfig | None
Optional RobotConfig to bundle.
normalizer
ObservationNormalizer | None
Optional ObservationNormalizer to bundle.
training_info
dict[str, Any] | None
Optional training metadata dict.
Returns: Path - The directory path.
Policy.load(path) (classmethod) Load a self-describing policy from a directory.
path
str | Path
required
Directory path or legacy .safetensors file.
Returns: Policy - Policy instance with loaded weights.

Policy Networks

MLP

Multi-layer perceptron policy with tanh activations. Suitable for most locomotion tasks.
from rfx.nn import MLP
from tinygrad import Tensor

policy = MLP(obs_dim=48, act_dim=12, hidden=[256, 256])
obs = Tensor.randn(1, 48)
actions = policy(obs)
print(actions.shape)  # (1, 12)
obs_dim
int
required
Observation dimension.
act_dim
int
required
Action dimension.
hidden
list[int] | None
default:"[256, 256]"
List of hidden layer sizes.

JitPolicy

A policy wrapper that enables TinyJit compilation for faster inference. The first call traces the computation graph, subsequent calls are fast.
from rfx.nn import MLP, JitPolicy

mlp = MLP(48, 12)
jit_policy = JitPolicy(mlp)

# First call: traces graph
actions = jit_policy(obs)

# Second call: runs compiled kernel
actions = jit_policy(obs)
policy
Policy
required
The policy to wrap.
Note: Save/load delegates to the inner policy — JIT compilation state is re-created on load automatically.

ActorCritic

Actor-critic network for PPO training. Shares a backbone between actor (policy) and critic (value function).
from rfx.nn import ActorCritic
from tinygrad import Tensor

ac = ActorCritic(48, 12)
obs = Tensor.randn(32, 48)
actions, values = ac.forward_actor_critic(obs)
obs_dim
int
required
Observation dimension.
act_dim
int
required
Action dimension.
hidden
list[int] | None
default:"[256, 256]"
Hidden layer sizes for shared backbone.

Methods

forward_actor_critic(obs) Get both actions and values (for training).
obs
Tensor
Observation tensor.
Returns: tuple[Tensor, Tensor] - Tuple of (action_mean, value).
get_action_and_value(obs, action=None) Sample action and compute log prob + entropy (for PPO update).
obs
Tensor
Observations.
action
Tensor | None
Optional pre-sampled action (for computing log prob).
Returns: tuple[Tensor, Tensor, Tensor, Tensor] - Tuple of (action, log_prob, entropy, value).

TorchJitPolicy

Policy wrapper for PyTorch TorchScript (.pt) models. Loads a model saved with torch.jit.save() and exposes it through the standard rfx Policy interface.
from rfx.nn import TorchJitPolicy

policy = TorchJitPolicy(
    model_path="model.pt",
    obs_keys=["state"],
    device="cpu"
)

obs = {"state": torch.randn(1, 48)}
actions = policy(obs)
model_path
str | Path | None
Path to the .pt TorchScript file.
obs_keys
list[str] | None
default:"['state']"
Observation dict keys to concatenate as model input.
device
str
default:"'cpu'"
Torch device string.
Note: Accepts dict[str, torch.Tensor] observations directly — no tinygrad conversion.

Convenience Constructors

go2_mlp

Create an MLP policy sized for the Go2 robot.
  • Go2 observation space: 48 dimensions
  • Go2 action space: 12 dimensions (joint positions)
from rfx.nn import go2_mlp

policy = go2_mlp()  # Uses [256, 256] hidden layers
policy = go2_mlp(hidden=[512, 512])  # Custom hidden sizes
hidden
list[int] | None
default:"[256, 256]"
Hidden layer sizes.
Returns: MLP - MLP policy for Go2.

go2_actor_critic

Create an ActorCritic network sized for the Go2 robot.
from rfx.nn import go2_actor_critic

ac = go2_actor_critic()
ac = go2_actor_critic(hidden=[512, 512])  # Custom hidden sizes
hidden
list[int] | None
default:"[256, 256]"
Hidden layer sizes.
Returns: ActorCritic - ActorCritic network for Go2.

Utilities

register_policy

Register a policy class for auto-detection during load.
from rfx.nn import Policy, register_policy

@register_policy
class CustomPolicy(Policy):
    def __init__(self, obs_dim, act_dim):
        ...

    def forward(self, obs):
        ...
cls
type[Policy]
required
Policy class to register.
Returns: type[Policy] - The same class (decorator).

Build docs developers (and LLMs) love