Skip to content

Commit

Permalink
fix initialize and create basic structure
Browse files Browse the repository at this point in the history
  • Loading branch information
vaibhavjindal committed Feb 21, 2025
1 parent 7f2da45 commit aea11f2
Showing 1 changed file with 78 additions and 45 deletions.
123 changes: 78 additions & 45 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss


if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training

Expand Down Expand Up @@ -782,6 +781,20 @@ def make_inputs_require_grad(module, input, output):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

# Import Liger loss if enabled
if self.args.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
if self.precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with liger kernel. Please set `precompute_ref_log_probs=False`."
)
self.orpo_loss_fn = LigerFusedLinearKTOLoss(ignore_index=self.label_pad_token_id, beta=self.beta, use_ref_model=(self.ref_model is not None))
print("Correctly set the liger loss")

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
Expand Down Expand Up @@ -1196,6 +1209,11 @@ def kto_loss(

return losses, chosen_rewards, rejected_rewards, kl

def _compute_loss_liger(self, model, batch):
print("Not implemented")
exit(1)
pass

def get_batch_loss_metrics(
self,
model,
Expand All @@ -1205,58 +1223,73 @@ def get_batch_loss_metrics(
metrics = {}
batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}

forward_output = self.forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_KL_logps,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

# if reference_logps in batch use them, otherwise use the reference model
if "reference_logps" in batch:
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]

reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
if self.calculate_KL:
reference_KL_logps = batch["reference_KL_logps"]
else:
reference_KL_logps = None
if self.args.use_liger_loss:
model_output = self._compute_loss_liger(model, batch)
losses = model_output["losses"]
policy_chosen_logits = model_output["policy_chosen_logits"]
policy_rejected_logits = model_output["policy_rejected_logits"]
policy_chosen_logps = model_output["policy_chosen_logps"]
policy_rejected_logps = model_output["policy_rejected_logps"]
chosen_rewards = model_output["chosen_rewards"]
rejected_rewards = model_output["rejected_rewards"]
kl = model_output["kl"]
else:
with torch.no_grad():
if self.ref_model is None:
with self.null_ref_context():
forward_output = self.forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_KL_logps,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

# if reference_logps in batch use them, otherwise use the reference model
if "reference_logps" in batch:
chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True]
rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False]

reference_chosen_logps = batch["reference_logps"][chosen_idx, ...]
reference_rejected_logps = batch["reference_logps"][rejected_idx, ...]
if self.calculate_KL:
reference_KL_logps = batch["reference_KL_logps"]
else:
reference_KL_logps = None
else:
with torch.no_grad():
if self.ref_model is None:
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_KL_logps,
) = self.forward(self.model, batch)[:5]
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_KL_logps,
) = self.forward(self.model, batch)[:5]
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
reference_KL_logps,
) = self.forward(self.ref_model, batch)[:5]

losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,
policy_KL_logps,
reference_chosen_logps,
reference_rejected_logps,
reference_KL_logps,
)
) = self.forward(self.ref_model, batch)[:5]

losses, chosen_rewards, rejected_rewards, kl = self.kto_loss(
policy_chosen_logps,
policy_rejected_logps,
policy_KL_logps,
reference_chosen_logps,
reference_rejected_logps,
reference_KL_logps,
)
# print("losses: ", losses)
# print("chosen_rewards: ", chosen_rewards)
# print("rejected_rewards: ", rejected_rewards)
# print("kl: ", kl)

metrics["kl"] = kl.item()

num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device)

Expand Down

0 comments on commit aea11f2

Please sign in to comment.