Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Liger] add native liger-kernel orpo loss #2482

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion tests/test_orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.testing_utils import require_liger_kernel, require_peft

from trl import ORPOConfig, ORPOTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
Expand Down Expand Up @@ -148,3 +148,39 @@ def test_orpo_trainer_with_lora(self, config_name):
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@require_liger_kernel
def test_orpo_trainer_with_liger(self):
"""Test ORPO trainer with Liger loss enabled."""
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = ORPOConfig(
output_dir=tmp_dir,
report_to="none",
use_liger_loss=True, # Enable Liger loss
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = ORPOTrainer(
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

# Verify Liger is being used
self.assertTrue(trainer._using_liger)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
3 changes: 3 additions & 0 deletions trl/trainer/orpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class ORPOConfig(TrainingArguments):
string.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use Liger loss.
"""

learning_rate: float = 1e-6
Expand All @@ -76,3 +78,4 @@ class ORPOConfig(TrainingArguments):
is_encoder_decoder: Optional[bool] = None
model_init_kwargs: Optional[dict[str, Any]] = None
dataset_num_proc: Optional[int] = None
use_liger_loss: bool = False
167 changes: 113 additions & 54 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from transformers.utils import is_liger_kernel_available, is_peft_available, is_torch_fx_proxy
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
Expand All @@ -68,7 +68,6 @@
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


if is_wandb_available():
import wandb

Expand Down Expand Up @@ -357,6 +356,26 @@ def make_inputs_require_grad(module, input, output):
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)

# Import Liger loss if enabled
if self.args.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
try:
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss

self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
self._using_liger = True
except ImportError:
raise ImportError(
"Failed to import LigerFusedLinearORPOLoss from liger-kernel. "
"Please ensure you have the correct liger-kernel version installed."
)
else:
self._using_liger = False

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
Expand Down Expand Up @@ -756,51 +775,74 @@ def concatenated_forward(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
output_hidden_states=True if self._using_liger else False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
if self._using_liger:
lm_head = model.get_output_embeddings()

# Get the last hidden state from hidden_states tuple
last_hidden_state = outputs.hidden_states[-1]

# return the final loss and aux_outputs tuple
return self.orpo_loss_fn(
lm_head.weight,
last_hidden_state,
concatenated_batch["concatenated_labels"],
lm_head.bias if hasattr(lm_head, "bias") else None,
)
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)

chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]
chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)
if self.aux_loss_enabled:
return (
chosen_logps,
rejected_logps,
chosen_logits,
rejected_logits,
chosen_nll_loss,
outputs.aux_loss,
)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)

def get_batch_loss_metrics(
self,
Expand All @@ -812,21 +854,41 @@ def get_batch_loss_metrics(
metrics = {}

forward_output = self.concatenated_forward(model, batch)
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]
if self._using_liger:
# full ORPO loss and aux outputs
(
loss,
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
chosen_rewards,
rejected_rewards,
log_odds_ratio,
log_odds_chosen,
),
) = forward_output
else:
(
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_nll_loss,
) = forward_output[:5]
if self.aux_loss_enabled:
aux_loss = forward_output[5]

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()

losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
policy_chosen_logps, policy_rejected_logps
)
# full ORPO loss
loss = policy_nll_loss - losses.mean()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

reward_accuracies = (chosen_rewards > rejected_rewards).float()

Expand All @@ -846,8 +908,6 @@ def get_batch_loss_metrics(
xm.mark_step() # needed because .item() calls
for k, v in metrics.items():
metrics[k] = v.item()
if self.aux_loss_enabled:
loss += self.aux_loss_coef * aux_loss

return loss, metrics

Expand All @@ -859,7 +919,6 @@ def compute_loss(
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
compute_loss_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()

with compute_loss_context_manager:
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

Expand Down
Loading