Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/lansinuote/Simple_Reinforcement_Learning/llms.txt

Use this file to discover all available pages before exploring further.

Offline reinforcement learning (also called batch RL) trains a policy entirely from a pre-collected dataset of transitions (state, action, reward, next_state, done) without any further interaction with the environment. This is critical in domains such as healthcare, robotics, or autonomous driving, where collecting new data is expensive or dangerous. The central challenge is Q-value extrapolation error: a standard Q-learning algorithm may assign unrealistically high values to out-of-distribution (OOD) actions — actions that do not appear in the dataset — causing the policy to exploit these spurious values and perform poorly at deployment. This tutorial implements Conservative Q-Learning (CQL), which explicitly penalizes high Q-values on OOD actions to keep the learned values conservative and close to the data distribution.

Key Concepts

In standard online RL, the policy and replay buffer are updated together, so the distribution of visited states and actions stays relatively consistent. In offline RL, the policy is updated but the dataset is fixed, creating a growing gap between the states the policy wants to visit and the states actually in the dataset.

Environment and Dataset

The offline dataset is collected from a Pendulum-v1 environment. A teacher SAC agent is first trained online to produce expert-quality data:
import gym

class MyWrapper(gym.Wrapper):
    def __init__(self):
        env = gym.make('Pendulum-v1', render_mode='rgb_array')
        super().__init__(env)
        self.env = env
        self.step_n = 0

    def reset(self):
        state, _ = self.env.reset()
        self.step_n = 0
        return state

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        self.step_n += 1
        if self.step_n >= 200:
            done = True
        return state, reward, done, info

env = MyWrapper()

The Replay Buffer

The Data class wraps a fixed dataset. At training time, samples are drawn from it at random — the dataset itself never changes:
import random
import torch

class Data:
    def __init__(self):
        self.datas = []

    def update_data(self, agent):
        state = env.reset()
        over = False
        while not over:
            action = agent.get_action(state)
            next_state, reward, over, _ = env.step([action])
            self.datas.append((state, action, reward, next_state, over))
            state = next_state
        # Cap dataset size
        while len(self.datas) > 100000:
            self.datas.pop(0)

    def get_sample(self):
        samples = random.sample(self.datas, 64)
        state      = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)
        action     = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)
        reward     = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)
        next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)
        over       = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)
        return state, action, reward, next_state, over

Online Baseline — SAC Teacher

A soft actor-critic (SAC) teacher is trained online to fill the dataset. Its core networks are a stochastic policy and two Q-value heads:
import math

class SAC:
    class ModelAction(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.fc_state = torch.nn.Sequential(
                torch.nn.Linear(3, 128), torch.nn.ReLU()
            )
            self.fc_mu = torch.nn.Linear(128, 1)
            self.fc_std = torch.nn.Sequential(
                torch.nn.Linear(128, 1), torch.nn.Softplus()
            )

        def forward(self, state):
            state = self.fc_state(state)
            mu = self.fc_mu(state)
            std = self.fc_std(state)
            dist = torch.distributions.Normal(mu, std)
            sample = dist.rsample()
            action = torch.tanh(sample)
            log_prob = dist.log_prob(sample)
            entropy = log_prob - (1 - action.tanh()**2 + 1e-7).log()
            entropy = -entropy
            return action * 2, entropy

    class ModelValue(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.sequential = torch.nn.Sequential(
                torch.nn.Linear(4, 128), torch.nn.ReLU(),
                torch.nn.Linear(128, 128), torch.nn.ReLU(),
                torch.nn.Linear(128, 1),
            )

        def forward(self, state, action):
            state = torch.cat([state, action], dim=1)
            return self.sequential(state)
Train the teacher, collect data, then freeze the dataset:
teacher = SAC()

data = Data()

# Online training loop
for epoch in range(100):
    data.update_data(teacher)
    for i in range(200):
        teacher.train(*data.get_sample())

    if epoch % 10 == 0:
        test_result = sum([teacher.test(play=False) for _ in range(10)]) / 10
        print(epoch, test_result)

Conservative Q-Learning (CQL)

CQL extends SAC’s value loss with an additional penalty term that raises Q-values for in-distribution actions while suppressing them for random and policy-generated actions across a wider range of states. The student inherits all of SAC but overrides _get_loss_value:
class CQL(SAC):
    def __init__(self):
        super().__init__()

    def _get_loss_value(self, model_value, target, state, action, next_state):
        # Standard SAC value loss
        value = model_value(state, action)
        loss_value = self.loss_fn(value, target)

        # --- CQL penalty ---
        # Expand states and next_states to sample 5 actions each
        state      = state.unsqueeze(1).repeat(1, 5, 1).reshape(-1, 3)
        next_state = next_state.unsqueeze(1).repeat(1, 5, 1).reshape(-1, 3)

        rand_action  = torch.empty([len(state), 1]).uniform_(-1, 1)
        curr_action, curr_entropy = self.model_action(state)
        next_action, next_entropy = self.model_action(next_state)

        value_rand = model_value(state, rand_action).reshape(-1, 5, 1)
        value_curr = model_value(state, curr_action).reshape(-1, 5, 1)
        value_next = model_value(state, next_action).reshape(-1, 5, 1)

        curr_entropy = curr_entropy.detach().reshape(-1, 5, 1)
        next_entropy = next_entropy.detach().reshape(-1, 5, 1)

        value_rand -= math.log(0.5)
        value_curr -= curr_entropy
        value_next -= next_entropy

        value_cat = torch.cat([value_rand, value_curr, value_next], dim=1)
        loss_cat  = torch.logsumexp(value_cat, dim=1).mean()

        # Add weighted CQL penalty
        loss_value += 5.0 * (loss_cat - value.mean())
        return loss_value

Offline Training Loop

Once the dataset is frozen, train the student without any environment interaction:
student = CQL()

# Offline training — dataset never changes
for i in range(50000):
    student.train(*data.get_sample())

    if i % 2000 == 0:
        test_result = sum([student.test(play=False) for _ in range(10)]) / 10
        print(i, test_result)
Notice that data.update_data() is never called inside the offline training loop. The entire optimization happens on the fixed replay buffer collected by the teacher.

Why CQL Works

1

Standard SAC over-estimates OOD values

Without constraints, the critic assigns high Q-values to actions not seen in the dataset, tempting the policy into those regions.
2

CQL penalizes high OOD Q-values

For each batch, random actions, current-policy actions, and next-state actions are sampled. The logsumexp of their Q-values is penalized relative to the in-dataset Q-value.
3

The resulting policy stays near the data distribution

Because the Q-values of OOD actions are artificially reduced, the greedy policy never drifts far from the training distribution.

Distribution Shift Summary

Even with CQL, if the dataset is collected by a weak policy, the trained agent cannot exceed the data quality. Offline RL is most powerful when the dataset contains a diverse mixture of behaviors, including near-optimal demonstrations.
ChallengeOnline RLOffline RL
Environment interactionRequiredNot allowed
Distribution shiftHandled automaticallyMust be explicitly addressed
Q-value extrapolationBenign (explored states)Dangerous (unexplored states)
Key techniqueStandard TD updateConservative penalty (CQL)

Build docs developers (and LLMs) love