Skip to content

Commit

Permalink
[better_errors] Continue adding debug info to Jaxprs (step 8)
Browse files Browse the repository at this point in the history
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
  • Loading branch information
gnecula committed Feb 12, 2025
1 parent 550d1aa commit 341313d
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 17 deletions.
2 changes: 1 addition & 1 deletion jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def __call__(self, *args, **kwargs):
args = tuple(x if i in static_argnums else x for i, x in enumerate(args))
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
f_, dyn_args = api_util.argnums_partial(
lu.wrap_init(self.fun),
lu.wrap_init(self.fun, debug_info=debug),
dyn_argnums,
args,
require_static_args_hashable=False,
Expand Down
31 changes: 20 additions & 11 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
add_jaxvals, replace_internal_symbolic_zeros,
replace_rule_output_symbolic_zeros, Zero, zeros_like_aval, SymbolicZero)
from jax._src.ad_util import zeros_like_p, add_jaxvals_p # noqa: F401
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs, debug_info
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
from jax._src.dtypes import dtype, float0
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
Expand Down Expand Up @@ -632,7 +632,9 @@ def process_primitive(self, primitive, args, params):
else:
return maybe_linearize_tracer(self, primal_out, tangent_nzs_out, tangent_out)

def process_custom_jvp_call(self, prim, fun, f_jvp, tracers, *, symbolic_zeros):
def process_custom_jvp_call(self, prim, fun: lu.WrappedFun,
f_jvp: lu.WrappedFun, tracers, *,
symbolic_zeros: bool):
primals_in, tangents_in = unzip2(map(self.to_primal_tangent_pair, tracers))
if all(type(t) is Zero for t in tangents_in):
return prim.bind_with_trace(self.parent_trace, (fun, f_jvp, *primals_in),
Expand All @@ -647,7 +649,8 @@ def _f_jvp(primals, tangents):
instantiate_zeros = not symbolic_zeros
nonzeros_in = [type(t) is not Zero for t in tangents_in]
primals_out, tangent_nzs_out, residuals, linearized = linearize_from_jvp(
_f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros, primals_in, {})
_f_jvp, True, nonzeros_in, symbolic_zeros, instantiate_zeros,
f_jvp.debug_info, primals_in, {})

with core.set_current_trace(self.tangent_trace):
tangents_out = linearized(residuals, *tangents_in)
Expand Down Expand Up @@ -751,15 +754,22 @@ def maybe_linearize_tracer(trace, primal, is_nonzero, tangent):
assert type(tangent) is Zero
return primal

def fallback_linearize_rule(_prim, _nonzeros, *primals, **params):
def fallback_linearize_rule(_prim: core.Primitive,
_nonzeros: Sequence[bool], *primals, **params):
jvp = primitive_jvps.get(_prim)
if not jvp:
msg = f"Differentiation rule for '{_prim}' not implemented"
raise NotImplementedError(msg)
return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False, primals, params)

def linearize_from_jvp(jvp, multiple_results, nonzeros,
user_facing_symbolic_zeros, instantiate_input_zeros, primals, params):
debug_jvp = debug_info("linearize_prim_jvp", jvp, primals, params)
return linearize_from_jvp(jvp, _prim.multiple_results, _nonzeros, False, False,
debug_jvp, primals, params)

def linearize_from_jvp(jvp: Callable,
multiple_results: bool,
nonzeros: Sequence[bool],
user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
debug_info: core.DebugInfo | None,
primals, params):
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
trace = pe.JaxprTrace(parent_trace, current_name_stack, core.TraceTag())
Expand All @@ -776,7 +786,7 @@ def make_zero(aval):
if user_facing_symbolic_zeros:
zero_type = SymbolicZero
else:
zero_type = Zero
zero_type = Zero # type: ignore[assignment]

tangent_args = tuple(trace.new_arg(pe.PartialVal.unknown(aval)) if nz else make_zero(aval)
for aval, nz in zip(tangent_avals, nonzeros))
Expand All @@ -795,8 +805,7 @@ def make_zero(aval):
out_nz_tracers = [trace.to_jaxpr_tracer(r)
for (r, nz) in zip(out_tangents, out_nzs) if nz]
in_tracers = [t for t, nz in zip(tangent_args, nonzeros) if nz]
# TODO(necula): pass debug_info here
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, None)
jaxpr, out_consts, _ = pe.tracers_to_jaxpr(in_tracers, out_nz_tracers, debug_info)

def linearized(residuals, *tangents):
nz_tangents_in = [t for (t, nz) in zip(tangents, nonzeros) if nz]
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -3127,7 +3127,10 @@ def call(self, *args):
if self._all_args_info is None:
kept_args = args_after_dce
ref_avals = self.in_avals
debug_info = None
# TODO(necula): ensure we have actual debug info
debug_info = core.DebugInfo(
"MeshExecutable", "<unknown>",
tuple(f"args[{i}]" for i in range(len(kept_args))), ())
else:
kept_args = args
ref_avals = self._all_args_info.in_avals
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,7 +1626,8 @@ def f_tangent(*args):

nz_tangents_in = [t for (t, nz) in zip(tangents, nzs_in) if nz]
nz_tangents_out = shard_map_p.bind_with_trace(trace.tangent_trace,
(lu.wrap_init(f_tangent), *residuals, *nz_tangents_in), tangent_params)
(lu.wrap_init(f_tangent, debug_info=lin_jaxpr.debug_info),
*residuals, *nz_tangents_in), tangent_params)
nz_tangents_out_iter = iter(nz_tangents_out)
tangents_out = [next(nz_tangents_out_iter) if nz else ad.Zero.from_primal_value(primal)
for nz, primal in zip(nzs_out, primals_out)]
Expand Down
15 changes: 12 additions & 3 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import math
import re
import sys
from typing import Callable
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import api_util
from jax import lax
from jax._src import checkify
from jax._src import state
Expand Down Expand Up @@ -61,6 +63,13 @@ def string_stdout():
sys.stdout = initial_stdout


def wrap_init(f: Callable, nr_args: int):
# wrapper for lu.wrap_init with debugging info
return lu.wrap_init(
f,
debug_info=api_util.debug_info("state_test", f, (0,) * nr_args, {}))


class PallasBaseTest(jtu.JaxTestCase):
INTERPRET: bool = False

Expand Down Expand Up @@ -793,7 +802,7 @@ def body(temp_ref):
return []

jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(kernel),
wrap_init(kernel, 2),
[
state.shaped_array_ref((8,), jnp.float32),
state.shaped_array_ref((8,), jnp.float32),
Expand Down Expand Up @@ -923,7 +932,7 @@ def scope():
aref1 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
aref2 = state.AbstractRef(jax.core.ShapedArray((4,), jnp.dtype('float32')))
in_avals = [aref1, aref2]
stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(lu.wrap_init(f),
stateful_jaxpr, _, (), () = pe.trace_to_jaxpr_dynamic(wrap_init(f, 2),
in_avals)
discharged_jaxpr, _ = state_discharge.discharge_state(
stateful_jaxpr, consts=(), should_discharge=[False, True])
Expand Down Expand Up @@ -2509,7 +2518,7 @@ def inner(x_ref, sem):
return []

jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(body), [state.shaped_array_ref((2, 8, 128), jnp.int32),
wrap_init(body, 2), [state.shaped_array_ref((2, 8, 128), jnp.int32),
jax.core.ShapedArray((), jnp.int32)]
)
self.assertIn(expected, jaxpr.pretty_print(use_color=False))
Expand Down

0 comments on commit 341313d

Please sign in to comment.