Skip to content

Commit

Permalink
[JAX] Update users of jax.tree.map() to be more careful about how the…
Browse files Browse the repository at this point in the history
…y handle Nones.

Due to a bug in JAX, JAX previously permitted `jax.tree.map(f, None, x)` where `x` is not `None`, effectively treating `None` as if it were pytree-prefix of any value. But `None` is a pytree container, and it is only a prefix of `None` itself.

Fix user code that was relying on this bug. Most commonly, the fix is to write
`jax.tree.map(lambda a, b: (None if a is None else f(a, b)), x, y, is_leaf=lambda t: t is None)`.

PiperOrigin-RevId: 642481961
  • Loading branch information
hawkinsp authored and pax authors committed Jun 12, 2024
1 parent 4b1899e commit fe19dcb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion paxml/tasks_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _replace_bprop_masked(x, from_mdl_vars):
_replace_bprop_masked,
extracted,
model_states.mdl_vars,
is_leaf=py_utils.is_bprop_masked_node,
is_leaf=lambda x: x is None or py_utils.is_bprop_masked_node(x),
)
return TrainState(
step=model_states.step,
Expand Down

0 comments on commit fe19dcb

Please sign in to comment.