Documentation Index
Fetch the complete documentation index at: https://mintlify.com/lansinuote/Simple_Reinforcement_Learning/llms.txt
Use this file to discover all available pages before exploring further.
Model-Based Policy Optimization (MBPO) is a hybrid algorithm that combines the sample efficiency of model-based RL with the asymptotic performance of a model-free policy optimizer. The core idea is simple: learn a dynamics model from real environment interactions, use it to generate synthetic short rollouts, and add those synthetic transitions to the replay buffer used to train a SAC policy. This way, the policy receives many more gradient updates per unit of real environment interaction, dramatically improving sample efficiency.
The key insight that makes MBPO work is keeping synthetic rollouts short (typically 1–5 steps). Compounding model errors over long rollouts can destabilize training; short rollouts stay close to states where the model is accurate.
Key Concepts
MBPO maintains two replay buffers: env_pool (real transitions) and model_pool (synthetic transitions). The SAC policy is updated on a mix of both. The training loop alternates between: (1) collect real data → (2) train model → (3) generate synthetic data → (4) update SAC policy.
Environment
Pendulum-v1 is used as the benchmark:
import gym
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('Pendulum-v1', render_mode='rgb_array')
super().__init__(env)
self.env = env
self.step_n = 0
def reset(self):
state, _ = self.env.reset()
self.step_n = 0
return state
def step(self, action):
state, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
self.step_n += 1
if self.step_n >= 200:
done = True
return state, reward, done, info
env = MyWrapper()
Dual Replay Buffers
Two separate buffers hold real and synthetic data respectively:
import numpy as np
import random
import torch
class Pool:
def __init__(self, limit):
self.datas = []
self.limit = limit
def add(self, state, action, reward, next_state, over):
if isinstance(state, (np.ndarray, torch.Tensor)):
state = state.reshape(3).tolist()
action = float(action)
reward = float(reward)
if isinstance(next_state, (np.ndarray, torch.Tensor)):
next_state = next_state.reshape(3).tolist()
over = bool(over)
self.datas.append((state, action, reward, next_state, over))
while len(self.datas) > self.limit:
self.datas.pop(0)
def get_sample(self, size=None):
if size is None:
size = len(self)
size = min(size, len(self))
samples = random.sample(self.datas, size)
state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)
action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)
reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)
next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)
over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)
return state, action, reward, next_state, over
def __len__(self):
return len(self.datas)
# Real transitions: up to 10,000 steps
env_pool = Pool(10000)
# Synthetic transitions: up to 1,000 steps (kept small to stay fresh)
model_pool = Pool(1000)
SAC Policy
The policy uses a stochastic actor that outputs a squashed Gaussian and two critic networks (twin critics for stability). Only the key interfaces are shown:
import math
class SAC:
class ModelAction(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc_state = torch.nn.Sequential(
torch.nn.Linear(3, 128), torch.nn.ReLU()
)
self.fc_mu = torch.nn.Linear(128, 1)
self.fc_std = torch.nn.Sequential(
torch.nn.Linear(128, 1), torch.nn.Softplus()
)
def forward(self, state):
state = self.fc_state(state)
mu = self.fc_mu(state)
std = self.fc_std(state)
dist = torch.distributions.Normal(mu, std)
sample = dist.rsample()
action = torch.tanh(sample)
log_prob = dist.log_prob(sample)
entropy = -(log_prob - (1 - action.tanh()**2 + 1e-7).log())
return action * 2, entropy
class ModelValue(torch.nn.Module):
def __init__(self):
super().__init__()
self.sequential = torch.nn.Sequential(
torch.nn.Linear(4, 128), torch.nn.ReLU(),
torch.nn.Linear(128, 128), torch.nn.ReLU(),
torch.nn.Linear(128, 1),
)
def forward(self, state, action):
return self.sequential(torch.cat([state, action], dim=1))
def get_action(self, state):
state = torch.FloatTensor(state).reshape(1, 3)
action, _ = self.model_action(state)
return action.item()
sac = SAC()
Ensemble Dynamics Model
The model is an ensemble of 5 networks that share parameters through a batched FCLayer. It predicts (mean, logvar) of a Gaussian over (reward, Δstate):
class Model(torch.nn.Module):
class Swish(torch.nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class FCLayer(torch.nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
std = 1 / (in_size**0.5 * 2)
weight = torch.empty(5, in_size, out_size)
torch.nn.init.normal_(weight, mean=0.0, std=std)
self.weight = torch.nn.Parameter(weight)
self.bias = torch.nn.Parameter(torch.zeros(5, 1, out_size))
def forward(self, x):
return torch.bmm(x, self.weight) + self.bias
def __init__(self):
super().__init__()
self.sequential = torch.nn.Sequential(
self.FCLayer(4, 200), self.Swish(),
self.FCLayer(200, 200), self.Swish(),
self.FCLayer(200, 200), self.Swish(),
self.FCLayer(200, 200), self.Swish(),
self.FCLayer(200, 8),
torch.nn.Identity(),
)
self.softplus = torch.nn.Softplus()
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
def forward(self, x):
x = self.sequential(x)
mean = x[..., :4]
logvar = x[..., 4:]
logvar = 0.5 - self.softplus(0.5 - logvar)
logvar = self.softplus(logvar + 10) - 10
return mean, logvar
def train(self):
state, action, reward, next_state, _ = env_pool.get_sample()
input = torch.cat([state, action], dim=1)
label = torch.cat([reward, next_state - state], dim=1)
for _ in range(len(input) // 64 * 20):
select = torch.stack([torch.randperm(len(input))[:64] for _ in range(5)])
input_select = input[select]
label_select = label[select]
mean, logvar = model(input_select)
mse_loss = ((mean - label_select)**2 * (-logvar).exp()).mean(dim=1).mean()
var_loss = logvar.mean(dim=1).mean()
loss = mse_loss + var_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
model = Model()
MBPO: Generating Synthetic Rollouts
The MBPO class samples real states from env_pool, asks the SAC policy for actions, queries the model to predict transitions, and stores them in model_pool:
class MBPO:
def _fake_step(self, state, action):
state = torch.FloatTensor(state).reshape(-1, 3)
action = torch.FloatTensor([action]).reshape(-1, 1)
input = torch.cat([state, action], dim=1)
input = input.unsqueeze(dim=0).repeat([5, 1, 1])
with torch.no_grad():
mean, std = model(input)
std = std.exp().sqrt()
# Add state delta to recover absolute next_state
mean[:, :, 1:] += state
sample = torch.distributions.Normal(0, 1).sample(mean.shape)
sample = mean + sample * std
# Pick one ensemble member randomly per sample
select = [random.choice(range(5)) for _ in range(mean.shape[1])]
sample = sample[select, range(mean.shape[1])]
reward, next_state = sample[:, :1], sample[:, 1:]
return reward, next_state
def rollout(self):
# Sample 1000 real states as rollout starting points
states, _, _, _, _ = env_pool.get_sample(1000)
for state in states:
action = sac.get_action(state)
reward, next_state = self._fake_step(state, action)
model_pool.add(state, action, reward, next_state, False)
state = next_state
mbpo = MBPO()
Full Training Loop
Seed the real pool
Run one episode with the untrained policy to populate env_pool with initial data:def seed():
state = env.reset()
over = False
while not over:
action = sac.get_action(state)
next_state, reward, over, _ = env.step([action])
env_pool.add(state, action, reward, next_state, over)
state = next_state
seed()
Interleaved real + synthetic training
Every 50 real environment steps: retrain the model and generate synthetic rollouts. Then update SAC on a 50/50 mix of real and synthetic data:for i in range(20):
reward_sum = 0
state = env.reset()
over = False
step = 0
while not over:
# Periodically retrain model and generate synthetic data
if step % 50 == 0:
model.train()
mbpo.rollout()
step += 1
action = sac.get_action(state)
next_state, reward, over, _ = env.step([action])
reward_sum += reward
env_pool.add(state, action, reward, next_state, over)
state = next_state
# Update SAC on mixed real + synthetic batch
for _ in range(10):
sample_env = env_pool.get_sample(32)
sample_model = model_pool.get_sample(32)
sample = [
torch.cat([i1, i2], dim=0)
for i1, i2 in zip(sample_env, sample_model)
]
sac.train(*sample)
print(i, len(env_pool), len(model_pool), reward_sum)
The model_pool capacity (1000) is intentionally small. Flushing stale synthetic data prevents the policy from over-fitting to predictions made by an outdated model.
Why Short Rollouts?
Model errors compound exponentially over long rollouts. A 1% prediction error per step becomes a ~10% error after 10 steps and ~26% after 30 steps. MBPO typically uses rollouts of 1–5 steps, staying within the region where the model is reliably accurate.
| Component | Purpose |
|---|
env_pool | Real experience for model training and policy grounding |
model_pool | Synthetic experience for policy data augmentation |
| Ensemble of 5 | Captures epistemic uncertainty; random member selection during rollout |
| Short rollouts (1 step) | Avoids compounding model error |
| Mixed batch (50/50) | Balances real data fidelity with synthetic data quantity |