diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index f70c25a96cf0..9fb4a8f1111c 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -476,7 +476,7 @@ def __call__(self, *args, **kwargs): args = tuple(x if i in static_argnums else x for i, x in enumerate(args)) dyn_argnums = [i for i in range(len(args)) if i not in static_argnums] f_, dyn_args = api_util.argnums_partial( - lu.wrap_init(self.fun), + lu.wrap_init(self.fun, debug_info=debug), dyn_argnums, args, require_static_args_hashable=False, diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 4aa77c861fe8..ac3c03933795 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -33,7 +33,7 @@ add_jaxvals, replace_internal_symbolic_zeros, replace_rule_output_symbolic_zeros, Zero, zeros_like_aval, SymbolicZero) from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401 -from jax._src.api_util import flatten_fun, flatten_fun_nokwargs +from jax._src.api_util import flatten_fun, flatten_fun_nokwargs, debug_info from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal) from jax._src.dtypes import dtype, float0 from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name, @@ -632,7 +632,9 @@ def process_primitive(self, primitive, args, params): else: return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out) - def process_custom_jvp_call(self, prim, fun, f_jvp, tracers, *, symbolic_zeros): + def process_custom_jvp_call(self, prim, fun: lu.WrappedFun, + f_jvp: lu.WrappedFun, tracers, *, + symbolic_zeros: bool): primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers)) if all(type(t) is Zero for t in tangents_in): return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in), @@ -647,7 +649,8 @@ def _f_jvp(primals, tangents): instantiate_zeros = not symbolic_zeros nonzeros_in = [type(t) is not Zero for t in tangents_in] primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp( - _f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros, primals_in, {}) + _f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros, + f_jvp.debug_info, primals_in, {}) with core.set_current_trace(self.tangent_trace): tangents_out = linearized(residuals, *tangents_in) @@ -751,15 +754,22 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent): assert type(tangent) is Zero return primal -def fallback_linearize_rule(_prim, _nonzeros, *primals, **params): +def fallback_linearize_rule(_prim: core.Primitive, + _nonzeros: Sequence[bool], *primals, **params): jvp = primitive_jvps.get(_prim) if not jvp: msg = f"Differentiation rule for '{_prim}' not implemented" raise NotImplementedError(msg) - return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False, primals, params) - -def linearize_from_jvp(jvp, multiple_results, nonzeros, - user_facing_symbolic_zeros, instantiate_input_zeros, primals, params): + debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params) + return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False, + debug_jvp, primals, params) + +def linearize_from_jvp(jvp: Callable, + multiple_results: bool, + nonzeros: Sequence[bool], + user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool, + debug_info: core.DebugInfo | None, + primals, params): current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag()) @@ -776,7 +786,7 @@ def make_zero(aval): if user_facing_symbolic_zeros: zero_type = SymbolicZero else: - zero_type = Zero + zero_type = Zero # type: ignore[assignment] tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval) for aval, nz in zip(tangent_avals, nonzeros)) @@ -795,8 +805,7 @@ def make_zero(aval): out_nz_tracers = [trace.to_jaxpr_tracer(r) for (r, nz) in zip(out_tangents, out_nzs) if nz] in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz] - # TODO(necula): pass debug_info here - jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, None) + jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, debug_info) def linearized(residuals, *tangents): nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz] diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f8fe90e81292..8121f9edbe64 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3127,7 +3127,10 @@ def call(self, *args): if self._all_args_info is None: kept_args = args_after_dce ref_avals = self.in_avals - debug_info = None + # TODO(necula): ensure we have actual debug info + debug_info = core.DebugInfo( + "MeshExecutable", "", + tuple(f"args[{i}]" for i in range(len(kept_args))), ()) else: kept_args = args ref_avals = self._all_args_info.in_avals diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index f873394e782f..3247e7ba7bac 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1626,7 +1626,8 @@ def f_tangent(*args): nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz] nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace, - (lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), tangent_params) + (lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info), + *residuals, *nz_tangents_in), tangent_params) nz_tangents_out_iter = iter(nz_tangents_out) tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal) for nz, primal in zip(nzs_out, primals_out)] diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index c1fb076c5a09..4ff6208bdd92 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -21,9 +21,11 @@ import math import re import sys +from typing import Callable from absl.testing import absltest from absl.testing import parameterized import jax +from jax import api_util from jax import lax from jax._src import checkify from jax._src import state @@ -61,6 +63,13 @@ def string_stdout(): sys.stdout = initial_stdout +def wrap_init(f: Callable, nr_args: int): + # wrapper for lu.wrap_init with debugging info + return lu.wrap_init( + f, + debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {})) + + class PallasBaseTest(jtu.JaxTestCase): INTERPRET: bool = False @@ -793,7 +802,7 @@ def body(temp_ref): return [] jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(kernel), + wrap_init(kernel, 2), [ state.shaped_array_ref((8,), jnp.float32), state.shaped_array_ref((8,), jnp.float32), @@ -923,7 +932,7 @@ def scope(): aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32'))) in_avals = [aref1, aref2] - stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f), + stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2), in_avals) discharged_jaxpr, _ = state_discharge.discharge_state( stateful_jaxpr, consts=(), should_discharge=[False, True]) @@ -2509,7 +2518,7 @@ def inner(x_ref, sem): return [] jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(body), [state.shaped_array_ref((2, 8, 128), jnp.int32), + wrap_init(body, 2), [state.shaped_array_ref((2, 8, 128), jnp.int32), jax.core.ShapedArray((), jnp.int32)] ) self.assertIn(expected, jaxpr.pretty_print(use_color=False))