Skip to content

Commit

Permalink
test (llm/sdpa): Added basic tests for SDPA
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Nov 20, 2024
1 parent f32115d commit b3b61ab
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,22 @@ def test_small_models_acc(caplog, acc_args_and_acc):

@pytest_cases.fixture(
ids=[
"opt-replace-mha",],
"opt-replace-mha",
"opt-quant-sdpa",],
params=[
{
"model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run
"weight_equalization": True,
"ln_affine_merge": True,
"replace_mha": True,
"float_ppl": 50016.0,
"quant_ppl": 50016.0},
{
"model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run
"weight_equalization": True,
"ln_affine_merge": True,
"quant_sdpa": True,
"float_ppl": 50016.0,
"quant_ppl": 50016.0},])
def acc_args_and_acc_pt_ge_2_4(default_run_args, request):
args = default_run_args
Expand Down Expand Up @@ -426,7 +434,8 @@ def test_small_models_quant_layer(caplog, layer_args):

@pytest_cases.fixture(
ids=[
"opt-replace-mha",],
"opt-replace-mha",
"opt-quant-sdpa",],
params=[
{
"model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run
Expand All @@ -435,7 +444,13 @@ def test_small_models_quant_layer(caplog, layer_args):
"model.decoder.layers.0.self_attn":
"<class 'brevitas_examples.llm.llm_quant.mha_layers.QuantizableOPTAttention'>",
"model.decoder.layers.0.self_attn.mha":
"<class 'brevitas.nn.quant_mha.QuantMultiheadAttention'>",}},])
"<class 'brevitas.nn.quant_mha.QuantMultiheadAttention'>",}},
{
"model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run
"quant_sdpa": True,
"exp_layer_types": {
"scaled_dot_product_attention":
"<class 'brevitas.nn.quant_sdpa.QuantScaledDotProductAttention'>",}},])
def layer_args_pt_ge_2_4(default_run_args, request):
args = default_run_args
layer_dict = request.param
Expand Down

0 comments on commit b3b61ab

Please sign in to comment.