diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 5b466e6046b6..c3017bf9d574 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -1079,13 +1079,19 @@ def memref_ptr(memref_arg, memory_space=None): elem_bitwidth = bitwidth(memref_ty.element_type) if elem_bitwidth < 8: *_, static_offset = memref_ty.get_strides_and_offset() - if static_offset == ir.ShapedType.get_dynamic_stride_or_offset(): - raise NotImplementedError - assert elem_bitwidth.bit_count() == 1 - packing = 8 // elem_bitwidth - if static_offset % packing != 0: - raise ValueError - offset_bytes = c(static_offset // packing, i64) + if static_offset != ir.ShapedType.get_dynamic_stride_or_offset(): + assert elem_bitwidth.bit_count() == 1 + packing = 8 // elem_bitwidth + if static_offset % packing != 0: + raise ValueError + offset_bytes = c(static_offset // packing, i64) + else: + offset_bits = llvm.mul( + offset_elems, + c(elem_bitwidth, i64), + overflow_flags=llvm.IntegerOverflowFlags.none, + ) + offset_bytes = llvm.udiv(offset_bits, c(8, i64)) else: assert elem_bitwidth % 8 == 0 offset_bytes = llvm.mul(