Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom DPO losses support #2292

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions docs/source/recipes/dpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,127 @@ To use any of these, simply use the ``loss`` config entry or flag through the :r
loss=torchtune.modules.loss.RSOLoss \
gamma=0.5

We also support custom contrastive losses! But due to our philosophy related to the simplicity of the recipes, we do not support any of them directly in torchtune.
Instead, we provide a mechanism to make it possible to use a recipe with a custom loss without touching internals.

Here's how:

1. Introduce your loss in the following format:

.. code-block:: python

class SimPOLoss(nn.Module):
"""
SimPO: Simple Preference Optimization with a Reference-Free Reward: https://arxiv.org/abs/2405.14734.
Intuition from the paper:
The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as
the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to
encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance.
Based on the TRL implementation:
https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/cpo_trainer.py#L603
SimPO is pretty much identitcal to DPO but uses average logprobs to eliminate the need for a reference model to regularize
the policy during training. It also uses a target reward margin to guide the policy towards better responses.
This is kind of the same intuition as in :class:`~torchtune.rlhf.loss.IPOLoss`, but instead of optimizing against
a margin between the reference policy and policy models, we're optimizing against a margin between the chosen and
rejected responses.
Args:
beta (float): Equivalent temperature scaling parameter to DPO loss, typically in the range of 2.0 to 2.5. Default is 2.0.
gamma (float): Target reward margin hyperparameter, typically we have ``gamma in (0, 1.5]``.
Default is 0.5.
label_smoothing (float): Parameter encoding uncertainty about the labels. Default is 0.
"""

def __init__(
self,
beta: float = 2.0,
gamma: float = 0.5,
label_smoothing: float = 0.0,
):
super().__init__()
self.beta = beta
self.gamma = gamma
self.label_smoothing = label_smoothing

def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the SimPO loss for a batch chosen and rejected average log probabilities.
Args:
policy_chosen_logps (torch.Tensor): Average log probabilities of the policy model
for the chosen responses with shape [b,].
policy_rejected_logps (torch.Tensor): Average log probabilities of the policy model
for the rejected responses with shape [b,].
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; A tuple of three tensors with shape [b,]:
- losses: The SimPO loss for each example in the batch.
- chosen_rewards: Rewards for the chosen responses.
- rejected_rewards: Rewards for the rejected responses.
"""

pi_logratios = policy_chosen_logps - policy_rejected_logps

gamma_logratios = self.gamma / self.beta
logits = pi_logratios - gamma_logratios

losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)

chosen_rewards = self.beta * (policy_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps).detach()

return losses, chosen_rewards, rejected_rewards
2. Create some module in your config directory with this loss, for instance `my_loss.py`
3. If it is required, you may implement your own forward function. Note that you need to wrap it in a class:

.. code-block:: python
class RLHFForward:
def concatenated_forward(
self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor], _device, activations_handling_ctx
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Run forward pass of the model with chosen and rejected samples concatenated.

Args:
model (nn.Module): The model to be used for the forward pass.
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.

Returns:
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
"""
concatenated_input_ids, concatenated_labels = batch
concatenated_input_ids = concatenated_input_ids.to(_device)
concatenated_labels = concatenated_labels.to(_device)

# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2

with activations_handling_ctx:
all_logits = model(concatenated_input_ids)

all_log_probs = rlhf.get_batch_log_probs(all_logits, concatenated_labels)

chosen_log_probs = all_log_probs[:len_chosen]
rejected_log_probs = all_log_probs[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)
4. Note that the arguments of the concatenated_forward are fixed, but the output, except for the required logits, can be different!
5. Finally, introduce both loss and forward (not always required) in the config:

.. code-block:: yaml
loss:
_component_: my_loss.SimPOLoss
forward:
_component_: my_loss.RLHFForward


For a deeper understanding of the different levers you can pull when using this recipe,
see our documentation for the different PEFT training paradigms we support:

Expand Down
22 changes: 20 additions & 2 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import Any, Dict, List, Optional, Tuple
from warnings import warn

import omegaconf

import torch
from omegaconf import DictConfig, ListConfig

Expand Down Expand Up @@ -290,6 +292,16 @@ def setup(self, cfg: DictConfig) -> None:

utils.log_rank_zero(log, "Loss is initialized.")

try:
self._forward = config.instantiate(cfg.forward)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for this change?

self.concatenated_forward = self._forward.concatenated_forward
utils.log_rank_zero(log, "Concatenated forward is initialized.")
except omegaconf.errors.ConfigAttributeError:
utils.log_rank_zero(
log,
"Has not initialized custom concatenated_forward, using common one.",
)

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after all of these are setup
self._sampler, self._dataloader = self._setup_data(
Expand Down Expand Up @@ -685,7 +697,10 @@ def train(self) -> None:
policy_rejected_log_probs,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(self._model, batch)
*args,
) = self.concatenated_forward(
self._model, batch, self._device, self.activations_handling_ctx
)

policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()
Expand All @@ -699,12 +714,15 @@ def train(self) -> None:
reference_rejected_log_probs,
_,
_,
) = self.concatenated_forward(self._model, batch)
) = self.concatenated_forward(
self._model, batch, self._device, self.activations_handling_ctx
)
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
*args,
)

loss = loss.mean()
Expand Down
21 changes: 19 additions & 2 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import Any, Dict, Optional, Tuple
from warnings import warn

import omegaconf

import torch
from omegaconf import DictConfig, ListConfig

Expand Down Expand Up @@ -241,6 +243,15 @@ def setup(self, cfg: DictConfig) -> None:
self._loss_fn = config.instantiate(cfg.loss)
log.info("Loss function is initialized.")

try:
self._forward = config.instantiate(cfg.forward)
self.concatenated_forward = self._forward.concatenated_forward
log.info("Concatenated forward is initialized.")
except omegaconf.errors.ConfigAttributeError:
log.info(
"Has not initialized custom concatenated_forward, using common one."
)

# Dataloader depends on the tokenizer and loss_fn and should be
# setup after all of these are setup
self._sampler, self._dataloader = self._setup_data(
Expand Down Expand Up @@ -536,7 +547,10 @@ def train(self) -> None:
policy_rejected_log_probs,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(self._model, batch)
*args,
) = self.concatenated_forward(
self._model, batch, self._device, self.activations_handling_ctx
)

policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()
Expand All @@ -550,12 +564,15 @@ def train(self) -> None:
reference_rejected_log_probs,
_,
_,
) = self.concatenated_forward(self._model, batch)
) = self.concatenated_forward(
self._model, batch, self._device, self.activations_handling_ctx
)
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
*args,
)

loss = loss.mean()
Expand Down
3 changes: 1 addition & 2 deletions torchtune/models/t5/_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ class T5EncoderSelfAttention(nn.Module):
output_proj (nn.Module): Projection layer for output.

Raises:
ValueError: If ``num_heads % num_kv_heads != 0``
ValueError: If ``embed_dim // num_heads != head_dim``
ValueError: If ``num_heads % num_kv_heads != 0`` or ``embed_dim // num_heads != head_dim``
"""

def __init__(
Expand Down