Skip to content

Commit

Permalink
Set the mesh as manual during partial_eval_custom in shard_map so tha…
Browse files Browse the repository at this point in the history
…t `_add_reshapes` happens under the correct mesh.

PiperOrigin-RevId: 723268798
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Feb 5, 2025
1 parent 02f4531 commit 307006e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,8 @@ def __exit__(self, exc_type, exc_value, traceback):

@staticmethod
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
jax_config.abstract_mesh_context_manager.set_local(mesh)
return
prev = jax_config.abstract_mesh_context_manager.swap_local(mesh)
return prev


# Create this indirection because pytype fails to recognize a property if a
Expand Down
4 changes: 3 additions & 1 deletion jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,7 +1719,9 @@ def _partial_eval_jaxpr_custom_rule(
idx_map = {id(v): i for i, v in enumerate(out_vars)}
out_fwd = [idx_map.get(id(v)) for v in res_vars]
which = [f1 is None and f2 is None for f1, f2 in zip(in_fwd, out_fwd)]
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
mesh = eqn.params['mesh']
with (core.extend_axis_env_nd(mesh.shape.items()),
set_abstract_mesh(_as_manual_mesh(mesh))):
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
Expand Down

0 comments on commit 307006e

Please sign in to comment.