Skip to content

Commit

Permalink
Add a custom __reduce__ for UnconstrainedSingleton because it can…
Browse files Browse the repository at this point in the history
… be picked and then loaded back and we need the `id` of `P.UNCONSTRAINED` to match before and after loading.

PiperOrigin-RevId: 725874879
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Feb 12, 2025
1 parent bba0913 commit 675be01
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions jax/_src/partition_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,18 @@ class UnconstrainedSingleton:
def __repr__(self):
return "UNCONSTRAINED"

def __reduce__(self):
return (_get_default_unconstrained, ())


# Unconstrained sentinel value for PartitionSpec, representing a dimension for
# which the user wants XLA to assign the best partitioning.
# TODO(yashkatariya): May rename to AUTO.
_UNCONSTRAINED_PARTITION = UnconstrainedSingleton()

def _get_default_unconstrained():
return _UNCONSTRAINED_PARTITION


class PartitionSpec(tuple):
"""Tuple describing how to partition an array across a mesh of devices.
Expand Down

0 comments on commit 675be01

Please sign in to comment.