Skip to content

Commit

Permalink
Fix various bugs in the test harness.
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoldfarb-nvidia committed Jan 14, 2025
1 parent 7eb6428 commit 3e3939c
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 67 deletions.
23 changes: 8 additions & 15 deletions tests/jax/distributed_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,30 @@

def generate_configs():
configs = []
mr = MeshResource(dp_resource="dp", tp_resource="tp")
axes = ("dp", "tp")
if is_devices_enough(2):
configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
configs.append(pytest.param(2, (2, 1), axes, mr, id="n2_dp2_tp1"))
configs.append(pytest.param(2, (1, 2), axes, mr, id="n2_dp1_tp2"))

if is_devices_enough(4):
TP_size = 2
DP_size = 2
configs.append(
[4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
)
configs.append(pytest.param(4, (2, 2), axes, mr, id=f"n4_dp2_tp2"))

return configs


def generate_context_parallel_configs():
configs = []

mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp")
axes = ("dp", "cp", "tp")
DP_sizes = (1, 2)
CP_sizes = (1, 2, 4, 8)
TP_sizes = (1, 2)
for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes):
ndev = cp * tp * dp
if is_devices_enough(ndev):
configs.append(
pytest.param(
ndev,
(dp, cp, tp),
("dp", "cp", "tp"),
MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"),
id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}",
)
pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}")
)

return configs
Expand Down
62 changes: 34 additions & 28 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,19 @@
)
from transformer_engine.jax.sharding import MeshResource

from test_fused_attn import FusedAttnRunner, general_dot_product_attention, make_mask
from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask

DTYPES = [jnp.float16, jnp.bfloat16]


class TestDistributedSelfAttn:

def generate_collectives_count_ref(
self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
):
def generate_collectives_count_ref(self, mesh_shape, attn_bias_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
tp_size = 1
if mesh_resource.tp_resource is not None:
idx = mesh_axes.index(mesh_resource.tp_resource)
tp_size = mesh_shape[idx]
with_bias = attn_bias_type != AttnBiasType.NO_BIAS
dp_size, tp_size = mesh_shape[:2]
is_dp_enabled = dp_size > 1

all_reduce_loss_bytes = 4 # 1 * FP32
bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
Expand All @@ -62,13 +58,27 @@ def generate_collectives_count_ref(
return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0)

@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
@pytest.mark.parametrize("data_shape", [[32, 512, 12, 64], [32, 1024, 16, 128]])
@pytest.mark.parametrize(
"attn_bias_type",
[AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
"data_shape",
[
pytest.param((32, 512, 12, 64), id="32-512-12-64"),
pytest.param((32, 1024, 16, 128), id="32-1024-16-128"),
],
)
@pytest.mark.parametrize(
"attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"),
],
)
@pytest.mark.parametrize("dtype", DTYPES)
def test_self_attn(
Expand All @@ -79,6 +89,7 @@ def test_self_attn(
mesh_resource,
data_shape,
attn_bias_type,
bias_shape,
attn_mask_type,
dtype,
):
Expand All @@ -103,6 +114,7 @@ def test_self_attn(
):
pytest.skip(f"No FusedAttn backend found")

col_ref = self.generate_collectives_count_ref(mesh_shape, attn_bias_type, data_shape, dtype)
runner = FusedAttnRunner(
batch,
seqlen,
Expand All @@ -116,23 +128,15 @@ def test_self_attn(
dtype,
is_training,
QKVLayout.BS3HD,
None,
bias_shape,
None,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
coll_count_ref=col_ref,
)

col_ref = self.generate_collectives_count_ref(
mesh_shape,
mesh_axes,
mesh_resource,
attn_bias_type != AttnBiasType.NO_BIAS,
data_shape,
dtype,
)
runner.test_backward(col_ref)
runner.test_backward()


class TestDistributedCrossAttn:
Expand All @@ -152,6 +156,7 @@ def test_cross_attn(
self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
):
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
dropout_prob = 0.0
is_training = True

Expand All @@ -173,6 +178,7 @@ def test_cross_attn(
):
pytest.skip(f"No FusedAttn backend found")

col_ref = self.generate_collectives_count_ref()
runner = FusedAttnRunner(
batch,
seqlen,
Expand All @@ -186,15 +192,14 @@ def test_cross_attn(
dtype,
is_training,
QKVLayout.BSHD_BS2HD,
None,
bias_shape,
None,
number_of_devices=device_count,
mesh_shape=mesh_shape,
mesh_axes=mesh_axes,
mesh_resource=mesh_resource,
coll_count_ref=col_ref,
)

col_ref = self.generate_collectives_count_ref()
runner.test_backward()


Expand Down Expand Up @@ -246,6 +251,7 @@ def impl_test_context_parallel_attn(
cp_strategy,
):
attn_bias_type = AttnBiasType.NO_BIAS
bias_shape = None
dropout_prob = 0.0
is_training = True
dp_size, cp_size, tp_size = mesh_shape
Expand Down Expand Up @@ -274,7 +280,7 @@ def impl_test_context_parallel_attn(
dtype,
is_training,
qkv_layout,
None,
bias_shape,
None,
number_of_devices=device_count,
mesh_shape=mesh_shape,
Expand Down
89 changes: 65 additions & 24 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from functools import partial
from math import sqrt
from typing import Tuple, Optional
from typing import Tuple, Optional, Dict
import random

import jax
Expand Down Expand Up @@ -43,7 +43,7 @@
)

from distributed_test_base import assert_equal_collectives
from utils import assert_allclose
from utils import assert_allclose, print_debug_tensor_stats


@pytest.fixture(autouse=True, scope="module")
Expand Down Expand Up @@ -321,6 +321,9 @@ class FusedAttnRunner:
cp_strategy: CPStrategy = CPStrategy.DEFAULT
cp_load_balanced: bool = True

# dictionary of expected collective comm bytes
coll_count_ref: Optional[Dict[str, int]] = None

# See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
Expand Down Expand Up @@ -387,7 +390,7 @@ def _setup_inputs(self):
self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1)
self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1)

key = jax.random.PRNGKey(0)
key = jax.random.PRNGKey(1124)
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)

q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
Expand Down Expand Up @@ -644,10 +647,12 @@ def test_forward(self):
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
}

customcall_fused_dpa_jit = jit(
customcall_fused_dpa,
partial(customcall_fused_dpa, **kwargs),
static_argnames=kwargs.keys(),
in_shardings=[
self.qkvo_sharding,
Expand All @@ -659,15 +664,13 @@ def test_forward(self):
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.dropout_rng_sharding,
],
)

with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
# Convert the outputs to float32 for the elementwise comparison
primitive_out = customcall_fused_dpa_jit(*customcall_args, **kwargs)
primative_out = self.cp_inverse_reorder_fn(primative_out)
primitive_out = customcall_fused_dpa_jit(*customcall_args)
primitive_out = self.cp_inverse_reorder_fn(primitive_out)

reference_out = jax_dpa(*args, **kwargs)

Expand All @@ -681,13 +684,14 @@ def test_forward(self):
assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)

if coll_count_ref is not None:
target_hlo = (
customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
)
assert_equal_collectives(target_hlo, coll_count_ref)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
target_hlo = (
customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text()
)
assert_equal_collectives(target_hlo, self.coll_count_ref)

def test_backward(self, coll_count_ref=None):
def test_backward(self):
"""
Test value_and_grad with JIT, which includes both forward and backward.
Expand All @@ -706,7 +710,9 @@ def grad_func(func, *args, **kwargs):
ret_valid = jnp.where(
self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs)
)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype)
return (
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
).astype(self.dtype)

args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
customcall_args = [
Expand All @@ -732,10 +738,22 @@ def grad_func(func, *args, **kwargs):
"qkv_layout": self.qkv_layout,
"max_segments_per_seq": self._get_max_segments_per_sequence(),
"window_size": self.window_size,
"context_parallel_strategy": self.cp_strategy,
"context_parallel_causal_load_balanced": self.cp_load_balanced,
}

# We can compute dBias only for the [1, h, s, s] layout
arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2)
if self.bias_shape == BiasShape._1HSS:
arg_nums = (0, 1, 2, 3)
grad_shardings = (
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
)
else:
arg_nums = (0, 1, 2)
grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding)

# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
Expand All @@ -744,7 +762,20 @@ def grad_func(func, *args, **kwargs):
customcall_fused_dpa, q, k, v, bias, *args, **kwargs
),
arg_nums,
)
),
in_shardings=(
self.qkvo_sharding,
self.qkvo_sharding,
self.qkvo_sharding,
self.bias_sharding,
self.mask_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.seq_length_offset_sharding,
self.dropout_rng_sharding,
),
out_shardings=(None, grad_shardings),
)
jitted_reference = jit(
value_and_grad(
Expand All @@ -762,13 +793,22 @@ def grad_func(func, *args, **kwargs):
if self.dropout_prob > 0.0:
return

print_debug_tensor_stats(f"primitive_out", primitive_out)
print_debug_tensor_stats(f"reference_grad_valid", reference_out)
print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out))
assert_allclose(primitive_out, reference_out, dtype=self.dtype)

def check_dqkv(primitive, reference, pad):
def check_dqkv(primitive, reference, pad, idx):
primitive_valid, primitive_invalid, reference_valid, reference_invalid = (
_split_valid_and_invalid(primitive, reference, pad)
)

print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx])
print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx])
print_debug_tensor_stats(
f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx])
)

assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype)
assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype)
assert_allclose(primitive_valid, reference_valid, dtype=self.dtype)
Expand All @@ -780,9 +820,9 @@ def check_dqkv(primitive, reference, pad):
primitive_dk = self.cp_inverse_reorder_fn(primitive_dk)
primitive_dv = self.cp_inverse_reorder_fn(primitive_dv)

check_dqkv(primitive_dq, reference_dq, self.pad_q)
check_dqkv(primitive_dk, reference_dk, self.pad_kv)
check_dqkv(primitive_dv, reference_dv, self.pad_kv)
check_dqkv(primitive_dq, reference_dq, self.pad_q, 0)
check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1)
check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2)

if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS:
# TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation.
Expand Down Expand Up @@ -814,9 +854,10 @@ def check_dqkv(primitive, reference, pad):
dtype=self.dtype,
)

if coll_count_ref is not None:
target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
assert_equal_collectives(target_hlo, coll_count_ref)
if self.coll_count_ref is not None:
with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource):
target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text()
assert_equal_collectives(target_hlo, self.coll_count_ref)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 3e3939c

Please sign in to comment.