Skip to content

Commit

Permalink
[better_errors] Make it explicit that debug_info is not None.
Browse files Browse the repository at this point in the history
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See #26480 for more details.

PiperOrigin-RevId: 725512692
  • Loading branch information
gnecula authored and Google-ML-Automation committed Feb 13, 2025
1 parent 305e55f commit 8103ab6
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 106 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
This package may safely be removed if it is present on your machine; JAX now
uses `libtpu` instead.

* Deprecations
* The internal function `linear_util.wrap_init` and the constructor
`core.Jaxpr` now must take a non-empty `core.DebugInfo` kwarg. For
a limited time, a `DeprecationWarning` is printed if
`jax.extend.linear_util.wrap_init` is used without debugging info.
A downstream effect of this several other internal functions need debug
info. This change does not affect public APIs.
See https://github.com/jax-ml/jax/issues/26480 for more detail.

## jax 0.5.0 (Jan 17, 2025)

As of this release, JAX now uses
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,15 +620,15 @@ def fun_signature(fun: Callable) -> inspect.Signature | None:
return None

def save_wrapped_fun_sourceinfo(wrapper: Callable,
wrapped: Callable | core.DebugInfo | None) -> None:
wrapped: Callable | core.DebugInfo) -> None:
# Prefer this to functools.wraps because it does not create a reference to
# the wrapped function.
if isinstance(wrapped, core.DebugInfo):
func_src_info = wrapped.func_src_info
elif callable(wrapped):
func_src_info = fun_sourceinfo(wrapped)
else:
return
assert False, wrapped # Unreachable
setattr(wrapper, "__fun_sourceinfo__", func_src_info)

_fun_name_re = re.compile(r"(?:<built-in function (\S+)>)")
Expand Down Expand Up @@ -716,7 +716,7 @@ def register_class_with_attrs(t: type) -> None:
_class_with_attrs: set[type] = set()

# TODO(mattjj): make this function faster
def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
def _check_no_aliased_ref_args(dbg: core.DebugInfo, avals, args):
assert config.mutable_array_checks.value
refs: dict[int, int] = {}
for i, (a, x) in enumerate(zip(avals, args)):
Expand All @@ -730,7 +730,7 @@ def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args):
if dbg else
f"at both flat index {dup_idx} and flat index {i}") from None

def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None:
def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> None:
assert config.mutable_array_checks.value
refs: set[int] = {id(core.get_referent(c)) for c in consts
if isinstance(core.get_aval(c), AbstractRef)}
Expand Down
16 changes: 11 additions & 5 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Jaxpr:
_outvars: list[Atom]
_eqns: list[JaxprEqn]
_effects: Effects
_debug_info: DebugInfo | None
_debug_info: DebugInfo

@property
def constvars(self) -> list[Var]:
Expand All @@ -116,13 +116,17 @@ def effects(self) -> Effects:
return self._effects

@property
def debug_info(self) -> DebugInfo | None:
def debug_info(self) -> DebugInfo:
return self._debug_info

def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
effects: Effects = no_effects,
debug_info: DebugInfo | None = None):
# We want all calls to pass a DebugInfo object, but for backwards
# compatibility we have to allow calls when the debug_info
# is missing.
debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch]
):
"""
Args:
constvars: list of variables introduced for constants. Array constants are
Expand All @@ -133,14 +137,16 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
eqns: list of equations.
effects: set of effects. The effects on a jaxpr are a superset of the
union of the effects for each equation.
debug_info: optional DebugInfo.
debug_info: debugging information.
"""
self._constvars = list(constvars)
self._invars = list(invars)
self._outvars = list(outvars)
self._eqns = list(eqns)
self._effects = effects
self._debug_info = debug_info and debug_info.resolve_result_paths()
# TODO(https://github.com/jax-ml/jax/issues/26480)
debug_info = debug_info or lu._missing_debug_info("core.Jaxpr")
self._debug_info = debug_info.resolve_result_paths()
# TODO(necula): re-enable these safety checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
Expand Down
20 changes: 7 additions & 13 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,28 +686,22 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable

@lu.transformation2
def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int],
debug_info: core.DebugInfo | None, *args):
debug_info: core.DebugInfo, *args):
_check_for_aliased_refs(f, nondiff_argnums, debug_info, args)
out = f(*args)
_check_for_returned_refs(f, out, 'primal')
return out

def _check_for_aliased_refs(f: Callable,
nondiff_argnums: Sequence[int],
debug: core.DebugInfo | None,
debug: core.DebugInfo,
args):
leaves = tree_leaves(args)
refs: dict[int, int] = {}
for i, x in enumerate(leaves):
if (isinstance((a := core.get_aval(x)), AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
if debug is not None:
arg_names = debug.safe_arg_names(len(leaves))
else:
# TODO(necula): drop this branch
arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
if arg_names is None:
arg_names = [f'flat index {j}' for j in range(len(leaves))]
arg_names = debug.safe_arg_names(len(leaves))
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but custom_vjp function {f} got the same mutable "
Expand Down Expand Up @@ -763,8 +757,8 @@ def _check_for_tracers(x):
def _flatten_fwd(f: Callable, store: lu.EqualStore,
nondiff_argnums: Sequence[int],
symbolic_zeros: bool,
debug_primal: core.DebugInfo | None,
debug_fwd: core.DebugInfo | None,
debug_primal: core.DebugInfo,
debug_fwd: core.DebugInfo,
in_tree: PyTreeDef, maybe_out_type, *args):
primal_name = debug_primal.func_name if debug_primal else str(f)
fwd_name = debug_fwd.func_name if debug_fwd else "<unknown>"
Expand Down Expand Up @@ -1560,9 +1554,9 @@ def jvp(primals, tangents):
# simpler, but it would be worth revisiting this.
def optimize_remat_of_custom_vjp_fwd(
fun: Callable[..., ReturnValue],
debug_fun: core.DebugInfo | None,
debug_fun: core.DebugInfo,
fwd: Callable[..., tuple[ReturnValue, Any]],
debug_fwd: core.DebugInfo | None,
debug_fwd: core.DebugInfo,
nondiff_argnums: Sequence[int] = (),
symbolic_zeros: bool = False,
) -> Callable[..., tuple[ReturnValue, Any]]:
Expand Down
23 changes: 11 additions & 12 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents):
@lu.transformation_with_aux2
def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag,
nzs_in: Sequence[bool],
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
*primals, **params):
with core.take_current_trace() as parent_trace:
tangent_trace = pe.DynamicJaxprTrace(debug_info)
Expand Down Expand Up @@ -133,7 +133,7 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
return out_primals, out_tangents

def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
dbg = jaxpr.debug_info._replace(
arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars))
return core.Jaxpr(constvars=(),
invars=jaxpr.invars + jaxpr.constvars,
Expand Down Expand Up @@ -768,7 +768,7 @@ def linearize_from_jvp(jvp: Callable,
multiple_results: bool,
nonzeros: Sequence[bool],
user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
primals, params):
current_name_stack = source_info_util.current_name_stack()
with core.take_current_trace() as parent_trace:
Expand Down Expand Up @@ -1100,15 +1100,14 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
new_debug_info = jaxpr.jaxpr.debug_info
if new_debug_info is not None:
new_arg_names = tuple(_perm(primals_in, tangents_in,
jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
new_result_paths = tuple(_perm(primals_out, tangents_out,
jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
new_debug_info = new_debug_info._replace(
arg_names=new_arg_names,
result_paths=new_result_paths,
)
new_arg_names = tuple(_perm(primals_in, tangents_in,
jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars))))
new_result_paths = tuple(_perm(primals_out, tangents_out,
jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars))))
new_debug_info = new_debug_info._replace(
arg_names=new_arg_names,
result_paths=new_result_paths,
)
new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
new_invars, new_outvars, jaxpr.jaxpr.eqns,
jaxpr.jaxpr.effects,
Expand Down
47 changes: 23 additions & 24 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,8 @@ def _closed_call_param_updater(params, _, __):
return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ()))
call_param_updaters[core.closed_call_p] = _closed_call_param_updater

def abstract_eval_fun(fun: Callable, *avals, debug_info=None, **params):
def abstract_eval_fun(fun: Callable, *avals,
debug_info: core.DebugInfo, **params):
_, avals_out, _, () = trace_to_jaxpr_dynamic(
lu.wrap_init(fun, params, debug_info=debug_info), avals)
assert all(isinstance(aval, AbstractValue) for aval in avals_out)
Expand Down Expand Up @@ -582,7 +583,7 @@ def trace_to_subjaxpr_nounits(
f: Callable,
trace: JaxprTrace,
instantiate: Sequence[bool] | bool,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits(
Expand All @@ -595,7 +596,7 @@ def trace_to_subjaxpr_nounits(
def trace_to_subjaxpr_nounits2(
f: Callable,
tag: TraceTag,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert isinstance(tag, TraceTag)
Expand All @@ -612,7 +613,7 @@ def trace_to_subjaxpr_nounits2(
def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
instantiate: Sequence[bool] | bool,
in_pvals: Sequence[PartialVal],
debug_info: core.DebugInfo | None):
debug_info: core.DebugInfo):
in_knowns = [pval.is_known() for pval in in_pvals]
in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()]
in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()]
Expand All @@ -639,7 +640,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace,
def trace_to_subjaxpr_nounits_fwd(
f: Callable,
tag: TraceTag,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
Expand Down Expand Up @@ -669,7 +670,7 @@ def trace_to_subjaxpr_nounits_fwd(
def trace_to_subjaxpr_nounits_fwd2(
f: Callable,
tag: TraceTag,
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
instantiate: bool | Sequence[bool],
in_pvals: Sequence[PartialVal]):
assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals
Expand Down Expand Up @@ -752,13 +753,14 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
def tracers_to_jaxpr(
in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer],
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]:
"""Constructs Jaxpr given tracers for inputs and outputs.
Params:
in_tracers: the tracers that were created for the function inputs
out_tracers: the tracers that were output by the function.
debug_info: the debug info for the function.
Returns: a triple of a `Jaxpr`, a list of constant values corresponding to
the `constvars` in the returned Jaxps, and a list of environment values.
Expand Down Expand Up @@ -838,7 +840,7 @@ def type_substitute(aval: AbstractValue) -> AbstractValue:
def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
config.enable_checks.value and core.check_jaxpr(jaxpr)
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
dbg = jaxpr.debug_info._replace(
arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
Expand All @@ -854,7 +856,7 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
return jaxpr.replace() # 'return jaxpr' would create cache reference cycle
config.enable_checks.value and core.check_jaxpr(jaxpr)
constvars, invars = split_list(jaxpr.invars, [n])
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
dbg = jaxpr.debug_info._replace(
arg_names=jaxpr.debug_info.arg_names[n:])
lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars,
debug_info=dbg)
Expand All @@ -868,7 +870,7 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars,
invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns,
effects=jaxpr.effects)
effects=jaxpr.effects, debug_info=jaxpr.debug_info)
config.enable_checks.value and core.check_jaxpr(converted_jaxpr)
return converted_jaxpr

Expand Down Expand Up @@ -1363,7 +1365,7 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr:

def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr:
outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b]
dbg = jaxpr.debug_info and core.DebugInfo(
dbg = core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.arg_names,
jaxpr.debug_info.filter_result_paths(used_outputs))
Expand Down Expand Up @@ -1451,7 +1453,7 @@ def write(x: Atom, b: bool) -> None:
eqns = new_eqns[::-1]
jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns)

dbg = jaxpr.debug_info and core.DebugInfo(
dbg = core.DebugInfo(
jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info,
jaxpr.debug_info.filter_arg_names(used_inputs),
jaxpr.debug_info.filter_result_paths(used_outputs))
Expand Down Expand Up @@ -1653,9 +1655,9 @@ class JaxprStackFrame:
attrs_tracked: list[tuple[Any, str]]
attrs_inits: list
attrs_vars: list[Var]
debug_info: core.DebugInfo | None
debug_info: core.DebugInfo

def __init__(self, debug_info: core.DebugInfo | None):
def __init__(self, debug_info: core.DebugInfo):
self.gensym = core.gensym()
self.tracer_to_var = {}
self.constid_to_tracer = {}
Expand All @@ -1674,7 +1676,7 @@ def add_eqn(self, eqn: core.JaxprEqn):

def to_jaxpr(self, trace: DynamicJaxprTrace,
out_tracers: Sequence[Tracer],
debug_info: core.DebugInfo | None,
debug_info: core.DebugInfo,
) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]:
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
Expand All @@ -1696,7 +1698,7 @@ def to_jaxpr(self, trace: DynamicJaxprTrace,
return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked)

def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
debug_info: core.DebugInfo | None):
debug_info: core.DebugInfo):
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
constvars, constvals = unzip2(self.constvar_to_val.items())
Expand Down Expand Up @@ -1843,7 +1845,7 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:
class DynamicJaxprTrace(core.Trace):
__slots__ = ("frame",)

def __init__(self, debug_info: core.DebugInfo | None):
def __init__(self, debug_info: core.DebugInfo):
self.frame = JaxprStackFrame(debug_info)

def invalidate(self):
Expand Down Expand Up @@ -2117,7 +2119,7 @@ def transpose_jaxpr_thunk():
return out_tracers

def to_jaxpr(self, out_tracers: Sequence[Tracer],
debug_info: core.DebugInfo | None):
debug_info: core.DebugInfo):
return self.frame.to_jaxpr(self, out_tracers, debug_info)


Expand Down Expand Up @@ -2180,17 +2182,13 @@ def trace_to_jaxpr_dynamic(
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked

def _check_no_returned_refs(
dbg: core.DebugInfo | None,
dbg: core.DebugInfo,
out_tracers: Sequence[DynamicJaxprTracer]
) -> None:
if not config.mutable_array_checks.value: return
for i, t in enumerate(out_tracers):
a = t.aval
if isinstance(a, AbstractRef):
if dbg is None:
raise ValueError(
f"function returned a mutable array reference of type {a.str_short()}, "
"but mutable array references cannot be returned.")
result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers))
loc = result_paths[i] and f' at output tree path {result_paths[i]}'
frame = t._trace.frame
Expand Down Expand Up @@ -2469,7 +2467,8 @@ def substitute(aval: AbstractValue) -> AbstractValue:
return aval

in_avals = [substitute(v.aval) for v in jaxpr.invars]
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts))
eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts),
debug_info=jaxpr.debug_info)
padded_jaxpr, _, padded_consts, () = trace_to_jaxpr_dynamic(eval_padded, in_avals)
return padded_jaxpr, padded_consts

Expand Down
Loading

0 comments on commit 8103ab6

Please sign in to comment.