Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/lansinuote/Simple_Reinforcement_Learning/llms.txt

Use this file to discover all available pages before exploring further.

Imitation learning trains a policy by mimicking an expert’s behavior rather than learning through reward signals alone. Instead of exploring the environment and discovering a reward function, the agent observes a collection of (state, action) pairs produced by a trained expert and uses supervised learning to replicate those decisions. This approach can dramatically accelerate training when expert demonstrations are available, bypassing the slow trial-and-error process of conventional reinforcement learning. The approach used here is Generative Adversarial Imitation Learning (GAIL): a discriminator network is trained to distinguish expert (state, action) pairs from student pairs, and the student policy is rewarded for fooling the discriminator. This is more robust than plain behavioral cloning because it lets the student interact with the environment and corrects for distribution shift online.

Key Concepts

Behavioral Cloning (BC) treats imitation as pure supervised learning — cross-entropy loss for discrete actions, MSE for continuous. Its key limitation is covariate shift: small prediction errors compound over time and push the agent into states never seen during training.GAIL addresses covariate shift by letting the agent interact with the environment. A discriminator labels expert data as 0 and student data as 1; the student is rewarded proportional to how well it fools the discriminator.

Environment Setup

The CartPole-v1 environment is wrapped to cap episodes at 200 steps:
import gym

class MyWrapper(gym.Wrapper):
    def __init__(self):
        env = gym.make('CartPole-v1', render_mode='rgb_array')
        super().__init__(env)
        self.env = env
        self.step_n = 0

    def reset(self):
        state, _ = self.env.reset()
        self.step_n = 0
        return state

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        self.step_n += 1
        if self.step_n >= 200:
            done = True
        return state, reward, done, info

env = MyWrapper()

Step 1 — Train the Expert (PPO)

A PPO agent serves as the expert. It is trained for 500 episodes until it achieves a perfect score of 200 per episode:
import torch
import random

class PPO:
    def __init__(self):
        self.model = torch.nn.Sequential(
            torch.nn.Linear(4, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 2),
            torch.nn.Softmax(dim=1),
        )
        self.model_td = torch.nn.Sequential(
            torch.nn.Linear(4, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 1),
        )
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        self.optimizer_td = torch.optim.Adam(self.model_td.parameters(), lr=1e-2)
        self.loss_fn = torch.nn.MSELoss()

    def get_action(self, state):
        state = torch.FloatTensor(state).reshape(1, 4)
        prob = self.model(state)
        action = random.choices(range(2), weights=prob[0].tolist(), k=1)[0]
        return action

    def get_data(self):
        states, rewards, actions, next_states, overs = [], [], [], [], []
        state = env.reset()
        over = False
        while not over:
            action = self.get_action(state)
            next_state, reward, over, _ = env.step(action)
            states.append(state)
            rewards.append(reward)
            actions.append(action)
            next_states.append(next_state)
            overs.append(over)
            state = next_state

        states = torch.FloatTensor(states).reshape(-1, 4)
        rewards = torch.FloatTensor(rewards).reshape(-1, 1)
        actions = torch.LongTensor(actions).reshape(-1, 1)
        next_states = torch.FloatTensor(next_states).reshape(-1, 4)
        overs = torch.LongTensor(overs).reshape(-1, 1)
        return states, rewards, actions, next_states, overs

teacher = PPO()

# Train the expert for 500 episodes
for i in range(500):
    teacher.train(*teacher.get_data())
    if i % 50 == 0:
        test_result = sum([teacher.test(play=False) for _ in range(10)]) / 10
        print(i, test_result)

Step 2 — Collect Expert Demonstrations

Once trained, collect one episode of expert data and discard the expert model:
# Collect expert demonstrations
teacher_states, _, teacher_actions, _, _ = teacher.get_data()

# Discard the expert — only the data is needed
del teacher

print(teacher_states.shape, teacher_actions.shape)
# torch.Size([200, 4])  torch.Size([200, 1])

Step 3 — Build the Discriminator

The discriminator takes (state, one-hot action) pairs and outputs the probability that the pair comes from the expert. Because CartPole has discrete actions, the action is one-hot encoded before concatenation:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.sequential = torch.nn.Sequential(
            torch.nn.Linear(6, 128),   # 4 state dims + 2 one-hot action dims
            torch.nn.ReLU(),
            torch.nn.Linear(128, 1),
            torch.nn.Sigmoid(),
        )

    def forward(self, states, actions):
        one_hot = torch.nn.functional.one_hot(
            actions.squeeze(dim=1), num_classes=2
        )
        cat = torch.cat([states, one_hot], dim=1)
        return self.sequential(cat)

discriminator = Discriminator()

Step 4 — GAIL Training Loop

The student is trained in a GAN-style loop. The discriminator learns to distinguish expert from student trajectories; the student is rewarded for trajectories the discriminator classifies as expert-like:
student = PPO()

def copy_learn():
    optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-3)
    bce_loss = torch.nn.BCELoss()

    for i in range(500):
        # Student collects a trajectory
        states, _, actions, next_states, overs = student.get_data()

        # Discriminator distinguishes expert (label=0) from student (label=1)
        prob_teacher = discriminator(teacher_states, teacher_actions)
        prob_student = discriminator(states, actions)

        loss_teacher = bce_loss(prob_teacher, torch.zeros_like(prob_teacher))
        loss_student = bce_loss(prob_student, torch.ones_like(prob_student))
        loss = loss_teacher + loss_student

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Use -log(prob_student) as reward: student is rewarded for fooling discriminator
        rewards = -prob_student.log().detach()

        # Update the student policy with PPO
        student.train(states, rewards, actions, next_states, overs)

        if i % 50 == 0:
            test_result = sum([student.test(play=False) for _ in range(10)]) / 10
            print(i, test_result)

copy_learn()
The key insight of GAIL: the discriminator labels expert data as 0 and student data as 1. The student’s reward is -log(prob_student) — it is high when the discriminator classifies student data as coming from the expert (probability near 0).

Limitations

Covariate shift is the fundamental challenge of behavioral cloning. When the cloned policy makes even a small error, it enters a state distribution not seen during training, which can lead to cascading failures. GAIL mitigates this by letting the student explore and by training the discriminator online, but it requires the environment to be available at training time.
ApproachProsCons
Behavioral CloningSimple, no environment needed at train timeCovariate shift; degrades far from training distribution
GAILOnline correction, more robustNeeds environment interaction; adversarial training instability

Build docs developers (and LLMs) love