Skip to content

Commit

Permalink
[Mosaic GPU] Handle the swizzle attribute in the lowering of `async…
Browse files Browse the repository at this point in the history
…_store` and `async_load`

PiperOrigin-RevId: 720129408
  • Loading branch information
dimitar-asenov authored and Google-ML-Automation committed Jan 27, 2025
1 parent 101f18d commit a3a285d
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 5 deletions.
2 changes: 2 additions & 0 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def _mgpu_async_load_op_lowering_rule(
barrier=barrier,
arrive=load_op.arrive,
uniform=False,
swizzle=load_op.swizzle.value,
)
return []

Expand All @@ -239,6 +240,7 @@ def _mgpu_async_store_op_lowering_rule(
launch_context.async_copy(
src_ref=store_op.source,
dst_ref=store_op.destination,
swizzle=store_op.swizzle.value,
)
return []

Expand Down
14 changes: 12 additions & 2 deletions jax/experimental/mosaic/gpu/launch_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import math
from typing import Any

from jax._src.lib import mosaic_gpu_dialect as mgpu_dialect
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
from jaxlib.mlir.dialects import func
Expand Down Expand Up @@ -287,14 +288,19 @@ def init_tma_desc(host_ptr):
)
rank = ref_ty.rank
assert rank * 2 == len(sizes_and_strides)
swizzle_arg = (
mgpu_dialect.SwizzlingMode.kNoSwizzle
if swizzle is None
else swizzle
)
args = [
host_ptr,
base_ptr,
c(utils.bitwidth(ref_ty.element_type), i64),
c(rank, i64),
utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]),
utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]),
c(0 if swizzle is None else swizzle, i64),
c(swizzle_arg, i64),
utils.pack_array([c(v, i64) for v in transformed_slice_shape]),
]
func.call([], "mosaic_gpu_init_tma_desc", args)
Expand Down Expand Up @@ -513,7 +519,11 @@ def partition_dim(dim: int, idx: ir.Value, num_chunks: int):
"Async copies require the number of bytes copied along the last"
f" dimension to be divisible by 16, but got {zeroth_bw}"
)
if swizzle is not None and slice_shape[-1] != (swizzle * 8) // element_bitwidth:
if (
swizzle is not None
and swizzle != mgpu_dialect.SwizzlingMode.kNoSwizzle
and slice_shape[-1] != (swizzle * 8) // element_bitwidth
):
raise ValueError(
f"Async copies with {swizzle=} require the last dimension of the"
f" slice to be exactly {swizzle} bytes i.e. "
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/mosaic/gpu/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr,
}
cuuint32_t element_strides[5] = {1, 1, 1, 1, 1};
CUtensorMapSwizzle swizzle;
if (swizzle_bytes == 0) {
if (swizzle_bytes == 16) {
swizzle = CU_TENSOR_MAP_SWIZZLE_NONE;
} else if (swizzle_bytes == 32) {
swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
Expand Down
9 changes: 7 additions & 2 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1997,7 +1997,8 @@ def add(ctx, a, b, result, smem):

self.assertArraysEqual(jax.jit(kernel)(x, y), x + y)

def test_pointwise_kernel_with_tma(self):
@parameterized.parameters(*mgpu_dialect.SwizzlingMode)
def test_pointwise_kernel_with_tma(self, swizzle):
def add(
ctx: launch_context.LaunchContext,
a_gmem_ref: ir.Value,
Expand Down Expand Up @@ -2029,6 +2030,7 @@ def add(
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
arrive=False,
swizzle=swizzle,
)
mgpu_dialect.async_load(
source=b_gmem_ref,
Expand All @@ -2039,6 +2041,7 @@ def add(
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
arrive=False,
swizzle=swizzle,
)

tma_barrier.wait()
Expand All @@ -2063,12 +2066,14 @@ def add(
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
swizzle=swizzle,
)
nvvm.cp_async_bulk_wait_group(0)
utils.warpgroup_barrier()

dtype = jnp.bfloat16
shape = (128, 128)
shape = (128, swizzle*8 // jnp.finfo(dtype).bits)

jax_shape = jax.ShapeDtypeStruct(shape, dtype)
kernel = mgpu.as_gpu_kernel(
add,
Expand Down

0 comments on commit a3a285d

Please sign in to comment.