Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/pytorch/rl/llms.txt

Use this file to discover all available pages before exploring further.

Model-based reinforcement learning trains an explicit dynamics model of the environment — a world model — that the agent can query to imagine future outcomes without spending real environment steps. TorchRL provides first-class support for model-based RL through a dedicated ModelBasedEnvBase hierarchy, RSSM (Recurrent State-Space Model) building blocks, and three separate loss modules for Dreamer and DreamerV3. All components follow the standard TensorDict data model so they plug directly into TorchRL collectors, replay buffers, and value estimators.

What is Model-Based RL?

In model-free RL the agent only observes real environment transitions. Every gradient step requires fresh rollouts, which makes sample efficiency a bottleneck. Model-based approaches instead maintain a learned transition model p(s_{t+1} | s_t, a_t) and a reward model r(s_t, a_t). Once a world model is trained, the agent can generate arbitrarily long imagined rollouts inside it, providing dense, cheap training signal with far fewer real environment interactions. Dreamer (Hafner et al. 2019) and DreamerV3 (Hafner et al. 2023) are the canonical deep model-based algorithms. Both learn a compact latent representation through a Recurrent State-Space Model and train the actor exclusively inside the imagination. TorchRL implements the full family of loss modules for both versions.

World Models and the RSSM

The RSSM splits the latent state into two complementary parts:
  • Deterministic belief h_t — a GRU hidden state that accumulates history without noise.
  • Stochastic state s_t — a small Gaussian random variable that captures the irreducible uncertainty about the current moment.
TorchRL provides RSSMPrior and RSSMPosterior as standalone nn.Module components, and RSSMRollout as a TensorDictModuleBase that chains them across time.
from torchrl.modules.models.model_based import RSSMPrior, RSSMPosterior, RSSMRollout
from tensordict.nn import TensorDictModule

# Prior: p(s_{t+1} | s_t, a_t, h_t)
rssm_prior = TensorDictModule(
    RSSMPrior(action_spec=action_spec, hidden_dim=200, rnn_hidden_dim=200, state_dim=30),
    in_keys=["state", "belief", "action"],
    out_keys=["prior_mean", "prior_std", "state", "belief"],
)

# Posterior: q(s_t | h_t, o_t)
rssm_posterior = TensorDictModule(
    RSSMPosterior(hidden_dim=200, state_dim=30),
    in_keys=["belief", "encoded_obs"],
    out_keys=["posterior_mean", "posterior_std", "state"],
)

# RSSMRollout chains prior and posterior over a time dimension
rssm_rollout = RSSMRollout(rssm_prior=rssm_prior, rssm_posterior=rssm_posterior)
RSSMRollout supports three execution modes via the use_scan flag and compile_step option:
ModeDescription
use_scan=False (default)Standard Python loop over time steps
use_scan=TrueUses torch._higher_order_ops.scan — more torch.compile-friendly
compile_step=TrueCompiles the individual step function with Inductor

Dreamer Loss Modules

Dreamer training alternates between two phases: world-model learning from real data and actor/value learning from imagined rollouts. TorchRL maps each phase to a dedicated loss class.
1
Phase 1 — World Model Loss
2
DreamerModelLoss trains the RSSM encoder, decoder, and reward head jointly:
3
  • KL loss — divergence between prior p(s | h, a) and posterior q(s | h, o).
  • Reconstruction loss — pixel or feature reconstruction from the decoded belief.
  • Reward loss — predicted vs. observed reward.
  • 4
    from torchrl.objectives.dreamer import DreamerModelLoss
    
    world_model_loss = DreamerModelLoss(
        world_model=world_model,        # TensorDictModule wrapping encoder + RSSM + decoder
        lambda_kl=1.0,                  # weight for KL divergence term
        lambda_reco=1.0,                # weight for reconstruction term
        lambda_reward=1.0,              # weight for reward prediction term
        free_nats=3,                    # KL clamping floor (nats)
        reco_loss="l2",
        reward_loss="l2",
        delayed_clamp=False,            # clamp before (False) or after (True) averaging
    )
    
    # Forward returns (loss_td, updated_tensordict)
    loss_td, _ = world_model_loss(batch)
    # loss_td has keys: "loss_model_kl", "loss_model_reco", "loss_model_reward"
    total_model_loss = loss_td["loss_model_kl"] + loss_td["loss_model_reco"] + loss_td["loss_model_reward"]
    total_model_loss.backward()
    
    5
    Phase 2 — Actor and Value Losses
    6
    DreamerActorLoss imagines imagination_horizon steps, computes lambda-return targets, and maximises them:
    7
    from torchrl.objectives.dreamer import DreamerActorLoss, DreamerValueLoss
    from torchrl.envs.model_based import DreamerEnv
    
    model_env = DreamerEnv(
        world_model=world_model,
        prior_shape=(30,),
        belief_shape=(200,),
        device="cuda",
    )
    
    actor_loss = DreamerActorLoss(
        actor_model=actor,
        value_model=value_net,
        model_based_env=model_env,
        imagination_horizon=15,   # rollout length in latent space
        discount_loss=True,       # discount lambda targets by gamma^t
    )
    
    value_loss = DreamerValueLoss(
        value_model=value_net,
    )
    

    DreamerV3 Loss Modules

    DreamerV3 introduces KL balancing, symlog-compressed reconstruction, and a two-hot categorical value distribution. TorchRL exposes these via three parallel classes under torchrl.objectives.
    from torchrl.objectives.dreamer_v3 import (
        DreamerV3ModelLoss,
        DreamerV3ActorLoss,
        DreamerV3ValueLoss,
        symlog,
        symexp,
        two_hot_encode,
        two_hot_decode,
    )
    
    The symmetric log/exp transforms and two-hot utilities are available as standalone functions you can use in custom models:
    import torch
    from torchrl.objectives.dreamer_v3 import symlog, symexp
    
    x = torch.tensor([-100.0, 0.0, 100.0])
    compressed = symlog(x)        # sign(x) * log(|x| + 1)
    restored = symexp(compressed) # inverse transform
    

    Comparison: Dreamer vs. DreamerV3

    FeatureDreamerDreamerV3
    KL regularisationPlain KLKL balancing (free bits)
    ReconstructionL2 pixel losssymlog MSE
    Value targetsLambda returnTwo-hot categorical CE
    Actor gradientStraight-throughREINFORCE + entropy
    Import pathtorchrl.objectives.dreamertorchrl.objectives.dreamer_v3

    ModelBasedEnvBase and the Imagination API

    ModelBasedEnvBase is a drop-in replacement for any EnvBase that executes transitions through the world model instead of a real simulator. All TorchRL collectors and value estimators work unchanged.
    from torchrl.envs.model_based import ModelBasedEnvBase, DreamerEnv, WorldModelEnv
    
    # DreamerEnv is pre-wired for Dreamer-style imagination
    env = DreamerEnv(
        world_model=world_model,
        prior_shape=(30,),
        belief_shape=(200,),
        device="cuda",
        batch_size=[32],
    )
    
    # A rollout inside latent space — no real simulator calls
    fake_data = env.rollout(
        max_steps=15,
        policy=actor,
        tensordict=initial_latent_state,
        auto_reset=False,
    )
    
    WorldModelEnv is a more generic variant that wraps any callable world model:
    from torchrl.envs.model_based import WorldModelEnv
    
    generic_env = WorldModelEnv(
        world_model=my_world_model,
        observation_spec=obs_spec,
        action_spec=action_spec,
        reward_spec=reward_spec,
        device="cuda",
    )
    

    RSSM Latent Imagination in Practice

    A full Dreamer training step looks like this:
    1
    Collect Real Transitions
    2
    from torchrl.collectors import Collector
    
    collector = Collector(
        env=real_env,
        policy=actor,
        frames_per_batch=1000,
    )
    real_batch = next(iter(collector))
    replay_buffer.extend(real_batch)
    
    3
    Train the World Model
    4
    # Sample from replay buffer
    batch = replay_buffer.sample(batch_size=32)
    
    # Encode observations
    encoded = obs_encoder(batch)
    
    # Unroll RSSM over the trajectory time dimension
    latent_batch = rssm_rollout(encoded)
    
    # Compute world model losses
    loss_td, _ = world_model_loss(latent_batch)
    wm_optimizer.zero_grad()
    (loss_td["loss_model_kl"] + loss_td["loss_model_reco"] + loss_td["loss_model_reward"]).backward()
    wm_optimizer.step()
    
    5
    Train Actor and Value in Imagination
    6
    # Start from posterior states obtained during world-model training
    posterior_states = latent_batch.select("state", "belief")
    
    # DreamerActorLoss internally rolls out imagination_horizon steps
    actor_loss_td, fake_data = actor_loss(posterior_states)
    value_loss_td, _ = value_loss(fake_data)
    
    actor_optimizer.zero_grad()
    actor_loss_td["loss_actor"].backward()
    actor_optimizer.step()
    
    value_optimizer.zero_grad()
    value_loss_td["loss_value"].backward()
    value_optimizer.step()
    

    PILCO: Gaussian Process World Models

    For lower-dimensional problems TorchRL also provides a Gaussian Process world model based on PILCO (Deisenroth & Rasmussen, 2011). GPWorldModel fits one independent GP per state dimension and propagates Gaussian beliefs via analytic moment matching — no neural network required.
    from torchrl.modules.models.gp import GPWorldModel
    from torchrl.objectives.pilco import ExponentialQuadraticCost
    
    # Requires botorch and gpytorch
    gp_model = GPWorldModel(obs_dim=4, action_dim=1)
    gp_model.fit(dataset)   # dataset is a TensorDict with obs/action/next_obs keys
    
    # Smooth saturating cost for analytic policy gradients
    cost = ExponentialQuadraticCost(
        target=torch.zeros(4),
        weights=torch.eye(4),
    )
    
    The ExponentialQuadraticCost computes E_{x~N(m,s)}[c(x)] analytically (Eq. 24-25 in the PILCO paper), enabling gradient-based policy search without stochastic sampling.
    GPWorldModel requires gpytorch and botorch as optional dependencies. Install them with pip install gpytorch botorch. The PILCO path scales well to continuous control tasks with low state dimensionality (≤ 20) but becomes computationally expensive for high-dimensional observations.

    Key Imports Reference

    # Loss modules
    from torchrl.objectives.dreamer import (
        DreamerModelLoss,
        DreamerActorLoss,
        DreamerValueLoss,
    )
    from torchrl.objectives.dreamer_v3 import (
        DreamerV3ModelLoss,
        DreamerV3ActorLoss,
        DreamerV3ValueLoss,
        symlog, symexp,
        two_hot_encode, two_hot_decode,
    )
    from torchrl.objectives.pilco import ExponentialQuadraticCost
    
    # Environment API
    from torchrl.envs.model_based import (
        ModelBasedEnvBase,
        DreamerEnv,
        DreamerDecoder,
        ImaginedEnv,
        WorldModelEnv,
    )
    
    # Network building blocks
    from torchrl.modules.models.model_based import (
        DreamerActor,
        ObsEncoder,
        RSSMPrior,
        RSSMPosterior,
        RSSMRollout,
    )
    from torchrl.modules.models.gp import GPWorldModel
    

    Build docs developers (and LLMs) love