Skip to content

Commit

Permalink
#9370: removed ndpcc work around and debug code in sdpa decode and re…
Browse files Browse the repository at this point in the history
…-enabled CI
  • Loading branch information
caixunshiren committed Sep 30, 2024
1 parent a195aa7 commit 76b6df1
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 219 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ def run_test_sdpa_decode_single_iter(


@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.skip("Skipping due to potential nd pcc issue #9370")
@pytest.mark.parametrize(
"dtype, q_dtype",
[
Expand Down Expand Up @@ -482,7 +481,6 @@ def test_sdpa_decode(


@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.skip("Skipping due to potential nd pcc issue #9370")
@pytest.mark.parametrize(
"dtype, q_dtype",
[
Expand Down Expand Up @@ -704,7 +702,6 @@ def to_contiguous_cache(paged_cache, batch, num_kv, max_num_blocks_per_seq, bloc

@skip_for_blackhole("Unsupported on BH, see #12349")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
# @pytest.mark.skip("Skipping due to potential nd pcc issue #9370")
@pytest.mark.parametrize(
"kv_dtype, q_dtype",
[
Expand Down Expand Up @@ -756,7 +753,6 @@ def test_sdpa_decode_paged_attention(


@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.skip("Skipping due to potential nd pcc issue #9370")
@pytest.mark.parametrize(
"dtype, q_dtype",
[
Expand Down Expand Up @@ -840,7 +836,6 @@ def test_sdpa_decode_perf(device, use_program_cache):


@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.skip("Skipping due to potential nd pcc issue #9370")
@pytest.mark.parametrize(
"dtype",
[ttnn.bfloat8_b, ttnn.bfloat16],
Expand Down Expand Up @@ -962,8 +957,8 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty
)
dram_memcfg = ttnn.DRAM_MEMORY_CONFIG

K = fa_rand(nkv, b, s, d)
V = fa_rand(nkv, b, s, d)
K = fa_rand(b, nkv, s, d)
V = fa_rand(b, nkv, s, d)

tt_K = ttnn.as_tensor(K, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg)
tt_V = ttnn.as_tensor(V, device=device, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg)
Expand Down Expand Up @@ -998,8 +993,8 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty
attn_mask[:, :, :, start_idx + 1 :] = torch.finfo(torch.float32).min

Q_slice = Q[:, :, :nh, :].permute(1, 2, 0, 3) # b, nh, 1, d
K_slice = K[:, :, :padded_layer_len, :].permute(1, 0, 2, 3) # nh, b, S, d
V_slice = V[:, :, :padded_layer_len, :].permute(1, 0, 2, 3) # nh, b, S, d
K_slice = K[:, :, :padded_layer_len, :]
V_slice = V[:, :, :padded_layer_len, :]
attn_mask_slice = attn_mask[:, :, :nh, :].permute(1, 2, 0, 3) # b, nh, 1, S

expect = torch.nn.functional.scaled_dot_product_attention(
Expand All @@ -1009,7 +1004,7 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty

all_out_pass = True

for i in range(200):
for i in range(500):
tt_Q = ttnn.as_tensor(
Q[:, :, :nh],
device=device,
Expand Down Expand Up @@ -1049,9 +1044,9 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty
if not all_out_pass:
failed_start_pos.append(start_idx)

start_idx += 20 # if start_idx < 4096 else 3001
start_idx += 200 # if start_idx < 4096 else 3001

logger.info(f"ND Start Pos: {failed_start_pos}")
logger.info(f"PCC failed Start Pos: {failed_start_pos}")


@pytest.mark.timeout(600)
Expand All @@ -1060,13 +1055,13 @@ def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dty
@pytest.mark.parametrize(
"dtype, q_dtype",
[
# [ttnn.bfloat16, ttnn.bfloat16],
# [ttnn.bfloat8_b, ttnn.bfloat8_b],
[ttnn.bfloat16, ttnn.bfloat16],
[ttnn.bfloat8_b, ttnn.bfloat8_b],
[ttnn.bfloat4_b, ttnn.bfloat4_b],
],
ids=[
# "bfp16_bfp16",
# "bfp8_bfp8",
"bfp16_bfp16",
"bfp8_bfp8",
"bfp4_bfp4",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ def run_test_sdpa_decode_single_iter(


@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.skip("Skipping due to potential nd pcc issue #9370")
@pytest.mark.parametrize(
"dtype, q_dtype",
[
Expand Down Expand Up @@ -287,7 +286,6 @@ def test_sdpa_decode(


@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
@pytest.mark.skip("Skipping due to potential nd pcc issue #9370")
@pytest.mark.parametrize(
"dtype",
[
Expand Down
Loading

0 comments on commit 76b6df1

Please sign in to comment.