Skip to content

Commit

Permalink
Relax static offset restriction on memref_ptr for sub-byte types
Browse files Browse the repository at this point in the history
It simply assumes that the base offset is a multiple of byte packing

PiperOrigin-RevId: 720919148
  • Loading branch information
apaszke authored and Google-ML-Automation committed Jan 29, 2025
1 parent 9d39ab3 commit c9dfdb4
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions jax/experimental/mosaic/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,13 +1079,19 @@ def memref_ptr(memref_arg, memory_space=None):
elem_bitwidth = bitwidth(memref_ty.element_type)
if elem_bitwidth < 8:
*_, static_offset = memref_ty.get_strides_and_offset()
if static_offset == ir.ShapedType.get_dynamic_stride_or_offset():
raise NotImplementedError
assert elem_bitwidth.bit_count() == 1
packing = 8 // elem_bitwidth
if static_offset % packing != 0:
raise ValueError
offset_bytes = c(static_offset // packing, i64)
if static_offset != ir.ShapedType.get_dynamic_stride_or_offset():
assert elem_bitwidth.bit_count() == 1
packing = 8 // elem_bitwidth
if static_offset % packing != 0:
raise ValueError
offset_bytes = c(static_offset // packing, i64)
else:
offset_bits = llvm.mul(
offset_elems,
c(elem_bitwidth, i64),
overflow_flags=llvm.IntegerOverflowFlags.none,
)
offset_bytes = llvm.udiv(offset_bits, c(8, i64))
else:
assert elem_bitwidth % 8 == 0
offset_bytes = llvm.mul(
Expand Down

0 comments on commit c9dfdb4

Please sign in to comment.