Skip to content

Commit

Permalink
PPO + GRPO (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert authored Jan 22, 2025
1 parent 4474202 commit 68c5a41
Showing 1 changed file with 157 additions and 5 deletions.
162 changes: 157 additions & 5 deletions chapters/11-policy-gradients.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ $$\nabla_\theta J(\pi_\theta) = \mathbb{E}_\tau \left[ \sum_{t=0}^T \nabla_\thet

TODO cite further reading



### Reinforce

The REINFORCE update is as follows:
Expand Down Expand Up @@ -131,12 +133,13 @@ $$J(\theta) = \frac{1}{G}\sum_{i=1}^G \left(\min\left(\frac{\pi_\theta(a_i|s)}{\

With the advantage computation for the completion index $i$:

$$A_i = \frac{r_i - \text{mean}({r_1, r_2, \cdots, r_G})}{\text{std}({r_1, r_2, \cdots, r_G})}. \quad (3)$$

$$A_i = \frac{r_i - \text{mean}({r_1, r_2, \cdots, r_G})}{\text{std}({r_1, r_2, \cdots, r_G})}. \quad (3)$$ {#eq:GRPO_ADV}

## Computing Policy Gradients with a Language Model
@eq:GRPO_ADV is the implementation of GRPO when working with outcome supervision (either a standard reward model or a single verifiable reward) and a different implementation is needed with process supervision.
In this case, GRPO computes the advantage as the sum of the normalized rewards for the following reasoning steps.
To do so, the rewards are accumulated with additional tracking of a reasoning index $j$, and then computed step wise as TODO, ref paper

## Implementation Tricks
## Implementation

- Only score a response with a reward model with the `eos_token` is generated, otherwise the response is truncated.

Expand All @@ -145,7 +148,156 @@ https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#

https://lilianweng.github.io/posts/2018-04-08-policy-gradient/

### KL Controllers
### Policy Gradient

A simple implementation of policy gradient, using advantages to estimate the gradient to prepare for advanced algorithms such as PPO and GRPO follows:
```
pg_loss = -advantages * ratio
```
Ratio here is the logratio of the new policy model probabilities relative to the reference model.

In order to understand this equation it is good to understand different cases that can fall within a batch of updates.
Remember that we want the loss to *decrease* as the model gets better at the task.

Case 1: Positive advantage, so the action was better than the expected value of the state. We want to reinforce this. In this case, the model will make this more likely with the negative sign. To do so it'll increase the logratio. A positive logratio, or sum of log probabilties of the tokens, means that the model is more likely to generate those tokens.

Case 2: Negative advantage, so the action was worse than the expected value of the state. This follows very similarly. Here, the loss will be positive if the new model was more likely, so the model will try to make it so the policy parameters make this completion less likely.

Case 3: Zero advantage, so no update is needed. The loss is zero, don't change the policy model.

### Proximal Policy Optimization

There are many, many implementations of PPO available.
The core *loss* computation is shown below.
Crucial to stable performance is also the *value* computation, where multiple options exist (including multiple options for the *value model* loss).



```
# B: Batch Size, L: Sequence Length, G: Num of Generations
# Apply KL penalty to rewards
rewards = rewards - self.beta * per_token_kl # Shape: (B*G, L)
# Get value predictions
values = value_net(completions) # Shape: (B*G, L)
# Compute simple advantages
advantages = rewards - values.detach() # Shape: (B*G, L)
# Normalize advantages (optional but stable)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
advantages = advantages.unsqueeze(1) # Shape: (B*G, 1)
# Compute probability ratio between new and old policies
ratio = torch.exp(new_per_token_logps - per_token_logps) # Shape: (B*G, L)
# PPO clipping objective
eps = self.cliprange # e.g. 0.2
pg_losses1 = -advantages * ratio # Shape: (B*G, L)
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - eps, 1.0 + eps) # Shape: (B*G, L)
pg_loss_max = torch.max(pg_losses1, pg_losses2) # Shape: (B*G, L)
# Simple value function loss
vf_loss = 0.5 * ((rewards - values) ** 2) # Shape: (B*G, L)
# Combine policy and value losses
per_token_loss = pg_loss_max + self.vf_coef * vf_loss # Shape: (B*G, L)
# Apply completion mask and compute final loss
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Scalar
# Compute metrics for logging
with torch.no_grad():
# Compute clipping fraction
clip_frac = ((pg_losses2 > pg_losses1).float() * completion_mask).sum() / completion_mask.sum()
# Compute approximate KL
approx_kl = 0.5 * ((new_per_token_logps - per_token_logps)**2).mean()
# Compute value loss for logging
value_loss = vf_loss.mean()
```

The core piece to understand with PPO is how the policy gradient loss is updated.
Focus on these three lines:
```
pg_losses1 = -advantages * ratio # Shape: (B*G, L)
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - eps, 1.0 + eps) # Shape: (B*G, L)
pg_loss_max = torch.max(pg_losses1, pg_losses2) # Shape: (B*G, L)
```
`pg_losses1` is the same as the vanilla advantage-based PR loss above, which is included in PPO, but the loss (and gradient update) can be clipped.
Though, PPO is controlling the update size to not be too big. Because losses can be negative, we must create a more conservative version of the vanilla policy gradient update rule.

We know that if we *do not* constrain the loss, the policy gradient algorithm will update the weights exactly to the new probability distribution.
Hence, by clamping the logratio's, PPO is limiting the distance that the update can move the policy parameters.

Finally, the max of two is taken as mentioned above, in order to take the more conversative loss update.

For PPO, all of this happens *while* learning a value function, which opens more complexity, but this is the core logic for the parameter update.

### Group Relative Policy Optimization

The DeepSeekMath paper details some implementation details of GRPO that differ from PPO [@shao2024deepseekmath], especially if comparing to a standard application of PPO from Deep RL rather than language models.
For example, the KL penalty within the RLHF optimization (recall the KL penalty is also used when training reasoning models on verifiable rewards without a reward model) is applied directly in the loss update rather to the reward function.
Where the standard KL penalty application for RLHF is applied as $r=r_\theta + \beta D_{KL}$, the GRPO implementation is along the lines of:

$$ L = L_{\text{policy gradient}} - \beta * D_{KL} $$

Though, there are multiple ways to implement this.
Traditionally, the KL distance is computed with respect to each token in the completion to a prompt $s$.
For reasoning training, multiple completions are sampled from one prompt, and there are multiple prompts in one batch,
so the KL distance will have a shape of [B, L, N], where B is the batch size, L is the sequence length, and N is the number of completions per prompt.
The question when implementing GRPO is: How do you sum over the KL distance and loss to design different types of value-attribution.
In the below implementation, the loss is summed over the tokens in the completion, but mean could be an alternative.

```
# B: Batch Size, L: Sequence Length, G: Number of Generations
# Compute grouped-wise rewards # Shape: (B,)
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# Shape: (B*G,)
# Compute advantages
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
advantages = advantages.unsqueeze(1)
# Shape: (B*G, 1)
# Compute probability ratio between new and old policies
ratio = torch.exp(new_per_token_logps - per_token_logps) # Shape: (B*G, L)
# PPO clipping objective
eps = self.cliprange # e.g. 0.2
pg_losses1 = -advantages * ratio # Shape: (B*G, L)
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - eps, 1.0 + eps) # Shape: (B*G, L)
pg_loss_max = torch.max(pg_losses1, pg_losses2) # Shape: (B*G, L)
# important to GRPO -- PPO applies this in reward traditionally
# Combine with KL penalty
per_token_loss = pg_loss_max + self.beta * per_token_kl # Shape: (B*G, L)
# Apply completion mask and compute final loss
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# Scalar
# Compute core metric for logging (KL, reward, etc. also logged)
with torch.no_grad():
# Compute clipping fraction
clip_frac = ((pg_losses2 > pg_losses1).float() * completion_mask).sum() / completion_mask.sum()
# Compute approximate KL
approx_kl = 0.5 * ((new_per_token_logps - per_token_logps)**2).mean()
```

For more details on how to interpret this code, see the PPO section above.


## KL Controllers

TODO: adaptive vs static KL control

Expand Down

0 comments on commit 68c5a41

Please sign in to comment.