From d62d59b9499c75aa583e7e924353e18acfda8139 Mon Sep 17 00:00:00 2001 From: John Giorgi Date: Sun, 2 Feb 2025 15:59:16 -0500 Subject: [PATCH 1/4] fix: warn a user if target_modules="all-linear" when reward modeling --- trl/trainer/reward_trainer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 063fe5e8e8..6b4b89519d 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -25,6 +25,7 @@ from accelerate import PartialState from accelerate.utils import gather_object from datasets import Dataset +from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND from transformers import ( BaseImageProcessor, DataCollator, @@ -159,6 +160,19 @@ def __init__( model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + # Warn if the user passes "all-linear" for the target_modules + target_modules = ( + peft_config.get("target_modules", None) + if isinstance(peft_config, dict) + else peft_config.target_modules + ) + if target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND: + warnings.warn( + f"You passed target_modules='{INCLUDE_LINEAR_LAYERS_SHORTHAND}' in the peft_config." + " This will result in all linear layers except the output layer being adapted. " + " This will negatively impact the performance of the reward model as the newly initialized output layer will not be adapted or trained.", + UserWarning, + ) model = get_peft_model(model, peft_config) # Disable dropout in the model From a974572b52dd3586136e6d8ef9138045c99b4213 Mon Sep 17 00:00:00 2001 From: John Giorgi Date: Sun, 2 Feb 2025 15:59:38 -0500 Subject: [PATCH 2/4] tests: add test confirming user warning is raised when target_modules="all-linear" when RMing --- tests/test_reward_trainer.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index 51ea1183f7..4321c1392d 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -189,6 +189,29 @@ def test_train_lora_pretokenized(self): new_param = trainer.model.get_parameter(n) self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + @require_peft + def test_all_linear_user_warning(self): + peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + target_modules="all-linear", + ) + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none") + + with self.assertWarns(UserWarning): + _ = RewardTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + peft_config=peft_config, + ) + def test_margin(self): with tempfile.TemporaryDirectory() as tmp_dir: dummy_dataset_dict = { From e606875f5c1e122e532515097f95c7bf41ba2f8f Mon Sep 17 00:00:00 2001 From: John Giorgi Date: Sun, 2 Feb 2025 16:13:26 -0500 Subject: [PATCH 3/4] fix: tweak UserWarning --- trl/trainer/reward_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 6b4b89519d..6069104608 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -168,9 +168,9 @@ def __init__( ) if target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND: warnings.warn( - f"You passed target_modules='{INCLUDE_LINEAR_LAYERS_SHORTHAND}' in the peft_config." - " This will result in all linear layers except the output layer being adapted. " - " This will negatively impact the performance of the reward model as the newly initialized output layer will not be adapted or trained.", + f"You passed target_modules='{INCLUDE_LINEAR_LAYERS_SHORTHAND}' in the peft_config. " + "This will result in all linear layers except the output layer being adapted. " + "This could negatively impact the performance of the reward model as the output layer (used for scoring of chosen and rejected completions) will not be adapted or trained.", UserWarning, ) model = get_peft_model(model, peft_config) From 8550984c625ce5cffde5388c43adf099aaa4854e Mon Sep 17 00:00:00 2001 From: John Giorgi Date: Sun, 2 Feb 2025 16:16:29 -0500 Subject: [PATCH 4/4] tests: make test look for specific userwarning --- tests/test_reward_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index 4321c1392d..5a858b2f38 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -202,8 +202,7 @@ def test_all_linear_user_warning(self): with tempfile.TemporaryDirectory() as tmp_dir: dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none") - - with self.assertWarns(UserWarning): + with self.assertWarns(UserWarning, msg="You passed target_modules="): _ = RewardTrainer( model=self.model, args=training_args,