Skip to content

Commit

Permalink
Use an PartitionSpec.UNCONSTRAINED to represent unconstrained dimensi…
Browse files Browse the repository at this point in the history
…ons in ParsedPartitionSpec, rather than None.

This makes PartitionSpec and ParsedPartitionSpec more similar, and fixes some TODOs.

PiperOrigin-RevId: 724927217
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Feb 9, 2025
1 parent 8401d9b commit cf308a8
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
16 changes: 12 additions & 4 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import (AUTO, NamedSharding,
modify_sdy_sharding_wrt_axis_types,
Expand Down Expand Up @@ -1062,20 +1063,27 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
assert isinstance(s, JSharding)
return s.memory_kind


def contains_unconstrained(s):
return isinstance(s, NamedSharding) and None in s._parsed_pspec
return (
isinstance(s, NamedSharding)
and PartitionSpec.UNCONSTRAINED in s._parsed_pspec
)


def all_unconstrained(s, aval):
if isinstance(s, NamedSharding):
if aval.ndim != len(s._parsed_pspec):
return False
return all(p is None for p in s._parsed_pspec)
return all(p is PartitionSpec.UNCONSTRAINED for p in s._parsed_pspec)
return False

def _get_unconstrained_dimensions(s, aval):
us = contains_unconstrained(s)
return (us, all_unconstrained(s, aval),
({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None))
return (
us, all_unconstrained(s, aval),
({i for i, p in enumerate(s._parsed_pspec)
if p is PartitionSpec.UNCONSTRAINED} if us else None))

def lower_jaxpr_to_module(
module_name: str,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/partition_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _normalized_spec_for_aval(self, ndim: int) -> PartitionSpec:
for p in self:
if p is None:
out.append(None)
elif isinstance(p, UnconstrainedSingleton):
elif p is _UNCONSTRAINED_PARTITION:
out.append(None)
elif isinstance(p, (list, tuple)):
if len(p) == 1:
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,7 +2025,8 @@ def _pjit_batcher_for_sharding(
if sharding_impls.is_op_sharding_replicated(hlo_s):
return s
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
parsed_pspec = s._parsed_pspec.insert_axis_partitions(dim, None)
parsed_pspec = s._parsed_pspec.insert_axis_partitions(
dim, PartitionSpec.UNCONSTRAINED)
return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec)
new_op = hlo_s.to_proto().clone()
tad = list(new_op.tile_assignment_dimensions)
Expand Down Expand Up @@ -2659,7 +2660,6 @@ def _sharding_constraint_batcher(
f"{sharding.spec}")
x, = vals_in
d, = dims_in
# None means unconstrained in ParsedPartitionSpec
unconstrained_dims = {ud + (d <= ud) for ud in unconstrained_dims}
if axis_data.spmd_name is None:
unconstrained_dims.add(d)
Expand Down Expand Up @@ -2887,7 +2887,7 @@ def use_explicit_axes(*axes):
def get_unconstrained_dims(sharding: NamedSharding):
assert sharding._parsed_pspec is not None
return {i for i, axes in enumerate(sharding._parsed_pspec)
if axes is None}
if axes is PartitionSpec.UNCONSTRAINED}


def _get_partition_spec(
Expand Down
19 changes: 9 additions & 10 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TransferToMemoryKind:
@util.cache(max_size=128, trace_context_in_key=False)
def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
for p in parsed_pspec:
if p is not None:
if p is not PartitionSpec.UNCONSTRAINED:
for r in p:
if r not in mesh.shape:
raise ValueError(
Expand All @@ -71,7 +71,7 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
@util.cache(max_size=128, trace_context_in_key=False)
def _check_axis_type_consistency(mesh, parsed_pspec):
for p in parsed_pspec:
if p is not None:
if p is not PartitionSpec.UNCONSTRAINED:
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
raise ValueError(
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
Expand Down Expand Up @@ -431,7 +431,7 @@ def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
dim_shardings = [SdyDimSharding(axes=[], is_closed=True)
for _ in range(num_dimensions)]
for i, dim_spec in enumerate(self._parsed_pspec):
if dim_spec is None:
if dim_spec is PartitionSpec.UNCONSTRAINED:
dim_shardings[i].is_closed = False
elif not dim_spec:
# Already empty and closed sharding.
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def get_array_mapping(
return axis_resources
return OrderedDict((axis, i)
for i, axes in enumerate(axis_resources)
if axes is not None for axis in axes)
if axes is not PartitionSpec.UNCONSTRAINED for axis in axes)


get_single_pspec = lambda p: array_mapping_to_axis_resources(
Expand All @@ -1090,12 +1090,11 @@ class ParsedPartitionSpec:
__slots__ = ('_user_spec', 'partitions')

_user_spec: PartitionSpec | None
partitions: tuple[tuple[MeshAxisName, ...] | None, ...]
partitions: tuple[tuple[MeshAxisName, ...] | UnconstrainedSingleton, ...]

def __init__(self, user_spec, partitions):
self._user_spec = user_spec
# None in partitions represents unconstrained dim.
# TODO(yashkatariya): May use a sentinel value.
assert None not in partitions, partitions
self.partitions = tuple(partitions)

def get_partition_spec(self) -> PartitionSpec:
Expand Down Expand Up @@ -1130,10 +1129,10 @@ def from_user_input(
axis_spec = ()
elif isinstance(axis_spec, (list, tuple)):
axis_spec = tuple(axis_spec)
elif isinstance(axis_spec, UnconstrainedSingleton):
elif axis_spec is PartitionSpec.UNCONSTRAINED:
if not allow_unconstrained_dims:
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
axis_spec = None
axis_spec = PartitionSpec.UNCONSTRAINED
else:
axis_spec = (axis_spec,)
axis_specs.append(axis_spec)
Expand Down Expand Up @@ -1204,7 +1203,7 @@ def _check_unique_resources(
resource_counts: dict[MeshAxisName, int] = {}
duplicate = False
for d in arg_axis_resources:
if d is not None:
if d is not PartitionSpec.UNCONSTRAINED:
for resource in d:
count = resource_counts.get(resource, 0)
if count > 0:
Expand Down

0 comments on commit cf308a8

Please sign in to comment.