Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/lansinuote/Simple_Reinforcement_Learning/llms.txt

Use this file to discover all available pages before exploring further.

Model Predictive Control (MPC) is a planning-based approach to reinforcement learning. Rather than learning a policy directly, MPC maintains a learned dynamics model of the environment and, at every time step, uses that model to search for the best sequence of actions over a short planning horizon. Only the first action in the best sequence is executed; then the model is re-queried from the new state. This “receding horizon” strategy allows the agent to continuously correct for model error. MPC is particularly appealing because it is sample efficient — a relatively small amount of real environment data is enough to train a useful dynamics model — and because planning can be done online, it can adapt to new information immediately.

Key Concepts

MPC separates two sub-problems: (1) model learning — fit a function f(s, a) → (r, s') from observed transitions, and (2) planning — given the current state, search for the action sequence that maximizes predicted cumulative reward over a horizon H. The Cross-Entropy Method (CEM) is used for planning.

Environment

The Pendulum-v1 environment is used throughout. The wrapper caps episodes at 200 steps:
import gym

class MyWrapper(gym.Wrapper):
    def __init__(self):
        env = gym.make('Pendulum-v1', render_mode='rgb_array')
        super().__init__(env)
        self.env = env
        self.step_n = 0

    def reset(self):
        state, _ = self.env.reset()
        self.step_n = 0
        return state

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        self.step_n += 1
        if self.step_n >= 200:
            done = True
        return state, reward, done, info

env = MyWrapper()

Experience Pool

Transitions are stored in a Pool whose get_sample() method returns inputs of shape [b, 4] (state + action) and labels of shape [b, 4] (reward + state delta):
import numpy as np
import torch

class Pool:
    def __init__(self, limit):
        self.datas = []
        self.limit = limit

    def add(self, state, action, reward, next_state, over):
        if isinstance(state, np.ndarray) or isinstance(state, torch.Tensor):
            state = state.reshape(3).tolist()
        action = float(action)
        reward = float(reward)
        if isinstance(next_state, np.ndarray) or isinstance(next_state, torch.Tensor):
            next_state = next_state.reshape(3).tolist()
        over = bool(over)
        self.datas.append((state, action, reward, next_state, over))
        while len(self.datas) > self.limit:
            self.datas.pop(0)

    def get_sample(self):
        samples = self.datas
        state      = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)
        action     = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)
        reward     = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)
        next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)

        # Input: (state, action) concatenated
        input = torch.cat([state, action], dim=1)
        # Label: (reward, next_state - state) — predict delta, not absolute
        label = torch.cat([reward, next_state - state], dim=1)
        return input, label

    def __len__(self):
        return len(self.datas)

pool = Pool(100000)

Ensemble Dynamics Model

The dynamics model uses an ensemble of 5 networks that share weights through a batched FCLayer. Each network independently predicts a Gaussian output (mean, logvar). During planning, predictions are sampled from a randomly selected ensemble member to capture model uncertainty:
import random

class Model(torch.nn.Module):

    class Swish(torch.nn.Module):
        def forward(self, x):
            return x * torch.sigmoid(x)

    class FCLayer(torch.nn.Module):
        def __init__(self, in_size, out_size):
            super().__init__()
            std = 1 / (in_size**0.5 * 2)
            weight = torch.empty(5, in_size, out_size)
            torch.nn.init.normal_(weight, mean=0.0, std=std)
            self.weight = torch.nn.Parameter(weight)
            self.bias   = torch.nn.Parameter(torch.zeros(5, 1, out_size))

        def forward(self, x):
            # x: [5, b, in] -> [5, b, out]
            x = torch.bmm(x, self.weight)
            x = x + self.bias
            return x

    def __init__(self):
        super().__init__()
        self.sequential = torch.nn.Sequential(
            self.FCLayer(4, 200), self.Swish(),
            self.FCLayer(200, 200), self.Swish(),
            self.FCLayer(200, 200), self.Swish(),
            self.FCLayer(200, 200), self.Swish(),
            self.FCLayer(200, 8),
            torch.nn.Identity(),
        )
        self.softplus = torch.nn.Softplus()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self, x):
        # x: [5, b, 4] -> [5, b, 8]
        x = self.sequential(x)
        mean   = x[..., :4]
        logvar = x[..., 4:]
        logvar = 0.5 - self.softplus(0.5 - logvar)
        logvar = self.softplus(logvar + 10) - 10
        return mean, logvar

    def train(self, input, label):
        for _ in range(len(input) // 64 * 20):
            select = torch.stack([torch.randperm(len(input))[:64] for _ in range(5)])
            input_select = input[select]
            label_select = label[select]

            mean, logvar = model(input_select)
            mse_loss = ((mean - label_select)**2 * (-logvar).exp()).mean(dim=1).mean()
            var_loss = logvar.mean(dim=1).mean()
            loss = mse_loss + var_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

model = Model()

MPC with Cross-Entropy Method (CEM)

The MPC class wraps two routines:
  • _fake_step(state, action) — queries the model to predict (reward, next_state) for a batch of (state, action) pairs.
  • _cem_optimize(state, mean) — iteratively refines a distribution over 25-step action sequences using CEM: sample 50 candidates, evaluate their cumulative reward, keep the top 10, and update the distribution.
class MPC:
    def _fake_step(self, state, action):
        # state: [b, 3], action: [b, 1]
        input = torch.cat([state, action], dim=1)
        # Broadcast to all 5 ensemble members
        input = input.unsqueeze(dim=0).repeat([5, 1, 1])

        with torch.no_grad():
            mean, std = model(input)
        std = std.exp().sqrt()

        # Add state delta to get absolute next_state
        mean[:, :, 1:] += state

        sample = torch.distributions.Normal(0, 1).sample(mean.shape)
        sample = mean + sample * std

        # Randomly select one ensemble member per sample
        select = [random.choice(range(5)) for _ in range(mean.shape[1])]
        sample = sample[select, range(mean.shape[1])]

        reward, next_state = sample[:, :1], sample[:, 1:]
        return reward, next_state

    def _cem_optimize(self, state, mean):
        state = torch.FloatTensor(state).reshape(1, 3)
        var   = torch.ones(25)

        # Expand current state to batch of 50
        state = state.repeat(50, 1)

        for _ in range(5):   # 5 CEM iterations
            actions = torch.distributions.Normal(0, 1).sample([50, 25])
            actions = actions * var**0.5 + mean

            reward_sum = torch.zeros(50, 1)
            for i in range(25):
                action = actions[:, i].unsqueeze(dim=1)
                reward, state = self._fake_step(state, action)
                reward_sum += reward

            # Keep top-10 action sequences
            select  = torch.sort(reward_sum.squeeze(dim=1)).indices
            actions = actions[select][-10:]

            new_mean = actions.mean(dim=0)
            new_var  = actions.var(dim=0)

            mean = 0.1 * mean + 0.9 * new_mean
            var  = 0.1 * var  + 0.9 * new_var

        return mean

    def mpc(self):
        mean = torch.zeros(25)
        reward_sum = 0
        state = env.reset()
        over  = False

        while not over:
            # Plan 25 actions from current state
            actions = self._cem_optimize(state, mean)
            action  = actions[0].item()

            next_state, reward, over, _ = env.step([action])
            pool.add(state, action, reward, next_state, over)

            state       = next_state
            reward_sum += reward

            # Shift the plan: reuse actions 1..24, zero-pad the last slot
            mean        = torch.empty(actions.shape)
            mean[:-1]   = actions[1:]
            mean[-1]    = 0

        return reward_sum

mpc = MPC()

Training Loop

Alternate between planning (to gather real data) and model training (to improve predictions):
# Seed the pool with one random episode
def seed_pool():
    state = env.reset()
    over  = False
    while not over:
        action = env.action_space.sample()[0]
        next_state, reward, over, _ = env.step([action])
        pool.add(state, action, reward, next_state, over)
        state = next_state

seed_pool()

for i in range(10):
    # 1. Train dynamics model on current pool
    input, label = pool.get_sample()
    model.train(input, label)

    # 2. Collect real data with MPC planning
    reward_sum = mpc.mpc()
    print(i, len(pool), reward_sum)
The warm-starting trick (mean[:-1] = actions[1:]) reuses the previously optimized action plan shifted by one step. This reduces the CEM computation at each real-world timestep because the plan from the previous step is already close to optimal.

Advantages of MPC

1

Sample efficiency

A good dynamics model can be learned from far fewer environment interactions than a comparable model-free policy, because the model generalizes across states.
2

No policy network required

Planning is done online using the model, so there is no separate policy to train. The planning algorithm itself is the policy.
3

Adaptability

Because planning happens at each step, the agent automatically adapts to changes in the environment or corrections in the model without re-training.
MPC is computationally expensive at inference time because the CEM loop runs at every real-world step. For real-time control with tight latency constraints, consider caching plans or using amortized planning (MBPO).

Build docs developers (and LLMs) love