Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/lansinuote/Simple_Reinforcement_Learning/llms.txt

Use this file to discover all available pages before exploring further.

Model-Based Policy Optimization (MBPO) is a hybrid algorithm that combines the sample efficiency of model-based RL with the asymptotic performance of a model-free policy optimizer. The core idea is simple: learn a dynamics model from real environment interactions, use it to generate synthetic short rollouts, and add those synthetic transitions to the replay buffer used to train a SAC policy. This way, the policy receives many more gradient updates per unit of real environment interaction, dramatically improving sample efficiency. The key insight that makes MBPO work is keeping synthetic rollouts short (typically 1–5 steps). Compounding model errors over long rollouts can destabilize training; short rollouts stay close to states where the model is accurate.

Key Concepts

MBPO maintains two replay buffers: env_pool (real transitions) and model_pool (synthetic transitions). The SAC policy is updated on a mix of both. The training loop alternates between: (1) collect real data → (2) train model → (3) generate synthetic data → (4) update SAC policy.

Environment

Pendulum-v1 is used as the benchmark:
import gym

class MyWrapper(gym.Wrapper):
    def __init__(self):
        env = gym.make('Pendulum-v1', render_mode='rgb_array')
        super().__init__(env)
        self.env = env
        self.step_n = 0

    def reset(self):
        state, _ = self.env.reset()
        self.step_n = 0
        return state

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        self.step_n += 1
        if self.step_n >= 200:
            done = True
        return state, reward, done, info

env = MyWrapper()

Dual Replay Buffers

Two separate buffers hold real and synthetic data respectively:
import numpy as np
import random
import torch

class Pool:
    def __init__(self, limit):
        self.datas = []
        self.limit = limit

    def add(self, state, action, reward, next_state, over):
        if isinstance(state, (np.ndarray, torch.Tensor)):
            state = state.reshape(3).tolist()
        action = float(action)
        reward = float(reward)
        if isinstance(next_state, (np.ndarray, torch.Tensor)):
            next_state = next_state.reshape(3).tolist()
        over = bool(over)
        self.datas.append((state, action, reward, next_state, over))
        while len(self.datas) > self.limit:
            self.datas.pop(0)

    def get_sample(self, size=None):
        if size is None:
            size = len(self)
        size = min(size, len(self))
        samples = random.sample(self.datas, size)
        state      = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)
        action     = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)
        reward     = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)
        next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)
        over       = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)
        return state, action, reward, next_state, over

    def __len__(self):
        return len(self.datas)

# Real transitions: up to 10,000 steps
env_pool   = Pool(10000)
# Synthetic transitions: up to 1,000 steps (kept small to stay fresh)
model_pool = Pool(1000)

SAC Policy

The policy uses a stochastic actor that outputs a squashed Gaussian and two critic networks (twin critics for stability). Only the key interfaces are shown:
import math

class SAC:
    class ModelAction(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.fc_state = torch.nn.Sequential(
                torch.nn.Linear(3, 128), torch.nn.ReLU()
            )
            self.fc_mu  = torch.nn.Linear(128, 1)
            self.fc_std = torch.nn.Sequential(
                torch.nn.Linear(128, 1), torch.nn.Softplus()
            )

        def forward(self, state):
            state  = self.fc_state(state)
            mu     = self.fc_mu(state)
            std    = self.fc_std(state)
            dist   = torch.distributions.Normal(mu, std)
            sample = dist.rsample()
            action = torch.tanh(sample)
            log_prob = dist.log_prob(sample)
            entropy  = -(log_prob - (1 - action.tanh()**2 + 1e-7).log())
            return action * 2, entropy

    class ModelValue(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.sequential = torch.nn.Sequential(
                torch.nn.Linear(4, 128), torch.nn.ReLU(),
                torch.nn.Linear(128, 128), torch.nn.ReLU(),
                torch.nn.Linear(128, 1),
            )

        def forward(self, state, action):
            return self.sequential(torch.cat([state, action], dim=1))

    def get_action(self, state):
        state = torch.FloatTensor(state).reshape(1, 3)
        action, _ = self.model_action(state)
        return action.item()

sac = SAC()

Ensemble Dynamics Model

The model is an ensemble of 5 networks that share parameters through a batched FCLayer. It predicts (mean, logvar) of a Gaussian over (reward, Δstate):
class Model(torch.nn.Module):

    class Swish(torch.nn.Module):
        def forward(self, x):
            return x * torch.sigmoid(x)

    class FCLayer(torch.nn.Module):
        def __init__(self, in_size, out_size):
            super().__init__()
            std    = 1 / (in_size**0.5 * 2)
            weight = torch.empty(5, in_size, out_size)
            torch.nn.init.normal_(weight, mean=0.0, std=std)
            self.weight = torch.nn.Parameter(weight)
            self.bias   = torch.nn.Parameter(torch.zeros(5, 1, out_size))

        def forward(self, x):
            return torch.bmm(x, self.weight) + self.bias

    def __init__(self):
        super().__init__()
        self.sequential = torch.nn.Sequential(
            self.FCLayer(4, 200), self.Swish(),
            self.FCLayer(200, 200), self.Swish(),
            self.FCLayer(200, 200), self.Swish(),
            self.FCLayer(200, 200), self.Swish(),
            self.FCLayer(200, 8),
            torch.nn.Identity(),
        )
        self.softplus = torch.nn.Softplus()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self, x):
        x      = self.sequential(x)
        mean   = x[..., :4]
        logvar = x[..., 4:]
        logvar = 0.5 - self.softplus(0.5 - logvar)
        logvar = self.softplus(logvar + 10) - 10
        return mean, logvar

    def train(self):
        state, action, reward, next_state, _ = env_pool.get_sample()
        input = torch.cat([state, action], dim=1)
        label = torch.cat([reward, next_state - state], dim=1)

        for _ in range(len(input) // 64 * 20):
            select       = torch.stack([torch.randperm(len(input))[:64] for _ in range(5)])
            input_select = input[select]
            label_select = label[select]

            mean, logvar = model(input_select)
            mse_loss = ((mean - label_select)**2 * (-logvar).exp()).mean(dim=1).mean()
            var_loss = logvar.mean(dim=1).mean()
            loss = mse_loss + var_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

model = Model()

MBPO: Generating Synthetic Rollouts

The MBPO class samples real states from env_pool, asks the SAC policy for actions, queries the model to predict transitions, and stores them in model_pool:
class MBPO:
    def _fake_step(self, state, action):
        state  = torch.FloatTensor(state).reshape(-1, 3)
        action = torch.FloatTensor([action]).reshape(-1, 1)

        input = torch.cat([state, action], dim=1)
        input = input.unsqueeze(dim=0).repeat([5, 1, 1])

        with torch.no_grad():
            mean, std = model(input)
        std = std.exp().sqrt()

        # Add state delta to recover absolute next_state
        mean[:, :, 1:] += state

        sample = torch.distributions.Normal(0, 1).sample(mean.shape)
        sample = mean + sample * std

        # Pick one ensemble member randomly per sample
        select = [random.choice(range(5)) for _ in range(mean.shape[1])]
        sample = sample[select, range(mean.shape[1])]

        reward, next_state = sample[:, :1], sample[:, 1:]
        return reward, next_state

    def rollout(self):
        # Sample 1000 real states as rollout starting points
        states, _, _, _, _ = env_pool.get_sample(1000)
        for state in states:
            action = sac.get_action(state)
            reward, next_state = self._fake_step(state, action)
            model_pool.add(state, action, reward, next_state, False)
            state = next_state

mbpo = MBPO()

Full Training Loop

1

Seed the real pool

Run one episode with the untrained policy to populate env_pool with initial data:
def seed():
    state = env.reset()
    over  = False
    while not over:
        action = sac.get_action(state)
        next_state, reward, over, _ = env.step([action])
        env_pool.add(state, action, reward, next_state, over)
        state = next_state

seed()
2

Interleaved real + synthetic training

Every 50 real environment steps: retrain the model and generate synthetic rollouts. Then update SAC on a 50/50 mix of real and synthetic data:
for i in range(20):
    reward_sum = 0
    state = env.reset()
    over  = False
    step  = 0

    while not over:
        # Periodically retrain model and generate synthetic data
        if step % 50 == 0:
            model.train()
            mbpo.rollout()
        step += 1

        action = sac.get_action(state)
        next_state, reward, over, _ = env.step([action])
        reward_sum += reward
        env_pool.add(state, action, reward, next_state, over)
        state = next_state

        # Update SAC on mixed real + synthetic batch
        for _ in range(10):
            sample_env   = env_pool.get_sample(32)
            sample_model = model_pool.get_sample(32)
            sample = [
                torch.cat([i1, i2], dim=0)
                for i1, i2 in zip(sample_env, sample_model)
            ]
            sac.train(*sample)

    print(i, len(env_pool), len(model_pool), reward_sum)
The model_pool capacity (1000) is intentionally small. Flushing stale synthetic data prevents the policy from over-fitting to predictions made by an outdated model.

Why Short Rollouts?

Model errors compound exponentially over long rollouts. A 1% prediction error per step becomes a ~10% error after 10 steps and ~26% after 30 steps. MBPO typically uses rollouts of 1–5 steps, staying within the region where the model is reliably accurate.
ComponentPurpose
env_poolReal experience for model training and policy grounding
model_poolSynthetic experience for policy data augmentation
Ensemble of 5Captures epistemic uncertainty; random member selection during rollout
Short rollouts (1 step)Avoids compounding model error
Mixed batch (50/50)Balances real data fidelity with synthetic data quantity

Build docs developers (and LLMs) love