From a3a285dddc84b2ab98c9ec27b541b79fe347376e Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Mon, 27 Jan 2025 05:17:31 -0800 Subject: [PATCH] [Mosaic GPU] Handle the `swizzle` attribute in the lowering of `async_store` and `async_load` PiperOrigin-RevId: 720129408 --- jax/experimental/mosaic/gpu/dialect_lowering.py | 2 ++ jax/experimental/mosaic/gpu/launch_context.py | 14 ++++++++++++-- jaxlib/mosaic/gpu/runtime.cc | 2 +- tests/mosaic/gpu_test.py | 9 +++++++-- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index e01ac9c99ce4..a1d1dd825532 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -227,6 +227,7 @@ def _mgpu_async_load_op_lowering_rule( barrier=barrier, arrive=load_op.arrive, uniform=False, + swizzle=load_op.swizzle.value, ) return [] @@ -239,6 +240,7 @@ def _mgpu_async_store_op_lowering_rule( launch_context.async_copy( src_ref=store_op.source, dst_ref=store_op.destination, + swizzle=store_op.swizzle.value, ) return [] diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 29d38b8ab3a7..c6aa419696e0 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -20,6 +20,7 @@ import math from typing import Any +from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import func @@ -287,6 +288,11 @@ def init_tma_desc(host_ptr): ) rank = ref_ty.rank assert rank * 2 == len(sizes_and_strides) + swizzle_arg = ( + mgpu_dialect.SwizzlingMode.kNoSwizzle + if swizzle is None + else swizzle + ) args = [ host_ptr, base_ptr, @@ -294,7 +300,7 @@ def init_tma_desc(host_ptr): c(rank, i64), utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), - c(0 if swizzle is None else swizzle, i64), + c(swizzle_arg, i64), utils.pack_array([c(v, i64) for v in transformed_slice_shape]), ] func.call([], "mosaic_gpu_init_tma_desc", args) @@ -513,7 +519,11 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int): "Async copies require the number of bytes copied along the last" f" dimension to be divisible by 16, but got {zeroth_bw}" ) - if swizzle is not None and slice_shape[-1] != (swizzle * 8) // element_bitwidth: + if ( + swizzle is not None + and swizzle != mgpu_dialect.SwizzlingMode.kNoSwizzle + and slice_shape[-1] != (swizzle * 8) // element_bitwidth + ): raise ValueError( f"Async copies with {swizzle=} require the last dimension of the" f" slice to be exactly {swizzle} bytes i.e. " diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index 6a6bf5a94dfa..ad3cd0e19644 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -122,7 +122,7 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, } cuuint32_t element_strides[5] = {1, 1, 1, 1, 1}; CUtensorMapSwizzle swizzle; - if (swizzle_bytes == 0) { + if (swizzle_bytes == 16) { swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; } else if (swizzle_bytes == 32) { swizzle = CU_TENSOR_MAP_SWIZZLE_32B; diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 5d3fc148a6e3..113373c21b21 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1997,7 +1997,8 @@ def add(ctx, a, b, result, smem): self.assertArraysEqual(jax.jit(kernel)(x, y), x + y) - def test_pointwise_kernel_with_tma(self): + @parameterized.parameters(*mgpu_dialect.SwizzlingMode) + def test_pointwise_kernel_with_tma(self, swizzle): def add( ctx: launch_context.LaunchContext, a_gmem_ref: ir.Value, @@ -2029,6 +2030,7 @@ def add( transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), arrive=False, + swizzle=swizzle, ) mgpu_dialect.async_load( source=b_gmem_ref, @@ -2039,6 +2041,7 @@ def add( transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), arrive=False, + swizzle=swizzle, ) tma_barrier.wait() @@ -2063,12 +2066,14 @@ def add( indices=[zero_i32, zero_i32], slice_lengths=shape, transforms=ir.ArrayAttr.get([]), + swizzle=swizzle, ) nvvm.cp_async_bulk_wait_group(0) utils.warpgroup_barrier() dtype = jnp.bfloat16 - shape = (128, 128) + shape = (128, swizzle*8 // jnp.finfo(dtype).bits) + jax_shape = jax.ShapeDtypeStruct(shape, dtype) kernel = mgpu.as_gpu_kernel( add,