08/09 11:11
There are many open-source implementations of PPO and other reinforcement learning algorithms online, including stable baselines. You can use them, but I prefer implementing algorithms myself, since doing so dramatically increases customizability and understanding of the algorithm. I tried to implement PPO with help of spinning up code, but it was hard to make the algorithm work. There were so much beyond what original paper tells you to make the algorithm work. The devil was in the details.
I came across this amazing post that analyzes implementation details that made PPO work. With the help of this article, I managed to implement working PPO to train an agent that plays simple game.
In this article, I will explain how a minimal policy gradient algorithm can develop into a working PPO algorithm. Instead of addressing math deeply, I will focus on intuition and code.
Policy is a neural network which receives states from the environment and outputs action distribution(Categorical if action is discrete and usually Normal if action is continuous). The agent interacts with environment for some steps, using actions randomly sampled from the policy. After we've collected trajectories, we update policy network weights such that actions of higher return are selected more often.
For example, let's think of a stock trading agent. Current policy is as follows:
The agent interacts with the environment and collect samples:
Buying at state A has higher return. So we update policy network weights with gradient descent such that it has higher probability to buy at state A. After gradient descent, the policy now buys with, say 53% and sells with 47% at state A.
We need to know gradient of the objective to do gradient descent. Our objective is to maximize the return. By doing some math, we arrive at the following equation.
class PolicyNet(nn.Module): def __init__(self, state_dim, hidden_dim, action_dim): self.mlp = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, action_dim) ) def forward(self, state): logits = self.mlp(state) dist = Categorical(logits=logits) return dist p_net = PolicyNet() p_opt = Adam(p_net.parameters()) env = Env() for _ in range(n_global_epochs): # collect episodes with torch.no_grad(): states, actions, returns = [], [], [] for _ in range(n_episodes): state = env.reset() rewards = [] done = False while not done: dist = p_net(state) action = dist.sample().item() states.append(state) actions.append(action) state, reward, done = env.step(action) rewards.append(reward) stack(rewards) cur_returns = discounted_cumsum(rewards) returns.extend(cur_returns) stack(states, actions, returns) # update policy network weights logps = p_net(states).log_prob(actions) p_loss = -(logps * returns).mean() p_opt.zero_grad() p_loss.backward() p_opt.step()
Now, we'll upgrade this algorithm by applying tricks one by one.
Baseline is an expectation of return given state. Advantage = Return - Baseline.
Let's think of an agent that plays soccer. Let's say we collected trajectories like:
If we use return as the signal for policy gradient, the network sees "shooting at audience when kicking penalty" and "defender passing to our mid-fielder" has the same value. This is not what we want. Instead if we use advantage, the network understands that "shooting at audience when kicking penalty" is a terrible action.
Then, how do we get baseline? It's simple. We initialize neural network and train it to minimize MSE loss with the true return for each state.
Pseudo code from spinning up is:
Now we add this trick 1 to our previous implementation.
class ValueNet(nn.Module): def __init__(self, state_dim, hidden_dim): self.mlp = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) ) def forward(self, state): value = self.mlp(state).squeeze(-1) return value p_net = PolicyNet() p_opt = Adam(p_net.parameters()) v_net = ValueNet() v_opt = Adam(v_net.parameters()) env = Env() for _ in range(n_global_epochs): # collect episodes with torch.no_grad(): states, actions, returns, advantages = [], [], [], [] for _ in range(n_episodes): state = env.reset() rewards = [] values = [] done = False while not done: value = value_network(state).item() dist = p_net(state) action = dist.sample().item() states.append(state) values.append(value) actions.append(action) state, reward, done = env.step(action) rewards.append(reward) stack(rewards, values) cur_returns = discounted_cumsum(rewards) cur_advantages = cur_returns - values returns.extend(cur_returns) advantages.extend(cur_advantages) stack(states, actions, returns, advantages) # update policy network weights logps = p_net(states).log_prob(actions) p_loss = -(logps * advantages).mean() p_opt.zero_grad() p_loss.backward() p_opt.step() # update value network weights values = v_net(states) v_loss = ((values - returns)**2).mean() v_opt.zero_grad() v_loss.backward() v_opt.step()
In vanilla policy gradient method, we collect episode trajectories, do optimizer.step()
once, throw away the collected data, then repeat. It takes long time to collect episode trajectories, but we are not efficiently using the samples. PPO(Proximal Policy Optimization) makes it possible to do multiple gradient descent steps with collected data, thus increases sample efficiency. Main idea of PPO is "Clipped Loss". It clips the loss if the updated policy is deviated too far away from the original policy.
Pseudo code of PPO from spinning up is:
We add trick 2 to our previous implementation.
p_net = PolicyNet() p_opt = Adam(p_net.parameters()) v_net = ValueNet() v_opt = Adam(v_net.parameters()) env = Env() for _ in range(n_global_epochs): # collect episodes with torch.no_grad(): states, actions, returns, advantages, logps = [], [], [], [], [] for _ in range(n_episodes): state = env.reset() rewards = [] values = [] done = False while not done: value = value_network(state).item() dist = p_net(state) action = dist.sample() logp = dist.log_prob(action) action = action.item() states.append(state) values.append(value) actions.append(action) logps.append(logp) state, reward, done = env.step(action) rewards.append(reward) stack(rewards, values) cur_returns = discounted_cumsum(rewards) cur_advantages = cur_returns - values returns.extend(cur_returns) advantages.extend(cur_advantages) stack(states, actions, returns, advantages, logps) for _ in range(n_epochs): for b_states, b_actions, b_returns, b_advantages, b_logps \ in generate_batch(states, actions, returns, advantages, logps): # update policy network weights new_logps = p_net(b_states).log_prob(b_actions) ratio = torch.exp(new_logps - b_logps) p_loss1 = - b_advantages * ratio p_loss2 = - b_advantages * torch.clamp(ratio, 1-clip, 1+clip) p_loss = torch.max(p_loss1, p_loss2).mean() p_opt.zero_grad() p_loss.backward() p_opt.step() # update value network weights values = v_net(b_states) v_loss = ((b_values - b_returns)**2).mean() v_opt.zero_grad() v_loss.backward() v_opt.step()
We can express return with different horizons.
If we set the target of value to , we are only seeing the next step reward(short horizon). Value network trained with this target will have high-bias but low-variance. So value network will be easy to train, but it might deviate from true return. In contrast, if we consider long term rewards, our value has low-bias but high-variance.
We want to control this bias-variance trade-off by weight-averaging all the possible estimations. We control the weights via hyperparameter . If , we are estimating value with , and if , we are estimating value with .
We estimated return with TD-Lambda. We can apply the same with advantage, which is used as the signal for updating policy network.
Calculating it is simple: TD-Lambda Advantage = TD-Lambda Return - Value Estimation.
Here are some references: my implementation, GAE paper, GAE implementation.
Replacing advantage and return with TD-Lambda version, now our algorithm becomes:
... for _ in range(n_global_epochs): # collect episodes with torch.no_grad(): states, actions, returns, advantages, logps = [], [], [], [], [] for _ in range(n_episodes): ... cur_returns, cur_advantages = td_lambda(rewards, values) returns.extend(cur_returns) advantages.extend(cur_advantages) ...
In stable baselines implementations, layer weights are initialized orthogonally and biases are initialized to 0, which is not the default initialization method in pytorch. Also, policy head is initialized with std=0.01
and value head with std=1
.
def layer_init(layer, std=np.sqrt(2), bias_const=0.): torch.nn.init.orthogonal_(layer.weight, std) if layer.bias is not None: torch.nn.init.constant_(layer.bias, bias_const) return layer class PolicyNet(nn.Module): def __init__(self, state_dim, hidden_dim, action_dim): self.mlp = nn.Sequential( layer_init(nn.Linear(state_dim, hidden_dim)), nn.ReLU(), layer_init(nn.Linear(hidden_dim, action_dim), std=0.01) ) def forward(self, state): logits = self.mlp(state) dist = Categorical(logits=logits) return dist class ValueNet(nn.Module): def __init__(self, state_dim, hidden_dim): self.mlp = nn.Sequential( layer_init(nn.Linear(state_dim, hidden_dim)), nn.ReLU(), layer_init(nn.Linear(hidden_dim, 1), std=1.) ) def forward(self, state): value = self.mlp(state).squeeze(-1) return value
If advantage has too large scale, the scale of policy gradient becomes too large and network update becomes unstable. We can simply normalize advantage mini-batch-wise to fix its variance to 1 for stable training.
Like when training other deep neural networks, normalizing the inputs enhance training stability a lot. However, unlike supervised learning setting, the input of RL agent changes continuously so we cannot pre-compute mean and std. Instead, we update mean and std while the agent faces states. The speed of the update needs to be slow to make the model adapt smoothly. Simplest update method is moving average.
The scale of return becomes large when scale of the reward is large or gamma is big. Return is the target of the value network, but when the scale of the target is too large, the gradient scale becomes too large thus unstablizes training. Just like when doing state normalization, we can use methods like moving average to slowly update mean and std of the return, and use those statistics to normalize the value network's target.
After applying the normalizations, our code becomes:
class MeanStd: def __init__(self, dim, momentum=0.999): self.mean = np.zeros(dim) self.std = np.zeros(dim) self.momentum = momentum def update(x): # moving average version self.mean = self.momentum * self.mean + (1-self.momentum) * x.mean() self.std = self.momentum * self.std + (1-self.momentum) * x.std() def normalize(x): return (x - self.mean()) / self.std() def unnormalize(x): return x * self.std() + self.mean() ... state_mean_std = MeanStd() return_mean_std = MeanStd() for _ in range(n_global_epochs): # collect episodes ... while not done: # normalize state state = state_mean_std.normalize(state) value = value_network(state).item() dist = p_net(state) action = dist.sample() logp = dist.log_prob(action) action = action.item() states.append(state) # unnormalize value values.append(return_mean_std.unnormalize(value)) actions.append(action) logps.append(logp) state, reward, done = env.step(action) rewards.append(reward) ... for _ in range(n_epochs): for b_states, b_actions, b_returns, b_advantages, b_logps \ in generate_batch(states, actions, returns, advantages, logps): # normalize advantage b_advantages = (b_advantages-b_advantages.mean()) / (b_advantages.std()) ... # normalize return b_returns = return_mean_std.normalize(b_returns) ... # update state & return mean_std state_mean_std.update(states) return_mean_std.update(returns)
If the scale of the gradient is too large, weights of the policy network change too rapidly, resulting in unstable training. We can clip the gradient of each parameter such that the norm of concatenated gradients doesn't exceed 0.5.
... p_opt.zero_grad() p_loss.backward() nn.utils.clip_grad_norm_(p_net.parameters(), 0.5) p_opt.step() ...
Higher the entropy of the policy network output distribution, higher the probability to explore more diverse actions. We directly add entropy loss term to control exploration.
... # update policy network weights dist = p_net(b_states) newlogps = dist.log_prob(b_actions) entropy = dist.entropy() ratio = torch.exp(new_logps - b_logps) p_loss1 = - b_advantages * ratio p_loss2 = - b_advantages * torch.clamp(ratio, 1-clip, 1+clip) p_loss = torch.max(p_loss1, p_loss2).mean() - ent_coef * entropy.mean() ...
There can be actions that are invalid in certain states. For example, in the state of not possessing a stock, you can't take action of selling it. We can generate action mask using the state, then replace the logits corresponding to the masked action with -inf
. Then, the probability of taking that action becomes 0%, and backpropagation through masked action is blocked. To take one step further, you can mask action with heuristic to boost initial performance.
class PolicyNet(nn.Module): def __init__(self, state_dim, hidden_dim, action_dim): self.mlp = nn.Sequential( layer_init(nn.Linear(state_dim, hidden_dim)), nn.ReLU(), layer_init(nn.Linear(hidden_dim, action_dim), std=0.01) ) def forward(self, state, action_mask): logits = self.mlp(state) logits = torch.where(action_mask, logits, -np.inf) dist = Categorical(logits=logits) return dist
Note that this is pseudo code.
def layer_init(layer, std=np.sqrt(2), bias_const=0.): torch.nn.init.orthogonal_(layer.weight, std) if layer.bias is not None: torch.nn.init.constant_(layer.bias, bias_const) return layer class PolicyNet(nn.Module): def __init__(self, state_dim, hidden_dim, action_dim): self.mlp = nn.Sequential( layer_init(nn.Linear(state_dim, hidden_dim)), nn.ReLU(), layer_init(nn.Linear(hidden_dim, action_dim), std=0.01) ) def forward(self, state, action_mask): logits = self.mlp(state) logits = torch.where(action_mask, logits, -np.inf) dist = Categorical(logits=logits) return dist class ValueNet(nn.Module): def __init__(self, state_dim, hidden_dim): self.mlp = nn.Sequential( layer_init(nn.Linear(state_dim, hidden_dim)), nn.ReLU(), layer_init(nn.Linear(hidden_dim, 1), std=1.) ) def forward(self, state): value = self.mlp(state).squeeze(-1) return value class MeanStd: def __init__(self, dim, momentum=0.999): self.mean = np.zeros(dim) self.std = np.zeros(dim) self.momentum = momentum def update(x): self.mean = self.momentum * self.mean + x.mean() self.std = self.momentum * self.std + x.std() def normalize(x): return (x - self.mean()) / self.std() def unnormalize(x): return x * self.std() + self.mean() p_net = PolicyNet() p_opt = Adam(p_net.parameters()) v_net = ValueNet() v_opt = Adam(v_net.parameters()) state_mean_std = MeanStd() return_mean_std = MeanStd() env = Env() for _ in range(n_global_epochs): # collect episodes with torch.no_grad(): states, actions, returns, advantages, logps, masks = [], [], [], [], [], [] for _ in range(n_episodes): state = env.reset() rewards = [] values = [] done = False while not done: state = state_mean_std.normalize(state) value = value_network(state).item() dist = p_net(state) action = dist.sample() logp = dist.log_prob(action) action = action.item() states.append(state) masks.append(generate_mask(state)) values.append(return_mean_std.unnormalize(value)) actions.append(action) logps.append(logp) state, reward, done = env.step(action) rewards.append(reward) stack(rewards, values) cur_returns = discounted_cumsum(rewards) cur_advantages = cur_returns - values returns.extend(cur_returns) advantages.extend(cur_advantages) stack(states, actions, returns, advantages, logps, masks) for _ in range(n_epochs): for b_states, b_actions, b_returns, b_advantages, b_logps, b_masks \ in generate_batch(states, actions, returns, advantages, logps, masks): # normalize advantage b_advantages = (b_advantages-b_advantages.mean()) / (b_advantages.std()) # update policy network weights dist = p_net(b_states, b_masks) newlogps = dist.log_prob(b_actions) entropy = dist.entropy() ratio = torch.exp(new_logps - b_logps) p_loss1 = - b_advantages * ratio p_loss2 = - b_advantages * torch.clamp(ratio, 1-clip, 1+clip) p_loss = torch.max(p_loss1, p_loss2).mean() - ent_coef * entropy.mean() p_loss = torch.max(p_loss1, p_loss2).mean() p_opt.zero_grad() p_loss.backward() nn.utils.clip_grad_norm_(p_net.parameters(), 0.5) p_opt.step() # update value network weights b_returns = return_mean_std.normalize(b_returns) values = v_net(b_states) v_loss = ((b_values - b_returns)**2).mean() v_opt.zero_grad() v_loss.backward() v_opt.step() state_mean_std.update(states) return_mean_std.update(returns)
Stable baselines and OpenAI Gym implementations are different than above pseudo code mainly in two parts.
MeanStd
.You can refer here.
Finally for reference, this is my implementation to solve Kaggle Kore2022 competition, based on these tricks, and this is the blog post I referenced a lot when implementing PPO.