Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Liger] add native liger-kernel orpo loss #2482

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion tests/test_orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.testing_utils import require_liger_kernel, require_peft

from trl import ORPOConfig, ORPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
Expand Down Expand Up @@ -146,3 +146,36 @@ def test_orpo_trainer_with_lora(self, config_name):
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
self.assertFalse(torch.equal(param, new_param))

@require_liger_kernel
def test_orpo_trainer_with_liger(self):
"""Test ORPO trainer with Liger loss enabled."""
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = ORPOConfig(
output_dir=tmp_dir,
report_to="none",
use_liger_loss=True, # Enable Liger loss
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = ORPOTrainer(
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
16 changes: 16 additions & 0 deletions trl/trainer/orpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class ORPOConfig(TrainingArguments):
string.
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
base_model_attribute_name (`str`, *optional*, defaults to `"model"`):
Name of the attribute in the model that contains the base model. This is used to get the base model from the
model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`.
"""

learning_rate: float = field(
Expand Down Expand Up @@ -138,3 +143,14 @@ class ORPOConfig(TrainingArguments):
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
use_liger_loss: bool = field(
default=False,
metadata={"help": "Whether to use Liger loss."},
)
base_model_attribute_name: str = field(
default="model",
metadata={
"help": "Name of the attribute in the model that contains the base model. This is used to get the base model "
"from the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`."
},
)
227 changes: 152 additions & 75 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import PreTrainedModelWrapper
Expand All @@ -71,7 +71,6 @@
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


if is_wandb_available():
import wandb

Expand All @@ -81,6 +80,9 @@
if is_torch_xla_available():
import torch_xla.core.xla_model as xm

if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss


class ORPOTrainer(Trainer):
r"""
Expand Down Expand Up @@ -358,6 +360,15 @@ def make_inputs_require_grad(module, input, output):
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)

# Import Liger loss if enabled
if self.args.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
Expand Down Expand Up @@ -753,59 +764,112 @@ def concatenated_forward(
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

# orpo nll target is with respect to the concatenated prompt + completionlabels
if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
# orpo chosen nll loss is computed over the full prompt and response
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
if self.args.use_liger_loss:
if self.is_encoder_decoder:
# 1. Get encoder outputs
encoder_outputs = model.get_encoder()(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
return_dict=True,
)
# 2. Get decoder outputs
outputs = model.get_decoder()(
input_ids=model_kwargs["decoder_input_ids"],
encoder_hidden_states=encoder_outputs.last_hidden_state,
use_cache=False,
)
else:
# skip the lm head and get the last hidden state
if hasattr(model, "get_decoder"):
base_model = model.get_decoder()
else:
base_model = getattr(model, self.args.base_model_attribute_name)
outputs = base_model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
lm_head = model.get_output_embeddings()

# return the final loss and aux_outputs tuple
loss, aux_outputs = self.orpo_loss_fn(
lm_head.weight,
outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state,
concatenated_batch["concatenated_labels"][:, 1:]
if not self.is_encoder_decoder
else concatenated_batch["concatenated_labels"],
lm_head.bias if hasattr(lm_head, "bias") else None,
nll_target=labels[:, 1:] if not self.is_encoder_decoder else labels,
)

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
if self.aux_loss_enabled:
loss += self.aux_loss_coef * outputs.aux_loss

if not self.is_encoder_decoder:
chosen_logits = all_logits[:len_chosen, :-1, :]
rejected_logits = all_logits[len_chosen:, :-1, :]
return loss, aux_outputs
else:
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
output_hidden_states=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
if not self.is_encoder_decoder:
chosen_logits = all_logits[:len_chosen, :-1, :]
rejected_logits = all_logits[len_chosen:, :-1, :]
else:
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
outputs.aux_loss,
)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)

def get_batch_loss_metrics(
self,
Expand All @@ -817,46 +881,60 @@ def get_batch_loss_metrics(
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
if self.args.use_liger_loss:
# full ORPO loss and aux outputs
(
loss,
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
chosen_rewards,
rejected_rewards,
log_odds_ratio,
log_odds_chosen,
),
) = forward_output
else:
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

reward_accuracies = (chosen_rewards > rejected_rewards).float()

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean()
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean()
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean()
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
chosen_rewards - rejected_rewards
).mean()
metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
metrics[f"{prefix}logits/rejected"] = (
self.accelerator.gather_for_metrics(policy_rejected_logits).detach().mean()
)
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logits).detach().mean()
metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).mean()
metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).mean()
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
if is_torch_xla_available():
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

return loss, metrics

Expand All @@ -868,7 +946,6 @@ def compute_loss(
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

Expand Down
Loading