diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 7769797108..96ead61b4c 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -69,7 +69,6 @@ if is_liger_kernel_available(): from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss - if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training @@ -782,6 +781,20 @@ def make_inputs_require_grad(module, input, output): else: self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + # Import Liger loss if enabled + if self.args.use_liger_loss: + if not is_liger_kernel_available(): + raise ValueError( + "You set `use_liger_loss=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + if self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with liger kernel. Please set `precompute_ref_log_probs=False`." + ) + self.orpo_loss_fn = LigerFusedLinearKTOLoss(ignore_index=self.label_pad_token_id, beta=self.beta, use_ref_model=(self.ref_model is not None)) + print("Correctly set the liger loss") + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 deepspeed_plugin = self.accelerator.state.deepspeed_plugin @@ -1196,6 +1209,11 @@ def kto_loss( return losses, chosen_rewards, rejected_rewards, kl + def _compute_loss_liger(self, model, batch): + print("Not implemented") + exit(1) + pass + def get_batch_loss_metrics( self, model, @@ -1205,58 +1223,73 @@ def get_batch_loss_metrics( metrics = {} batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} - forward_output = self.forward(model, batch) - ( - policy_chosen_logps, - policy_rejected_logps, - policy_chosen_logits, - policy_rejected_logits, - policy_KL_logps, - ) = forward_output[:5] - if self.aux_loss_enabled: - aux_loss = forward_output[5] - - # if reference_logps in batch use them, otherwise use the reference model - if "reference_logps" in batch: - chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True] - rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False] - - reference_chosen_logps = batch["reference_logps"][chosen_idx, ...] - reference_rejected_logps = batch["reference_logps"][rejected_idx, ...] - if self.calculate_KL: - reference_KL_logps = batch["reference_KL_logps"] - else: - reference_KL_logps = None + if self.args.use_liger_loss: + model_output = self._compute_loss_liger(model, batch) + losses = model_output["losses"] + policy_chosen_logits = model_output["policy_chosen_logits"] + policy_rejected_logits = model_output["policy_rejected_logits"] + policy_chosen_logps = model_output["policy_chosen_logps"] + policy_rejected_logps = model_output["policy_rejected_logps"] + chosen_rewards = model_output["chosen_rewards"] + rejected_rewards = model_output["rejected_rewards"] + kl = model_output["kl"] else: - with torch.no_grad(): - if self.ref_model is None: - with self.null_ref_context(): + forward_output = self.forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_KL_logps, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + # if reference_logps in batch use them, otherwise use the reference model + if "reference_logps" in batch: + chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True] + rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False] + + reference_chosen_logps = batch["reference_logps"][chosen_idx, ...] + reference_rejected_logps = batch["reference_logps"][rejected_idx, ...] + if self.calculate_KL: + reference_KL_logps = batch["reference_KL_logps"] + else: + reference_KL_logps = None + else: + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + reference_KL_logps, + ) = self.forward(self.model, batch)[:5] + else: ( reference_chosen_logps, reference_rejected_logps, _, _, reference_KL_logps, - ) = self.forward(self.model, batch)[:5] - else: - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - reference_KL_logps, - ) = self.forward(self.ref_model, batch)[:5] - - losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( - policy_chosen_logps, - policy_rejected_logps, - policy_KL_logps, - reference_chosen_logps, - reference_rejected_logps, - reference_KL_logps, - ) + ) = self.forward(self.ref_model, batch)[:5] + + losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( + policy_chosen_logps, + policy_rejected_logps, + policy_KL_logps, + reference_chosen_logps, + reference_rejected_logps, + reference_KL_logps, + ) + # print("losses: ", losses) + # print("chosen_rewards: ", chosen_rewards) + # print("rejected_rewards: ", rejected_rewards) + # print("kl: ", kl) + metrics["kl"] = kl.item() - num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)