Documentation Index
Fetch the complete documentation index at: https://mintlify.com/lansinuote/Simple_Reinforcement_Learning/llms.txt
Use this file to discover all available pages before exploring further.
Offline reinforcement learning (also called batch RL) trains a policy entirely from a pre-collected dataset of transitions (state, action, reward, next_state, done) without any further interaction with the environment. This is critical in domains such as healthcare, robotics, or autonomous driving, where collecting new data is expensive or dangerous.
The central challenge is Q-value extrapolation error: a standard Q-learning algorithm may assign unrealistically high values to out-of-distribution (OOD) actions — actions that do not appear in the dataset — causing the policy to exploit these spurious values and perform poorly at deployment. This tutorial implements Conservative Q-Learning (CQL), which explicitly penalizes high Q-values on OOD actions to keep the learned values conservative and close to the data distribution.
Key Concepts
In standard online RL, the policy and replay buffer are updated together, so the distribution of visited states and actions stays relatively consistent. In offline RL, the policy is updated but the dataset is fixed, creating a growing gap between the states the policy wants to visit and the states actually in the dataset.
Environment and Dataset
The offline dataset is collected from a Pendulum-v1 environment. A teacher SAC agent is first trained online to produce expert-quality data:
import gym
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('Pendulum-v1', render_mode='rgb_array')
super().__init__(env)
self.env = env
self.step_n = 0
def reset(self):
state, _ = self.env.reset()
self.step_n = 0
return state
def step(self, action):
state, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
self.step_n += 1
if self.step_n >= 200:
done = True
return state, reward, done, info
env = MyWrapper()
The Replay Buffer
The Data class wraps a fixed dataset. At training time, samples are drawn from it at random — the dataset itself never changes:
import random
import torch
class Data:
def __init__(self):
self.datas = []
def update_data(self, agent):
state = env.reset()
over = False
while not over:
action = agent.get_action(state)
next_state, reward, over, _ = env.step([action])
self.datas.append((state, action, reward, next_state, over))
state = next_state
# Cap dataset size
while len(self.datas) > 100000:
self.datas.pop(0)
def get_sample(self):
samples = random.sample(self.datas, 64)
state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)
action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)
reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)
next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)
over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)
return state, action, reward, next_state, over
Online Baseline — SAC Teacher
A soft actor-critic (SAC) teacher is trained online to fill the dataset. Its core networks are a stochastic policy and two Q-value heads:
import math
class SAC:
class ModelAction(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc_state = torch.nn.Sequential(
torch.nn.Linear(3, 128), torch.nn.ReLU()
)
self.fc_mu = torch.nn.Linear(128, 1)
self.fc_std = torch.nn.Sequential(
torch.nn.Linear(128, 1), torch.nn.Softplus()
)
def forward(self, state):
state = self.fc_state(state)
mu = self.fc_mu(state)
std = self.fc_std(state)
dist = torch.distributions.Normal(mu, std)
sample = dist.rsample()
action = torch.tanh(sample)
log_prob = dist.log_prob(sample)
entropy = log_prob - (1 - action.tanh()**2 + 1e-7).log()
entropy = -entropy
return action * 2, entropy
class ModelValue(torch.nn.Module):
def __init__(self):
super().__init__()
self.sequential = torch.nn.Sequential(
torch.nn.Linear(4, 128), torch.nn.ReLU(),
torch.nn.Linear(128, 128), torch.nn.ReLU(),
torch.nn.Linear(128, 1),
)
def forward(self, state, action):
state = torch.cat([state, action], dim=1)
return self.sequential(state)
Train the teacher, collect data, then freeze the dataset:
teacher = SAC()
data = Data()
# Online training loop
for epoch in range(100):
data.update_data(teacher)
for i in range(200):
teacher.train(*data.get_sample())
if epoch % 10 == 0:
test_result = sum([teacher.test(play=False) for _ in range(10)]) / 10
print(epoch, test_result)
Conservative Q-Learning (CQL)
CQL extends SAC’s value loss with an additional penalty term that raises Q-values for in-distribution actions while suppressing them for random and policy-generated actions across a wider range of states. The student inherits all of SAC but overrides _get_loss_value:
class CQL(SAC):
def __init__(self):
super().__init__()
def _get_loss_value(self, model_value, target, state, action, next_state):
# Standard SAC value loss
value = model_value(state, action)
loss_value = self.loss_fn(value, target)
# --- CQL penalty ---
# Expand states and next_states to sample 5 actions each
state = state.unsqueeze(1).repeat(1, 5, 1).reshape(-1, 3)
next_state = next_state.unsqueeze(1).repeat(1, 5, 1).reshape(-1, 3)
rand_action = torch.empty([len(state), 1]).uniform_(-1, 1)
curr_action, curr_entropy = self.model_action(state)
next_action, next_entropy = self.model_action(next_state)
value_rand = model_value(state, rand_action).reshape(-1, 5, 1)
value_curr = model_value(state, curr_action).reshape(-1, 5, 1)
value_next = model_value(state, next_action).reshape(-1, 5, 1)
curr_entropy = curr_entropy.detach().reshape(-1, 5, 1)
next_entropy = next_entropy.detach().reshape(-1, 5, 1)
value_rand -= math.log(0.5)
value_curr -= curr_entropy
value_next -= next_entropy
value_cat = torch.cat([value_rand, value_curr, value_next], dim=1)
loss_cat = torch.logsumexp(value_cat, dim=1).mean()
# Add weighted CQL penalty
loss_value += 5.0 * (loss_cat - value.mean())
return loss_value
Offline Training Loop
Once the dataset is frozen, train the student without any environment interaction:
student = CQL()
# Offline training — dataset never changes
for i in range(50000):
student.train(*data.get_sample())
if i % 2000 == 0:
test_result = sum([student.test(play=False) for _ in range(10)]) / 10
print(i, test_result)
Notice that data.update_data() is never called inside the offline training loop. The entire optimization happens on the fixed replay buffer collected by the teacher.
Why CQL Works
Standard SAC over-estimates OOD values
Without constraints, the critic assigns high Q-values to actions not seen in the dataset, tempting the policy into those regions.
CQL penalizes high OOD Q-values
For each batch, random actions, current-policy actions, and next-state actions are sampled. The logsumexp of their Q-values is penalized relative to the in-dataset Q-value.
The resulting policy stays near the data distribution
Because the Q-values of OOD actions are artificially reduced, the greedy policy never drifts far from the training distribution.
Distribution Shift Summary
Even with CQL, if the dataset is collected by a weak policy, the trained agent cannot exceed the data quality. Offline RL is most powerful when the dataset contains a diverse mixture of behaviors, including near-optimal demonstrations.
| Challenge | Online RL | Offline RL |
|---|
| Environment interaction | Required | Not allowed |
| Distribution shift | Handled automatically | Must be explicitly addressed |
| Q-value extrapolation | Benign (explored states) | Dangerous (unexplored states) |
| Key technique | Standard TD update | Conservative penalty (CQL) |