Details and Tricks to Implement a Working PPO

08/09 11:11

Profile picture of Yoonsoo Kim

Yoonsoo Kim

AI Research Engineer / Kaggle Grandmaster / Creator Of This Website

There are many open-source implementations of PPO and other reinforcement learning algorithms online, including stable baselines. You can use them, but I prefer implementing algorithms myself, since doing so dramatically increases customizability and understanding of the algorithm. I tried to implement PPO with help of spinning up code, but it was hard to make the algorithm work. There were so much beyond what original paper tells you to make the algorithm work. The devil was in the details.

I came across this amazing post that analyzes implementation details that made PPO work. With the help of this article, I managed to implement working PPO to train an agent that plays simple game.

In this article, I will explain how a minimal policy gradient algorithm can develop into a working PPO algorithm. Instead of addressing math deeply, I will focus on intuition and code.

How Does Policy Gradient Method Work?

Policy is a neural network which receives states from the environment and outputs action distribution(Categorical if action is discrete and usually Normal if action is continuous). The agent interacts with environment for some steps, using actions randomly sampled from the policy. After we've collected trajectories, we update policy network weights such that actions of higher return are selected more often.

For example, let's think of a stock trading agent. Current policy is as follows:

  • State A -> 50% buy, 50% sell
  • State B -> hold

The agent interacts with the environment and collect samples:

  • Sample1: State A -> buy, received return 1
  • Sample2: State A -> buy, received return 2
  • Sample3: State A -> sell, received return -1
  • Sample4: State A -> sell, received return -3

Buying at state A has higher return. So we update policy network weights with gradient descent such that it has higher probability to buy at state A. After gradient descent, the policy now buys with, say 53% and sells with 47% at state A.

We need to know gradient of the objective to do gradient descent. Our objective is to maximize the return. By doing some math, we arrive at the following equation.

g=θE[t=0Trt]=E[t=0TRtθlogπθ(atst)]g=\nabla_\theta \mathbb{E}\left[\sum_{t=0}^T r_t\right] = \mathbb{E}\left[\sum_{t=0}^TR_t\nabla_\theta\log\pi_\theta(a_t|s_t)\right]
  • rr: reward
  • RR: return
  • πθ\pi_\theta: policy with network parameters θ\theta
  • aa: action
  • ss: state
  • TT: total number of steps in an episode

Simplest Policy Gradient Method Pseudo Implementation

class PolicyNet(nn.Module):
    
    def __init__(self, state_dim, hidden_dim, action_dim):
        self.mlp = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, state):
        logits = self.mlp(state)
        dist = Categorical(logits=logits)
        return dist
    
p_net = PolicyNet()
p_opt = Adam(p_net.parameters())
env = Env()
for _ in range(n_global_epochs):
    # collect episodes
    with torch.no_grad():
        states, actions, returns = [], [], []
        for _ in range(n_episodes):
            state = env.reset()
            rewards = []
            done = False
            while not done:
                dist = p_net(state)
                action = dist.sample().item()
                states.append(state)
                actions.append(action)
                state, reward, done = env.step(action)
                rewards.append(reward)
            stack(rewards)
            cur_returns = discounted_cumsum(rewards)
            returns.extend(cur_returns)
        stack(states, actions, returns)
    # update policy network weights
    logps = p_net(states).log_prob(actions)
    p_loss = -(logps * returns).mean()
    p_opt.zero_grad()
    p_loss.backward()
    p_opt.step()
    

Now, we'll upgrade this algorithm by applying tricks one by one.

Trick 1: Replace Policy Gradient Signal From Return To Advantage

Baseline is an expectation of return given state. Advantage = Return - Baseline.

Let's think of an agent that plays soccer. Let's say we collected trajectories like:

  • state: offender kicking penalty (baseline: 80)
    • action: shoot to the corner -> return: 100, advantage: 20
    • action: shoot at the audience -> return: 0, advantage: -80
  • state: defender passing the ball (baseline 0)
    • action: pass to mid-fielder -> return: 0, advantage: 0
    • action: pass to opponent -> return: -30, advantage: -30

If we use return as the signal for policy gradient, the network sees "shooting at audience when kicking penalty" and "defender passing to our mid-fielder" has the same value. This is not what we want. Instead if we use advantage, the network understands that "shooting at audience when kicking penalty" is a terrible action.

Then, how do we get baseline? It's simple. We initialize neural network and train it to minimize MSE loss with the true return for each state.

Pseudo code from spinning up is:

vanilla policy gradient pseudo code

Now we add this trick 1 to our previous implementation.

class ValueNet(nn.Module):
    
    def __init__(self, state_dim, hidden_dim):
        self.mlp = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state):
        value = self.mlp(state).squeeze(-1)
        return value
    
p_net = PolicyNet()
p_opt = Adam(p_net.parameters())
v_net = ValueNet()
v_opt = Adam(v_net.parameters())
env = Env()
for _ in range(n_global_epochs):
    # collect episodes
    with torch.no_grad():
        states, actions, returns, advantages = [], [], [], []
        for _ in range(n_episodes):
            state = env.reset()
            rewards = []
            values = []
            done = False
            while not done:
                value = value_network(state).item()
                dist = p_net(state)
                action = dist.sample().item()
                states.append(state)
                values.append(value)
                actions.append(action)
                state, reward, done = env.step(action)
                rewards.append(reward)
            stack(rewards, values)
            cur_returns = discounted_cumsum(rewards)
            cur_advantages = cur_returns - values
            returns.extend(cur_returns)
            advantages.extend(cur_advantages)
        stack(states, actions, returns, advantages)
    # update policy network weights
    logps = p_net(states).log_prob(actions)
    p_loss = -(logps * advantages).mean()
    p_opt.zero_grad()
    p_loss.backward()
    p_opt.step()
    # update value network weights
    values = v_net(states)
    v_loss = ((values - returns)**2).mean()
    v_opt.zero_grad()
    v_loss.backward()
    v_opt.step()

Trick 2: Clipped Loss

In vanilla policy gradient method, we collect episode trajectories, do optimizer.step() once, throw away the collected data, then repeat. It takes long time to collect episode trajectories, but we are not efficiently using the samples. PPO(Proximal Policy Optimization) makes it possible to do multiple gradient descent steps with collected data, thus increases sample efficiency. Main idea of PPO is "Clipped Loss". It clips the loss if the updated policy is deviated too far away from the original policy.

Pseudo code of PPO from spinning up is:

PPO pseudo code

We add trick 2 to our previous implementation.

p_net = PolicyNet()
p_opt = Adam(p_net.parameters())
v_net = ValueNet()
v_opt = Adam(v_net.parameters())
env = Env()
for _ in range(n_global_epochs):
    # collect episodes
    with torch.no_grad():
        states, actions, returns, advantages, logps = [], [], [], [], []
        for _ in range(n_episodes):
            state = env.reset()
            rewards = []
            values = []
            done = False
            while not done:
                value = value_network(state).item()
                dist = p_net(state)                
                action = dist.sample()
                logp = dist.log_prob(action)
                action = action.item()
                states.append(state)
                values.append(value)
                actions.append(action)
                logps.append(logp)
                state, reward, done = env.step(action)
                rewards.append(reward)
            stack(rewards, values)
            cur_returns = discounted_cumsum(rewards)
            cur_advantages = cur_returns - values
            returns.extend(cur_returns)
            advantages.extend(cur_advantages)
        stack(states, actions, returns, advantages, logps)
    for _ in range(n_epochs):
        for b_states, b_actions, b_returns, b_advantages, b_logps \
        in generate_batch(states, actions, returns, advantages, logps):
            # update policy network weights            
            new_logps = p_net(b_states).log_prob(b_actions)
            ratio = torch.exp(new_logps - b_logps)
            p_loss1 = - b_advantages * ratio
            p_loss2 = - b_advantages * torch.clamp(ratio, 1-clip, 1+clip)
            p_loss = torch.max(p_loss1, p_loss2).mean()
            p_opt.zero_grad()
            p_loss.backward()
            p_opt.step()
            # update value network weights
            values = v_net(b_states)
            v_loss = ((b_values - b_returns)**2).mean()
            v_opt.zero_grad()
            v_loss.backward()
            v_opt.step()

Trick 3: TD-Lambda Return

We can express return with different horizons.

  • Vtrt+Vt+1V_t\approx r_t + V_{t+1}
  • Vtrt+rt+1+Vt+2V_t \approx r_t+r_{t+1}+V_{t+2}
  • ...
  • Vtrt+rt+1+=RtV_t \approx r_t + r_{t+1} + \dots = R_t

If we set the target of value to rt+Vt+1r_t+V_{t+1}, we are only seeing the next step reward(short horizon). Value network trained with this target will have high-bias but low-variance. So value network will be easy to train, but it might deviate from true return. In contrast, if we consider long term rewards, our value has low-bias but high-variance.

We want to control this bias-variance trade-off by weight-averaging all the possible estimations. We control the weights via hyperparameter λ\lambda. If λ=0\lambda=0, we are estimating value with rt+Vt+1r_t+V_{t+1}, and if λ=1\lambda=1, we are estimating value with RtR_t.

Trick 4: TD-Lambda Advantage

We estimated return with TD-Lambda. We can apply the same with advantage, which is used as the signal for updating policy network.

Calculating it is simple: TD-Lambda Advantage = TD-Lambda Return - Value Estimation.

Here are some references: my implementation, GAE paper, GAE implementation.

Replacing advantage and return with TD-Lambda version, now our algorithm becomes:

...
for _ in range(n_global_epochs):
    # collect episodes
    with torch.no_grad():
        states, actions, returns, advantages, logps = [], [], [], [], []
        for _ in range(n_episodes):
			...
            cur_returns, cur_advantages = td_lambda(rewards, values)
            returns.extend(cur_returns)
            advantages.extend(cur_advantages)
...

Trick 5: Layer Initialization

In stable baselines implementations, layer weights are initialized orthogonally and biases are initialized to 0, which is not the default initialization method in pytorch. Also, policy head is initialized with std=0.01 and value head with std=1.

def layer_init(layer, std=np.sqrt(2), bias_const=0.):
    torch.nn.init.orthogonal_(layer.weight, std)
    if layer.bias is not None:
        torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class PolicyNet(nn.Module):
    
    def __init__(self, state_dim, hidden_dim, action_dim):
        self.mlp = nn.Sequential(
            layer_init(nn.Linear(state_dim, hidden_dim)),
            nn.ReLU(),
            layer_init(nn.Linear(hidden_dim, action_dim), std=0.01)
        )
    
    def forward(self, state):
        logits = self.mlp(state)
        dist = Categorical(logits=logits)
        return dist
    

class ValueNet(nn.Module):
    
    def __init__(self, state_dim, hidden_dim):
        self.mlp = nn.Sequential(
            layer_init(nn.Linear(state_dim, hidden_dim)),
            nn.ReLU(),
            layer_init(nn.Linear(hidden_dim, 1), std=1.)
        )
    
    def forward(self, state):
        value = self.mlp(state).squeeze(-1)
        return value

Trick 6: Advantage Normalization

If advantage has too large scale, the scale of policy gradient becomes too large and network update becomes unstable. We can simply normalize advantage mini-batch-wise to fix its variance to 1 for stable training.

Trick 7: State Normalization

Like when training other deep neural networks, normalizing the inputs enhance training stability a lot. However, unlike supervised learning setting, the input of RL agent changes continuously so we cannot pre-compute mean and std. Instead, we update mean and std while the agent faces states. The speed of the update needs to be slow to make the model adapt smoothly. Simplest update method is moving average.

Trick 8: Return Normalization

The scale of return becomes large when scale of the reward is large or gamma is big. Return is the target of the value network, but when the scale of the target is too large, the gradient scale becomes too large thus unstablizes training. Just like when doing state normalization, we can use methods like moving average to slowly update mean and std of the return, and use those statistics to normalize the value network's target.

After applying the normalizations, our code becomes:

class MeanStd:
    
    def __init__(self, dim, momentum=0.999):
        self.mean = np.zeros(dim)
        self.std = np.zeros(dim)
        self.momentum = momentum
    
    def update(x):  # moving average version
        self.mean = self.momentum * self.mean + (1-self.momentum) * x.mean()
        self.std = self.momentum * self.std + (1-self.momentum) * x.std()
    
    def normalize(x):
        return (x - self.mean()) / self.std()
    
    def unnormalize(x):
        return x * self.std() + self.mean()
    
...
state_mean_std = MeanStd()
return_mean_std = MeanStd()
for _ in range(n_global_epochs):
    # collect episodes
    ...
		while not done:
            # normalize state
            state = state_mean_std.normalize(state)
            value = value_network(state).item()
            dist = p_net(state)                
            action = dist.sample()
            logp = dist.log_prob(action)
            action = action.item()
            states.append(state)
            # unnormalize value
            values.append(return_mean_std.unnormalize(value))
            actions.append(action)
            logps.append(logp)
            state, reward, done = env.step(action)
            rewards.append(reward)
    ...
    for _ in range(n_epochs):
        for b_states, b_actions, b_returns, b_advantages, b_logps \
        in generate_batch(states, actions, returns, advantages, logps):
            # normalize advantage
            b_advantages = (b_advantages-b_advantages.mean()) / (b_advantages.std())
			...
            # normalize return
            b_returns = return_mean_std.normalize(b_returns)
			...
    # update state & return mean_std
    state_mean_std.update(states)
    return_mean_std.update(returns)

Trick 9: Gradient Clipping

If the scale of the gradient is too large, weights of the policy network change too rapidly, resulting in unstable training. We can clip the gradient of each parameter such that the norm of concatenated gradients doesn't exceed 0.5.

...
p_opt.zero_grad()
p_loss.backward()
nn.utils.clip_grad_norm_(p_net.parameters(), 0.5)
p_opt.step()
...

Trick 10: Entropy Bonus

Higher the entropy of the policy network output distribution, higher the probability to explore more diverse actions. We directly add entropy loss term to control exploration.

...
# update policy network weights            
dist = p_net(b_states)
newlogps = dist.log_prob(b_actions)
entropy = dist.entropy()
ratio = torch.exp(new_logps - b_logps)
p_loss1 = - b_advantages * ratio
p_loss2 = - b_advantages * torch.clamp(ratio, 1-clip, 1+clip)
p_loss = torch.max(p_loss1, p_loss2).mean() - ent_coef * entropy.mean()
...

Trick 11: Action Masking

There can be actions that are invalid in certain states. For example, in the state of not possessing a stock, you can't take action of selling it. We can generate action mask using the state, then replace the logits corresponding to the masked action with -inf. Then, the probability of taking that action becomes 0%, and backpropagation through masked action is blocked. To take one step further, you can mask action with heuristic to boost initial performance.

class PolicyNet(nn.Module):
    
    def __init__(self, state_dim, hidden_dim, action_dim):
        self.mlp = nn.Sequential(
            layer_init(nn.Linear(state_dim, hidden_dim)),
            nn.ReLU(),
            layer_init(nn.Linear(hidden_dim, action_dim), std=0.01)
        )
    
    def forward(self, state, action_mask):
        logits = self.mlp(state)
        logits = torch.where(action_mask, logits, -np.inf)
        dist = Categorical(logits=logits)
        return dist

Final Version

Note that this is pseudo code.

def layer_init(layer, std=np.sqrt(2), bias_const=0.):
    torch.nn.init.orthogonal_(layer.weight, std)
    if layer.bias is not None:
        torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class PolicyNet(nn.Module):
    
    def __init__(self, state_dim, hidden_dim, action_dim):
        self.mlp = nn.Sequential(
            layer_init(nn.Linear(state_dim, hidden_dim)),
            nn.ReLU(),
            layer_init(nn.Linear(hidden_dim, action_dim), std=0.01)
        )
    
    def forward(self, state, action_mask):
        logits = self.mlp(state)
        logits = torch.where(action_mask, logits, -np.inf)
        dist = Categorical(logits=logits)
        return dist
    
class ValueNet(nn.Module):
    
    def __init__(self, state_dim, hidden_dim):
        self.mlp = nn.Sequential(
            layer_init(nn.Linear(state_dim, hidden_dim)),
            nn.ReLU(),
            layer_init(nn.Linear(hidden_dim, 1), std=1.)
        )
    
    def forward(self, state):
        value = self.mlp(state).squeeze(-1)
        return value

class MeanStd:
    
    def __init__(self, dim, momentum=0.999):
        self.mean = np.zeros(dim)
        self.std = np.zeros(dim)
        self.momentum = momentum
    
    def update(x):
        self.mean = self.momentum * self.mean + x.mean()
        self.std = self.momentum * self.std + x.std()
    
    def normalize(x):
        return (x - self.mean()) / self.std()
    
    def unnormalize(x):
        return x * self.std() + self.mean()
    
p_net = PolicyNet()
p_opt = Adam(p_net.parameters())
v_net = ValueNet()
v_opt = Adam(v_net.parameters())
state_mean_std = MeanStd()
return_mean_std = MeanStd()
env = Env()
for _ in range(n_global_epochs):
    # collect episodes
    with torch.no_grad():
        states, actions, returns, advantages, logps, masks = [], [], [], [], [], []
        for _ in range(n_episodes):
            state = env.reset()
            rewards = []
            values = []
            done = False
            while not done:
                state = state_mean_std.normalize(state)
                value = value_network(state).item()
                dist = p_net(state)
                action = dist.sample()
                logp = dist.log_prob(action)
                action = action.item()
                states.append(state)
                masks.append(generate_mask(state))
                values.append(return_mean_std.unnormalize(value))
                actions.append(action)
                logps.append(logp)
                state, reward, done = env.step(action)
                rewards.append(reward)
            stack(rewards, values)
            cur_returns = discounted_cumsum(rewards)
            cur_advantages = cur_returns - values
            returns.extend(cur_returns)
            advantages.extend(cur_advantages)
        stack(states, actions, returns, advantages, logps, masks)
    for _ in range(n_epochs):
        for b_states, b_actions, b_returns, b_advantages, b_logps, b_masks \
        in generate_batch(states, actions, returns, advantages, logps, masks):
            # normalize advantage
            b_advantages = (b_advantages-b_advantages.mean()) / (b_advantages.std())
            # update policy network weights            
            dist = p_net(b_states, b_masks)
            newlogps = dist.log_prob(b_actions)
            entropy = dist.entropy()
            ratio = torch.exp(new_logps - b_logps)
            p_loss1 = - b_advantages * ratio
            p_loss2 = - b_advantages * torch.clamp(ratio, 1-clip, 1+clip)
            p_loss = torch.max(p_loss1, p_loss2).mean() - ent_coef * entropy.mean()
            p_loss = torch.max(p_loss1, p_loss2).mean()
            p_opt.zero_grad()
            p_loss.backward()
            nn.utils.clip_grad_norm_(p_net.parameters(), 0.5)
            p_opt.step()
            # update value network weights
            b_returns = return_mean_std.normalize(b_returns)
            values = v_net(b_states)
            v_loss = ((b_values - b_returns)**2).mean()
            v_opt.zero_grad()
            v_loss.backward()
            v_opt.step()
    state_mean_std.update(states)
    return_mean_std.update(returns)

Discussions

Stable baselines and OpenAI Gym implementations are different than above pseudo code mainly in two parts.

  1. It uses different update method for MeanStd.
  2. It doesn't directly normalize return. It normalizes reward.

You can refer here.

Finally for reference, this is my implementation to solve Kaggle Kore2022 competition, based on these tricks, and this is the blog post I referenced a lot when implementing PPO.

© 2024 Yoonsoo.