From af667199dbc1188d1ee0f70cf8922cdb7d64a701 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 16 Jan 2025 18:56:16 -0800 Subject: [PATCH] [sharding_in_types] Rename `.at[...].get(out_spec)` to `.at[...].get(out_sharding)`. PiperOrigin-RevId: 716466870 --- jax/_src/basearray.pyi | 2 +- jax/_src/lax/slicing.py | 4 ++-- jax/_src/numpy/array_methods.py | 10 ++++++---- tests/pjit_test.py | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index d5bd7043ad81..a368b593332d 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -280,7 +280,7 @@ class _IndexUpdateHelper: class _IndexUpdateRef: def get(self, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None, - out_spec: PartitionSpec | None = None) -> Array: ... + out_spec: Sharding | PartitionSpec | None = None) -> Array: ... def set(self, values: Any, indices_are_sorted: bool = False, unique_indices: bool = False, mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ... diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 37cb099d751a..9e7dd24c744c 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1886,8 +1886,8 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers, if mesh_lib.get_abstract_mesh()._are_all_axes_hidden: # type: ignore return None raise GatherShardingError( - "Use `.at[...].get(out_specs=)` to provide output PartitionSpec for the" - " gather indexing.") + "Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for" + " the gather indexing.") def _gather_fill(operand, indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, fill_value, diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index eda6e36fb612..894c0d865d23 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -43,6 +43,7 @@ from jax._src.numpy import lax_numpy from jax._src import mesh as mesh_lib from jax._src.pjit import hidden_mode, PartitionSpec +from jax._src.sharding_impls import canonicalize_sharding, NamedSharding from jax._src.numpy import reductions from jax._src.numpy import ufuncs from jax._src.ops import scatter @@ -765,7 +766,7 @@ def __repr__(self) -> str: return f"_IndexUpdateRef({self.array!r}, {self.index!r})" def get(self, *, indices_are_sorted=False, unique_indices=False, - mode=None, fill_value=None, out_spec=None): + mode=None, fill_value=None, out_sharding=None): """Equivalent to ``x[idx]``. Returns the value of ``x`` that would result from the NumPy-style @@ -779,10 +780,11 @@ def get(self, *, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode, fill_value=fill_value) - if out_spec is not None: - assert isinstance(out_spec, PartitionSpec) + if out_sharding is not None: + assert isinstance(out_sharding, (NamedSharding, PartitionSpec)) + out_sharding = canonicalize_sharding(out_sharding) take = hidden_mode(take, axes=mesh_lib.get_abstract_mesh().axis_names, # type: ignore - out_specs=out_spec) + out_specs=out_sharding.spec) return take(self.array, self.index) def set(self, values, *, indices_are_sorted=False, unique_indices=False, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 180dea92fae0..d7339d2dba48 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6117,7 +6117,7 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) @jtu.with_user_mesh((2, 2), ('x', 'y')) - def test_auto_gather_out_spec(self, mesh): + def test_auto_gather_out_sharding(self, mesh): embed = jax.device_put(jnp.arange(128 * 8.).reshape(64, 16), jax.NamedSharding(mesh, P(None, 'x'))) tok = jax.device_put(jnp.arange(8 * 4).reshape(8, 4), @@ -6125,7 +6125,7 @@ def test_auto_gather_out_spec(self, mesh): @jax.jit def f(embed_vd, token_bt): - out = embed_vd.at[token_bt].get(out_spec=P('x', None, None)) + out = embed_vd.at[token_bt].get(out_sharding=P('x', None, None)) self.assertEqual(out.shape, (8, 4, 16)) self.assertEqual(out.sharding.spec, P('x', None, None)) return out