From 9b9cf6d7d9b47179e43cdc6cf4a045096eaf13fb Mon Sep 17 00:00:00 2001 From: nemo Date: Fri, 24 Jan 2025 14:21:26 +0100 Subject: [PATCH] Move fix to the right place --- tests/test_torch_compile.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py index f3665fad18..b6011831b1 100644 --- a/tests/test_torch_compile.py +++ b/tests/test_torch_compile.py @@ -176,6 +176,10 @@ def test_causal_lm_training_trainer_compile(self, settings, tokenizer, data, tmp "output_dir": tmp_dir, "seed": 0, } + + if isinstance(config, AdaLoraConfig): + train_kwargs["learning_rate"] = 1e-2 + training_args = TrainingArguments( torch_compile=not self.fake_compile, torch_compile_backend=compile_kwargs.get("torch_compile_backend", None), @@ -195,7 +199,6 @@ class OptimizerStepCallback(TrainerCallback): def on_optimizer_step(self, args, state, control, **kwargs): model.update_and_allocate(state.global_step) trainer.add_callback(OptimizerStepCallback()) - train_kwargs["learning_rate"] = 1e-2 trainer.train()