From c9dfdb4e2367f489901f09fa7fb2cfa11be77046 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 29 Jan 2025 04:27:01 -0800 Subject: [PATCH] Relax static offset restriction on memref_ptr for sub-byte types It simply assumes that the base offset is a multiple of byte packing PiperOrigin-RevId: 720919148 --- jax/experimental/mosaic/gpu/utils.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 5b466e6046b6..c3017bf9d574 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1079,13 +1079,19 @@ def memref_ptr(memref_arg, memory_space=None): elem_bitwidth = bitwidth(memref_ty.element_type) if elem_bitwidth < 8: *_, static_offset = memref_ty.get_strides_and_offset() - if static_offset == ir.ShapedType.get_dynamic_stride_or_offset(): - raise NotImplementedError - assert elem_bitwidth.bit_count() == 1 - packing = 8 // elem_bitwidth - if static_offset % packing != 0: - raise ValueError - offset_bytes = c(static_offset // packing, i64) + if static_offset != ir.ShapedType.get_dynamic_stride_or_offset(): + assert elem_bitwidth.bit_count() == 1 + packing = 8 // elem_bitwidth + if static_offset % packing != 0: + raise ValueError + offset_bytes = c(static_offset // packing, i64) + else: + offset_bits = llvm.mul( + offset_elems, + c(elem_bitwidth, i64), + overflow_flags=llvm.IntegerOverflowFlags.none, + ) + offset_bytes = llvm.udiv(offset_bits, c(8, i64)) else: assert elem_bitwidth % 8 == 0 offset_bytes = llvm.mul(