diff --git a/tests/jax/test_distributed_gemm.py b/tests/jax/test_distributed_gemm.py index f1e3c58c4a..b246999d8a 100644 --- a/tests/jax/test_distributed_gemm.py +++ b/tests/jax/test_distributed_gemm.py @@ -18,7 +18,7 @@ from utils import assert_allclose -jax.config.update('jax_enable_compilation_cache', False) +jax.config.update("jax_enable_compilation_cache", False) # AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P) @@ -48,20 +48,18 @@ def _get_mesh(parallel_dist): batched = False fsdp = False mesh_shape = dict(tp=NUM_DEVICES) - resources = dict(cp_resource='tp', tp_resource='tp') + resources = dict(cp_resource="tp", tp_resource="tp") if parallel_dist in ["DP_TP", "FSDP_TP"]: batched = True - mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=NUM_DEVICES//2)) - resources.update(dict(dp_resource='dp')) + mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=NUM_DEVICES // 2)) + resources.update(dict(dp_resource="dp")) if parallel_dist == "FSDP_TP": fsdp = True - mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=1, zp=NUM_DEVICES//2)) - resources.update(dict(fsdp_resource='zp')) + mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=1, zp=NUM_DEVICES // 2)) + resources.update(dict(fsdp_resource="zp")) mesh_resource = te.MeshResource(**resources) - devices = mesh_utils.create_device_mesh( - (NUM_DEVICES, ), devices=jax.devices()[:NUM_DEVICES] - ) + devices = mesh_utils.create_device_mesh((NUM_DEVICES,), devices=jax.devices()[:NUM_DEVICES]) mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) @@ -73,9 +71,7 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw # Operand and output shapes lhs_shape = ( - [SEQ_LEN, HIDDEN_SIZE] - if fwd_comm_type == "ALL_GATHER" - else [SEQ_LEN, FFN_HIDDEN_SIZE] + [SEQ_LEN, HIDDEN_SIZE] if fwd_comm_type == "ALL_GATHER" else [SEQ_LEN, FFN_HIDDEN_SIZE] ) rhs_shape = ( [HIDDEN_SIZE, FFN_HIDDEN_SIZE] @@ -125,12 +121,12 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw sigma = 0.023 shapes = (lhs_shape, rhs_shape) if fwd_bwd: - shapes += (out_shape, ) + shapes += (out_shape,) global_operands = list( map( lambda key, shape: jax.device_put( mu + (sigma * jax.random.normal(key, shape, dtype=dtype)), - NamedSharding(mesh, PartitionSpec(None)) + NamedSharding(mesh, PartitionSpec(None)), ), split_keys, shapes, @@ -140,7 +136,7 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw # Allocate sharded operands on device partition_axes = (lhs_spec, rhs_spec) if fwd_bwd: - partition_axes += (out_spec, ) + partition_axes += (out_spec,) local_operands = list( map( lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))), @@ -245,9 +241,7 @@ def test_gemm_impl(comm_type, mesh_type): global_operands, output_info, fsdp_gathered_rhs_spec, - ) = _get_inputs( - mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp - ) + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp) @jax.jit def _test_fn(lhs, rhs): @@ -272,9 +266,7 @@ def test_gemm_fwd_bwd(comm_type, mesh_type): global_operands, output_info, fsdp_gathered_rhs_spec, - ) = _get_inputs( - mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True - ) + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True) @jax.jit def _test_fn(lhs, rhs, grad): @@ -308,4 +300,3 @@ def _test_fn(lhs, rhs, grad): ) _check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True) - diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 30ff0ca54a..250e8e0c29 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -525,9 +525,9 @@ def infer_sharding_from_operands( rhs_spec_new = list(rhs_spec).copy() lhs_spec_new[lhs_outer_dim] = None if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: - assert lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim], ( - "Contracting dimensions of LHS and RHS operands must have the same sharding." - ) + assert ( + lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim] + ), "Contracting dimensions of LHS and RHS operands must have the same sharding." if lhs_spec_new[lhs_outer_dim] is not None: warnings.warn( "Outer dimension of the LHS operand must be all-gathered when both contracting "