Skip to content

Commit

Permalink
Make style
Browse files Browse the repository at this point in the history
  • Loading branch information
nemo committed Jan 23, 2025
1 parent d1d25f8 commit 9bf10c2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/peft/tuners/adalora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ class AdaLoraConfig(LoraConfig):
The last phase, beginning once `total_step - tfinal` steps are reached, does not change the layer ranks anymore but
fine-tunes the reduced-rank layers that resulted from the previous phase.
A practical example: `tinit` is 10, `tfinal` is 20, `total_step` is 100. We spend 10 steps doing pre-training
A practical example: `tinit` is 10, `tfinal` is 20, `total_step` is 100. We spend 10 steps doing pre-training
without rank reduction because our budget is constant (init phase), then we spend 80 (100-20) steps in the
reduction phase where our budget decreases step-wise and, finally, 20 steps in the final fine-tuning stage
without reduction.
reduction phase where our budget decreases step-wise and, finally, 20 steps in the final fine-tuning stage without
reduction.
Args:
target_r (`int`): The target average rank of incremental matrix.
Expand Down
40 changes: 23 additions & 17 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,31 +332,37 @@ def test_adalora_config_r_warning(self):
def test_adalora_config_correct_timing_still_works(self):
pass

@pytest.mark.parametrize('timing_kwargs', [
{'total_step': 100, 'tinit': 0, 'tfinal': 0},
{'total_step': 100, 'tinit': 10, 'tfinal': 10},
{'total_step': 100, 'tinit': 79, 'tfinal': 20},
{'total_step': 100, 'tinit': 80, 'tfinal': 19},
])
@pytest.mark.parametrize(
"timing_kwargs",
[
{"total_step": 100, "tinit": 0, "tfinal": 0},
{"total_step": 100, "tinit": 10, "tfinal": 10},
{"total_step": 100, "tinit": 79, "tfinal": 20},
{"total_step": 100, "tinit": 80, "tfinal": 19},
],
)
def test_adalora_config_valid_timing_works(self, timing_kwargs):
# Make sure that passing correct timing values is not prevented by faulty config checks.
AdaLoraConfig(**timing_kwargs) # does not raise
AdaLoraConfig(**timing_kwargs) # does not raise

def test_adalora_config_invalid_total_step_raises(self):
with pytest.raises(ValueError) as e:
AdaLoraConfig(total_step=None)
assert "AdaLoRA does not work when `total_step` is None, supply a value > 0." in str(e)

@pytest.mark.parametrize('timing_kwargs', [
{'total_step': 100, 'tinit': 20, 'tfinal': 80},
{'total_step': 100, 'tinit': 80, 'tfinal': 20},
{'total_step': 10, 'tinit': 20, 'tfinal': 0},
{'total_step': 10, 'tinit': 0, 'tfinal': 10},
{'total_step': 10, 'tinit': 10, 'tfinal': 0},
{'total_step': 10, 'tinit': 20, 'tfinal': 0},
{'total_step': 10, 'tinit': 20, 'tfinal': 20},
{'total_step': 10, 'tinit': 0, 'tfinal': 20},
])
@pytest.mark.parametrize(
"timing_kwargs",
[
{"total_step": 100, "tinit": 20, "tfinal": 80},
{"total_step": 100, "tinit": 80, "tfinal": 20},
{"total_step": 10, "tinit": 20, "tfinal": 0},
{"total_step": 10, "tinit": 0, "tfinal": 10},
{"total_step": 10, "tinit": 10, "tfinal": 0},
{"total_step": 10, "tinit": 20, "tfinal": 0},
{"total_step": 10, "tinit": 20, "tfinal": 20},
{"total_step": 10, "tinit": 0, "tfinal": 20},
],
)
def test_adalora_config_timing_bounds_error(self, timing_kwargs):
# Check if the user supplied timing values that will certainly fail because it breaks
# AdaLoRA assumptions.
Expand Down

0 comments on commit 9bf10c2

Please sign in to comment.