Skip to content

Commit

Permalink
Add back checks that check expected collectives in HLO.
Browse files Browse the repository at this point in the history
Signed-off-by: Michael Goldfarb <[email protected]>
  • Loading branch information
mgoldfarb-nvidia committed Jan 13, 2025
1 parent ffc8268 commit fe2b02b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
13 changes: 11 additions & 2 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def generate_collectives_count_ref(
self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape
_, seqlen, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
tp_size = 1
if mesh_resource.tp_resource is not None:
Expand Down Expand Up @@ -124,7 +124,15 @@ def test_self_attn(
mesh_resource=mesh_resource,
)

runner.test_backward()
col_ref = self.generate_collectives_count_ref(
mesh_shape,
mesh_axes,
mesh_resource,
attn_bias_type != AttnBiasType.NO_BIAS,
data_shape,
dtype,
)
runner.test_backward(col_ref)


class TestDistributedCrossAttn:
Expand Down Expand Up @@ -186,6 +194,7 @@ def test_cross_attn(
mesh_resource=mesh_resource,
)

col_ref = self.generate_collectives_count_ref()
runner.test_backward()


Expand Down
18 changes: 16 additions & 2 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
get_cudnn_version,
)

from distributed_test_base import assert_equal_collectives
from utils import assert_allclose


Expand Down Expand Up @@ -680,9 +681,18 @@ def test_forward(self):
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

def test_backward(self):
if coll_count_ref is not None:
target_hlo = (
customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
)
assert_equal_collectives(target_hlo, coll_count_ref)

def test_backward(self, coll_count_ref=None):
"""
Test value_and_grad with JIT, which includes both forward and backward
Test value_and_grad with JIT, which includes both forward and backward.
If coll_count_ref is not None then the HLO of the backwrds function
HLO will be examined for the expected comms.
"""

self._setup_inputs()
Expand Down Expand Up @@ -804,6 +814,10 @@ def check_dqkv(primitive, reference, pad):
dtype=self.dtype,
)

if coll_count_ref is not None:
target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
assert_equal_collectives(target_hlo, coll_count_ref)


@pytest.mark.parametrize(
"attn_mask_type",
Expand Down

0 comments on commit fe2b02b

Please sign in to comment.