Documentation Index
Fetch the complete documentation index at: https://mintlify.com/lansinuote/Simple_Reinforcement_Learning/llms.txt
Use this file to discover all available pages before exploring further.
Model Predictive Control (MPC) is a planning-based approach to reinforcement learning. Rather than learning a policy directly, MPC maintains a learned dynamics model of the environment and, at every time step, uses that model to search for the best sequence of actions over a short planning horizon. Only the first action in the best sequence is executed; then the model is re-queried from the new state. This “receding horizon” strategy allows the agent to continuously correct for model error.
MPC is particularly appealing because it is sample efficient — a relatively small amount of real environment data is enough to train a useful dynamics model — and because planning can be done online, it can adapt to new information immediately.
Key Concepts
MPC separates two sub-problems: (1) model learning — fit a function f(s, a) → (r, s') from observed transitions, and (2) planning — given the current state, search for the action sequence that maximizes predicted cumulative reward over a horizon H. The Cross-Entropy Method (CEM) is used for planning.
Environment
The Pendulum-v1 environment is used throughout. The wrapper caps episodes at 200 steps:
import gym
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('Pendulum-v1', render_mode='rgb_array')
super().__init__(env)
self.env = env
self.step_n = 0
def reset(self):
state, _ = self.env.reset()
self.step_n = 0
return state
def step(self, action):
state, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
self.step_n += 1
if self.step_n >= 200:
done = True
return state, reward, done, info
env = MyWrapper()
Experience Pool
Transitions are stored in a Pool whose get_sample() method returns inputs of shape [b, 4] (state + action) and labels of shape [b, 4] (reward + state delta):
import numpy as np
import torch
class Pool:
def __init__(self, limit):
self.datas = []
self.limit = limit
def add(self, state, action, reward, next_state, over):
if isinstance(state, np.ndarray) or isinstance(state, torch.Tensor):
state = state.reshape(3).tolist()
action = float(action)
reward = float(reward)
if isinstance(next_state, np.ndarray) or isinstance(next_state, torch.Tensor):
next_state = next_state.reshape(3).tolist()
over = bool(over)
self.datas.append((state, action, reward, next_state, over))
while len(self.datas) > self.limit:
self.datas.pop(0)
def get_sample(self):
samples = self.datas
state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)
action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)
reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)
next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)
# Input: (state, action) concatenated
input = torch.cat([state, action], dim=1)
# Label: (reward, next_state - state) — predict delta, not absolute
label = torch.cat([reward, next_state - state], dim=1)
return input, label
def __len__(self):
return len(self.datas)
pool = Pool(100000)
Ensemble Dynamics Model
The dynamics model uses an ensemble of 5 networks that share weights through a batched FCLayer. Each network independently predicts a Gaussian output (mean, logvar). During planning, predictions are sampled from a randomly selected ensemble member to capture model uncertainty:
import random
class Model(torch.nn.Module):
class Swish(torch.nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class FCLayer(torch.nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
std = 1 / (in_size**0.5 * 2)
weight = torch.empty(5, in_size, out_size)
torch.nn.init.normal_(weight, mean=0.0, std=std)
self.weight = torch.nn.Parameter(weight)
self.bias = torch.nn.Parameter(torch.zeros(5, 1, out_size))
def forward(self, x):
# x: [5, b, in] -> [5, b, out]
x = torch.bmm(x, self.weight)
x = x + self.bias
return x
def __init__(self):
super().__init__()
self.sequential = torch.nn.Sequential(
self.FCLayer(4, 200), self.Swish(),
self.FCLayer(200, 200), self.Swish(),
self.FCLayer(200, 200), self.Swish(),
self.FCLayer(200, 200), self.Swish(),
self.FCLayer(200, 8),
torch.nn.Identity(),
)
self.softplus = torch.nn.Softplus()
self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
def forward(self, x):
# x: [5, b, 4] -> [5, b, 8]
x = self.sequential(x)
mean = x[..., :4]
logvar = x[..., 4:]
logvar = 0.5 - self.softplus(0.5 - logvar)
logvar = self.softplus(logvar + 10) - 10
return mean, logvar
def train(self, input, label):
for _ in range(len(input) // 64 * 20):
select = torch.stack([torch.randperm(len(input))[:64] for _ in range(5)])
input_select = input[select]
label_select = label[select]
mean, logvar = model(input_select)
mse_loss = ((mean - label_select)**2 * (-logvar).exp()).mean(dim=1).mean()
var_loss = logvar.mean(dim=1).mean()
loss = mse_loss + var_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
model = Model()
MPC with Cross-Entropy Method (CEM)
The MPC class wraps two routines:
_fake_step(state, action) — queries the model to predict (reward, next_state) for a batch of (state, action) pairs.
_cem_optimize(state, mean) — iteratively refines a distribution over 25-step action sequences using CEM: sample 50 candidates, evaluate their cumulative reward, keep the top 10, and update the distribution.
class MPC:
def _fake_step(self, state, action):
# state: [b, 3], action: [b, 1]
input = torch.cat([state, action], dim=1)
# Broadcast to all 5 ensemble members
input = input.unsqueeze(dim=0).repeat([5, 1, 1])
with torch.no_grad():
mean, std = model(input)
std = std.exp().sqrt()
# Add state delta to get absolute next_state
mean[:, :, 1:] += state
sample = torch.distributions.Normal(0, 1).sample(mean.shape)
sample = mean + sample * std
# Randomly select one ensemble member per sample
select = [random.choice(range(5)) for _ in range(mean.shape[1])]
sample = sample[select, range(mean.shape[1])]
reward, next_state = sample[:, :1], sample[:, 1:]
return reward, next_state
def _cem_optimize(self, state, mean):
state = torch.FloatTensor(state).reshape(1, 3)
var = torch.ones(25)
# Expand current state to batch of 50
state = state.repeat(50, 1)
for _ in range(5): # 5 CEM iterations
actions = torch.distributions.Normal(0, 1).sample([50, 25])
actions = actions * var**0.5 + mean
reward_sum = torch.zeros(50, 1)
for i in range(25):
action = actions[:, i].unsqueeze(dim=1)
reward, state = self._fake_step(state, action)
reward_sum += reward
# Keep top-10 action sequences
select = torch.sort(reward_sum.squeeze(dim=1)).indices
actions = actions[select][-10:]
new_mean = actions.mean(dim=0)
new_var = actions.var(dim=0)
mean = 0.1 * mean + 0.9 * new_mean
var = 0.1 * var + 0.9 * new_var
return mean
def mpc(self):
mean = torch.zeros(25)
reward_sum = 0
state = env.reset()
over = False
while not over:
# Plan 25 actions from current state
actions = self._cem_optimize(state, mean)
action = actions[0].item()
next_state, reward, over, _ = env.step([action])
pool.add(state, action, reward, next_state, over)
state = next_state
reward_sum += reward
# Shift the plan: reuse actions 1..24, zero-pad the last slot
mean = torch.empty(actions.shape)
mean[:-1] = actions[1:]
mean[-1] = 0
return reward_sum
mpc = MPC()
Training Loop
Alternate between planning (to gather real data) and model training (to improve predictions):
# Seed the pool with one random episode
def seed_pool():
state = env.reset()
over = False
while not over:
action = env.action_space.sample()[0]
next_state, reward, over, _ = env.step([action])
pool.add(state, action, reward, next_state, over)
state = next_state
seed_pool()
for i in range(10):
# 1. Train dynamics model on current pool
input, label = pool.get_sample()
model.train(input, label)
# 2. Collect real data with MPC planning
reward_sum = mpc.mpc()
print(i, len(pool), reward_sum)
The warm-starting trick (mean[:-1] = actions[1:]) reuses the previously optimized action plan shifted by one step. This reduces the CEM computation at each real-world timestep because the plan from the previous step is already close to optimal.
Advantages of MPC
Sample efficiency
A good dynamics model can be learned from far fewer environment interactions than a comparable model-free policy, because the model generalizes across states.
No policy network required
Planning is done online using the model, so there is no separate policy to train. The planning algorithm itself is the policy.
Adaptability
Because planning happens at each step, the agent automatically adapts to changes in the environment or corrections in the model without re-training.
MPC is computationally expensive at inference time because the CEM loop runs at every real-world step. For real-time control with tight latency constraints, consider caching plans or using amortized planning (MBPO).