Skip to content

Commit

Permalink
[sharding_in_types] Error out if the sharding's specs passed to with_…
Browse files Browse the repository at this point in the history
…sharding_constraint don't refer to Auto axes.

PiperOrigin-RevId: 725679220
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Feb 11, 2025
1 parent 7000747 commit 005c14b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
19 changes: 19 additions & 0 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2521,6 +2521,23 @@ def _pjit_state_discharge_rule(

# -------------------- with_sharding_constraint --------------------

def check_shardings_are_auto(shardings_flat):
if not config.sharding_in_types.value:
return

for s in shardings_flat:
if not isinstance(s, NamedSharding):
continue
mesh = s.mesh.abstract_mesh
if not all(mesh._name_to_type[i] == mesh_lib.AxisTypes.Auto
for axes in s._parsed_pspec
if axes is not PartitionSpec.UNCONSTRAINED for i in axes):
raise ValueError(
'The spec of NamedSharding passed to with_sharding_constraint can'
f' only refer to Auto axes of the mesh. Got spec={s.spec} and'
f' mesh={mesh}')


def with_sharding_constraint(x, shardings):
"""Mechanism to constrain the sharding of an Array inside a jitted computation
Expand Down Expand Up @@ -2575,6 +2592,8 @@ def with_sharding_constraint(x, shardings):
shardings_flat, x_flat, None, "with_sharding_constraint arguments",
allow_uneven_sharding=True)

check_shardings_are_auto(shardings_flat)

check_aval_layout_compatibility(user_layouts_flat, x_flat, None,
"with_sharding_constraint arguments")

Expand Down
26 changes: 26 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6549,6 +6549,32 @@ def f():
shmap_f() # doesn't crash
jax.jit(shmap_f)() # doesn't crash

@jtu.with_user_mesh((2, 1), ('x', 'y'))
def test_wsc_error(self, mesh):
s = NamedSharding(mesh, P('x'))
with self.assertRaisesRegex(
ValueError,
"The spec of NamedSharding passed to with_sharding_constraint"):
jax.lax.with_sharding_constraint(np.arange(8), s)

s = NamedSharding(mesh, P(('x', 'y'), None))
with self.assertRaisesRegex(
ValueError,
"The spec of NamedSharding passed to with_sharding_constraint"):
jax.lax.with_sharding_constraint(np.arange(8).reshape(4, 2), s)

s = NamedSharding(mesh, P())
jax.lax.with_sharding_constraint(np.arange(8), s)

s = NamedSharding(mesh, P(P.UNCONSTRAINED, 'x'))
with self.assertRaisesRegex(
ValueError,
"The spec of NamedSharding passed to with_sharding_constraint"):
jax.lax.with_sharding_constraint(np.arange(8).reshape(4, 2), s)

s = NamedSharding(mesh, P(P.UNCONSTRAINED))
jax.lax.with_sharding_constraint(np.arange(8), s)


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 005c14b

Please sign in to comment.