From 5af9cbb9e006781f5585c402ff29d87d9aa65345 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 9 Jan 2025 06:52:23 +0000 Subject: [PATCH] Make yapf happy Signed-off-by: Jee Jee Li --- tests/lora/test_punica_ops_sizes.py | 22 +++++++++++----------- tests/lora/test_punica_ops_variation.py | 24 +++++++++++++----------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/tests/lora/test_punica_ops_sizes.py b/tests/lora/test_punica_ops_sizes.py index 9ec72e48959c5..433ca7577d084 100644 --- a/tests/lora/test_punica_ops_sizes.py +++ b/tests/lora/test_punica_ops_sizes.py @@ -124,7 +124,7 @@ @pytest.mark.parametrize("scaling", SCALES) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( @@ -198,7 +198,7 @@ def test_punica_sgmv( token_nums, scaling, ) - + else: with _dict_lock: _LORA_B_PTR_DICT.clear() @@ -215,7 +215,7 @@ def test_punica_sgmv( offset_start=0, add_inputs=True, ) - if nslices==1: + if nslices == 1: # Verify the torch's sgmv_expand op sgmv_expand( inputs_tensor[0], @@ -387,14 +387,14 @@ def test_punica_bgmv_expand_nslices( add_inputs=True, ) bgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) + inputs_tensor, + lora_weights, + ref_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) slice_offset += hidden_size assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/test_punica_ops_variation.py b/tests/lora/test_punica_ops_variation.py index 6d3d79c2d7b82..2583da3fb6c0c 100644 --- a/tests/lora/test_punica_ops_variation.py +++ b/tests/lora/test_punica_ops_variation.py @@ -32,6 +32,7 @@ _dict_lock = Lock() + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) @@ -39,7 +40,7 @@ @pytest.mark.parametrize("scaling", SCALES) @pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", DEVICES) def test_punica_sgmv( @@ -113,7 +114,7 @@ def test_punica_sgmv( token_nums, scaling, ) - + else: with _dict_lock: _LORA_B_PTR_DICT.clear() @@ -131,7 +132,7 @@ def test_punica_sgmv( add_inputs=True, ) slice_offset = 0 - if nslices==1: + if nslices == 1: # Verify the torch's sgmv_expand op sgmv_expand( inputs_tensor[0], @@ -166,6 +167,7 @@ def test_punica_sgmv( assert_close(our_out_tensor, ref_out_tensor) + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) @@ -301,14 +303,14 @@ def test_punica_bgmv_expand_nslices( add_inputs=True, ) bgmv_expand_slice( - inputs_tensor, - lora_weights, - ref_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) + inputs_tensor, + lora_weights, + ref_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) slice_offset += hidden_size assert_close(our_outputs, ref_outputs)