Skip to content

Commit

Permalink
Make yapf happy
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee committed Jan 9, 2025
1 parent 0d19f03 commit 5af9cbb
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
22 changes: 11 additions & 11 deletions tests/lora/test_punica_ops_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_sgmv(
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_punica_sgmv(
token_nums,
scaling,
)

else:
with _dict_lock:
_LORA_B_PTR_DICT.clear()
Expand All @@ -215,7 +215,7 @@ def test_punica_sgmv(
offset_start=0,
add_inputs=True,
)
if nslices==1:
if nslices == 1:
# Verify the torch's sgmv_expand op
sgmv_expand(
inputs_tensor[0],
Expand Down Expand Up @@ -387,14 +387,14 @@ def test_punica_bgmv_expand_nslices(
add_inputs=True,
)
bgmv_expand_slice(
inputs_tensor,
lora_weights,
ref_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
inputs_tensor,
lora_weights,
ref_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)

slice_offset += hidden_size
assert_close(our_outputs, ref_outputs)
24 changes: 13 additions & 11 deletions tests/lora/test_punica_ops_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@

_dict_lock = Lock()


@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", DEVICES)
def test_punica_sgmv(
Expand Down Expand Up @@ -113,7 +114,7 @@ def test_punica_sgmv(
token_nums,
scaling,
)

else:
with _dict_lock:
_LORA_B_PTR_DICT.clear()
Expand All @@ -131,7 +132,7 @@ def test_punica_sgmv(
add_inputs=True,
)
slice_offset = 0
if nslices==1:
if nslices == 1:
# Verify the torch's sgmv_expand op
sgmv_expand(
inputs_tensor[0],
Expand Down Expand Up @@ -166,6 +167,7 @@ def test_punica_sgmv(

assert_close(our_out_tensor, ref_out_tensor)


@pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS)
Expand Down Expand Up @@ -301,14 +303,14 @@ def test_punica_bgmv_expand_nslices(
add_inputs=True,
)
bgmv_expand_slice(
inputs_tensor,
lora_weights,
ref_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
inputs_tensor,
lora_weights,
ref_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)

slice_offset += hidden_size
assert_close(our_outputs, ref_outputs)

0 comments on commit 5af9cbb

Please sign in to comment.