Documentation Index
Fetch the complete documentation index at: https://mintlify.com/lansinuote/Simple_Reinforcement_Learning/llms.txt
Use this file to discover all available pages before exploring further.
Imitation learning trains a policy by mimicking an expert’s behavior rather than learning through reward signals alone. Instead of exploring the environment and discovering a reward function, the agent observes a collection of (state, action) pairs produced by a trained expert and uses supervised learning to replicate those decisions. This approach can dramatically accelerate training when expert demonstrations are available, bypassing the slow trial-and-error process of conventional reinforcement learning.
The approach used here is Generative Adversarial Imitation Learning (GAIL): a discriminator network is trained to distinguish expert (state, action) pairs from student pairs, and the student policy is rewarded for fooling the discriminator. This is more robust than plain behavioral cloning because it lets the student interact with the environment and corrects for distribution shift online.
Key Concepts
Behavioral Cloning (BC) treats imitation as pure supervised learning — cross-entropy loss for discrete actions, MSE for continuous. Its key limitation is covariate shift: small prediction errors compound over time and push the agent into states never seen during training.GAIL addresses covariate shift by letting the agent interact with the environment. A discriminator labels expert data as 0 and student data as 1; the student is rewarded proportional to how well it fools the discriminator.
Environment Setup
The CartPole-v1 environment is wrapped to cap episodes at 200 steps:
import gym
class MyWrapper(gym.Wrapper):
def __init__(self):
env = gym.make('CartPole-v1', render_mode='rgb_array')
super().__init__(env)
self.env = env
self.step_n = 0
def reset(self):
state, _ = self.env.reset()
self.step_n = 0
return state
def step(self, action):
state, reward, terminated, truncated, info = self.env.step(action)
done = terminated or truncated
self.step_n += 1
if self.step_n >= 200:
done = True
return state, reward, done, info
env = MyWrapper()
Step 1 — Train the Expert (PPO)
A PPO agent serves as the expert. It is trained for 500 episodes until it achieves a perfect score of 200 per episode:
import torch
import random
class PPO:
def __init__(self):
self.model = torch.nn.Sequential(
torch.nn.Linear(4, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 2),
torch.nn.Softmax(dim=1),
)
self.model_td = torch.nn.Sequential(
torch.nn.Linear(4, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 1),
)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
self.optimizer_td = torch.optim.Adam(self.model_td.parameters(), lr=1e-2)
self.loss_fn = torch.nn.MSELoss()
def get_action(self, state):
state = torch.FloatTensor(state).reshape(1, 4)
prob = self.model(state)
action = random.choices(range(2), weights=prob[0].tolist(), k=1)[0]
return action
def get_data(self):
states, rewards, actions, next_states, overs = [], [], [], [], []
state = env.reset()
over = False
while not over:
action = self.get_action(state)
next_state, reward, over, _ = env.step(action)
states.append(state)
rewards.append(reward)
actions.append(action)
next_states.append(next_state)
overs.append(over)
state = next_state
states = torch.FloatTensor(states).reshape(-1, 4)
rewards = torch.FloatTensor(rewards).reshape(-1, 1)
actions = torch.LongTensor(actions).reshape(-1, 1)
next_states = torch.FloatTensor(next_states).reshape(-1, 4)
overs = torch.LongTensor(overs).reshape(-1, 1)
return states, rewards, actions, next_states, overs
teacher = PPO()
# Train the expert for 500 episodes
for i in range(500):
teacher.train(*teacher.get_data())
if i % 50 == 0:
test_result = sum([teacher.test(play=False) for _ in range(10)]) / 10
print(i, test_result)
Step 2 — Collect Expert Demonstrations
Once trained, collect one episode of expert data and discard the expert model:
# Collect expert demonstrations
teacher_states, _, teacher_actions, _, _ = teacher.get_data()
# Discard the expert — only the data is needed
del teacher
print(teacher_states.shape, teacher_actions.shape)
# torch.Size([200, 4]) torch.Size([200, 1])
Step 3 — Build the Discriminator
The discriminator takes (state, one-hot action) pairs and outputs the probability that the pair comes from the expert. Because CartPole has discrete actions, the action is one-hot encoded before concatenation:
class Discriminator(torch.nn.Module):
def __init__(self):
super().__init__()
self.sequential = torch.nn.Sequential(
torch.nn.Linear(6, 128), # 4 state dims + 2 one-hot action dims
torch.nn.ReLU(),
torch.nn.Linear(128, 1),
torch.nn.Sigmoid(),
)
def forward(self, states, actions):
one_hot = torch.nn.functional.one_hot(
actions.squeeze(dim=1), num_classes=2
)
cat = torch.cat([states, one_hot], dim=1)
return self.sequential(cat)
discriminator = Discriminator()
Step 4 — GAIL Training Loop
The student is trained in a GAN-style loop. The discriminator learns to distinguish expert from student trajectories; the student is rewarded for trajectories the discriminator classifies as expert-like:
student = PPO()
def copy_learn():
optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-3)
bce_loss = torch.nn.BCELoss()
for i in range(500):
# Student collects a trajectory
states, _, actions, next_states, overs = student.get_data()
# Discriminator distinguishes expert (label=0) from student (label=1)
prob_teacher = discriminator(teacher_states, teacher_actions)
prob_student = discriminator(states, actions)
loss_teacher = bce_loss(prob_teacher, torch.zeros_like(prob_teacher))
loss_student = bce_loss(prob_student, torch.ones_like(prob_student))
loss = loss_teacher + loss_student
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Use -log(prob_student) as reward: student is rewarded for fooling discriminator
rewards = -prob_student.log().detach()
# Update the student policy with PPO
student.train(states, rewards, actions, next_states, overs)
if i % 50 == 0:
test_result = sum([student.test(play=False) for _ in range(10)]) / 10
print(i, test_result)
copy_learn()
The key insight of GAIL: the discriminator labels expert data as 0 and student data as 1. The student’s reward is -log(prob_student) — it is high when the discriminator classifies student data as coming from the expert (probability near 0).
Limitations
Covariate shift is the fundamental challenge of behavioral cloning. When the cloned policy makes even a small error, it enters a state distribution not seen during training, which can lead to cascading failures. GAIL mitigates this by letting the student explore and by training the discriminator online, but it requires the environment to be available at training time.
| Approach | Pros | Cons |
|---|
| Behavioral Cloning | Simple, no environment needed at train time | Covariate shift; degrades far from training distribution |
| GAIL | Online correction, more robust | Needs environment interaction; adversarial training instability |