diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index dd4940c9a9fb..d1f5e6b9ce98 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -19,12 +19,18 @@ class UnconstrainedSingleton: def __repr__(self): return "UNCONSTRAINED" + def __reduce__(self): + return (_get_default_unconstrained, ()) + # Unconstrained sentinel value for PartitionSpec, representing a dimension for # which the user wants XLA to assign the best partitioning. # TODO(yashkatariya): May rename to AUTO. _UNCONSTRAINED_PARTITION = UnconstrainedSingleton() +def _get_default_unconstrained(): + return _UNCONSTRAINED_PARTITION + class PartitionSpec(tuple): """Tuple describing how to partition an array across a mesh of devices.