diff --git a/docs/source/recipes/dpo.rst b/docs/source/recipes/dpo.rst index c4854ef81e..6c9384e301 100644 --- a/docs/source/recipes/dpo.rst +++ b/docs/source/recipes/dpo.rst @@ -56,6 +56,127 @@ To use any of these, simply use the ``loss`` config entry or flag through the :r loss=torchtune.modules.loss.RSOLoss \ gamma=0.5 +We also support custom contrastive losses! But due to our philosophy related to the simplicity of the recipes, we do not support any of them directly in torchtune. +Instead, we provide a mechanism to make it possible to use a recipe with a custom loss without touching internals. + +Here's how: + +1. Introduce your loss in the following format: + +.. code-block:: python + + class SimPOLoss(nn.Module): + """ + SimPO: Simple Preference Optimization with a Reference-Free Reward: https://arxiv.org/abs/2405.14734. + Intuition from the paper: + The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as + the implicit reward. Additionally, we introduce a target reward margin to the Bradley-Terry objective to + encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. + Based on the TRL implementation: + https://github.com/huggingface/trl/blob/98ad01ddfd1e1b67ec018014b83cba40e0caea66/trl/trainer/cpo_trainer.py#L603 + SimPO is pretty much identitcal to DPO but uses average logprobs to eliminate the need for a reference model to regularize + the policy during training. It also uses a target reward margin to guide the policy towards better responses. + This is kind of the same intuition as in :class:`~torchtune.rlhf.loss.IPOLoss`, but instead of optimizing against + a margin between the reference policy and policy models, we're optimizing against a margin between the chosen and + rejected responses. + Args: + beta (float): Equivalent temperature scaling parameter to DPO loss, typically in the range of 2.0 to 2.5. Default is 2.0. + gamma (float): Target reward margin hyperparameter, typically we have ``gamma in (0, 1.5]``. + Default is 0.5. + label_smoothing (float): Parameter encoding uncertainty about the labels. Default is 0. + """ + + def __init__( + self, + beta: float = 2.0, + gamma: float = 0.5, + label_smoothing: float = 0.0, + ): + super().__init__() + self.beta = beta + self.gamma = gamma + self.label_smoothing = label_smoothing + + def forward( + self, + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute the SimPO loss for a batch chosen and rejected average log probabilities. + Args: + policy_chosen_logps (torch.Tensor): Average log probabilities of the policy model + for the chosen responses with shape [b,]. + policy_rejected_logps (torch.Tensor): Average log probabilities of the policy model + for the rejected responses with shape [b,]. + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]; A tuple of three tensors with shape [b,]: + - losses: The SimPO loss for each example in the batch. + - chosen_rewards: Rewards for the chosen responses. + - rejected_rewards: Rewards for the rejected responses. + """ + + pi_logratios = policy_chosen_logps - policy_rejected_logps + + gamma_logratios = self.gamma / self.beta + logits = pi_logratios - gamma_logratios + + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + + chosen_rewards = self.beta * (policy_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps).detach() + + return losses, chosen_rewards, rejected_rewards +2. Create some module in your config directory with this loss, for instance `my_loss.py` +3. If it is required, you may implement your own forward function. Note that you need to wrap it in a class: + +.. code-block:: python + class RLHFForward: + def concatenated_forward( + self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor], _device, activations_handling_ctx + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Run forward pass of the model with chosen and rejected samples concatenated. + + Args: + model (nn.Module): The model to be used for the forward pass. + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels. + + Returns: + Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits. + """ + concatenated_input_ids, concatenated_labels = batch + concatenated_input_ids = concatenated_input_ids.to(_device) + concatenated_labels = concatenated_labels.to(_device) + + # formed by concatenating an equal number of "chosen" and "rejected". + len_chosen = concatenated_input_ids.shape[0] // 2 + + with activations_handling_ctx: + all_logits = model(concatenated_input_ids) + + all_log_probs = rlhf.get_batch_log_probs(all_logits, concatenated_labels) + + chosen_log_probs = all_log_probs[:len_chosen] + rejected_log_probs = all_log_probs[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits) +4. Note that the arguments of the concatenated_forward are fixed, but the output, except for the required logits, can be different! +5. Finally, introduce both loss and forward (not always required) in the config: + +.. code-block:: yaml + loss: + _component_: my_loss.SimPOLoss + forward: + _component_: my_loss.RLHFForward + + For a deeper understanding of the different levers you can pull when using this recipe, see our documentation for the different PEFT training paradigms we support: diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index d54adc2cf4..b0805568df 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -11,6 +11,8 @@ from typing import Any, Dict, List, Optional, Tuple from warnings import warn +import omegaconf + import torch from omegaconf import DictConfig, ListConfig @@ -290,6 +292,16 @@ def setup(self, cfg: DictConfig) -> None: utils.log_rank_zero(log, "Loss is initialized.") + try: + self._forward = config.instantiate(cfg.forward) + self.concatenated_forward = self._forward.concatenated_forward + utils.log_rank_zero(log, "Concatenated forward is initialized.") + except omegaconf.errors.ConfigAttributeError: + utils.log_rank_zero( + log, + "Has not initialized custom concatenated_forward, using common one.", + ) + # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after all of these are setup self._sampler, self._dataloader = self._setup_data( @@ -685,7 +697,10 @@ def train(self) -> None: policy_rejected_log_probs, policy_chosen_logits, policy_rejected_logits, - ) = self.concatenated_forward(self._model, batch) + *args, + ) = self.concatenated_forward( + self._model, batch, self._device, self.activations_handling_ctx + ) policy_chosen_logits_mean = policy_chosen_logits.detach().mean() policy_rejected_logits_mean = policy_rejected_logits.detach().mean() @@ -699,12 +714,15 @@ def train(self) -> None: reference_rejected_log_probs, _, _, - ) = self.concatenated_forward(self._model, batch) + ) = self.concatenated_forward( + self._model, batch, self._device, self.activations_handling_ctx + ) loss, chosen_rewards, rejected_rewards = self._loss_fn( policy_chosen_log_probs, policy_rejected_log_probs, reference_chosen_log_probs, reference_rejected_log_probs, + *args, ) loss = loss.mean() diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index c493b65602..78475d440c 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -11,6 +11,8 @@ from typing import Any, Dict, Optional, Tuple from warnings import warn +import omegaconf + import torch from omegaconf import DictConfig, ListConfig @@ -241,6 +243,15 @@ def setup(self, cfg: DictConfig) -> None: self._loss_fn = config.instantiate(cfg.loss) log.info("Loss function is initialized.") + try: + self._forward = config.instantiate(cfg.forward) + self.concatenated_forward = self._forward.concatenated_forward + log.info("Concatenated forward is initialized.") + except omegaconf.errors.ConfigAttributeError: + log.info( + "Has not initialized custom concatenated_forward, using common one." + ) + # Dataloader depends on the tokenizer and loss_fn and should be # setup after all of these are setup self._sampler, self._dataloader = self._setup_data( @@ -536,7 +547,10 @@ def train(self) -> None: policy_rejected_log_probs, policy_chosen_logits, policy_rejected_logits, - ) = self.concatenated_forward(self._model, batch) + *args, + ) = self.concatenated_forward( + self._model, batch, self._device, self.activations_handling_ctx + ) policy_chosen_logits_mean = policy_chosen_logits.detach().mean() policy_rejected_logits_mean = policy_rejected_logits.detach().mean() @@ -550,12 +564,15 @@ def train(self) -> None: reference_rejected_log_probs, _, _, - ) = self.concatenated_forward(self._model, batch) + ) = self.concatenated_forward( + self._model, batch, self._device, self.activations_handling_ctx + ) loss, chosen_rewards, rejected_rewards = self._loss_fn( policy_chosen_log_probs, policy_rejected_log_probs, reference_chosen_log_probs, reference_rejected_log_probs, + *args, ) loss = loss.mean() diff --git a/torchtune/models/t5/_encoder.py b/torchtune/models/t5/_encoder.py index 7828e9ecc5..1dcf7d4dc8 100644 --- a/torchtune/models/t5/_encoder.py +++ b/torchtune/models/t5/_encoder.py @@ -147,8 +147,7 @@ class T5EncoderSelfAttention(nn.Module): output_proj (nn.Module): Projection layer for output. Raises: - ValueError: If ``num_heads % num_kv_heads != 0`` - ValueError: If ``embed_dim // num_heads != head_dim`` + ValueError: If ``num_heads % num_kv_heads != 0`` or ``embed_dim // num_heads != head_dim`` """ def __init__(