From ad6bf2ae16ed7d7c54ecf725c649d19e6eab5017 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 24 Oct 2024 22:38:05 +0000 Subject: [PATCH 01/19] added XLA custom op defs for TE GEMM Signed-off-by: Alp Dener Added XLA FFI custom op for TE GEMM Signed-off-by: Alp Dener finished GEMM custom op primitive and serial unit test Signed-off-by: Alp Dener fixed GEMM custom op batcher Signed-off-by: Alp Dener fixed output dtype error and contracting dimensions options Signed-off-by: Alp Dener AG overlap working but executes scatter to match outer LHS dim Signed-off-by: Alp Dener both all-gather and all-reduce are now working Signed-off-by: Alp Dener code style Signed-off-by: Alp Dener changed kwargs in abstract to be explicit Signed-off-by: Alp Dener added fwd/bwd implementation for non-fp8 gemm Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 55 ++ .../jax/cpp_extensions/__init__.py | 1 + transformer_engine/jax/cpp_extensions/gemm.py | 647 ++++++++++++++++++ transformer_engine/jax/cpp_extensions/misc.py | 7 + transformer_engine/jax/csrc/extensions.h | 39 ++ .../jax/csrc/extensions/gemm.cpp | 170 +++++ .../jax/csrc/extensions/packing.cpp | 11 + .../jax/csrc/extensions/pybind.cpp | 5 +- transformer_engine/jax/csrc/utils.h | 2 +- transformer_engine/jax/flax/module.py | 7 +- transformer_engine/jax/fp8.py | 7 +- transformer_engine/jax/gemm.py | 425 ++++++++++++ 12 files changed, 1370 insertions(+), 6 deletions(-) create mode 100644 transformer_engine/jax/cpp_extensions/gemm.py create mode 100644 transformer_engine/jax/csrc/extensions/gemm.cpp create mode 100644 transformer_engine/jax/gemm.py diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 20b16c2809..9bf3f9fa91 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -25,6 +25,7 @@ _jax_dbias_cast_transpose, ) from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8 +from transformer_engine.jax.gemm import fp8_gemm, gemm from transformer_engine.jax import cpp_extensions as tex @@ -415,6 +416,60 @@ def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_ ) +class TestGemm: + + @staticmethod + def _generate_inputs(b, m, n, k, dtype): + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 3) + a = jax.random.normal(subkeys[0], (b, m, k), dtype) + b = jax.random.normal(subkeys[1], (n, k), dtype) + bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16 + bias = jax.random.normal(subkeys[2], (n, ), bias_dtype) + return a, b, bias + + @staticmethod + def _generate_fp8_inputs(b, m, n, k, fp8_dtype): + a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16) + a_scale, b_scale = map( + lambda x: (jnp.max(jnp.abs(x)) / 127.).astype(jnp.float32), + [a, b] + ) + a_q, b_q = map( + lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype), + [(a, a_scale), (b, b_scale)] + ) + return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias + + @pytest.mark.parametrize("m,n,k", GEMM_CASES) + @pytest.mark.parametrize("use_bias", (False, True)) + @pytest.mark.parametrize("do_gelu", (False, True)) + def test_gemm(self, b, m, n, k, use_bias, do_gelu): + a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16) + + primitive_out = gemm(a, b, bias=bias if use_bias else None, layout='NT', do_gelu=do_gelu) + ref_out = jnp.dot(a, b) + if use_bias: + ref_out += bias + if do_gelu: + ref_out = jax.nn.gelu(ref_out) + + assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) + @pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE) + def test_fp8_gemm(self, m, n, k, fp8_dtype): + a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs( + m, n, k, fp8_dtype + ) + + primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16) + ref_out = jnp.dot(a, b) + + assert_allclose(primitive_out, ref_out, dtype=fp8_dtype) + + @pytest.fixture(name="random_inputs") def random_inputs_fixture(shape): key = jax.random.PRNGKey(0) diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 579daa8e41..1e5cc4c07e 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -4,6 +4,7 @@ """Python interface for c++ extensions""" from .activation import * from .attention import * +from .gemm import * from .normalization import * from .quantization import * from .softmax import * diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py new file mode 100644 index 0000000000..677fabca59 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -0,0 +1,647 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for cuBlasLt GEMM""" +import warnings +import operator +from functools import reduce +from typing import Optional, Union, Tuple + +import jax +import jax.numpy as jnp +from jax import dtypes +from jax.interpreters import mlir +from jax.interpreters.mlir import ir +from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi +from jax.typing import ArrayLike + +from transformer_engine import transformer_engine_jax as tex +from .base import BasePrimitive, register_primitive +from .custom_call import custom_caller, CustomCallArgsWrapper +from .misc import ( + jax_dtype_to_te_dtype, + jax_dtype_is_fp8, + get_padded_spec, + is_ffi_enabled, +) +from ..sharding import ( + global_mesh_resource, + get_mesh_axis_size, + lax_paral_op, + all_reduce_max_along_all_axes_except_PP, +) + + +__all__ = [ + "fp8_gemm_impl", + "gemm_impl", +] + + +def get_cublas_workspace_size_bytes() -> None: + """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" + if tex.get_device_compute_capability() >= 90: + return 33_554_432 + return 4_194_304 + + +class CollectiveGemmPrimitive(BasePrimitive): + """ + cuBlasLt GEMM Primitive w/ support for distributed inputs + """ + + name = "te_gemm" + impl_static_args = (8, 9, 10, 11, 12, 13, 14) + multiple_results = True + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_aval, + gelu_input_aval, out_amax_aval, out_scale_aval, out_dtype, contracting_dims, + fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator): + """ + cuBlasLt GEMM abstract + """ + del grad, accumulate, use_split_accumulator + + # Validate operand dtypes + lhs_dtype = dtypes.canonicalize_dtype(lhs_aval.dtype) + rhs_dtype = dtypes.canonicalize_dtype(rhs_aval.dtype) + assert lhs_dtype == rhs_dtype, "Mismatched matrix dtypes for GEMM." + is_fp8 = False + if jax_dtype_is_fp8(lhs_dtype): + assert ( + lhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(lhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing LHS operand scale inverse in FP8 GEMM." + is_fp8 = True + if jax_dtype_is_fp8(rhs_dtype): + assert ( + rhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing RHS operand scale inverse in FP8 GEMM." + + # Disallow batching for RHS + assert rhs_aval.ndim == 2, "GEMM does not support batching the RHS operand." + + # Validate operand layouts + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim) + ) + assert ( + lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] + ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." + + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + assert ( + not (lhs_trans and rhs_trans) + ), "GEMM does not support transposed LHS and transposed RHS at the same time." + if is_fp8: + assert lhs_trans, "FP8 GEMM does not support transposed LHS." + assert rhs_trans, "FP8 GEMM requires transposed RHS." + + # Validate output dtype + if jax_dtype_is_fp8(out_dtype): + assert ( + jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8(rhs_dtype) + ), "FP8 GEMM output requires FP8 inputs." + assert ( + out_amax_aval.size == out_scale_aval.size == 1 + ), "Invalid/missing output amax and scale." + out_amax_updated_dtype = dtypes.canonicalize_dtype(out_amax_aval.dtype) + out_scale_updated_dtype = dtypes.canonicalize_dtype(out_scale_aval.dtype) + assert ( + out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32 + ), "Invalid output amax or scale dtype." + else: + out_dtype = lhs_dtype + out_amax_updated_dtype = jnp.float32 + out_scale_updated_dtype = jnp.float32 + + # Infer output shape + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 + lhs_bdims = [dim for dim in range(lhs_aval.ndim) + if dim not in [lhs_outer_dim, lhs_inner_dim]] + lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + + # Validate bias/bias_grad shape against inferred output + bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype + if fuse_bias: + assert ( + bias_aval.size > 0 + and bias_aval.ndim == 1 + and bias_aval.shape[0] == out_shape[-1] + ), "Incorrect bias shape." + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + else: + assert bias_aval.size == 0, "Internal TE error." + + # Validate GELU input/output + if fuse_gelu: + assert ( + all([gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)]) + ), "Invalid GELU input shape." + assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." + else: + assert gelu_input_aval.size == 0, "Internal TE error." + + # Create abstract arrays for all outputs + out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) + out_amax_updated_aval = out_amax_aval.update(shape=out_amax_aval.shape, + dtype=out_amax_updated_dtype) + out_scale_updated_aval = out_scale_aval.update(shape=out_scale_aval.shape, + dtype=out_scale_updated_dtype) + pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_input_aval.shape, dtype=bias_dtype) + bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + workspace_aval = jax.core.ShapedArray(shape=(get_cublas_workspace_size_bytes(), ), + dtype=jnp.uint8) + + return ( + out_aval, + out_amax_updated_aval, + out_scale_updated_aval, + pre_gelu_out_aval, + bias_grad_aval, + workspace_aval + ) + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + cuBlasLt GEMM outer abstract + """ + ( + out_aval, + out_amax_aval, + out_scale_aval, + pre_gelu_out_aval, + bias_grad_aval, + _ + ) = CollectiveGemmPrimitive.abstract(*args, **kwargs) + return out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval + + @staticmethod + def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, + *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, + use_split_accumulator): + """ + Fused attention fwd lowering rules + """ + lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim) + ) + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + + operand_output_aliases = { + 4: 4, # bias <--> bias_grad + 5: 3, # gelu_input <--> pre_gelu_out + 6: 1, # out_amax <--> out_amax_updated + 7: 2, # out_scale <--> out_scale_updated + } + + if is_ffi_enabled(): + name = "te_gemm_ffi" + return ffi.ffi_lowering(name, operand_output_aliases=operand_output_aliases)( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + lhs_trans=lhs_trans, + rhs_trans=rhs_trans, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator + ) + else: + operands = [ + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + ] + operand_shapes = map(lambda x: ir.RankedTensorType(x.type).shape, operands) + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_dtype(output.dtype)) + for output in ctx.avals_out + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 + lhs_bdims = [dim for dim in range(lhs_aval.ndim) + if dim not in [lhs_outer_dim, lhs_inner_dim]] + lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + m = reduce(operator.mul, lhs_batch_shape, 1) * lhs_aval.shape[lhs_outer_dim] + k = rhs_aval.shape[rhs_inner_dim] + n = rhs_aval.shape[rhs_outer_dim] + workspace_size = get_cublas_workspace_size_bytes() + operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype) + bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) + opaque = tex.pack_gemm_descriptor(m, n, k, workspace_size, operand_dtype, + jax_dtype_to_te_dtype(out_dtype), bias_dtype, + lhs_trans, rhs_trans, fuse_gelu, fuse_bias, grad, + accumulate, use_split_accumulator) + + return custom_caller( + CollectiveGemmPrimitive.name, + args, + opaque, + has_side_effect=False, + operand_output_aliases=operand_output_aliases, + ) + + @staticmethod + def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, + out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, + use_split_accumulator): + assert CollectiveGemmPrimitive.inner_primitive is not None + + ( + out, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + _, + ) = CollectiveGemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, + accumulate, use_split_accumulator): + assert CollectiveGemmPrimitive.outer_primitive is not None + + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale = batched_args + assert rhs.ndim == 2, "TE/JAX GEMM custom op does not support batching RHS operands." + + # Get contracting and batch dimensions out + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs.ndim, rhs.ndim) + ) + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] + + # FP8 GEMM only supports lhs_trans = False and rhs_trans = True so we may need to + # reorder the axes here to match + if jax_dtype_is_fp8(lhs.dtype): + lhs = jnp.transpose(lhs, (*lhs_bdims, lhs_outer_dim, lhs_inner_dim)) + lhs_trans = False + rhs = jnp.transpose(rhs, (rhs_outer_dim, rhs_inner_dim)) + rhs_trans = True + contracting_dims = (1, 1) + + # Collapse all non-contracting dimensions + batch_shape = [lhs.shape[dim] for dim in lhs_bdims] + batch_size = reduce(operator.mul, batch_shape, 1) + lhs_outer_size = lhs.shape[lhs_outer_dim] + lhs_shape_2d = ( + (lhs.shape[lhs_inner_dim], batch_size * lhs_outer_size) + if lhs_trans + else (batch_size * lhs_outer_size, lhs.shape[lhs_inner_dim]) + ) + lhs = jnp.reshape(lhs, lhs_shape_2d) + if fuse_gelu: + gelu_input = jnp.reshape( + gelu_input, (batch_size * lhs_outer_size, rhs.shape[rhs_outer_dim]) + ) + + outputs = CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # Reshape output to recover original LHS batch shape + outputs[0] = jnp.reshape( + outputs[0], + (*batch_shape, lhs_outer_size, rhs.shape[rhs_outer_dim]) + ) + gelu_bdims = batch_dims[3] + if fuse_gelu: + outputs[3] = jnp.reshape(outputs[3], outputs[0].shape) + gelu_bdims = lhs_bdims + + return ( + outputs, + (lhs_bdims, batch_dims[1], batch_dims[2], gelu_bdims, batch_dims[4]) + ) + + @staticmethod + def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, + accumulate, use_split_accumulator, mesh, arg_infos, + result_infos): + del out_dtype, accumulate, use_split_accumulator, result_infos + lhs, _, rhs, *_ = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs.ndim, rhs.ndim) + ) + if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: + warnings.warn("Forcing the inner dimension of LHS to match the sharding of inner " + + "dimension of RHS. This can trigger additional communication if LHS is " + + "not already partitioned correctly.") + + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] + batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] + rhs_outer_spec = rhs_spec[rhs_outer_dim] + + if rhs_spec[rhs_inner_dim] is not None and rhs_outer_spec is not None: + raise RuntimeError("Both inner and outer dimensions of RHS cannot be sharded.") + + # Outer (sequence) dimension of the GEMM output is always unsharded + out_spec = [*batch_specs, None, rhs_outer_spec] + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded + gelu_spec = out_spec if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + # Bias gradient spec matches outer dimension of output if bias fusion is turned on + bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + + return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) + + @staticmethod + def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, + use_split_accumulator, mesh, arg_infos, result_infos): + del result_infos + lhs, _, rhs, *_ = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs.ndim, rhs.ndim) + ) + + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] + batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] + rhs_outer_spec = rhs_spec[rhs_outer_dim] + + # Force all-gather the outer (sequence) dimension of the LHS operand + lhs_spec_new = [spec for spec in lhs_spec] + lhs_spec_new[lhs_outer_dim] = None + lhs_spec_new[lhs_inner_dim] = rhs_spec[rhs_inner_dim] + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) + + # RHS operand is unchanged, we already enforce that only one dimension can be sharded + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec)) + + # Bias is sharded to match outer dimension spec of the RHS operand (also the output) + bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Outer (sequence) dimension of the GEMM output is always unsharded + out_spec = [*batch_specs, None, rhs_outer_spec] + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded + gelu_spec = out_spec if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + arg_shardings = (lhs_sharding, fp8_meta_sharding, rhs_sharding, fp8_meta_sharding, + bias_sharding, gelu_sharding, fp8_meta_sharding, fp8_meta_sharding) + out_shardings = (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, + bias_sharding) + + def sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, + out_scale): + ( + out, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + ) = CollectiveGemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # FP8 amax reduction + if jax_dtype_is_fp8(lhs.dtype): + out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) + + if rhs_spec[rhs_inner_dim] is not None: + # GEMM output needs to be all-reduced when the contracting dimension is sharded. + # If the layer is sequence-parallel, we also need to scatter the output, which we + # can combine into a reduce-scatter here. + out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, + mesh) + if fuse_gelu: + pre_gelu_out = lax_paral_op( + pre_gelu_out, jax.lax.psum, global_mesh_resource().cp_resource, mesh + ) + + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(CollectiveGemmPrimitive) + + +def fp8_gemm_impl( + lhs: ArrayLike, + lhs_scale_inv: ArrayLike, + rhs: ArrayLike, + rhs_scale_inv: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + out_amax: Optional[ArrayLike] = None, + out_scale: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + contracting_dims: Tuple[int, int] = (1, 1), + fuse_gelu: bool = False, + fuse_bias: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> Tuple[ArrayLike, ...]: + """FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + if out_dtype is not None and jax_dtype_is_fp8(out_dtype): + assert out_amax is not None and out_scale is not None, "Missing output amax and scale." + else: + out_amax = jnp.zeros(0, dtype=jnp.float32) + out_scale = jnp.zeros(0, dtype=jnp.float32) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=jnp.bfloat16) + else: + assert ( + bias is not None + ), "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=bias.dtype) + elif gelu_input is None: + lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 + rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) + + out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( + rhs, + rhs_scale_inv, + lhs, + lhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=tuple(reversed(contracting_dims)), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=False, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + return out, out_amax, out_scale, pre_gelu_out + + +def gemm_impl( + lhs: ArrayLike, + rhs: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + contracting_dims: Tuple[int, int] = (1, 0), + fuse_gelu: bool = False, + fuse_bias: bool = False, + grad: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> Tuple[ArrayLike, ...]: + """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) + + lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 + rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + bias = jnp.zeros(out_shape[-1], dtype=lhs.dtype) + else: + assert ( + bias is not None + ), "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + assert ( + gelu_input is not None + ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." + elif gelu_input is None: + lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 + rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) + + out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + dummy_fp8_meta, + rhs, + dummy_fp8_meta, + bias, + gelu_input, + dummy_fp8_meta, + dummy_fp8_meta, + out_dtype=lhs.dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if grad: + return out, pre_gelu_out, bias_grad + else: + return out, pre_gelu_out diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 1f13484b98..15d7537fbd 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -81,6 +81,13 @@ def jax_dtype_to_te_dtype(jax_dtype): return converter.get(jax_dtype) +def jax_dtype_is_fp8(dtype): + """ + Check if the given jax.numpy.dtype is an FP8 dtype. + """ + return dtypes.canonicalize_dtype(dtype) in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + def get_padded_spec(arg_info): """ Get padded spec for partitioning from arguments' information diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 02e6aaf9d5..afac283a6f 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -147,6 +147,31 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right); +struct CustomCallGemmDescriptor { + size_t batch; + size_t m; + size_t k; + size_t n; + size_t workspace_size; + DType operand_dtype; + DType bias_dtype; + DType out_dtype; + bool lhs_trans; + bool rhs_trans; + bool fuse_gelu; + bool fuse_bias; + bool grad; + bool accumulate; + bool use_split_accumulator; +}; + +pybind11::bytes PackCustomCallGemmDescriptor(size_t batch, size_t m, size_t n, size_t k, + size_t workspace_size, DType operand_dtype, + DType out_dtype, DType bias_dtype, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, + bool grad, bool accumulate, + bool use_split_accumulator); + // Transpose void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -308,6 +333,20 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); +// GEMM + +void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp new file mode 100644 index 0000000000..f60ae510df --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -0,0 +1,170 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/gemm.h" + +#include "common/util/cuda_runtime.h" +#include "common/util/system.h" +#include "extensions.h" + +namespace transformer_engine { + +namespace jax { + +void GemmImpl(cudaStream_t stream, void *lhs, const std::vector &lhs_shape, + float *lhs_scale_inv, bool lhs_trans, void *rhs, const std::vector &rhs_shape, + float *rhs_scale_inv, bool rhs_trans, DType operand_dtype, void *bias, + DType bias_dtype, void *out, float *out_amax, float *out_scale, DType out_dtype, + void *pre_gelu_out, void *workspace, size_t workspace_size, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator) { + auto lhs_ = TensorWrapper(lhs, lhs_shape, operand_dtype, nullptr, nullptr, lhs_scale_inv); + auto rhs_ = TensorWrapper(rhs, rhs_shape, operand_dtype, nullptr, nullptr, rhs_scale_inv); + + std::vector out_shape(2, 0); + out_shape[0] = (lhs_trans) ? lhs_shape[1] : lhs_shape[0]; + out_shape[1] = (rhs_trans) ? rhs_shape[0] : rhs_shape[1]; + auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); + + void *bias_ptr = (fuse_bias) ? bias : nullptr; + std::vector bias_shape = (fuse_bias) ? std::vector{out_shape[1]} + : std::vector{0}; + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + void *pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; + std::vector pre_gelu_shape = (fuse_gelu) ? out_shape : std::vector{0}; + auto pre_gelu_out_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, bias_dtype); + auto workspace_ = TensorWrapper(workspace, std::vector{workspace_size}, DType::kByte); + + // cuBLAS is column-major, so we swap LHS and RHS in the arguments + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_out_.data(), + (rhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, (lhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, + grad, workspace_.data(), accumulate, use_split_accumulator, num_math_sm, stream); +} + +void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + // Inputs + auto *lhs = buffers[0]; + auto *lhs_scale_inv = reinterpret_cast(buffers[1]); + auto *rhs = buffers[2]; + auto *rhs_scale_inv = reinterpret_cast(buffers[3]); + auto *bias = buffers[4]; + auto *gelu_input = buffers[5]; + auto *out_amax = reinterpret_cast(buffers[6]); + auto *out_scale = reinterpret_cast(buffers[7]); + + // Outputs + auto *out = buffers[8]; + auto *out_amax_updated = reinterpret_cast(buffers[9]); + auto *out_scale_updated = reinterpret_cast(buffers[10]); + auto *pre_gelu_out = buffers[11]; + auto *bias_grad = buffers[12]; + auto *workspace = buffers[13]; + + // Operand aliasing + NVTE_CHECK(bias == bias_grad, + "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input == pre_gelu_out, + "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax == out_amax_updated, + "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(out_scale == out_scale_updated, + "out_scale not bound to out_scale_updated in TE/JAX GEMM"); + + // GEMM sizing + const auto &desc = *UnpackOpaque(opaque, opaque_len); + std::vector lhs_shape = {(desc.lhs_trans) ? desc.k : desc.m, + (desc.lhs_trans) ? desc.m : desc.k}; + std::vector rhs_shape = {(desc.rhs_trans) ? desc.n : desc.k, + (desc.rhs_trans) ? desc.k : desc.n}; + + GemmImpl(stream, lhs, lhs_shape, lhs_scale_inv, desc.lhs_trans, rhs, rhs_shape, rhs_scale_inv, + desc.rhs_trans, desc.operand_dtype, bias, desc.bias_dtype, out, out_amax, out_scale, + desc.out_dtype, pre_gelu_out, workspace, desc.workspace_size, desc.fuse_gelu, + desc.fuse_bias, desc.grad, desc.accumulate, desc.use_split_accumulator); +} + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator) { + // Inputs + auto lhs_ptr = lhs.untyped_data(); + auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); + auto rhs_ptr = rhs.untyped_data(); + auto rhs_scale_inv_ptr = reinterpret_cast(rhs_scale_inv.untyped_data()); + auto operand_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + auto bias_ptr = bias.untyped_data(); + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + auto gelu_input_ptr = gelu_input.untyped_data(); + auto out_amax_ptr = reinterpret_cast(out_amax.untyped_data()); + auto out_scale_ptr = reinterpret_cast(out_scale.untyped_data()); + + // Outputs + auto out_ptr = out->untyped_data(); + auto out_amax_updated_ptr = reinterpret_cast(out_amax_updated->untyped_data()); + auto out_scale_updated_ptr = reinterpret_cast(out_scale_updated->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(out->element_type()); + auto pre_gelu_out_ptr = pre_gelu_out->untyped_data(); + auto bias_grad_ptr = bias_grad->untyped_data(); + auto workspace_ptr = workspace->untyped_data(); + auto workspace_size = workspace->dimensions().back(); + + // Operand aliasing + NVTE_CHECK(bias_ptr == bias_grad_ptr, + "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input_ptr == pre_gelu_out_ptr, + "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, + "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(out_scale_ptr == out_scale_updated_ptr, + "out_scale not bound to out_scale_updated in TE/JAX GEMM"); + + // GEMM sizing + std::vector lhs_shape(lhs.dimensions().begin(), lhs.dimensions().end()); + std::vector rhs_shape(rhs.dimensions().begin(), rhs.dimensions().end()); + + // Swap A and B argument locations to match what the TE/common kernel expects + GemmImpl(stream, lhs_ptr, lhs_shape, lhs_scale_inv_ptr, lhs_trans, rhs_ptr, rhs_shape, + rhs_scale_inv_ptr, rhs_trans, operand_dtype, bias_ptr, bias_dtype, out_ptr, out_amax_ptr, + out_scale_ptr, out_dtype, pre_gelu_out_ptr, workspace_ptr, workspace_size, fuse_gelu, + fuse_bias, grad, accumulate, use_split_accumulator); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Arg() // out_amax + .Arg() // out_scale + .Ret() // out + .Ret() // out_amax_updated + .Ret() // out_scale_updated + .Ret() // pre_gelu_out + .Ret() // bias_grad + .Ret() // workspace + .Attr("lhs_trans") + .Attr("rhs_trans") + .Attr("fuse_gelu") + .Attr("fuse_bias") + .Attr("grad") + .Attr("accumulate") + .Attr("use_split_accumulator"), + FFI_CudaGraph_Traits); + +} // namespace jax + +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 298478603b..1a9ce987af 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -80,5 +80,16 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( deterministic, window_size_left, window_size_right}); } +pybind11::bytes PackCustomCallGemmDescriptor(size_t batch, size_t m, size_t n, size_t k, + size_t workspace_size, DType operand_dtype, + DType bias_dtype, DType out_dtype, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, + bool grad, bool accumulate, + bool use_split_accumulator) { + return PackOpaque(CustomCallGemmDescriptor{batch, m, n, k, workspace_size, operand_dtype, + bias_dtype, out_dtype, lhs_trans, rhs_trans, fuse_gelu, + fuse_bias, grad, accumulate, use_split_accumulator}); +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 9b5c156e5d..7b8ebdcdd2 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -51,6 +51,7 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); + dict["te_gemm"] = EncapsulateFunction(Gemm); // Transpose dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); @@ -101,6 +102,7 @@ pybind11::dict Registrations() { fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler); dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi; + dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); return dict; } @@ -114,10 +116,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); + m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor); m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion); - m.def("get_device_compute_capability", &GetDeviceComputeCapability); + m.def("get_device_compute_capability", &GetDeviceComputeCapability, pybind11::arg("gpu_id") = -1); m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes); diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index 32de33bac9..b328c6e278 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -23,7 +23,7 @@ namespace jax { int GetCudaRuntimeVersion(); size_t GetCudnnRuntimeVersion(); -int GetDeviceComputeCapability(int gpu_id); +int GetDeviceComputeCapability(int gpu_id = -1); void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8b13c47cd4..7312aa8295 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -334,6 +334,7 @@ def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage: input_name_post_fix = f"_i_{postfix}" weight_name_post_fix = f"_w_{postfix}" grad_name_post_fix = f"_g_{postfix}" + output_name_post_fix = f"_o_{postfix}" def generate_a_set(target_postfix): amax = nn_partitioning.variable_with_axes( @@ -359,10 +360,10 @@ def generate_a_set(target_postfix): input_amax, input_scale = generate_a_set(input_name_post_fix) weight_amax, weight_scale = generate_a_set(weight_name_post_fix) grad_amax, grad_scale = generate_a_set(grad_name_post_fix) + output_amax, output_scale = generate_a_set(output_name_post_fix) - return FP8MetaPackage( - input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale - ) + return FP8MetaPackage(input_amax, input_scale, weight_amax, weight_scale, grad_amax, + grad_scale, output_amax, output_scale) class DenseGeneral(TransformerEngineBase): diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 5df8ce4386..3d58c86e3e 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -86,10 +86,11 @@ class FP8MetaPackage: A container that contains all required meta data for FP8 """ - NUM_OF_META: int = 3 + NUM_OF_META: int = 4 INPUT_IDX: int = 0 WEIGHT_IDX: int = 1 GRAD_IDX: int = 2 + OUTPUT_IDX: int = 3 def __init__( self, @@ -99,6 +100,8 @@ def __init__( weight_scale: jnp.ndarray, grad_amax: jnp.ndarray, grad_scale: jnp.ndarray, + output_amax: jnp.ndarray, + output_scale: jnp.ndarray, ) -> None: self._amax_list = [None] * FP8MetaPackage.NUM_OF_META @@ -110,6 +113,8 @@ def __init__( self._scale_list[FP8MetaPackage.WEIGHT_IDX] = weight_scale self._amax_list[FP8MetaPackage.GRAD_IDX] = grad_amax self._scale_list[FP8MetaPackage.GRAD_IDX] = grad_scale + self._amax_list[FP8MetaPackage.OUTPUT_IDX] = output_amax + self._scale_list[FP8MetaPackage.OUTPUT_IDX] = output_scale @property def amax_list(self) -> List[jnp.ndarray]: diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py new file mode 100644 index 0000000000..ccd109e095 --- /dev/null +++ b/transformer_engine/jax/gemm.py @@ -0,0 +1,425 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +from functools import partial +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike +from jax.ad_checkpoint import checkpoint_name + +from .fp8 import FP8Helper, FP8MetaPackage +from .cpp_extensions import ( + gemm_impl, + fp8_gemm_impl, + cast_fp8, + cast_transpose, + dact_lu, + dbias_cast_transpose, + dact_lu_dbias_cast_transpose, +) + + + +__all__ = [ + "gemm", + "fp8_gemm", + "type_safe_gemm", +] + + +def gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Optional[ArrayLike] = None, + contracting_dims: Tuple[int, int] = (1, 0), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + """Non-FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions.""" + return _gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +def _gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Union[ArrayLike, None], + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> ArrayLike: + out, _ = _gemm_fwd_rule(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, + use_split_accumulator) + return out + + +def _gemm_fwd_rule( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> Tuple[ArrayLike, ...]: + fuse_bias = bias is not None + + out, pre_gelu_out = gemm_impl( + x, + kernel, + bias=bias, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator + ) + + ctx = ( + x, + kernel, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + ) + + return out, ctx + + +def _gemm_bwd_rule( + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ctx, + grad, +): + x, kernel, pre_gelu_out, fuse_bias = ctx + + x_t_contracting = 0 if contracting_dims[0] == 1 else 1 + wgrad, dgelu, bgrad = gemm_impl( + x, + grad, + gelu_input=pre_gelu_out, + contracting_dims=(x_t_contracting, 0), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + kernel_t_contracting = 1 if contracting_dims[1] == 0 else 0 + dgrad, *_ = gemm_impl( + dgelu if fuse_gelu else grad, + kernel, + gelu_input=pre_gelu_out, + contracting_dims=(1, kernel_t_contracting), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if not fuse_bias: + bgrad = None + + return dgrad, wgrad, bgrad + + +_gemm.defvjp(_gemm_fwd_rule, _gemm_bwd_rule) + + +def fp8_gemm( + x: ArrayLike, + kernel: ArrayLike, + fp8_meta: FP8MetaPackage, + bias: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + contracting_dims: Tuple[int, int] = (1, 1), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + return _fp8_gemm(x, kernel, bias, fp8_meta.amax_list, fp8_meta.scale_list, out_dtype, + contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + + +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) +def _fp8_gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + amax_list: ArrayLike, + scale_list: ArrayLike, + out_dtype: jnp.dtype, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> ArrayLike: + """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" + out, _ = _fp8_gemm_fwd_rule(x, kernel, bias, amax_list, scale_list, out_dtype, + contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + return out + + +def _fp8_gemm_fwd_rule( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + amax_list: ArrayLike, + scale_list: ArrayLike, + out_dtype: jnp.dtype, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> Tuple[ArrayLike, ...]: + fuse_bias = bias is not None + + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( + *amax_list, *scale_list, + ) + amax_list = maybe_fm32_to_fp32(*amax_list) + scale_list = maybe_fm32_to_fp32(*scale_list) + + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype, fwd_dtype] + scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale( + amax_list, scale_list, fp8_dtype_list + ) + amax_list = FP8MetaPackage.update_amax_list(amax_list) + + x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1] + x_scale = scale_list[FP8MetaPackage.INPUT_IDX] + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + if x.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + if contracting_dims[0] == 0: + _, casted_x, updated_x_amax = cast_transpose( + x, + x_amax, + x_scale, + x_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + casted_x, updated_x_amax = cast_fp8(x, x_amax, x_scale, x_scale_inv, fwd_dtype) + else: + if contracting_dims[0] == 0: + casted_x_t = x + casted_x = casted_x_t.transpose() + else: + casted_x = x + updated_x_amax = x_amax + + kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1] + kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX] + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + if kernel.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + if contracting_dims[1] == 0: # need to transpose the kernel for FP8 GEMM + _, casted_kernel_t, updated_kernel_amax = cast_transpose( + kernel, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + casted_kernel_t, updated_kernel_amax = cast_fp8( + kernel, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + ) + else: + if contracting_dims[1] == 0: + casted_kernel = kernel + casted_kernel_t = casted_kernel.transpose() + else: + casted_kernel_t = kernel + updated_kernel_amax = kernel_amax + + out_amax = ( + amax_list[FP8MetaPackage.OUTPUT_IDX][0:1] + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + else None + ) + out_scale = ( + scale_list[FP8MetaPackage.OUTPUT_IDX][0:1] + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + else None + ) + out, updated_out_amax, updated_out_scale, pre_gelu_out = fp8_gemm_impl( + casted_x, + x_scale_inv, + casted_kernel_t, + kernel_scale_inv, + bias=bias, + out_amax=out_amax, + out_scale=out_scale, + out_dtype=out_dtype, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator + ) + if out_dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + updated_out_amax = None + updated_out_scale = None + + ctx = ( + casted_x, + casted_kernel_t, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + maybe_fp32_to_fm32 + ) + + return (out, updated_out_amax, updated_out_scale), ctx + + +def _fp8_gemm_bwd_rule( + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ctx, + grad, +): + ( + casted_x, + casted_kernel_t, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + pre_gelu_out, + fuse_bias, + maybe_fp32_to_fm32 + ) = ctx + + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + + grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] + grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] + grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] + if fuse_bias and not fuse_gelu: + # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. + _, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + # If both bias and GELU is fused into the forward pass, we will fuse dbias later with + # dGELU. No need to do it here. + _, casted_grad_t, updated_grad_amax = cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + bgrad = None + + + + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + wgrad, *_ = fp8_gemm_impl( + casted_x, + x_scale_inv, + casted_grad_t, + grad_scale_inv, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if fuse_gelu and fuse_bias: + # Fuse dbias into this dGELU. + casted_dgelu, casted_dgelu_t, bgrad, updated_dgelu_amax = dact_lu_dbias_cast_transpose( + grad, + pre_gelu_out, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + activation_type=("gelu", ), + ) + elif fuse_gelu: + # No bias to fuse so we just do dGELU. + casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu", )) + bgrad = None + + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + dgrad, *_ = gemm_impl( + casted_dgelu if fuse_gelu else grad, + grad_scale_inv, + casted_kernel_t, + kernel_scale_inv, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + amax_list[FP8MetaPackage.INPUT_IDX] = ( + amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) + ) + amax_list[FP8MetaPackage.WEIGHT_IDX] = ( + amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0]) + ) + + amax_list = maybe_fp32_to_fm32(*amax_list) + scale_list = maybe_fp32_to_fm32(*scale_list) + + return dgrad, wgrad, bgrad, amax_list, scale_list + + +_fp8_gemm.defvjp(_fp8_gemm_fwd_rule, _fp8_gemm_bwd_rule) + + +def type_safe_gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Optional[ArrayLike] = None, + fp8_meta: Optional[FP8MetaPackage] = None, + out_dtype: Optional[jnp.dtype] = None, + contracting_dims: Tuple[int, int] = (1, 0), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): + assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." + + if fp8_meta is not None: + return fp8_gemm(x, kernel, bias, fp8_meta, out_dtype, contracting_dims, fuse_gelu, + accumulate, use_split_accumulator) + else: + return gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) From c9774d8c203d5b0f5769f47daf70e0c655d0d110 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 14 Nov 2024 17:59:20 +0000 Subject: [PATCH 02/19] fixed batching rules to accommodated batched RHS operand for GEMM Signed-off-by: Alp Dener --- .../common/util/pybind_helper.h | 138 ++++++++++-------- transformer_engine/jax/cpp_extensions/gemm.py | 133 ++++++----------- .../jax/csrc/extensions/pybind.cpp | 59 +------- 3 files changed, 123 insertions(+), 207 deletions(-) diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 432ac815ec..a36ff3f0f9 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -8,72 +8,88 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include +#include #include #include #include #include "cuda_runtime.h" -#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType") \ - .value("kByte", transformer_engine::DType::kByte) \ - .value("kInt32", transformer_engine::DType::kInt32) \ - .value("kFloat32", transformer_engine::DType::kFloat32) \ - .value("kFloat16", transformer_engine::DType::kFloat16) \ - .value("kBFloat16", transformer_engine::DType::kBFloat16) \ - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ - pybind11::enum_(m, "NVTE_Bias_Type") \ - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ - pybind11::enum_(m, "NVTE_Mask_Type") \ - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_QKV_Layout") \ - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_(m, "CommOverlapType") \ - .value("RS", transformer_engine::CommOverlapType::RS) \ - .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo") \ - .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ - .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ - .value("SPLIT_PIPELINED_AG_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ - .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ - .value("SPLIT_PIPELINED_RS_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ - .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ - .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType") \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type") \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type") \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Format") \ + .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ + .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ + .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); \ + pybind11::enum_(m, "NVTE_QKV_Layout") \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "NVTE_Activation_Type") \ + .value("GELU", NVTE_Activation_Type::GELU) \ + .value("GEGLU", NVTE_Activation_Type::GEGLU) \ + .value("SILU", NVTE_Activation_Type::SILU) \ + .value("SWIGLU", NVTE_Activation_Type::SWIGLU) \ + .value("RELU", NVTE_Activation_Type::RELU) \ + .value("REGLU", NVTE_Activation_Type::REGLU) \ + .value("QGELU", NVTE_Activation_Type::QGELU) \ + .value("QGEGLU", NVTE_Activation_Type::QGEGLU) \ + .value("SRELU", NVTE_Activation_Type::SRELU) \ + .value("SREGLU", NVTE_Activation_Type::SREGLU); \ + pybind11::enum_(m, "CommOverlapType") \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo") \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + pybind11::call_guard()); #endif diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 677fabca59..ceafce46e1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -24,10 +24,10 @@ jax_dtype_is_fp8, get_padded_spec, is_ffi_enabled, + check_valid_batch_dims, ) from ..sharding import ( global_mesh_resource, - get_mesh_axis_size, lax_paral_op, all_reduce_max_along_all_axes_except_PP, ) @@ -83,9 +83,6 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 ), "Missing RHS operand scale inverse in FP8 GEMM." - # Disallow batching for RHS - assert rhs_aval.ndim == 2, "GEMM does not support batching the RHS operand." - # Validate operand layouts lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, @@ -97,12 +94,12 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 assert ( not (lhs_trans and rhs_trans) ), "GEMM does not support transposed LHS and transposed RHS at the same time." if is_fp8: - assert lhs_trans, "FP8 GEMM does not support transposed LHS." + assert not lhs_trans, "FP8 GEMM does not support transposed LHS." assert rhs_trans, "FP8 GEMM requires transposed RHS." # Validate output dtype @@ -124,11 +121,18 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av out_scale_updated_dtype = jnp.float32 # Infer output shape - rhs_outer_dim = 0 if rhs_trans else 1 lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 lhs_bdims = [dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + rhs_bdims = [dim for dim in range(rhs_aval.ndim) + if dim not in [rhs_outer_dim, rhs_inner_dim]] + rhs_batch_size = reduce(operator.mul, rhs_bdims, 1) + assert ( + lhs_batch_size == rhs_batch_size + ), "LHS and RHS operands must have the same batched sizes." out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) # Validate bias/bias_grad shape against inferred output @@ -201,7 +205,7 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ (lhs_aval.ndim, rhs_aval.ndim) ) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 operand_output_aliases = { 4: 4, # bias <--> bias_grad @@ -248,12 +252,9 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - rhs_outer_dim = 0 if rhs_trans else 1 lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - lhs_bdims = [dim for dim in range(lhs_aval.ndim) - if dim not in [lhs_outer_dim, lhs_inner_dim]] - lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] - m = reduce(operator.mul, lhs_batch_shape, 1) * lhs_aval.shape[lhs_outer_dim] + rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + m = lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] n = rhs_aval.shape[rhs_outer_dim] workspace_size = get_cublas_workspace_size_bytes() @@ -308,77 +309,32 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator): assert CollectiveGemmPrimitive.outer_primitive is not None + check_valid_batch_dims(batch_dims) + lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims - lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale = batched_args - assert rhs.ndim == 2, "TE/JAX GEMM custom op does not support batching RHS operands." - - # Get contracting and batch dimensions out - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs.ndim, rhs.ndim) - ) - lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == 1 - lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = 0 if rhs_trans else 1 - lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] - - # FP8 GEMM only supports lhs_trans = False and rhs_trans = True so we may need to - # reorder the axes here to match - if jax_dtype_is_fp8(lhs.dtype): - lhs = jnp.transpose(lhs, (*lhs_bdims, lhs_outer_dim, lhs_inner_dim)) - lhs_trans = False - rhs = jnp.transpose(rhs, (rhs_outer_dim, rhs_inner_dim)) - rhs_trans = True - contracting_dims = (1, 1) - - # Collapse all non-contracting dimensions - batch_shape = [lhs.shape[dim] for dim in lhs_bdims] - batch_size = reduce(operator.mul, batch_shape, 1) - lhs_outer_size = lhs.shape[lhs_outer_dim] - lhs_shape_2d = ( - (lhs.shape[lhs_inner_dim], batch_size * lhs_outer_size) - if lhs_trans - else (batch_size * lhs_outer_size, lhs.shape[lhs_inner_dim]) - ) - lhs = jnp.reshape(lhs, lhs_shape_2d) - if fuse_gelu: - gelu_input = jnp.reshape( - gelu_input, (batch_size * lhs_outer_size, rhs.shape[rhs_outer_dim]) - ) - - outputs = CollectiveGemmPrimitive.outer_primitive.bind( - lhs, - lhs_scale_inv, - rhs, - rhs_scale_inv, - bias, - gelu_input, - out_amax, - out_scale, - out_dtype=out_dtype, - contracting_dims=contracting_dims, - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - grad=grad, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - - # Reshape output to recover original LHS batch shape - outputs[0] = jnp.reshape( - outputs[0], - (*batch_shape, lhs_outer_size, rhs.shape[rhs_outer_dim]) - ) - gelu_bdims = batch_dims[3] - if fuse_gelu: - outputs[3] = jnp.reshape(outputs[3], outputs[0].shape) - gelu_bdims = lhs_bdims + # FP8 GEMM only supports non-transposed LHS and transposed RHS + lhs, _, rhs, *_ = batched_args + lhs_trans = contracting_dims[0] != lhs.ndim - 1 + rhs_trans = contracting_dims[1] == rhs.ndim - 1 + lhs = jnp.matrix_transpose(lhs) if lhs_trans and jax_dtype_is_fp8(lhs.dtype) else lhs + rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs + contracting_dims = (1, 1) return ( - outputs, - (lhs_bdims, batch_dims[1], batch_dims[2], gelu_bdims, batch_dims[4]) + CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + batched_args[1], + rhs, + *batched_args[3:], + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) ) @staticmethod @@ -400,9 +356,9 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bi + "not already partitioned correctly.") lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs.ndim - 1 lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = 0 if rhs_trans else 1 + rhs_outer_dim = rhs.ndim - 2 if rhs_trans else rhs.ndim - 1 lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] @@ -440,9 +396,9 @@ def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulat ) lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs.ndim - 1 lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = 0 if rhs_trans else 1 + rhs_outer_dim = rhs.ndim - 2 if rhs_trans else rhs.ndim - 1 lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] @@ -558,7 +514,7 @@ def fp8_gemm_impl( gelu_input = jnp.zeros(0, dtype=bias.dtype) elif gelu_input is None: lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) @@ -599,7 +555,7 @@ def gemm_impl( dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: @@ -618,9 +574,6 @@ def gemm_impl( gelu_input is not None ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." elif gelu_input is None: - lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 - out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 7b8ebdcdd2..ddf98d9d78 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include "common/util/pybind_helper.h" #include "extensions.h" namespace transformer_engine { @@ -107,6 +108,8 @@ pybind11::dict Registrations() { } PYBIND11_MODULE(transformer_engine_jax, m) { + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + m.def("registrations", &Registrations); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor, pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0); @@ -129,62 +132,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); - - pybind11::enum_(m, "DType", pybind11::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kInt64", DType::kInt64) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); - - pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - - pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) - .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) - .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) - .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); - - pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) - .value("GELU", NVTE_Activation_Type::GELU) - .value("GEGLU", NVTE_Activation_Type::GEGLU) - .value("SILU", NVTE_Activation_Type::SILU) - .value("SWIGLU", NVTE_Activation_Type::SWIGLU) - .value("RELU", NVTE_Activation_Type::RELU) - .value("REGLU", NVTE_Activation_Type::REGLU) - .value("QGELU", NVTE_Activation_Type::QGELU) - .value("QGEGLU", NVTE_Activation_Type::QGEGLU) - .value("SRELU", NVTE_Activation_Type::SRELU) - .value("SREGLU", NVTE_Activation_Type::SREGLU) - .export_values(); - - pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); } } // namespace jax From e523018a8f7e3de2e1e4ab2a989eb6e13ca4a9b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:14:24 +0000 Subject: [PATCH 03/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 15 +- transformer_engine/jax/cpp_extensions/gemm.py | 275 ++++++++++++------ .../jax/csrc/extensions/gemm.cpp | 16 +- transformer_engine/jax/flax/module.py | 12 +- transformer_engine/jax/gemm.py | 70 +++-- 5 files changed, 254 insertions(+), 134 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9bf3f9fa91..355f587265 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -425,19 +425,16 @@ def _generate_inputs(b, m, n, k, dtype): a = jax.random.normal(subkeys[0], (b, m, k), dtype) b = jax.random.normal(subkeys[1], (n, k), dtype) bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16 - bias = jax.random.normal(subkeys[2], (n, ), bias_dtype) + bias = jax.random.normal(subkeys[2], (n,), bias_dtype) return a, b, bias @staticmethod def _generate_fp8_inputs(b, m, n, k, fp8_dtype): a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16) - a_scale, b_scale = map( - lambda x: (jnp.max(jnp.abs(x)) / 127.).astype(jnp.float32), - [a, b] - ) + a_scale, b_scale = map(lambda x: (jnp.max(jnp.abs(x)) / 127.0).astype(jnp.float32), [a, b]) a_q, b_q = map( lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype), - [(a, a_scale), (b, b_scale)] + [(a, a_scale), (b, b_scale)], ) return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias @@ -447,7 +444,7 @@ def _generate_fp8_inputs(b, m, n, k, fp8_dtype): def test_gemm(self, b, m, n, k, use_bias, do_gelu): a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16) - primitive_out = gemm(a, b, bias=bias if use_bias else None, layout='NT', do_gelu=do_gelu) + primitive_out = gemm(a, b, bias=bias if use_bias else None, layout="NT", do_gelu=do_gelu) ref_out = jnp.dot(a, b) if use_bias: ref_out += bias @@ -460,9 +457,7 @@ def test_gemm(self, b, m, n, k, use_bias, do_gelu): @pytest.mark.parametrize("m,n,k", GEMM_CASES) @pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE) def test_fp8_gemm(self, m, n, k, fp8_dtype): - a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs( - m, n, k, fp8_dtype - ) + a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs(m, n, k, fp8_dtype) primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16) ref_out = jnp.dot(a, b) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ceafce46e1..2df05d6df4 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -58,9 +58,23 @@ class CollectiveGemmPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_aval, - gelu_input_aval, out_amax_aval, out_scale_aval, out_dtype, contracting_dims, - fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator): + def abstract( + lhs_aval, + lhs_scale_inv_aval, + rhs_aval, + rhs_scale_inv_aval, + bias_aval, + gelu_input_aval, + out_amax_aval, + out_scale_aval, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): """ cuBlasLt GEMM abstract """ @@ -87,7 +101,7 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim) + (lhs_aval.ndim, rhs_aval.ndim), ) assert ( lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] @@ -95,8 +109,8 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 - assert ( - not (lhs_trans and rhs_trans) + assert not ( + lhs_trans and rhs_trans ), "GEMM does not support transposed LHS and transposed RHS at the same time." if is_fp8: assert not lhs_trans, "FP8 GEMM does not support transposed LHS." @@ -104,8 +118,8 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Validate output dtype if jax_dtype_is_fp8(out_dtype): - assert ( - jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8(rhs_dtype) + assert jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8( + rhs_dtype ), "FP8 GEMM output requires FP8 inputs." assert ( out_amax_aval.size == out_scale_aval.size == 1 @@ -122,13 +136,15 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Infer output shape lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - lhs_bdims = [dim for dim in range(lhs_aval.ndim) - if dim not in [lhs_outer_dim, lhs_inner_dim]] + lhs_bdims = [ + dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] + ] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 - rhs_bdims = [dim for dim in range(rhs_aval.ndim) - if dim not in [rhs_outer_dim, rhs_inner_dim]] + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] rhs_batch_size = reduce(operator.mul, rhs_bdims, 1) assert ( lhs_batch_size == rhs_batch_size @@ -139,9 +155,7 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype if fuse_bias: assert ( - bias_aval.size > 0 - and bias_aval.ndim == 1 - and bias_aval.shape[0] == out_shape[-1] + bias_aval.size > 0 and bias_aval.ndim == 1 and bias_aval.shape[0] == out_shape[-1] ), "Incorrect bias shape." bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) else: @@ -149,8 +163,8 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Validate GELU input/output if fuse_gelu: - assert ( - all([gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)]) + assert all( + [gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)] ), "Invalid GELU input shape." assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." else: @@ -158,14 +172,17 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Create abstract arrays for all outputs out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) - out_amax_updated_aval = out_amax_aval.update(shape=out_amax_aval.shape, - dtype=out_amax_updated_dtype) - out_scale_updated_aval = out_scale_aval.update(shape=out_scale_aval.shape, - dtype=out_scale_updated_dtype) + out_amax_updated_aval = out_amax_aval.update( + shape=out_amax_aval.shape, dtype=out_amax_updated_dtype + ) + out_scale_updated_aval = out_scale_aval.update( + shape=out_scale_aval.shape, dtype=out_scale_updated_dtype + ) pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_input_aval.shape, dtype=bias_dtype) bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - workspace_aval = jax.core.ShapedArray(shape=(get_cublas_workspace_size_bytes(), ), - dtype=jnp.uint8) + workspace_aval = jax.core.ShapedArray( + shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 + ) return ( out_aval, @@ -173,7 +190,7 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av out_scale_updated_aval, pre_gelu_out_aval, bias_grad_aval, - workspace_aval + workspace_aval, ) @staticmethod @@ -181,20 +198,31 @@ def outer_abstract(*args, **kwargs): """ cuBlasLt GEMM outer abstract """ - ( - out_aval, - out_amax_aval, - out_scale_aval, - pre_gelu_out_aval, - bias_grad_aval, - _ - ) = CollectiveGemmPrimitive.abstract(*args, **kwargs) + (out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval, _) = ( + CollectiveGemmPrimitive.abstract(*args, **kwargs) + ) return out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval @staticmethod - def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, - *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, - use_split_accumulator): + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + *, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): """ Fused attention fwd lowering rules """ @@ -202,7 +230,7 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim) + (lhs_aval.ndim, rhs_aval.ndim), ) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 @@ -232,7 +260,7 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ fuse_bias=fuse_bias, grad=grad, accumulate=accumulate, - use_split_accumulator=use_split_accumulator + use_split_accumulator=use_split_accumulator, ) else: operands = [ @@ -260,10 +288,22 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ workspace_size = get_cublas_workspace_size_bytes() operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype) bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) - opaque = tex.pack_gemm_descriptor(m, n, k, workspace_size, operand_dtype, - jax_dtype_to_te_dtype(out_dtype), bias_dtype, - lhs_trans, rhs_trans, fuse_gelu, fuse_bias, grad, - accumulate, use_split_accumulator) + opaque = tex.pack_gemm_descriptor( + m, + n, + k, + workspace_size, + operand_dtype, + jax_dtype_to_te_dtype(out_dtype), + bias_dtype, + lhs_trans, + rhs_trans, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ) return custom_caller( CollectiveGemmPrimitive.name, @@ -274,9 +314,23 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ ) @staticmethod - def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, - out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, - use_split_accumulator): + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): assert CollectiveGemmPrimitive.inner_primitive is not None ( @@ -306,13 +360,23 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, - accumulate, use_split_accumulator): + def batcher( + batched_args, + batch_dims, + *, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): assert CollectiveGemmPrimitive.outer_primitive is not None check_valid_batch_dims(batch_dims) lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims - # FP8 GEMM only supports non-transposed LHS and transposed RHS + # FP8 GEMM only supports non-transposed LHS and transposed RHS lhs, _, rhs, *_ = batched_args lhs_trans = contracting_dims[0] != lhs.ndim - 1 rhs_trans = contracting_dims[1] == rhs.ndim - 1 @@ -320,27 +384,33 @@ def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs contracting_dims = (1, 1) - return ( - CollectiveGemmPrimitive.outer_primitive.bind( - lhs, - batched_args[1], - rhs, - *batched_args[3:], - out_dtype=out_dtype, - contracting_dims=contracting_dims, - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - grad=grad, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) - ) + return CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + batched_args[1], + rhs, + *batched_args[3:], + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + )(lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) @staticmethod - def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, - accumulate, use_split_accumulator, mesh, arg_infos, - result_infos): + def infer_sharding_from_operands( + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): del out_dtype, accumulate, use_split_accumulator, result_infos lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) @@ -348,12 +418,14 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bi lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: - warnings.warn("Forcing the inner dimension of LHS to match the sharding of inner " - + "dimension of RHS. This can trigger additional communication if LHS is " - + "not already partitioned correctly.") + warnings.warn( + "Forcing the inner dimension of LHS to match the sharding of inner " + + "dimension of RHS. This can trigger additional communication if LHS is " + + "not already partitioned correctly." + ) lhs_trans = lhs_inner_dim != lhs.ndim - 1 rhs_trans = rhs_inner_dim == rhs.ndim - 1 @@ -383,8 +455,18 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bi return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) @staticmethod - def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, - use_split_accumulator, mesh, arg_infos, result_infos): + def partition( + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): del result_infos lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) @@ -392,7 +474,7 @@ def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulat lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) lhs_trans = lhs_inner_dim != lhs.ndim - 1 @@ -426,13 +508,27 @@ def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulat gelu_spec = out_spec if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) - arg_shardings = (lhs_sharding, fp8_meta_sharding, rhs_sharding, fp8_meta_sharding, - bias_sharding, gelu_sharding, fp8_meta_sharding, fp8_meta_sharding) - out_shardings = (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, - bias_sharding) + arg_shardings = ( + lhs_sharding, + fp8_meta_sharding, + rhs_sharding, + fp8_meta_sharding, + bias_sharding, + gelu_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + ) + out_shardings = ( + out_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + gelu_sharding, + bias_sharding, + ) - def sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, - out_scale): + def sharded_impl( + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale + ): ( out, out_amax_updated, @@ -465,8 +561,7 @@ def sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_a # GEMM output needs to be all-reduced when the contracting dimension is sharded. # If the layer is sequence-parallel, we also need to scatter the output, which we # can combine into a reduce-scatter here. - out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, - mesh) + out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, mesh) if fuse_gelu: pre_gelu_out = lax_paral_op( pre_gelu_out, jax.lax.psum, global_mesh_resource().cp_resource, mesh @@ -485,10 +580,10 @@ def fp8_gemm_impl( lhs_scale_inv: ArrayLike, rhs: ArrayLike, rhs_scale_inv: ArrayLike, - bias: Optional[ArrayLike] = None, + bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, - out_amax: Optional[ArrayLike] = None, - out_scale: Optional[ArrayLike] = None, + out_amax: Optional[ArrayLike] = None, + out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, contracting_dims: Tuple[int, int] = (1, 1), fuse_gelu: bool = False, @@ -506,9 +601,7 @@ def fp8_gemm_impl( if not fuse_bias: bias = jnp.zeros(0, dtype=jnp.bfloat16) else: - assert ( - bias is not None - ), "Missing bias in forward GEMM when bias epilogue is enabled." + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=bias.dtype) @@ -542,8 +635,8 @@ def fp8_gemm_impl( def gemm_impl( lhs: ArrayLike, rhs: ArrayLike, - bias: Optional[ArrayLike] = None, - gelu_input: Optional[ArrayLike] = None, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, contracting_dims: Tuple[int, int] = (1, 0), fuse_gelu: bool = False, fuse_bias: bool = False, @@ -563,9 +656,7 @@ def gemm_impl( elif grad: bias = jnp.zeros(out_shape[-1], dtype=lhs.dtype) else: - assert ( - bias is not None - ), "Missing bias in forward GEMM when bias epilogue is enabled." + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=lhs.dtype) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f60ae510df..5dae9d6757 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -29,8 +29,8 @@ void GemmImpl(cudaStream_t stream, void *lhs, const std::vector &lhs_sha auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); void *bias_ptr = (fuse_bias) ? bias : nullptr; - std::vector bias_shape = (fuse_bias) ? std::vector{out_shape[1]} - : std::vector{0}; + std::vector bias_shape = + (fuse_bias) ? std::vector{out_shape[1]} : std::vector{0}; auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); void *pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; @@ -65,12 +65,9 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque auto *workspace = buffers[13]; // Operand aliasing - NVTE_CHECK(bias == bias_grad, - "bias not bound to bias_grad in TE/JAX GEMM"); - NVTE_CHECK(gelu_input == pre_gelu_out, - "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); - NVTE_CHECK(out_amax == out_amax_updated, - "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(bias == bias_grad, "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input == pre_gelu_out, "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax == out_amax_updated, "out_amax not bound to out_amax_updated in TE/JAX GEMM"); NVTE_CHECK(out_scale == out_scale_updated, "out_scale not bound to out_scale_updated in TE/JAX GEMM"); @@ -117,8 +114,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto workspace_size = workspace->dimensions().back(); // Operand aliasing - NVTE_CHECK(bias_ptr == bias_grad_ptr, - "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in TE/JAX GEMM"); NVTE_CHECK(gelu_input_ptr == pre_gelu_out_ptr, "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 7312aa8295..abe23fdf8b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -362,8 +362,16 @@ def generate_a_set(target_postfix): grad_amax, grad_scale = generate_a_set(grad_name_post_fix) output_amax, output_scale = generate_a_set(output_name_post_fix) - return FP8MetaPackage(input_amax, input_scale, weight_amax, weight_scale, grad_amax, - grad_scale, output_amax, output_scale) + return FP8MetaPackage( + input_amax, + input_scale, + weight_amax, + weight_scale, + grad_amax, + grad_scale, + output_amax, + output_scale, + ) class DenseGeneral(TransformerEngineBase): diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index ccd109e095..79499725b7 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -21,7 +21,6 @@ ) - __all__ = [ "gemm", "fp8_gemm", @@ -52,8 +51,9 @@ def _gemm( accumulate: bool, use_split_accumulator: bool, ) -> ArrayLike: - out, _ = _gemm_fwd_rule(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, - use_split_accumulator) + out, _ = _gemm_fwd_rule( + x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator + ) return out @@ -76,7 +76,7 @@ def _gemm_fwd_rule( fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, accumulate=accumulate, - use_split_accumulator=use_split_accumulator + use_split_accumulator=use_split_accumulator, ) ctx = ( @@ -145,8 +145,18 @@ def fp8_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - return _fp8_gemm(x, kernel, bias, fp8_meta.amax_list, fp8_meta.scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + return _fp8_gemm( + x, + kernel, + bias, + fp8_meta.amax_list, + fp8_meta.scale_list, + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ) @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) @@ -163,8 +173,18 @@ def _fp8_gemm( use_split_accumulator: bool, ) -> ArrayLike: """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" - out, _ = _fp8_gemm_fwd_rule(x, kernel, bias, amax_list, scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + out, _ = _fp8_gemm_fwd_rule( + x, + kernel, + bias, + amax_list, + scale_list, + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ) return out @@ -183,7 +203,8 @@ def _fp8_gemm_fwd_rule( fuse_bias = bias is not None maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( - *amax_list, *scale_list, + *amax_list, + *scale_list, ) amax_list = maybe_fm32_to_fp32(*amax_list) scale_list = maybe_fm32_to_fp32(*scale_list) @@ -272,7 +293,7 @@ def _fp8_gemm_fwd_rule( fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, accumulate=accumulate, - use_split_accumulator=use_split_accumulator + use_split_accumulator=use_split_accumulator, ) if out_dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: updated_out_amax = None @@ -288,7 +309,7 @@ def _fp8_gemm_fwd_rule( updated_kernel_amax, pre_gelu_out if fuse_gelu else None, fuse_bias, - maybe_fp32_to_fm32 + maybe_fp32_to_fm32, ) return (out, updated_out_amax, updated_out_scale), ctx @@ -313,7 +334,7 @@ def _fp8_gemm_bwd_rule( updated_kernel_amax, pre_gelu_out, fuse_bias, - maybe_fp32_to_fm32 + maybe_fp32_to_fm32, ) = ctx fwd_dtype = FP8Helper.FWD_DTYPE @@ -347,8 +368,6 @@ def _fp8_gemm_bwd_rule( ) bgrad = None - - x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] wgrad, *_ = fp8_gemm_impl( casted_x, @@ -370,11 +389,11 @@ def _fp8_gemm_bwd_rule( bwd_dtype, static_axis_boundary=-1, transpose_axis_boundary=-1, - activation_type=("gelu", ), + activation_type=("gelu",), ) elif fuse_gelu: # No bias to fuse so we just do dGELU. - casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu", )) + casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) bgrad = None kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] @@ -414,12 +433,23 @@ def type_safe_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] - or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): + if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ + jnp.float8_e4m3fn, + jnp.float8_e5m2, + ]: assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." if fp8_meta is not None: - return fp8_gemm(x, kernel, bias, fp8_meta, out_dtype, contracting_dims, fuse_gelu, - accumulate, use_split_accumulator) + return fp8_gemm( + x, + kernel, + bias, + fp8_meta, + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ) else: return gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) From 2c3dbf1cf516d3dec5022b9b8304ee0d053170ba Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 15 Nov 2024 23:56:38 +0000 Subject: [PATCH 04/19] re-applied bug fixes to working older version, updated backward pass, passing test Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 93 +++---- transformer_engine/jax/gemm.py | 260 +++++++++--------- 2 files changed, 174 insertions(+), 179 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2df05d6df4..ee4c38d076 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1,7 +1,6 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""JAX/TE custom ops for cuBlasLt GEMM""" import warnings import operator from functools import reduce @@ -39,6 +38,10 @@ ] +def sanitize_dims(dim, ndims): + return (ndims + dim) if dim < 0 else dim + + def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if tex.get_device_compute_capability() >= 90: @@ -98,11 +101,8 @@ def abstract( ), "Missing RHS operand scale inverse in FP8 GEMM." # Validate operand layouts - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim)) assert ( lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." @@ -134,23 +134,31 @@ def abstract( out_amax_updated_dtype = jnp.float32 out_scale_updated_dtype = jnp.float32 - # Infer output shape + # Make sure leading dimensions of RHS is broadcast-compatible with LHS lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 + rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + lhs_bdims = [ dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] ] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) - rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 - rhs_bdims = [ - dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] - ] - rhs_batch_size = reduce(operator.mul, rhs_bdims, 1) - assert ( - lhs_batch_size == rhs_batch_size - ), "LHS and RHS operands must have the same batched sizes." - out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + if rhs_aval.ndim > 2: + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] + rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + if rhs_batch_size > 1: + assert ( + lhs_batch_size == rhs_batch_size + ), ( + f"Leading dimensins of RHS ({rhs_batch_shape=}) is not broadcast-compatible " + + f"with the leading dimensions of LHS ({lhs_batch_shape=})." + ) + # Infer output shape + out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) # Validate bias/bias_grad shape against inferred output bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype if fuse_bias: @@ -227,11 +235,8 @@ def lowering( Fused attention fwd lowering rules """ lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim)) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 @@ -376,19 +381,8 @@ def batcher( check_valid_batch_dims(batch_dims) lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims - # FP8 GEMM only supports non-transposed LHS and transposed RHS - lhs, _, rhs, *_ = batched_args - lhs_trans = contracting_dims[0] != lhs.ndim - 1 - rhs_trans = contracting_dims[1] == rhs.ndim - 1 - lhs = jnp.matrix_transpose(lhs) if lhs_trans and jax_dtype_is_fp8(lhs.dtype) else lhs - rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs - contracting_dims = (1, 1) - return CollectiveGemmPrimitive.outer_primitive.bind( - lhs, - batched_args[1], - rhs, - *batched_args[3:], + *batched_args, out_dtype=out_dtype, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, @@ -415,11 +409,7 @@ def infer_sharding_from_operands( lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs.ndim, rhs.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: warnings.warn( "Forcing the inner dimension of LHS to match the sharding of inner " @@ -471,11 +461,7 @@ def partition( lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs.ndim, rhs.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) lhs_trans = lhs_inner_dim != lhs.ndim - 1 rhs_trans = rhs_inner_dim == rhs.ndim - 1 @@ -578,14 +564,13 @@ def sharded_impl( def fp8_gemm_impl( lhs: ArrayLike, lhs_scale_inv: ArrayLike, - rhs: ArrayLike, + rhs_t: ArrayLike, rhs_scale_inv: ArrayLike, bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, out_amax: Optional[ArrayLike] = None, out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, - contracting_dims: Tuple[int, int] = (1, 1), fuse_gelu: bool = False, fuse_bias: bool = False, accumulate: bool = False, @@ -606,22 +591,20 @@ def fp8_gemm_impl( if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=bias.dtype) elif gelu_input is None: - lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 - out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + out_shape = (*lhs.shape[:-2], lhs.shape[-2], rhs_t.shape[-2]) gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( - rhs, - rhs_scale_inv, lhs, lhs_scale_inv, + rhs_t, + rhs_scale_inv, bias, gelu_input, out_amax, out_scale, out_dtype=out_dtype, - contracting_dims=tuple(reversed(contracting_dims)), + contracting_dims=(-1, -1), fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, grad=False, @@ -645,10 +628,9 @@ def gemm_impl( use_split_accumulator: bool = False, ) -> Tuple[ArrayLike, ...]: """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" - dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) - - lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim = lhs.ndim - 1 if lhs_inner_dim == lhs.ndim - 2 else lhs.ndim - 2 + rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: @@ -667,6 +649,7 @@ def gemm_impl( elif gelu_input is None: gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) + dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( lhs, dummy_fp8_meta, diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 79499725b7..e9e046d182 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -1,7 +1,8 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -from functools import partial +import operator +from functools import partial, reduce from typing import Optional, Tuple, Union import jax @@ -19,6 +20,7 @@ dbias_cast_transpose, dact_lu_dbias_cast_transpose, ) +from .cpp_extensions.gemm import sanitize_dims __all__ = [ @@ -98,27 +100,48 @@ def _gemm_bwd_rule( grad, ): x, kernel, pre_gelu_out, fuse_bias = ctx + x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - x_t_contracting = 0 if contracting_dims[0] == 1 else 1 - wgrad, dgelu, bgrad = gemm_impl( - x, + + kernel_t_contracting = ( + kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 + ) + # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) + dgrad, dgelu, _ = gemm_impl( grad, + kernel, gelu_input=pre_gelu_out, - contracting_dims=(x_t_contracting, 0), + contracting_dims=(-1, kernel_t_contracting), fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, + fuse_bias=False, grad=True, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) - kernel_t_contracting = 1 if contracting_dims[1] == 0 else 0 - dgrad, *_ = gemm_impl( - dgelu if fuse_gelu else grad, - kernel, + # Collapse batch x sequence dimensions for WGRAD + x_outer_dim = x.ndim - 2 if x_inner_dim == x.ndim - 1 else x.ndim - 1 + wgrad_rhs = dgelu if fuse_gelu else grad + if x.ndim > 2: + batch_size = reduce(operator.mul, x.shape[:-2], 1) + x = jax.lax.reshape( + jax.lax.transpose(x, (*list(range(x.ndim - 2)), x_outer_dim, x_inner_dim)), + (batch_size * x.shape[x_outer_dim], x.shape[x_inner_dim]), + ) + wgrad_rhs = jnp.reshape( + wgrad_rhs, shape=(batch_size * wgrad_rhs.shape[-2], wgrad_rhs.shape[-1]) + ) + x_t_contracting = 0 + else: + x_t_contracting = x_outer_dim + + # WGRAD: ([B], M, K)^T x ([B], M, N) = ([B], K, N) + wgrad, _, bgrad = gemm_impl( + x, + wgrad_rhs, gelu_input=pre_gelu_out, - contracting_dims=(1, kernel_t_contracting), - fuse_gelu=fuse_gelu, + contracting_dims=(x_t_contracting, wgrad_rhs.ndim - 2), + fuse_gelu=False, fuse_bias=fuse_bias, grad=True, accumulate=accumulate, @@ -140,7 +163,6 @@ def fp8_gemm( fp8_meta: FP8MetaPackage, bias: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, - contracting_dims: Tuple[int, int] = (1, 1), fuse_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, @@ -152,7 +174,6 @@ def fp8_gemm( fp8_meta.amax_list, fp8_meta.scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, @@ -162,12 +183,11 @@ def fp8_gemm( @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) def _fp8_gemm( x: ArrayLike, - kernel: ArrayLike, + kernel_t: ArrayLike, bias: ArrayLike, amax_list: ArrayLike, scale_list: ArrayLike, out_dtype: jnp.dtype, - contracting_dims: Tuple[int, int], fuse_gelu: bool, accumulate: bool, use_split_accumulator: bool, @@ -175,12 +195,11 @@ def _fp8_gemm( """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" out, _ = _fp8_gemm_fwd_rule( x, - kernel, + kernel_t, bias, amax_list, scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, @@ -190,12 +209,11 @@ def _fp8_gemm( def _fp8_gemm_fwd_rule( x: ArrayLike, - kernel: ArrayLike, + kernel_t: ArrayLike, bias: ArrayLike, amax_list: ArrayLike, scale_list: ArrayLike, out_dtype: jnp.dtype, - contracting_dims: Tuple[int, int], fuse_gelu: bool, accumulate: bool, use_split_accumulator: bool, @@ -221,54 +239,36 @@ def _fp8_gemm_fwd_rule( x_scale = scale_list[FP8MetaPackage.INPUT_IDX] x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] if x.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: - if contracting_dims[0] == 0: - _, casted_x, updated_x_amax = cast_transpose( - x, - x_amax, - x_scale, - x_scale_inv, - fwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - else: - casted_x, updated_x_amax = cast_fp8(x, x_amax, x_scale, x_scale_inv, fwd_dtype) + casted_x, casted_x_t, updated_x_amax = cast_transpose( + x, + x_amax, + x_scale, + x_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) else: - if contracting_dims[0] == 0: - casted_x_t = x - casted_x = casted_x_t.transpose() - else: - casted_x = x + casted_x = x + casted_x_t = jnp.matrix_transpose(x) updated_x_amax = x_amax kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1] kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX] kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - if kernel.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: - if contracting_dims[1] == 0: # need to transpose the kernel for FP8 GEMM - _, casted_kernel_t, updated_kernel_amax = cast_transpose( - kernel, - kernel_amax, - kernel_scale, - kernel_scale_inv, - fwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - else: - casted_kernel_t, updated_kernel_amax = cast_fp8( - kernel, - kernel_amax, - kernel_scale, - kernel_scale_inv, - fwd_dtype, - ) + if kernel_t.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + casted_kernel_t, casted_kernel, updated_kernel_amax = cast_transpose( + kernel_t, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) else: - if contracting_dims[1] == 0: - casted_kernel = kernel - casted_kernel_t = casted_kernel.transpose() - else: - casted_kernel_t = kernel + casted_kernel = jnp.matrix_transpose(kernel_t) + casted_kernel_t = kernel_t updated_kernel_amax = kernel_amax out_amax = ( @@ -300,24 +300,24 @@ def _fp8_gemm_fwd_rule( updated_out_scale = None ctx = ( - casted_x, - casted_kernel_t, + casted_x_t, + casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax, updated_kernel_amax, + updated_out_amax, pre_gelu_out if fuse_gelu else None, fuse_bias, maybe_fp32_to_fm32, ) - return (out, updated_out_amax, updated_out_scale), ctx + return (out, updated_out_scale), ctx def _fp8_gemm_bwd_rule( out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, @@ -325,83 +325,84 @@ def _fp8_gemm_bwd_rule( grad, ): ( - casted_x, - casted_kernel_t, + casted_x_t, + casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax, updated_kernel_amax, + updated_out_amax, pre_gelu_out, fuse_bias, maybe_fp32_to_fm32, ) = ctx - fwd_dtype = FP8Helper.FWD_DTYPE bwd_dtype = FP8Helper.BWD_DTYPE grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] - if fuse_bias and not fuse_gelu: - # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. - _, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( - grad, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) + if fuse_gelu: + if fuse_bias: + # Fuse dbias into this dGELU. + casted_grad, casted_grad_t, bgrad, updated_grad_amax = dact_lu_dbias_cast_transpose( + grad, + pre_gelu_out, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + activation_type=("gelu",), + ) + else: + # No bias to fuse so we just do dGELU. + casted_grad, casted_grad_t, updated_grad_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) + bgrad = None else: - # If both bias and GELU is fused into the forward pass, we will fuse dbias later with - # dGELU. No need to do it here. - _, casted_grad_t, updated_grad_amax = cast_transpose( - grad, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - bgrad = None + if fuse_bias: + # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. + casted_grad, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + # If both bias and GELU is fused into the forward pass, we will fuse dbias later with + # dGELU. No need to do it here. + casted_grad, casted_grad_t, updated_grad_amax = cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + bgrad = None - x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] - wgrad, *_ = fp8_gemm_impl( - casted_x, - x_scale_inv, - casted_grad_t, + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + dgrad, *_ = fp8_gemm_impl( + casted_grad, grad_scale_inv, + casted_kernel, + kernel_scale_inv, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) - if fuse_gelu and fuse_bias: - # Fuse dbias into this dGELU. - casted_dgelu, casted_dgelu_t, bgrad, updated_dgelu_amax = dact_lu_dbias_cast_transpose( - grad, - pre_gelu_out, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - activation_type=("gelu",), - ) - elif fuse_gelu: - # No bias to fuse so we just do dGELU. - casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) - bgrad = None - - kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - dgrad, *_ = gemm_impl( - casted_dgelu if fuse_gelu else grad, + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + wgrad, *_ = fp8_gemm_impl( + casted_x_t, + x_scale_inv, + casted_grad_t, grad_scale_inv, - casted_kernel_t, - kernel_scale_inv, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) @@ -412,6 +413,13 @@ def _fp8_gemm_bwd_rule( amax_list[FP8MetaPackage.WEIGHT_IDX] = ( amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0]) ) + amax_list[FP8MetaPackage.GRAD_IDX] = ( + amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) + ) + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + amax_list[FP8MetaPackage.OUTPUT_IDX] = ( + amax_list[FP8MetaPackage.OUTPUT_IDX].at[0].set(updated_out_amax[0]) + ) amax_list = maybe_fp32_to_fm32(*amax_list) scale_list = maybe_fp32_to_fm32(*scale_list) @@ -433,20 +441,24 @@ def type_safe_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ - jnp.float8_e4m3fn, - jnp.float8_e5m2, - ]: + if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." if fp8_meta is not None: + x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) + assert ( + x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 2 + ), ( + "FP8 GEMM requires non-transposed X (LHS) and transposed kernel (RHS), " + + "i.e. contracting_dims=(-1, -1)." + ) return fp8_gemm( x, kernel, bias, fp8_meta, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, From 448eaa99a3c3c93d8bcf2cb2d8ca6273f4f950d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Nov 2024 23:57:09 +0000 Subject: [PATCH 05/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 14 +++++++------- transformer_engine/jax/gemm.py | 11 +++++------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ee4c38d076..b935a5c2f7 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -101,8 +101,9 @@ def abstract( ), "Missing RHS operand scale inverse in FP8 GEMM." # Validate operand layouts - lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim)) + lhs_inner_dim, rhs_inner_dim = map( + sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) + ) assert ( lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." @@ -150,9 +151,7 @@ def abstract( rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) if rhs_batch_size > 1: - assert ( - lhs_batch_size == rhs_batch_size - ), ( + assert lhs_batch_size == rhs_batch_size, ( f"Leading dimensins of RHS ({rhs_batch_shape=}) is not broadcast-compatible " + f"with the leading dimensions of LHS ({lhs_batch_shape=})." ) @@ -235,8 +234,9 @@ def lowering( Fused attention fwd lowering rules """ lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in - lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim)) + lhs_inner_dim, rhs_inner_dim = map( + sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) + ) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index e9e046d182..3cab17b10b 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -102,7 +102,6 @@ def _gemm_bwd_rule( x, kernel, pre_gelu_out, fuse_bias = ctx x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - kernel_t_contracting = ( kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 ) @@ -441,15 +440,15 @@ def type_safe_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] - or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): + if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ + jnp.float8_e4m3fn, + jnp.float8_e5m2, + ]: assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." if fp8_meta is not None: x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - assert ( - x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 2 - ), ( + assert x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 2, ( "FP8 GEMM requires non-transposed X (LHS) and transposed kernel (RHS), " + "i.e. contracting_dims=(-1, -1)." ) From cb6ae3cf7570285a13aae30b414a3a7ec19b4f6c Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Mon, 18 Nov 2024 22:31:35 +0000 Subject: [PATCH 06/19] batched operands for GEMM custom op seem to be working now Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 151 +++++++++++++----- transformer_engine/jax/gemm.py | 26 +-- 2 files changed, 119 insertions(+), 58 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index b935a5c2f7..cf029d16db 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -136,8 +136,11 @@ def abstract( out_scale_updated_dtype = jnp.float32 # Make sure leading dimensions of RHS is broadcast-compatible with LHS - lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + lhs_outer_dim, rhs_outer_dim = map( + lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + (lhs_inner_dim, rhs_inner_dim), + (lhs_aval.ndim, rhs_aval.ndim) + ) lhs_bdims = [ dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] @@ -152,12 +155,17 @@ def abstract( rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) if rhs_batch_size > 1: assert lhs_batch_size == rhs_batch_size, ( - f"Leading dimensins of RHS ({rhs_batch_shape=}) is not broadcast-compatible " - + f"with the leading dimensions of LHS ({lhs_batch_shape=})." + f"Leading dimensins of RHS ({rhs_aval.shape=}) is not broadcast-compatible " + + f"with the leading dimensions of LHS ({lhs_aval.shape=})." ) - # Infer output shape + # Infer output shape: out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + if lhs_aval.ndim > 2 and rhs_aval.ndim > 2 and lhs_batch_size > 1: + # When both RHS and LHS are batched, the batch dimensions are collapsed into the + # contracting dimension. + out_shape = (lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + # Validate bias/bias_grad shape against inferred output bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype if fuse_bias: @@ -169,9 +177,16 @@ def abstract( assert bias_aval.size == 0, "Internal TE error." # Validate GELU input/output + gelu_shape = (0, ) if fuse_gelu: - assert all( - [gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)] + gelu_shape = ( + (reduce(operator.mul, out_shape[:-1], 1), out_shape[-1]) + if len(out_shape) > 2 + else out_shape + ) + assert ( + gelu_input_aval.ndim == 2 + and all([gelu_input_aval.shape[i] == gelu_shape[i] for i in len(gelu_shape)]) ), "Invalid GELU input shape." assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." else: @@ -185,7 +200,7 @@ def abstract( out_scale_updated_aval = out_scale_aval.update( shape=out_scale_aval.shape, dtype=out_scale_updated_dtype ) - pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_input_aval.shape, dtype=bias_dtype) + pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_shape, dtype=bias_dtype) bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) workspace_aval = jax.core.ShapedArray( shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 @@ -285,8 +300,11 @@ def lowering( ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + lhs_outer_dim, rhs_outer_dim = map( + lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + (lhs_inner_dim, rhs_inner_dim), + (lhs_aval.ndim, rhs_aval.ndim) + ) m = lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] n = rhs_aval.shape[rhs_outer_dim] @@ -338,6 +356,43 @@ def impl( ): assert CollectiveGemmPrimitive.inner_primitive is not None + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == rhs.ndim - 1 + + # Squeeze batch dimensions of size 1 without any modification. + squeeze_dims = [] + expand_out = False + if lhs.ndim > 2: + squeeze_dims = [dim for dim in range(lhs.ndim - 2) if lhs.shape[dim] == 1] + if len(squeeze_dims) > 0: + expand_out = True + lhs = jax.lax.squeeze(lhs, squeeze_dims) + contracting_dims = (lhs.ndim - 2 if lhs_trans else lhs.ndim - 1, + contracting_dims[1]) + if rhs.ndim > 2: + rhs_squeeze_dims = [dim for dim in range(rhs.ndim - 2) if rhs.shape[dim] == 1] + if len(squeeze_dims) > 0: + rhs = jax.lax.squeeze(rhs, rhs_squeeze_dims) + contracting_dims = (contracting_dims[0], + rhs.ndim - 1 if rhs_trans else rhs.ndim - 2) + + # Collapse batch dimensions that are larger thanm size 1. + # FWD: (B, M, K) x (K, N) = (B*M, K) x (K, N) = (B*M, N) + # DGRAD: (B, M, N) x (K, N)^T = (B*M, N) x (N, K) = (B*M, K) + # WGRAD: (B, M, K)^T x (B, M, N) = (K, B*M) x (B*M, N) = (K, N) + batch_shape = [lhs.shape[dim] for dim in range(lhs.ndim - 2)] + batch_size = reduce(operator.mul, batch_shape, 1) + reshape_output = not (lhs.ndim > 2 and rhs.ndim > 2) + if lhs.ndim > 2: + lhs_2d_shape = (batch_size * lhs.shape[-2], lhs.shape[-1]) + lhs = jax.lax.reshape(lhs, lhs_2d_shape) + contracting_dims = (0 if lhs_trans else 1, contracting_dims[1]) + if rhs.ndim > 2: + rhs_2d_shape = (reduce(operator.mul, rhs.shape[:-1], 1), rhs.shape[-1]) + rhs = jax.lax.reshape(rhs, rhs_2d_shape) + contracting_dims = (contracting_dims[0], 1 if rhs_trans else 0) + ( out, out_amax_updated, @@ -362,6 +417,15 @@ def impl( accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) + + # Recover batched dimensions in the output + if reshape_output: + out_batched_shape = (*batch_shape, int(out.shape[-2] / batch_size), out.shape[-1]) + out = jax.lax.reshape(out, out_batched_shape) + + if expand_out: + out = jax.lax.expand_dims(out, squeeze_dims) + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @staticmethod @@ -381,16 +445,19 @@ def batcher( check_valid_batch_dims(batch_dims) lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims - return CollectiveGemmPrimitive.outer_primitive.bind( - *batched_args, - out_dtype=out_dtype, - contracting_dims=contracting_dims, - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - grad=grad, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - )(lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) + return ( + CollectiveGemmPrimitive.outer_primitive.bind( + *batched_args, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ), + (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) + ) @staticmethod def infer_sharding_from_operands( @@ -417,10 +484,12 @@ def infer_sharding_from_operands( + "not already partitioned correctly." ) - lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == rhs.ndim - 1 - lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if rhs_trans else rhs.ndim - 1 + lhs_outer_dim, rhs_outer_dim = map( + lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim) + ) + rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] @@ -430,18 +499,20 @@ def infer_sharding_from_operands( # Outer (sequence) dimension of the GEMM output is always unsharded out_spec = [*batch_specs, None, rhs_outer_spec] + batch_size = reduce(operator.mul, lhs.shape[:-2], 1) + if lhs.ndim > 2 and rhs.ndim > 2 and batch_size > 1: + out_spec = [None, rhs_outer_spec] out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) # FP8 metas are always unsharded fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) - # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded - gelu_spec = out_spec if fuse_gelu else [None] + # Pre-GELU output matches output, if GELU fusion is turned on, otherwise unsharded + gelu_spec = [None, rhs_outer_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) # Bias gradient spec matches outer dimension of output if bias fusion is turned on bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) - return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) @staticmethod @@ -462,11 +533,11 @@ def partition( lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) - - lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == rhs.ndim - 1 - lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if rhs_trans else rhs.ndim - 1 + lhs_outer_dim, rhs_outer_dim = map( + lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim) + ) lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] @@ -488,10 +559,13 @@ def partition( # Outer (sequence) dimension of the GEMM output is always unsharded out_spec = [*batch_specs, None, rhs_outer_spec] + batch_size = reduce(operator.mul, lhs.shape[:-2], 1) + if lhs.ndim > 2 and rhs.ndim > 2 and batch_size > 1: + out_spec = [None, rhs_outer_spec] out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded - gelu_spec = out_spec if fuse_gelu else [None] + gelu_spec = [None, rhs_outer_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) arg_shardings = ( @@ -547,10 +621,10 @@ def sharded_impl( # GEMM output needs to be all-reduced when the contracting dimension is sharded. # If the layer is sequence-parallel, we also need to scatter the output, which we # can combine into a reduce-scatter here. - out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, mesh) + out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().tp_resource, mesh) if fuse_gelu: pre_gelu_out = lax_paral_op( - pre_gelu_out, jax.lax.psum, global_mesh_resource().cp_resource, mesh + pre_gelu_out, jax.lax.psum, global_mesh_resource().tp_resource, mesh ) return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @@ -629,8 +703,11 @@ def gemm_impl( ) -> Tuple[ArrayLike, ...]: """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) - lhs_outer_dim = lhs.ndim - 1 if lhs_inner_dim == lhs.ndim - 2 else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 + lhs_outer_dim, rhs_outer_dim = map( + lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim) + ) out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 3cab17b10b..01ee60f24b 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -101,16 +101,15 @@ def _gemm_bwd_rule( ): x, kernel, pre_gelu_out, fuse_bias = ctx x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) + x_outer_dim = x.ndim - 1 if x_inner_dim != x.ndim - 1 else x.ndim - 2 + kernel_outer_dim = kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 - kernel_t_contracting = ( - kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 - ) # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) dgrad, dgelu, _ = gemm_impl( grad, kernel, gelu_input=pre_gelu_out, - contracting_dims=(-1, kernel_t_contracting), + contracting_dims=(-1, kernel_outer_dim), fuse_gelu=fuse_gelu, fuse_bias=False, grad=True, @@ -118,28 +117,13 @@ def _gemm_bwd_rule( use_split_accumulator=use_split_accumulator, ) - # Collapse batch x sequence dimensions for WGRAD - x_outer_dim = x.ndim - 2 if x_inner_dim == x.ndim - 1 else x.ndim - 1 + # WGRAD: ([B], M, K)^T x ([B], M, N) = (K, N) wgrad_rhs = dgelu if fuse_gelu else grad - if x.ndim > 2: - batch_size = reduce(operator.mul, x.shape[:-2], 1) - x = jax.lax.reshape( - jax.lax.transpose(x, (*list(range(x.ndim - 2)), x_outer_dim, x_inner_dim)), - (batch_size * x.shape[x_outer_dim], x.shape[x_inner_dim]), - ) - wgrad_rhs = jnp.reshape( - wgrad_rhs, shape=(batch_size * wgrad_rhs.shape[-2], wgrad_rhs.shape[-1]) - ) - x_t_contracting = 0 - else: - x_t_contracting = x_outer_dim - - # WGRAD: ([B], M, K)^T x ([B], M, N) = ([B], K, N) wgrad, _, bgrad = gemm_impl( x, wgrad_rhs, gelu_input=pre_gelu_out, - contracting_dims=(x_t_contracting, wgrad_rhs.ndim - 2), + contracting_dims=(x_outer_dim, wgrad_rhs.ndim - 2), fuse_gelu=False, fuse_bias=fuse_bias, grad=True, From 6f673559d250c9cf9c2713201da256b641cad279 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 22:32:02 +0000 Subject: [PATCH 07/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index cf029d16db..0948139dc9 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -139,7 +139,7 @@ def abstract( lhs_outer_dim, rhs_outer_dim = map( lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, (lhs_inner_dim, rhs_inner_dim), - (lhs_aval.ndim, rhs_aval.ndim) + (lhs_aval.ndim, rhs_aval.ndim), ) lhs_bdims = [ @@ -177,16 +177,15 @@ def abstract( assert bias_aval.size == 0, "Internal TE error." # Validate GELU input/output - gelu_shape = (0, ) + gelu_shape = (0,) if fuse_gelu: gelu_shape = ( (reduce(operator.mul, out_shape[:-1], 1), out_shape[-1]) if len(out_shape) > 2 else out_shape ) - assert ( - gelu_input_aval.ndim == 2 - and all([gelu_input_aval.shape[i] == gelu_shape[i] for i in len(gelu_shape)]) + assert gelu_input_aval.ndim == 2 and all( + [gelu_input_aval.shape[i] == gelu_shape[i] for i in len(gelu_shape)] ), "Invalid GELU input shape." assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." else: @@ -303,7 +302,7 @@ def lowering( lhs_outer_dim, rhs_outer_dim = map( lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, (lhs_inner_dim, rhs_inner_dim), - (lhs_aval.ndim, rhs_aval.ndim) + (lhs_aval.ndim, rhs_aval.ndim), ) m = lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] @@ -368,14 +367,18 @@ def impl( if len(squeeze_dims) > 0: expand_out = True lhs = jax.lax.squeeze(lhs, squeeze_dims) - contracting_dims = (lhs.ndim - 2 if lhs_trans else lhs.ndim - 1, - contracting_dims[1]) + contracting_dims = ( + lhs.ndim - 2 if lhs_trans else lhs.ndim - 1, + contracting_dims[1], + ) if rhs.ndim > 2: rhs_squeeze_dims = [dim for dim in range(rhs.ndim - 2) if rhs.shape[dim] == 1] if len(squeeze_dims) > 0: rhs = jax.lax.squeeze(rhs, rhs_squeeze_dims) - contracting_dims = (contracting_dims[0], - rhs.ndim - 1 if rhs_trans else rhs.ndim - 2) + contracting_dims = ( + contracting_dims[0], + rhs.ndim - 1 if rhs_trans else rhs.ndim - 2, + ) # Collapse batch dimensions that are larger thanm size 1. # FWD: (B, M, K) x (K, N) = (B*M, K) x (K, N) = (B*M, N) @@ -456,7 +459,7 @@ def batcher( accumulate=accumulate, use_split_accumulator=use_split_accumulator, ), - (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) + (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims), ) @staticmethod @@ -487,7 +490,7 @@ def infer_sharding_from_operands( lhs_outer_dim, rhs_outer_dim = map( lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, (lhs_inner_dim, rhs_inner_dim), - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] @@ -536,7 +539,7 @@ def partition( lhs_outer_dim, rhs_outer_dim = map( lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, (lhs_inner_dim, rhs_inner_dim), - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] @@ -706,7 +709,7 @@ def gemm_impl( lhs_outer_dim, rhs_outer_dim = map( lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, (lhs_inner_dim, rhs_inner_dim), - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) From 4b2b2d44d735714ea9917fb00748c77e473fdafa Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 19 Nov 2024 17:57:33 +0000 Subject: [PATCH 08/19] fixed batch size 1 issue and enabled FSDP sharding for RHS operand Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 65 ++++++++----------- transformer_engine/jax/gemm.py | 18 +++-- 2 files changed, 39 insertions(+), 44 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0948139dc9..431dea6c1d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -161,7 +161,7 @@ def abstract( # Infer output shape: out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) - if lhs_aval.ndim > 2 and rhs_aval.ndim > 2 and lhs_batch_size > 1: + if lhs_aval.ndim > 2 and rhs_aval.ndim > 2: # When both RHS and LHS are batched, the batch dimensions are collapsed into the # contracting dimension. out_shape = (lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) @@ -359,27 +359,6 @@ def impl( lhs_trans = lhs_inner_dim != lhs.ndim - 1 rhs_trans = rhs_inner_dim == rhs.ndim - 1 - # Squeeze batch dimensions of size 1 without any modification. - squeeze_dims = [] - expand_out = False - if lhs.ndim > 2: - squeeze_dims = [dim for dim in range(lhs.ndim - 2) if lhs.shape[dim] == 1] - if len(squeeze_dims) > 0: - expand_out = True - lhs = jax.lax.squeeze(lhs, squeeze_dims) - contracting_dims = ( - lhs.ndim - 2 if lhs_trans else lhs.ndim - 1, - contracting_dims[1], - ) - if rhs.ndim > 2: - rhs_squeeze_dims = [dim for dim in range(rhs.ndim - 2) if rhs.shape[dim] == 1] - if len(squeeze_dims) > 0: - rhs = jax.lax.squeeze(rhs, rhs_squeeze_dims) - contracting_dims = ( - contracting_dims[0], - rhs.ndim - 1 if rhs_trans else rhs.ndim - 2, - ) - # Collapse batch dimensions that are larger thanm size 1. # FWD: (B, M, K) x (K, N) = (B*M, K) x (K, N) = (B*M, N) # DGRAD: (B, M, N) x (K, N)^T = (B*M, N) x (N, K) = (B*M, K) @@ -426,9 +405,6 @@ def impl( out_batched_shape = (*batch_shape, int(out.shape[-2] / batch_size), out.shape[-1]) out = jax.lax.reshape(out, out_batched_shape) - if expand_out: - out = jax.lax.expand_dims(out, squeeze_dims) - return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @staticmethod @@ -497,13 +473,9 @@ def infer_sharding_from_operands( batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] - if rhs_spec[rhs_inner_dim] is not None and rhs_outer_spec is not None: - raise RuntimeError("Both inner and outer dimensions of RHS cannot be sharded.") - # Outer (sequence) dimension of the GEMM output is always unsharded out_spec = [*batch_specs, None, rhs_outer_spec] - batch_size = reduce(operator.mul, lhs.shape[:-2], 1) - if lhs.ndim > 2 and rhs.ndim > 2 and batch_size > 1: + if lhs.ndim > 2 and rhs.ndim > 2: out_spec = [None, rhs_outer_spec] out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) @@ -543,7 +515,6 @@ def partition( ) lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] - rhs_outer_spec = rhs_spec[rhs_outer_dim] # Force all-gather the outer (sequence) dimension of the LHS operand lhs_spec_new = [spec for spec in lhs_spec] @@ -551,8 +522,29 @@ def partition( lhs_spec_new[lhs_inner_dim] = rhs_spec[rhs_inner_dim] lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) + # If both dims of RHS is sharded (i.e. FSDP), determine if we do AG or AR based on LHS + # sharding. + rhs_spec_new = [spec for spec in rhs_spec] + if rhs_spec[rhs_inner_dim] is not None and rhs_spec[rhs_outer_dim] is not None: + if lhs_spec[lhs_inner_dim] is not None and lhs_spec[lhs_outer_dim] is not None: + # All dimensions of both LHS and RHS are sharded and the collective operation is + # ambiguous, we cannot infer sharding. + raise RuntimeError( + "Collective GEMM custom op cannot infer partitioning when both outer and " + + "contracting dimensions of both LHS and RHS operands are sharded." + ) + elif lhs_spec[lhs_inner_dim] is not None: + # All-reduce after GEMM, so unshard the outer dimension of RHS + rhs_spec_new[rhs_outer_dim] = None + else: + # We either do all-gather before GEMM, or LHS is already unsharded, so unshard + # the inner dimension of RHS to match + rhs_spec_new[rhs_inner_dim] = None + + rhs_outer_spec = rhs_spec_new[rhs_outer_dim] + # RHS operand is unchanged, we already enforce that only one dimension can be sharded - rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec_new)) # Bias is sharded to match outer dimension spec of the RHS operand (also the output) bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) @@ -562,8 +554,7 @@ def partition( # Outer (sequence) dimension of the GEMM output is always unsharded out_spec = [*batch_specs, None, rhs_outer_spec] - batch_size = reduce(operator.mul, lhs.shape[:-2], 1) - if lhs.ndim > 2 and rhs.ndim > 2 and batch_size > 1: + if lhs.ndim > 2 and rhs.ndim > 2: out_spec = [None, rhs_outer_spec] out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) @@ -620,10 +611,8 @@ def sharded_impl( if jax_dtype_is_fp8(lhs.dtype): out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) - if rhs_spec[rhs_inner_dim] is not None: - # GEMM output needs to be all-reduced when the contracting dimension is sharded. - # If the layer is sequence-parallel, we also need to scatter the output, which we - # can combine into a reduce-scatter here. + # GEMM output needs to be all-reduced when the contracting dimension is sharded. + if rhs_spec_new[rhs_inner_dim] is not None: out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().tp_resource, mesh) if fuse_gelu: pre_gelu_out = lax_paral_op( diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 01ee60f24b..3b562e4ffa 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -8,13 +8,11 @@ import jax import jax.numpy as jnp from jax.typing import ArrayLike -from jax.ad_checkpoint import checkpoint_name from .fp8 import FP8Helper, FP8MetaPackage from .cpp_extensions import ( gemm_impl, fp8_gemm_impl, - cast_fp8, cast_transpose, dact_lu, dbias_cast_transpose, @@ -68,6 +66,10 @@ def _gemm_fwd_rule( accumulate: bool, use_split_accumulator: bool, ) -> Tuple[ArrayLike, ...]: + assert kernel.ndim == 2, ( + "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." + ) + fuse_bias = bias is not None out, pre_gelu_out = gemm_impl( @@ -142,7 +144,7 @@ def _gemm_bwd_rule( def fp8_gemm( x: ArrayLike, - kernel: ArrayLike, + kernel_t: ArrayLike, fp8_meta: FP8MetaPackage, bias: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, @@ -150,9 +152,10 @@ def fp8_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: + """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" return _fp8_gemm( x, - kernel, + kernel_t, bias, fp8_meta.amax_list, fp8_meta.scale_list, @@ -175,7 +178,6 @@ def _fp8_gemm( accumulate: bool, use_split_accumulator: bool, ) -> ArrayLike: - """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" out, _ = _fp8_gemm_fwd_rule( x, kernel_t, @@ -201,6 +203,10 @@ def _fp8_gemm_fwd_rule( accumulate: bool, use_split_accumulator: bool, ) -> Tuple[ArrayLike, ...]: + assert kernel_t.ndim == 2, ( + "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." + ) + fuse_bias = bias is not None maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( @@ -432,7 +438,7 @@ def type_safe_gemm( if fp8_meta is not None: x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - assert x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 2, ( + assert x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 1, ( "FP8 GEMM requires non-transposed X (LHS) and transposed kernel (RHS), " + "i.e. contracting_dims=(-1, -1)." ) From 2b2753e2463ce788f5f7c582e898a304156b4f54 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:58:03 +0000 Subject: [PATCH 09/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/gemm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 3b562e4ffa..730d17846e 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -66,9 +66,9 @@ def _gemm_fwd_rule( accumulate: bool, use_split_accumulator: bool, ) -> Tuple[ArrayLike, ...]: - assert kernel.ndim == 2, ( - "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." - ) + assert ( + kernel.ndim == 2 + ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." fuse_bias = bias is not None @@ -203,9 +203,9 @@ def _fp8_gemm_fwd_rule( accumulate: bool, use_split_accumulator: bool, ) -> Tuple[ArrayLike, ...]: - assert kernel_t.ndim == 2, ( - "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." - ) + assert ( + kernel_t.ndim == 2 + ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." fuse_bias = bias is not None From 969f597cb11fe9fd5b9780e57e818d402704fc0c Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 09:28:44 +0000 Subject: [PATCH 10/19] fixed FSDP+TP w/ DP=1 and TP+DP, but FSDP+TP w/ DP>1 still crashes Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 283 +++++++++++------- transformer_engine/jax/gemm.py | 29 +- 2 files changed, 205 insertions(+), 107 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 431dea6c1d..bf80941f85 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -4,7 +4,8 @@ import warnings import operator from functools import reduce -from typing import Optional, Union, Tuple +from typing import Optional, Tuple +from collections.abc import Iterable import jax import jax.numpy as jnp @@ -42,6 +43,34 @@ def sanitize_dims(dim, ndims): return (ndims + dim) if dim < 0 else dim +def mirror_dim(dim, ndims): + return ndims - 2 if dim == ndims - 1 else ndims - 1 + + +def remove_fsdp_specs(pspecs): + fsdp_resource = global_mesh_resource().fsdp_resource + new_pspecs = [] + for spec in pspecs: + if spec is None: + new_pspecs.append(None) + elif fsdp_resource not in spec: + new_pspecs.append(spec) + elif isinstance(spec, Iterable) and not isinstance(spec, str): + new_spec = [] + for s in spec: + if s != fsdp_resource: + new_spec.append(s) + if len(new_spec) > 1: + new_pspecs.append(new_spec) + elif len(new_spec) == 1: + new_pspecs.append(new_spec[0]) + else: + new_pspecs.append(None) + else: + new_pspecs.append(None) + return new_pspecs + + def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if tex.get_device_compute_capability() >= 90: @@ -55,7 +84,7 @@ class CollectiveGemmPrimitive(BasePrimitive): """ name = "te_gemm" - impl_static_args = (8, 9, 10, 11, 12, 13, 14) + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15) multiple_results = True inner_primitive = None outer_primitive = None @@ -71,6 +100,7 @@ def abstract( out_amax_aval, out_scale_aval, out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -137,33 +167,40 @@ def abstract( # Make sure leading dimensions of RHS is broadcast-compatible with LHS lhs_outer_dim, rhs_outer_dim = map( - lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs_aval.ndim, rhs_aval.ndim), ) - lhs_bdims = [ dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] ] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) - if rhs_aval.ndim > 2: - rhs_bdims = [ - dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] - ] - rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] - rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) - if rhs_batch_size > 1: + + # Infer output shape + if batched_output: + assert lhs_aval.ndim > 2 and rhs_aval.ndim == 2, ( + "Batched output requires batched LHS and non-batched RHS operands." + ) + out_shape = ( + *lhs_batch_shape, + lhs_aval.shape[lhs_outer_dim], + rhs_aval.shape[rhs_outer_dim] + ) + else: + assert lhs_aval.ndim == rhs_aval.ndim, ( + "Non-batched output requires LHS and RHS operands with same number of dimensions." + ) + if lhs_aval.ndim > 2: + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] + rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) assert lhs_batch_size == rhs_batch_size, ( f"Leading dimensins of RHS ({rhs_aval.shape=}) is not broadcast-compatible " + f"with the leading dimensions of LHS ({lhs_aval.shape=})." ) - - # Infer output shape: - out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) - if lhs_aval.ndim > 2 and rhs_aval.ndim > 2: - # When both RHS and LHS are batched, the batch dimensions are collapsed into the - # contracting dimension. out_shape = (lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) # Validate bias/bias_grad shape against inferred output @@ -237,6 +274,7 @@ def lowering( out_scale, *, out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -247,6 +285,7 @@ def lowering( """ Fused attention fwd lowering rules """ + del batched_output lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in lhs_inner_dim, rhs_inner_dim = map( sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) @@ -300,9 +339,9 @@ def lowering( args = CustomCallArgsWrapper(out_types, operands, operand_shapes) lhs_outer_dim, rhs_outer_dim = map( - lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + mirror_dim, (lhs_inner_dim, rhs_inner_dim), - (lhs_aval.ndim, rhs_aval.ndim), + (lhs.ndim, rhs.ndim), ) m = lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] @@ -346,6 +385,7 @@ def impl( out_amax, out_scale, out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -356,25 +396,59 @@ def impl( assert CollectiveGemmPrimitive.inner_primitive is not None lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) - lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == rhs.ndim - 1 - - # Collapse batch dimensions that are larger thanm size 1. - # FWD: (B, M, K) x (K, N) = (B*M, K) x (K, N) = (B*M, N) - # DGRAD: (B, M, N) x (K, N)^T = (B*M, N) x (N, K) = (B*M, K) - # WGRAD: (B, M, K)^T x (B, M, N) = (K, B*M) x (B*M, N) = (K, N) - batch_shape = [lhs.shape[dim] for dim in range(lhs.ndim - 2)] - batch_size = reduce(operator.mul, batch_shape, 1) - reshape_output = not (lhs.ndim > 2 and rhs.ndim > 2) - if lhs.ndim > 2: - lhs_2d_shape = (batch_size * lhs.shape[-2], lhs.shape[-1]) - lhs = jax.lax.reshape(lhs, lhs_2d_shape) - contracting_dims = (0 if lhs_trans else 1, contracting_dims[1]) - if rhs.ndim > 2: - rhs_2d_shape = (reduce(operator.mul, rhs.shape[:-1], 1), rhs.shape[-1]) - rhs = jax.lax.reshape(rhs, rhs_2d_shape) - contracting_dims = (contracting_dims[0], 1 if rhs_trans else 0) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs.ndim, rhs.ndim) + ) + + # Infer output shape and collapse batch dimensions + lhs_2d_shape = rhs_2d_shape = None + lhs_layout = rhs_layout = None + lhs_batch_dims = [ + dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim] + ] + lhs_batch_shape = [lhs.shape[dim] for dim in lhs_batch_dims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + contracting_dims_2d = list(contracting_dims).copy() + if batched_output: + # If output is batched, the LSH batch dimension collapses into the outer dimension + # and RHS cannot be batched + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_outer_dim], lhs.shape[lhs_inner_dim]) + lhs_layout = (*lhs_batch_dims, lhs_outer_dim, lhs_inner_dim) + contracting_dims_2d[0] = 1 + else: + # If the output is not batched, both LHS and RHS batch dimensions collapse into the + # contracting dimensions + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_inner_dim], lhs.shape[lhs_outer_dim]) + lhs_layout = (*lhs_batch_dims, lhs_inner_dim, lhs_outer_dim) + contracting_dims_2d[0] = 0 + + rhs_batch_dims = [ + dim for dim in range(rhs.ndim) if dim not in [rhs_inner_dim, rhs_outer_dim] + ] + rhs_batch_shape = [rhs.shape[dim] for dim in rhs_batch_dims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + rhs_2d_shape = (rhs_batch_size * rhs.shape[rhs_inner_dim], rhs.shape[rhs_outer_dim]) + rhs_layout = (*rhs_batch_dims, rhs_inner_dim, rhs_outer_dim) + contracting_dims_2d[1] = 0 + + # Reshape LHS and RHS into 2D and fix layouts for FP8 GEMM + if lhs_2d_shape is not None and lhs.ndim > 2: + lhs = jax.lax.reshape(lhs, lhs_2d_shape, dimensions=lhs_layout) + if jax_dtype_is_fp8(lhs.dtype): + lhs = jax.lax.transpose(lhs, (1, 0)) + contracting_dims_2d[0] = 1 + else: + contracting_dims_2d[0] = contracting_dims[0] + + if rhs_2d_shape is not None and rhs.ndim > 2: + rhs = jax.lax.reshape(rhs, rhs_2d_shape, dimensions=rhs_layout) + if jax_dtype_is_fp8(rhs.dtype): + rhs = jax.lax.transpose(rhs, (1, 0)) + contracting_dims_2d[1] = 1 + else: + contracting_dims_2d[1] = contracting_dims[1] + # Invoke GEMM with guaranteed 2D inputs, so batched_output=False ( out, out_amax_updated, @@ -392,7 +466,8 @@ def impl( out_amax, out_scale, out_dtype=out_dtype, - contracting_dims=contracting_dims, + batched_output=False, + contracting_dims=contracting_dims_2d, fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, grad=grad, @@ -401,9 +476,9 @@ def impl( ) # Recover batched dimensions in the output - if reshape_output: - out_batched_shape = (*batch_shape, int(out.shape[-2] / batch_size), out.shape[-1]) - out = jax.lax.reshape(out, out_batched_shape) + if batched_output: + out_shape = (*lhs_batch_shape, out.shape[-2] // lhs_batch_size, out.shape[-1]) + out = jax.lax.reshape(out, out_shape) return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @@ -413,6 +488,7 @@ def batcher( batch_dims, *, out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -428,6 +504,7 @@ def batcher( CollectiveGemmPrimitive.outer_primitive.bind( *batched_args, out_dtype=out_dtype, + batched_output=batched_output, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, @@ -441,6 +518,7 @@ def batcher( @staticmethod def infer_sharding_from_operands( out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -456,34 +534,43 @@ def infer_sharding_from_operands( lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) - if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + + # Modify operand specs: + # - FSDP axes are all-gathered + # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded + # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension + lhs_spec_new = remove_fsdp_specs(lhs_spec) + rhs_spec_new = remove_fsdp_specs(rhs_spec) + if lhs_spec_new[lhs_inner_dim] != rhs_spec_new[rhs_inner_dim] and not grad: warnings.warn( "Forcing the inner dimension of LHS to match the sharding of inner " + "dimension of RHS. This can trigger additional communication if LHS is " + "not already partitioned correctly." ) - - lhs_outer_dim, rhs_outer_dim = map( - lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, - (lhs_inner_dim, rhs_inner_dim), - (lhs.ndim, rhs.ndim), - ) - rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 - lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] - batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] - rhs_outer_spec = rhs_spec[rhs_outer_dim] - - # Outer (sequence) dimension of the GEMM output is always unsharded - out_spec = [*batch_specs, None, rhs_outer_spec] - if lhs.ndim > 2 and rhs.ndim > 2: - out_spec = [None, rhs_outer_spec] + rhs_outer_spec = rhs_spec_new[rhs_outer_dim] + if rhs_outer_spec is not None: + lhs_spec_new[lhs_outer_dim] = None + lhs_spec_new[lhs_inner_dim] = rhs_spec_new[rhs_inner_dim] + + # Output sharding is conditional on output shape + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] + lhs_outer_spec = lhs_spec_new[lhs_outer_dim] + out_spec = [lhs_outer_spec, rhs_outer_spec] + if batched_output: + out_spec = batch_spec + out_spec out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) # FP8 metas are always unsharded fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) - # Pre-GELU output matches output, if GELU fusion is turned on, otherwise unsharded - gelu_spec = [None, rhs_outer_spec] if fuse_gelu else [None] + # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded + gelu_spec = [lhs_outer_spec, rhs_outer_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) # Bias gradient spec matches outer dimension of output if bias fusion is turned on @@ -493,6 +580,7 @@ def infer_sharding_from_operands( @staticmethod def partition( out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -509,41 +597,22 @@ def partition( lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) lhs_outer_dim, rhs_outer_dim = map( - lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs.ndim, rhs.ndim), ) - lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] - batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] - - # Force all-gather the outer (sequence) dimension of the LHS operand - lhs_spec_new = [spec for spec in lhs_spec] - lhs_spec_new[lhs_outer_dim] = None - lhs_spec_new[lhs_inner_dim] = rhs_spec[rhs_inner_dim] - lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) - - # If both dims of RHS is sharded (i.e. FSDP), determine if we do AG or AR based on LHS - # sharding. - rhs_spec_new = [spec for spec in rhs_spec] - if rhs_spec[rhs_inner_dim] is not None and rhs_spec[rhs_outer_dim] is not None: - if lhs_spec[lhs_inner_dim] is not None and lhs_spec[lhs_outer_dim] is not None: - # All dimensions of both LHS and RHS are sharded and the collective operation is - # ambiguous, we cannot infer sharding. - raise RuntimeError( - "Collective GEMM custom op cannot infer partitioning when both outer and " - + "contracting dimensions of both LHS and RHS operands are sharded." - ) - elif lhs_spec[lhs_inner_dim] is not None: - # All-reduce after GEMM, so unshard the outer dimension of RHS - rhs_spec_new[rhs_outer_dim] = None - else: - # We either do all-gather before GEMM, or LHS is already unsharded, so unshard - # the inner dimension of RHS to match - rhs_spec_new[rhs_inner_dim] = None + # Modify operand specs: + # - FSDP axes are all-gathered + # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded + # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension + lhs_spec_new = remove_fsdp_specs(lhs_spec) + rhs_spec_new = remove_fsdp_specs(rhs_spec) rhs_outer_spec = rhs_spec_new[rhs_outer_dim] - - # RHS operand is unchanged, we already enforce that only one dimension can be sharded + if rhs_outer_spec is not None: + lhs_spec_new[lhs_outer_dim] = None + lhs_spec_new[lhs_inner_dim] = rhs_spec_new[rhs_inner_dim] + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec_new)) # Bias is sharded to match outer dimension spec of the RHS operand (also the output) @@ -552,14 +621,17 @@ def partition( # FP8 metas are always unsharded fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) - # Outer (sequence) dimension of the GEMM output is always unsharded - out_spec = [*batch_specs, None, rhs_outer_spec] - if lhs.ndim > 2 and rhs.ndim > 2: - out_spec = [None, rhs_outer_spec] + # Output sharding is conditional on output shape + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] + lhs_outer_spec = lhs_spec_new[lhs_outer_dim] + out_spec = [lhs_outer_spec, rhs_outer_spec] + if batched_output: + out_spec = batch_spec + out_spec out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) - # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded - gelu_spec = [None, rhs_outer_spec] if fuse_gelu else [None] + # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded + gelu_spec = [lhs_outer_spec, rhs_outer_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) arg_shardings = ( @@ -599,6 +671,7 @@ def sharded_impl( out_amax, out_scale, out_dtype=out_dtype, + batched_output=batched_output, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, @@ -637,6 +710,7 @@ def fp8_gemm_impl( out_amax: Optional[ArrayLike] = None, out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, + batched_output: bool = False, fuse_gelu: bool = False, fuse_bias: bool = False, accumulate: bool = False, @@ -657,8 +731,8 @@ def fp8_gemm_impl( if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=bias.dtype) elif gelu_input is None: - out_shape = (*lhs.shape[:-2], lhs.shape[-2], rhs_t.shape[-2]) - gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) + gelu_shape = (reduce(operator.mul, lhs.shape[:-1]), rhs_t.shape[-1]) + gelu_input = jnp.zeros(gelu_shape, dtype=bias.dtype) out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( lhs, @@ -670,6 +744,7 @@ def fp8_gemm_impl( out_amax, out_scale, out_dtype=out_dtype, + batched_output=batched_output, contracting_dims=(-1, -1), fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, @@ -686,7 +761,8 @@ def gemm_impl( rhs: ArrayLike, bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, - contracting_dims: Tuple[int, int] = (1, 0), + batched_output: bool = False, + contracting_dims: Tuple[int, int] = (-1, -2), fuse_gelu: bool = False, fuse_bias: bool = False, grad: bool = False, @@ -696,16 +772,15 @@ def gemm_impl( """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) lhs_outer_dim, rhs_outer_dim = map( - lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs.ndim, rhs.ndim), ) - out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: bias = jnp.zeros(0, dtype=lhs.dtype) elif grad: - bias = jnp.zeros(out_shape[-1], dtype=lhs.dtype) + bias = jnp.zeros(rhs.shape[rhs_outer_dim], dtype=lhs.dtype) else: assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." @@ -716,7 +791,10 @@ def gemm_impl( gelu_input is not None ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." elif gelu_input is None: - gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) + bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_size = reduce(operator.mul, [lhs.shape[dim] for dim in bdims], 1) + gelu_shape = (batch_size * lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + gelu_input = jnp.zeros(gelu_shape, dtype=lhs.dtypes) dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( @@ -729,6 +807,7 @@ def gemm_impl( dummy_fp8_meta, dummy_fp8_meta, out_dtype=lhs.dtype, + batched_output=batched_output, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 730d17846e..18d1f76da7 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -18,7 +18,7 @@ dbias_cast_transpose, dact_lu_dbias_cast_transpose, ) -from .cpp_extensions.gemm import sanitize_dims +from .cpp_extensions.gemm import sanitize_dims, mirror_dim __all__ = [ @@ -72,10 +72,13 @@ def _gemm_fwd_rule( fuse_bias = bias is not None + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --------> ([B], M, N/P) + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) out, pre_gelu_out = gemm_impl( x, kernel, bias=bias, + batched_output=(x.ndim > 2), contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, @@ -103,14 +106,22 @@ def _gemm_bwd_rule( ): x, kernel, pre_gelu_out, fuse_bias = ctx x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - x_outer_dim = x.ndim - 1 if x_inner_dim != x.ndim - 1 else x.ndim - 2 - kernel_outer_dim = kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 + x_outer_dim, kernel_outer_dim = map( + mirror_dim, (x_inner_dim, kernel_inner_dim), (x.ndim, kernel.ndim) + ) + + # FWD MODE: + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --------> ([B], M, N/P) + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) - # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) + # DGRAD: + # AG+GEMM: ([B], M, N/P) x (K, N/P)^T --(AR)--> ([B], M, K) + # GEMM+AR: ([B], M, N) x (K/P, N)^T --------> ([B], M, K/P) dgrad, dgelu, _ = gemm_impl( grad, kernel, gelu_input=pre_gelu_out, + batched_output=(x.ndim > 2), contracting_dims=(-1, kernel_outer_dim), fuse_gelu=fuse_gelu, fuse_bias=False, @@ -119,12 +130,15 @@ def _gemm_bwd_rule( use_split_accumulator=use_split_accumulator, ) - # WGRAD: ([B], M, K)^T x ([B], M, N) = (K, N) + # WGRAD: + # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # GEMM+AR: ([B], M, K/P)^T x ([B], M, N) ----> (K/P, N) wgrad_rhs = dgelu if fuse_gelu else grad wgrad, _, bgrad = gemm_impl( x, wgrad_rhs, gelu_input=pre_gelu_out, + batched_output=False, contracting_dims=(x_outer_dim, wgrad_rhs.ndim - 2), fuse_gelu=False, fuse_bias=fuse_bias, @@ -279,6 +293,7 @@ def _fp8_gemm_fwd_rule( out_amax=out_amax, out_scale=out_scale, out_dtype=out_dtype, + batched_output=(x.ndim > 2), fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, accumulate=accumulate, @@ -300,6 +315,7 @@ def _fp8_gemm_fwd_rule( pre_gelu_out if fuse_gelu else None, fuse_bias, maybe_fp32_to_fm32, + (x.ndim > 2), ) return (out, updated_out_scale), ctx @@ -325,6 +341,7 @@ def _fp8_gemm_bwd_rule( pre_gelu_out, fuse_bias, maybe_fp32_to_fm32, + batched_input, ) = ctx bwd_dtype = FP8Helper.BWD_DTYPE @@ -382,6 +399,7 @@ def _fp8_gemm_bwd_rule( grad_scale_inv, casted_kernel, kernel_scale_inv, + batched_output=batched_input, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) @@ -392,6 +410,7 @@ def _fp8_gemm_bwd_rule( x_scale_inv, casted_grad_t, grad_scale_inv, + out_shape=False, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) From ce86dcb9c5d55c409ac92f9d8bafb0b7f01bc042 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 11:38:22 +0000 Subject: [PATCH 11/19] fixed logic to remove FSDP sharding Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index bf80941f85..d54009e60b 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -49,25 +49,44 @@ def mirror_dim(dim, ndims): def remove_fsdp_specs(pspecs): fsdp_resource = global_mesh_resource().fsdp_resource + if fsdp_resource is None: + return list(pspecs).copy() + new_pspecs = [] for spec in pspecs: if spec is None: new_pspecs.append(None) - elif fsdp_resource not in spec: - new_pspecs.append(spec) + elif isinstance(spec, Iterable) and not isinstance(spec, str): new_spec = [] for s in spec: - if s != fsdp_resource: + if s == fsdp_resource: + new_spec.append(None) + else: new_spec.append(s) + if len(new_spec) > 1: new_pspecs.append(new_spec) elif len(new_spec) == 1: new_pspecs.append(new_spec[0]) else: new_pspecs.append(None) + + elif isinstance(spec, str): + if spec == fsdp_resource: + new_pspecs.append(None) + else: + new_pspecs.append(spec) + else: - new_pspecs.append(None) + new_pspecs.append(spec) + + assert len(new_pspecs) == len(pspecs), ( + "Length of partition specs changed when removing FSDP sharding!\n" + + f"Original: {pspecs}\n" + + f"Filtered: {new_pspecs}\n" + ) + return new_pspecs From b215f207bd78acfd672264f7e52880a0a8137598 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:38:49 +0000 Subject: [PATCH 12/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d54009e60b..3c4bf15d00 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -198,18 +198,18 @@ def abstract( # Infer output shape if batched_output: - assert lhs_aval.ndim > 2 and rhs_aval.ndim == 2, ( - "Batched output requires batched LHS and non-batched RHS operands." - ) + assert ( + lhs_aval.ndim > 2 and rhs_aval.ndim == 2 + ), "Batched output requires batched LHS and non-batched RHS operands." out_shape = ( *lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], - rhs_aval.shape[rhs_outer_dim] + rhs_aval.shape[rhs_outer_dim], ) else: - assert lhs_aval.ndim == rhs_aval.ndim, ( - "Non-batched output requires LHS and RHS operands with same number of dimensions." - ) + assert ( + lhs_aval.ndim == rhs_aval.ndim + ), "Non-batched output requires LHS and RHS operands with same number of dimensions." if lhs_aval.ndim > 2: rhs_bdims = [ dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] From cbab16c03109cf5b802b93adc03841828df332dd Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 19:03:52 +0000 Subject: [PATCH 13/19] retained FSDP dims and pushed FSDP all-gather of weight array to outside the custom op Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 52 ++----------------- transformer_engine/jax/gemm.py | 1 + 2 files changed, 6 insertions(+), 47 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 3c4bf15d00..353f2d2509 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -47,49 +47,6 @@ def mirror_dim(dim, ndims): return ndims - 2 if dim == ndims - 1 else ndims - 1 -def remove_fsdp_specs(pspecs): - fsdp_resource = global_mesh_resource().fsdp_resource - if fsdp_resource is None: - return list(pspecs).copy() - - new_pspecs = [] - for spec in pspecs: - if spec is None: - new_pspecs.append(None) - - elif isinstance(spec, Iterable) and not isinstance(spec, str): - new_spec = [] - for s in spec: - if s == fsdp_resource: - new_spec.append(None) - else: - new_spec.append(s) - - if len(new_spec) > 1: - new_pspecs.append(new_spec) - elif len(new_spec) == 1: - new_pspecs.append(new_spec[0]) - else: - new_pspecs.append(None) - - elif isinstance(spec, str): - if spec == fsdp_resource: - new_pspecs.append(None) - else: - new_pspecs.append(spec) - - else: - new_pspecs.append(spec) - - assert len(new_pspecs) == len(pspecs), ( - "Length of partition specs changed when removing FSDP sharding!\n" - + f"Original: {pspecs}\n" - + f"Filtered: {new_pspecs}\n" - ) - - return new_pspecs - - def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if tex.get_device_compute_capability() >= 90: @@ -563,8 +520,8 @@ def infer_sharding_from_operands( # - FSDP axes are all-gathered # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension - lhs_spec_new = remove_fsdp_specs(lhs_spec) - rhs_spec_new = remove_fsdp_specs(rhs_spec) + lhs_spec_new = [spec for spec in lhs_spec] + rhs_spec_new = [spec for spec in rhs_spec] if lhs_spec_new[lhs_inner_dim] != rhs_spec_new[rhs_inner_dim] and not grad: warnings.warn( "Forcing the inner dimension of LHS to match the sharding of inner " @@ -594,6 +551,7 @@ def infer_sharding_from_operands( # Bias gradient spec matches outer dimension of output if bias fusion is turned on bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) @staticmethod @@ -625,8 +583,8 @@ def partition( # - FSDP axes are all-gathered # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension - lhs_spec_new = remove_fsdp_specs(lhs_spec) - rhs_spec_new = remove_fsdp_specs(rhs_spec) + lhs_spec_new = [spec for spec in lhs_spec] + rhs_spec_new = [spec for spec in rhs_spec] rhs_outer_spec = rhs_spec_new[rhs_outer_dim] if rhs_outer_spec is not None: lhs_spec_new[lhs_outer_dim] = None diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 18d1f76da7..464ccb12f9 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -8,6 +8,7 @@ import jax import jax.numpy as jnp from jax.typing import ArrayLike +from jax.sharding import NamedSharding, PartitionSpec from .fp8 import FP8Helper, FP8MetaPackage from .cpp_extensions import ( From 0ea55c0eed1c5551a8b8872ff095d70d9e5d1625 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 19:46:02 +0000 Subject: [PATCH 14/19] Added useful warning about DGRAD sharding not matching sequence/context-parallel LHS operands Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 353f2d2509..823e9f7ea1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -53,7 +53,6 @@ def get_cublas_workspace_size_bytes() -> None: return 33_554_432 return 4_194_304 - class CollectiveGemmPrimitive(BasePrimitive): """ cuBlasLt GEMM Primitive w/ support for distributed inputs @@ -385,15 +384,9 @@ def impl( lhs_batch_shape = [lhs.shape[dim] for dim in lhs_batch_dims] lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) contracting_dims_2d = list(contracting_dims).copy() - if batched_output: - # If output is batched, the LSH batch dimension collapses into the outer dimension - # and RHS cannot be batched - lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_outer_dim], lhs.shape[lhs_inner_dim]) - lhs_layout = (*lhs_batch_dims, lhs_outer_dim, lhs_inner_dim) - contracting_dims_2d[0] = 1 - else: - # If the output is not batched, both LHS and RHS batch dimensions collapse into the - # contracting dimensions + if lhs.ndim > 2 and rhs.ndim > 2: + # If both LHS and RHS are batched, the batch dimensions collapse into the + # contracting dimensions for both operands lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_inner_dim], lhs.shape[lhs_outer_dim]) lhs_layout = (*lhs_batch_dims, lhs_inner_dim, lhs_outer_dim) contracting_dims_2d[0] = 0 @@ -406,6 +399,11 @@ def impl( rhs_2d_shape = (rhs_batch_size * rhs.shape[rhs_inner_dim], rhs.shape[rhs_outer_dim]) rhs_layout = (*rhs_batch_dims, rhs_inner_dim, rhs_outer_dim) contracting_dims_2d[1] = 0 + elif lhs.ndim > 2: + # If only the LHS is batched,the batch dimension collapses into the outer dimension + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_outer_dim], lhs.shape[lhs_inner_dim]) + lhs_layout = (*lhs_batch_dims, lhs_outer_dim, lhs_inner_dim) + contracting_dims_2d[0] = 1 # Reshape LHS and RHS into 2D and fix layouts for FP8 GEMM if lhs_2d_shape is not None and lhs.ndim > 2: @@ -524,12 +522,17 @@ def infer_sharding_from_operands( rhs_spec_new = [spec for spec in rhs_spec] if lhs_spec_new[lhs_inner_dim] != rhs_spec_new[rhs_inner_dim] and not grad: warnings.warn( - "Forcing the inner dimension of LHS to match the sharding of inner " - + "dimension of RHS. This can trigger additional communication if LHS is " - + "not already partitioned correctly." + "Forcing LHS sharding in the contracting dimension to match RHS. This can trigger " + + "additional communication if LHS is not already partitioned correctly." ) rhs_outer_spec = rhs_spec_new[rhs_outer_dim] if rhs_outer_spec is not None: + warnings.warn( + "Forcing the outer dimension of LHS (sequence/context dim) to be all- gathered. " + + "This may trigger additional communication if LHS is not already partitioned " + + "correctly. Additionally, the DGRAD output in the backward pass will not match " + + "the sharding of a sequence/context-parallel LHS operand." + ) lhs_spec_new[lhs_outer_dim] = None lhs_spec_new[lhs_inner_dim] = rhs_spec_new[rhs_inner_dim] @@ -661,8 +664,8 @@ def sharded_impl( if jax_dtype_is_fp8(lhs.dtype): out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) - # GEMM output needs to be all-reduced when the contracting dimension is sharded. if rhs_spec_new[rhs_inner_dim] is not None: + # GEMM output needs to be all-reduced when the contracting dimension is sharded. out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().tp_resource, mesh) if fuse_gelu: pre_gelu_out = lax_paral_op( From 2acb92f49b4687fde25f803f3115b693b900569b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:46:35 +0000 Subject: [PATCH 15/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 823e9f7ea1..31a8760564 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -53,6 +53,7 @@ def get_cublas_workspace_size_bytes() -> None: return 33_554_432 return 4_194_304 + class CollectiveGemmPrimitive(BasePrimitive): """ cuBlasLt GEMM Primitive w/ support for distributed inputs From b07bb2db5726d45dba28b7207bbc2051f166d8c4 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 19:47:44 +0000 Subject: [PATCH 16/19] documentation fixes Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 31a8760564..0f567eecef 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -516,7 +516,6 @@ def infer_sharding_from_operands( ) # Modify operand specs: - # - FSDP axes are all-gathered # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension lhs_spec_new = [spec for spec in lhs_spec] @@ -584,7 +583,6 @@ def partition( ) # Modify operand specs: - # - FSDP axes are all-gathered # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension lhs_spec_new = [spec for spec in lhs_spec] From 765b844525e42d2def624bce7430f798828874d9 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 27 Nov 2024 21:54:39 +0000 Subject: [PATCH 17/19] added unit test, both AG+GEMM and GEMM+AR passing with FSDP+TP, DP+TP and TP-only meshes Signed-off-by: Alp Dener --- tests/jax/test_distributed_gemm.py | 311 ++++++++++++++++++ transformer_engine/jax/cpp_extensions/gemm.py | 107 +++--- transformer_engine/jax/gemm.py | 31 +- 3 files changed, 400 insertions(+), 49 deletions(-) create mode 100644 tests/jax/test_distributed_gemm.py diff --git a/tests/jax/test_distributed_gemm.py b/tests/jax/test_distributed_gemm.py new file mode 100644 index 0000000000..f1e3c58c4a --- /dev/null +++ b/tests/jax/test_distributed_gemm.py @@ -0,0 +1,311 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import pytest +from functools import partial +from collections.abc import Iterable + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils + +import transformer_engine.jax as te +from transformer_engine.jax.gemm import gemm + +from utils import assert_allclose + + +jax.config.update('jax_enable_compilation_cache', False) + + +# AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P) +# - DGRAD: (4, 32, 256/P) x (128, 256/P)^T --(AR)--> (4, 32, 128) +# - WGRAD: (4, 32/P, 128)^T --(AG)--> (4, 32, 128)^T x (4, 32, 256/P) --------> (128, 256/P) + +# GEMM+AR: (4, 32, 256/P) x (256/P, 128) --(AR)--> (4, 32, 128) +# - DGRAD: (4, 32, 128) x (256/P, 128)^T ------> (4, 32, 256/P) +# - WGRAD: (4, 32, 256/P)^T --(AG)--> (4, 32, 256)^T x (4, 32, 128) --------> (256, 128) + +BATCH = 4 +BASE_SIZE = 16 +SEQ_LEN = BASE_SIZE * 8 +HIDDEN_SIZE = BASE_SIZE * 6 +FFN_HIDDEN_SIZE = BASE_SIZE * 16 + +COMM_TYPES = ["ALL_GATHER", "ALL_REDUCE"] +MESH_TYPES = ["FSDP_TP", "DP_TP", "TP"] +NUM_DEVICES = 4 + +is_fp8_supported, no_fp8_reason = te.fp8.is_fp8_available() + + +def _get_mesh(parallel_dist): + jax.clear_caches() + + batched = False + fsdp = False + mesh_shape = dict(tp=NUM_DEVICES) + resources = dict(cp_resource='tp', tp_resource='tp') + if parallel_dist in ["DP_TP", "FSDP_TP"]: + batched = True + mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=NUM_DEVICES//2)) + resources.update(dict(dp_resource='dp')) + if parallel_dist == "FSDP_TP": + fsdp = True + mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=1, zp=NUM_DEVICES//2)) + resources.update(dict(fsdp_resource='zp')) + mesh_resource = te.MeshResource(**resources) + + devices = mesh_utils.create_device_mesh( + (NUM_DEVICES, ), devices=jax.devices()[:NUM_DEVICES] + ) + + mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) + + return mesh, mesh_resource, batched, fsdp + + +def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bwd=False): + fp8_gemm = dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + # Operand and output shapes + lhs_shape = ( + [SEQ_LEN, HIDDEN_SIZE] + if fwd_comm_type == "ALL_GATHER" + else [SEQ_LEN, FFN_HIDDEN_SIZE] + ) + rhs_shape = ( + [HIDDEN_SIZE, FFN_HIDDEN_SIZE] + if fwd_comm_type == "ALL_GATHER" + else [FFN_HIDDEN_SIZE, HIDDEN_SIZE] + ) + out_shape = [lhs_shape[0], rhs_shape[1]] + + if batched: + lhs_shape = [BATCH] + lhs_shape + out_shape = [BATCH] + out_shape + + # Operand and output partition specs + lhs_spec = ( + [mesh_resource.tp_resource, None] + if fwd_comm_type == "ALL_GATHER" + else [None, mesh_resource.tp_resource] + ) + rhs_spec = ( + [None, mesh_resource.tp_resource] + if fwd_comm_type == "ALL_GATHER" + else [mesh_resource.tp_resource, None] + ) + out_spec = [None, rhs_spec[-1]] + + # Modify RHS operand for FP8 + fsdp_gathered_rhs_spec = rhs_spec.copy() + if fp8_gemm: + rhs_shape = list(reversed(rhs_shape)) + rhs_spec = list(reversed(rhs_spec)) + fsdp_gathered_rhs_spec = list(reversed(fsdp_gathered_rhs_spec)) + + # Add batch dimensions and specs + if batched: + if fsdp: + lhs_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + lhs_spec + rhs_spec = [mesh_resource.fsdp_resource if spec is None else spec for spec in rhs_spec] + out_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + out_spec + else: + lhs_spec = [mesh_resource.dp_resource] + lhs_spec + out_spec = [mesh_resource.dp_resource] + out_spec + + # Allocate global operands on device + key = jax.random.PRNGKey(42) + split_keys = jax.random.split(key, 3 if fwd_bwd else 2) + mu = 0.0 + sigma = 0.023 + shapes = (lhs_shape, rhs_shape) + if fwd_bwd: + shapes += (out_shape, ) + global_operands = list( + map( + lambda key, shape: jax.device_put( + mu + (sigma * jax.random.normal(key, shape, dtype=dtype)), + NamedSharding(mesh, PartitionSpec(None)) + ), + split_keys, + shapes, + ) + ) + + # Allocate sharded operands on device + partition_axes = (lhs_spec, rhs_spec) + if fwd_bwd: + partition_axes += (out_spec, ) + local_operands = list( + map( + lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))), + global_operands, + partition_axes, + ) + ) + + # Tranpose global RHS back to non-transpoosed orientation if it was originally allocated + # for FP8 GEMM + if fp8_gemm: + rhs_global = jnp.matrix_transpose(global_operands[1]) + global_operands = (global_operands[0], rhs_global, *global_operands[2:]) + + return ( + local_operands, + global_operands, + (out_shape, out_spec), + fsdp_gathered_rhs_spec, + ) + + +def _check_output(mesh, expected_out_shape, expected_out_specs, *tensors, fwd_bwd=False): + num_operands = 3 if fwd_bwd else 2 + ref_operands = tensors[:num_operands] + test_outputs = tensors[num_operands:] + + # Check number of dimensions + assert test_outputs[0].ndim == len(expected_out_shape), ( + f"Output has different number of dimensions ({test_outputs[0].ndim}) than expected " + + f"({len(expected_out_shape)})" + ) + + # Pad test output spec for unsharded dimensions + test_spec = te.sharding.get_padded_spec(test_outputs[0].sharding.spec, test_outputs[0].ndim) + + for i in range(test_outputs[0].ndim): + # Check shape + assert test_outputs[0].shape[i] == expected_out_shape[i], ( + f"Output with shape {test_outputs[0].shape} does not match expected shape " + + f"{expected_out_shape} in dimension index {i}." + ) + + # Check shardings (with padded output spec) + spec_mismatch = False + if isinstance(expected_out_specs[i], str): + if test_spec[i] != expected_out_specs[i]: + spec_mismatch = True + elif isinstance(expected_out_specs[i], Iterable): + if not isinstance(test_spec[i], type(expected_out_specs[i])): + if test_spec[i] not in expected_out_specs[i]: + spec_mismatch = True + elif len(test_spec[i]) != len(expected_out_specs[i]): + spec_mismatch = True + else: + for j in range(len(expected_out_specs[i])): + if test_spec[i][j] != expected_out_specs[i][j]: + spec_mismatch = True + break + elif expected_out_specs[i] == None: + if test_spec[i] != None: + spec_mismatch = True + else: + raise RuntimeError("Internal TE error: Unrecognized reference partition spec type.") + if spec_mismatch: + raise AssertionError( + f"Output sharding {test_spec} does not match expected sharding " + + f"{expected_out_specs} in dimension index {i}." + ) + + def _native_gemm_fwd_bwd(lhs, rhs, grad): + fwd_out, vjp_fn = jax.vjp(jnp.dot, lhs, rhs) + lhs_grad, rhs_grad = vjp_fn(grad) + return fwd_out, lhs_grad, rhs_grad + + ref_fn = jax.jit(_native_gemm_fwd_bwd if fwd_bwd else jnp.dot) + + out_names = ["output"] + ref_outputs = ref_fn(*ref_operands) + if not fwd_bwd: + ref_outputs = [ref_outputs] + else: + out_names += ["dgrad", "wgrad"] + + for i, (test_out, ref_out) in enumerate(zip(test_outputs, ref_outputs)): + test_out_global = jax.lax.with_sharding_constraint( + test_out, NamedSharding(mesh, PartitionSpec(None)) + ) + try: + assert_allclose(ref_out, test_out_global) + except AssertionError as err: + raise AssertionError(f"Numerical mismatch in {out_names[i]}:\n" + str(err)) + + +@pytest.mark.parametrize("comm_type", COMM_TYPES) +@pytest.mark.parametrize("mesh_type", MESH_TYPES) +def test_gemm_impl(comm_type, mesh_type): + mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) + + ( + local_operands, + global_operands, + output_info, + fsdp_gathered_rhs_spec, + ) = _get_inputs( + mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp + ) + + @jax.jit + def _test_fn(lhs, rhs): + rhs_no_fsdp = jax.lax.with_sharding_constraint( + rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) + ) + return te.cpp_extensions.gemm_impl(lhs, rhs_no_fsdp, batched_output=batched) + + with te.sharding.global_shard_guard(mesh_resource): + output, *_ = _test_fn(*local_operands) + + _check_output(mesh, *output_info, *global_operands, output) + + +@pytest.mark.parametrize("comm_type", COMM_TYPES) +@pytest.mark.parametrize("mesh_type", MESH_TYPES) +def test_gemm_fwd_bwd(comm_type, mesh_type): + mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) + + ( + local_operands, + global_operands, + output_info, + fsdp_gathered_rhs_spec, + ) = _get_inputs( + mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True + ) + + @jax.jit + def _test_fn(lhs, rhs, grad): + # Gather weights in FSDP axis + rhs_no_fsdp = jax.lax.with_sharding_constraint( + rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) + ) + + # FWD pass + fwd_out, vjp_fn = jax.vjp(gemm, lhs, rhs_no_fsdp) + + # BWD pass + lhs_grad, rhs_grad = vjp_fn(grad) + + return fwd_out, lhs_grad, rhs_grad + + print( + f"INPUTS: {local_operands[0].shape} x {local_operands[1].shape}\n" + + f" LHS sharding: {local_operands[0].sharding.spec}\n" + + f" RHS sharding: {local_operands[1].sharding.spec}\n" + ) + + with te.sharding.global_shard_guard(mesh_resource): + output, dgrad, wgrad = _test_fn(*local_operands) + + print( + f"{'AG + GEMM' if comm_type == 'AG' else 'GEMM + AR'} output: " + + f"{output.shape} | {output.sharding.spec}\n" + + f"DGRAD: {dgrad.shape} | {dgrad.sharding.spec}\n" + + f"WGRAD: {wgrad.shape} | {wgrad.sharding.spec}\n" + ) + + _check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True) + diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0f567eecef..30ff0ca54a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -516,31 +516,53 @@ def infer_sharding_from_operands( ) # Modify operand specs: - # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded - # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension - lhs_spec_new = [spec for spec in lhs_spec] - rhs_spec_new = [spec for spec in rhs_spec] - if lhs_spec_new[lhs_inner_dim] != rhs_spec_new[rhs_inner_dim] and not grad: - warnings.warn( - "Forcing LHS sharding in the contracting dimension to match RHS. This can trigger " - + "additional communication if LHS is not already partitioned correctly." + # - If contracting dimensions of both operands are sharded, force them to match. + # - If contracting dimensions of both operands are sharded, all-gather outer dimensions. + # - If contracting dimension of only one operand is sharded, all-gather the sharded + # operand. + # - Never scatter any operand. + lhs_spec_new = list(lhs_spec).copy() + rhs_spec_new = list(rhs_spec).copy() + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + assert lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim], ( + "Contracting dimensions of LHS and RHS operands must have the same sharding." ) - rhs_outer_spec = rhs_spec_new[rhs_outer_dim] - if rhs_outer_spec is not None: - warnings.warn( - "Forcing the outer dimension of LHS (sequence/context dim) to be all- gathered. " - + "This may trigger additional communication if LHS is not already partitioned " - + "correctly. Additionally, the DGRAD output in the backward pass will not match " - + "the sharding of a sequence/context-parallel LHS operand." - ) - lhs_spec_new[lhs_outer_dim] = None - lhs_spec_new[lhs_inner_dim] = rhs_spec_new[rhs_inner_dim] + if lhs_spec_new[lhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the LHS operand must be all-gathered when both contracting " + + "dimensions are sharded. This will cause additional communication overhead." + ) + + if rhs_spec_new[rhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the RHS operand must be all-gathered when both contracting " + + "dimensions are sharded. This will cause additional communication overhead." + ) + rhs_spec_new[rhs_outer_dim] = None + else: + if lhs_spec_new[lhs_inner_dim] is None and rhs_spec_new[rhs_inner_dim] is not None: + warnings.warn( + "Contracting dimension of the RHS operand must be all-gathered when the " + + "contracting dimension of the LHS operand is unsharded. This will cause " + + "additional communication overhead." + ) + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is None: + if not grad: + # This is expected for sequence/context-parallel gradient in BWD (DGRAD) GEMM. + warnings.warn( + "Contracting dimension of the LHS operand must be all-gathered when the " + + "contracting dimension of the RHS operand is unsharded. This will cause " + + "additional communication overhead." + ) + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + out_col_spec = rhs_spec_new[rhs_outer_dim] # Output sharding is conditional on output shape lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] - lhs_outer_spec = lhs_spec_new[lhs_outer_dim] - out_spec = [lhs_outer_spec, rhs_outer_spec] + out_spec = [None, out_col_spec] if batched_output: out_spec = batch_spec + out_spec out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) @@ -549,11 +571,11 @@ def infer_sharding_from_operands( fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded - gelu_spec = [lhs_outer_spec, rhs_outer_spec] if fuse_gelu else [None] + gelu_spec = [None, out_col_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) # Bias gradient spec matches outer dimension of output if bias fusion is turned on - bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) @@ -583,19 +605,27 @@ def partition( ) # Modify operand specs: - # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded - # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension - lhs_spec_new = [spec for spec in lhs_spec] - rhs_spec_new = [spec for spec in rhs_spec] - rhs_outer_spec = rhs_spec_new[rhs_outer_dim] - if rhs_outer_spec is not None: - lhs_spec_new[lhs_outer_dim] = None - lhs_spec_new[lhs_inner_dim] = rhs_spec_new[rhs_inner_dim] + # - Always all-gather the outer dimension of LHS. + # - If contracting dimensions of both operands are sharded, all-gather RHS outer dimension. + # - If contracting dimension of only one operand is sharded, all-gather the sharded + # operand. + # - Never scatter any operand. + lhs_spec_new = list(lhs_spec).copy() + rhs_spec_new = list(rhs_spec).copy() + reduce_output = False + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + rhs_spec_new[rhs_outer_dim] = None + reduce_output = True + else: + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + out_col_spec = rhs_spec_new[rhs_outer_dim] lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec_new)) # Bias is sharded to match outer dimension spec of the RHS operand (also the output) - bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) # FP8 metas are always unsharded fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) @@ -603,14 +633,13 @@ def partition( # Output sharding is conditional on output shape lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] - lhs_outer_spec = lhs_spec_new[lhs_outer_dim] - out_spec = [lhs_outer_spec, rhs_outer_spec] + out_spec = [None, out_col_spec] if batched_output: out_spec = batch_spec + out_spec out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded - gelu_spec = [lhs_outer_spec, rhs_outer_spec] if fuse_gelu else [None] + gelu_spec = [None, out_col_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) arg_shardings = ( @@ -663,13 +692,11 @@ def sharded_impl( if jax_dtype_is_fp8(lhs.dtype): out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) - if rhs_spec_new[rhs_inner_dim] is not None: - # GEMM output needs to be all-reduced when the contracting dimension is sharded. - out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().tp_resource, mesh) + # All-reduce sum GEMM output when contracting dimensions are sharded + if reduce_output: + out = jax.lax.psum(out, global_mesh_resource().tp_resource) if fuse_gelu: - pre_gelu_out = lax_paral_op( - pre_gelu_out, jax.lax.psum, global_mesh_resource().tp_resource, mesh - ) + pre_gelu_out = jax.lax.psum(pre_gelu_out, global_mesh_resource().tp_resource) return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 464ccb12f9..4cf09a204f 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -33,7 +33,7 @@ def gemm( x: ArrayLike, kernel: ArrayLike, bias: Optional[ArrayLike] = None, - contracting_dims: Tuple[int, int] = (1, 0), + contracting_dims: Tuple[int, int] = (-1, -2), fuse_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, @@ -73,8 +73,11 @@ def _gemm_fwd_rule( fuse_bias = bias is not None - # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --------> ([B], M, N/P) - # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) out, pre_gelu_out = gemm_impl( x, kernel, @@ -112,12 +115,18 @@ def _gemm_bwd_rule( ) # FWD MODE: - # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --------> ([B], M, N/P) - # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) # DGRAD: - # AG+GEMM: ([B], M, N/P) x (K, N/P)^T --(AR)--> ([B], M, K) - # GEMM+AR: ([B], M, N) x (K/P, N)^T --------> ([B], M, K/P) + # AG+GEMM: ([B], M, N/P) x (K, N/P)^T ----(AR)----> ([B], M, K) + # (DP, None, TP) x (None, TP)^T --(AR)--> (DP, None, None) + # + # GEMM+AR: ([B], M, N) x (K/P, N)^T ------> ([B], M, K/P) + # (DP, None, None) x (TP, None)^T --> (DP, None, TP) dgrad, dgelu, _ = gemm_impl( grad, kernel, @@ -133,7 +142,11 @@ def _gemm_bwd_rule( # WGRAD: # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) - # GEMM+AR: ([B], M, K/P)^T x ([B], M, N) ----> (K/P, N) + # (DP, 'tp', None)^T --(AG)-->(DP, None, None)^T x (DP, None, 'tp') --> (None, 'tp') + # + # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K/P, N) + # (DP, None, 'tp')^T --(AG)--> (DP, None, None)^T x (DP, None, None) ----> (None, None) + # Make XLA scatter output in first dim. wgrad_rhs = dgelu if fuse_gelu else grad wgrad, _, bgrad = gemm_impl( x, @@ -445,7 +458,7 @@ def type_safe_gemm( bias: Optional[ArrayLike] = None, fp8_meta: Optional[FP8MetaPackage] = None, out_dtype: Optional[jnp.dtype] = None, - contracting_dims: Tuple[int, int] = (1, 0), + contracting_dims: Tuple[int, int] = (-1, -2), fuse_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, From 2ce4377702d20d48564383647caede1f2dcf1e6e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 21:55:29 +0000 Subject: [PATCH 18/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_distributed_gemm.py | 35 +++++++------------ transformer_engine/jax/cpp_extensions/gemm.py | 6 ++-- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/tests/jax/test_distributed_gemm.py b/tests/jax/test_distributed_gemm.py index f1e3c58c4a..b246999d8a 100644 --- a/tests/jax/test_distributed_gemm.py +++ b/tests/jax/test_distributed_gemm.py @@ -18,7 +18,7 @@ from utils import assert_allclose -jax.config.update('jax_enable_compilation_cache', False) +jax.config.update("jax_enable_compilation_cache", False) # AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P) @@ -48,20 +48,18 @@ def _get_mesh(parallel_dist): batched = False fsdp = False mesh_shape = dict(tp=NUM_DEVICES) - resources = dict(cp_resource='tp', tp_resource='tp') + resources = dict(cp_resource="tp", tp_resource="tp") if parallel_dist in ["DP_TP", "FSDP_TP"]: batched = True - mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=NUM_DEVICES//2)) - resources.update(dict(dp_resource='dp')) + mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=NUM_DEVICES // 2)) + resources.update(dict(dp_resource="dp")) if parallel_dist == "FSDP_TP": fsdp = True - mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=1, zp=NUM_DEVICES//2)) - resources.update(dict(fsdp_resource='zp')) + mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=1, zp=NUM_DEVICES // 2)) + resources.update(dict(fsdp_resource="zp")) mesh_resource = te.MeshResource(**resources) - devices = mesh_utils.create_device_mesh( - (NUM_DEVICES, ), devices=jax.devices()[:NUM_DEVICES] - ) + devices = mesh_utils.create_device_mesh((NUM_DEVICES,), devices=jax.devices()[:NUM_DEVICES]) mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) @@ -73,9 +71,7 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw # Operand and output shapes lhs_shape = ( - [SEQ_LEN, HIDDEN_SIZE] - if fwd_comm_type == "ALL_GATHER" - else [SEQ_LEN, FFN_HIDDEN_SIZE] + [SEQ_LEN, HIDDEN_SIZE] if fwd_comm_type == "ALL_GATHER" else [SEQ_LEN, FFN_HIDDEN_SIZE] ) rhs_shape = ( [HIDDEN_SIZE, FFN_HIDDEN_SIZE] @@ -125,12 +121,12 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw sigma = 0.023 shapes = (lhs_shape, rhs_shape) if fwd_bwd: - shapes += (out_shape, ) + shapes += (out_shape,) global_operands = list( map( lambda key, shape: jax.device_put( mu + (sigma * jax.random.normal(key, shape, dtype=dtype)), - NamedSharding(mesh, PartitionSpec(None)) + NamedSharding(mesh, PartitionSpec(None)), ), split_keys, shapes, @@ -140,7 +136,7 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw # Allocate sharded operands on device partition_axes = (lhs_spec, rhs_spec) if fwd_bwd: - partition_axes += (out_spec, ) + partition_axes += (out_spec,) local_operands = list( map( lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))), @@ -245,9 +241,7 @@ def test_gemm_impl(comm_type, mesh_type): global_operands, output_info, fsdp_gathered_rhs_spec, - ) = _get_inputs( - mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp - ) + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp) @jax.jit def _test_fn(lhs, rhs): @@ -272,9 +266,7 @@ def test_gemm_fwd_bwd(comm_type, mesh_type): global_operands, output_info, fsdp_gathered_rhs_spec, - ) = _get_inputs( - mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True - ) + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True) @jax.jit def _test_fn(lhs, rhs, grad): @@ -308,4 +300,3 @@ def _test_fn(lhs, rhs, grad): ) _check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True) - diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 30ff0ca54a..250e8e0c29 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -525,9 +525,9 @@ def infer_sharding_from_operands( rhs_spec_new = list(rhs_spec).copy() lhs_spec_new[lhs_outer_dim] = None if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: - assert lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim], ( - "Contracting dimensions of LHS and RHS operands must have the same sharding." - ) + assert ( + lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim] + ), "Contracting dimensions of LHS and RHS operands must have the same sharding." if lhs_spec_new[lhs_outer_dim] is not None: warnings.warn( "Outer dimension of the LHS operand must be all-gathered when both contracting " From f68d71edc56980932b4a4a07ab7d26c44fdaa4e7 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 5 Dec 2024 21:29:27 +0000 Subject: [PATCH 19/19] restored old test_custom_call_compute.py to remove erroneous changes Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 50 --------------------------- 1 file changed, 50 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 355f587265..20b16c2809 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -25,7 +25,6 @@ _jax_dbias_cast_transpose, ) from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8 -from transformer_engine.jax.gemm import fp8_gemm, gemm from transformer_engine.jax import cpp_extensions as tex @@ -416,55 +415,6 @@ def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_ ) -class TestGemm: - - @staticmethod - def _generate_inputs(b, m, n, k, dtype): - key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 3) - a = jax.random.normal(subkeys[0], (b, m, k), dtype) - b = jax.random.normal(subkeys[1], (n, k), dtype) - bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16 - bias = jax.random.normal(subkeys[2], (n,), bias_dtype) - return a, b, bias - - @staticmethod - def _generate_fp8_inputs(b, m, n, k, fp8_dtype): - a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16) - a_scale, b_scale = map(lambda x: (jnp.max(jnp.abs(x)) / 127.0).astype(jnp.float32), [a, b]) - a_q, b_q = map( - lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype), - [(a, a_scale), (b, b_scale)], - ) - return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias - - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - @pytest.mark.parametrize("use_bias", (False, True)) - @pytest.mark.parametrize("do_gelu", (False, True)) - def test_gemm(self, b, m, n, k, use_bias, do_gelu): - a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16) - - primitive_out = gemm(a, b, bias=bias if use_bias else None, layout="NT", do_gelu=do_gelu) - ref_out = jnp.dot(a, b) - if use_bias: - ref_out += bias - if do_gelu: - ref_out = jax.nn.gelu(ref_out) - - assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) - - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - @pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE) - def test_fp8_gemm(self, m, n, k, fp8_dtype): - a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs(m, n, k, fp8_dtype) - - primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16) - ref_out = jnp.dot(a, b) - - assert_allclose(primitive_out, ref_out, dtype=fp8_dtype) - - @pytest.fixture(name="random_inputs") def random_inputs_fixture(shape): key = jax.random.PRNGKey(0)