From 8103ab6091740fc107ca91ee4d871f884827b9f4 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 11 Feb 2025 00:30:12 -0800 Subject: [PATCH] [better_errors] Make it explicit that debug_info is not None. Now all internal uses of lu.wrap_init and core.Jaxpr are with actual debug info. This enables us to clean up the type declarations and to remove the checks whether debug_info is present. For usage outside of the JAX internals, we change `jax.extend.linear_util.wrap_init` to be usable without debug_info, for temporary backwards compatibility. We emit a deprecation warning and fill-in some fake debugging info. See https://github.com/jax-ml/jax/issues/26480 for more details. PiperOrigin-RevId: 725512692 --- CHANGELOG.md | 9 +++++ jax/_src/api_util.py | 8 ++--- jax/_src/core.py | 16 ++++++--- jax/_src/custom_derivatives.py | 20 ++++-------- jax/_src/interpreters/ad.py | 23 +++++++------ jax/_src/interpreters/partial_eval.py | 47 +++++++++++++-------------- jax/_src/interpreters/pxla.py | 25 ++++++-------- jax/_src/lax/control_flow/for_loop.py | 2 +- jax/_src/lax/control_flow/solves.py | 2 +- jax/_src/linear_util.py | 18 +++++----- jax/_src/pjit.py | 29 +++++++---------- jax/extend/linear_util.py | 11 ++++++- tests/debug_info_test.py | 4 +-- tests/extend_test.py | 3 +- 14 files changed, 111 insertions(+), 106 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b29ad3f1a6cb..11f93483d835 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,15 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. This package may safely be removed if it is present on your machine; JAX now uses `libtpu` instead. +* Deprecations + * The internal function `linear_util.wrap_init` and the constructor + `core.Jaxpr` now must take a non-empty `core.DebugInfo` kwarg. For + a limited time, a `DeprecationWarning` is printed if + `jax.extend.linear_util.wrap_init` is used without debugging info. + A downstream effect of this several other internal functions need debug + info. This change does not affect public APIs. + See https://github.com/jax-ml/jax/issues/26480 for more detail. + ## jax 0.5.0 (Jan 17, 2025) As of this release, JAX now uses diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 5cc31f53bfa0..a597e8b5b910 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -620,7 +620,7 @@ def fun_signature(fun: Callable) -> inspect.Signature | None: return None def save_wrapped_fun_sourceinfo(wrapper: Callable, - wrapped: Callable | core.DebugInfo | None) -> None: + wrapped: Callable | core.DebugInfo) -> None: # Prefer this to functools.wraps because it does not create a reference to # the wrapped function. if isinstance(wrapped, core.DebugInfo): @@ -628,7 +628,7 @@ def save_wrapped_fun_sourceinfo(wrapper: Callable, elif callable(wrapped): func_src_info = fun_sourceinfo(wrapped) else: - return + assert False, wrapped # Unreachable setattr(wrapper, "__fun_sourceinfo__", func_src_info) _fun_name_re = re.compile(r"(?:)") @@ -716,7 +716,7 @@ def register_class_with_attrs(t: type) -> None: _class_with_attrs: set[type] = set() # TODO(mattjj): make this function faster -def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args): +def _check_no_aliased_ref_args(dbg: core.DebugInfo, avals, args): assert config.mutable_array_checks.value refs: dict[int, int] = {} for i, (a, x) in enumerate(zip(avals, args)): @@ -730,7 +730,7 @@ def _check_no_aliased_ref_args(dbg: core.DebugInfo | None, avals, args): if dbg else f"at both flat index {dup_idx} and flat index {i}") from None -def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo | None, consts, args) -> None: +def _check_no_aliased_closed_over_refs(dbg: core.DebugInfo, consts, args) -> None: assert config.mutable_array_checks.value refs: set[int] = {id(core.get_referent(c)) for c in consts if isinstance(core.get_aval(c), AbstractRef)} diff --git a/jax/_src/core.py b/jax/_src/core.py index 90020dcccfca..21b4b48fb86e 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -93,7 +93,7 @@ class Jaxpr: _outvars: list[Atom] _eqns: list[JaxprEqn] _effects: Effects - _debug_info: DebugInfo | None + _debug_info: DebugInfo @property def constvars(self) -> list[Var]: @@ -116,13 +116,17 @@ def effects(self) -> Effects: return self._effects @property - def debug_info(self) -> DebugInfo | None: + def debug_info(self) -> DebugInfo: return self._debug_info def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], outvars: Sequence[Atom], eqns: Sequence[JaxprEqn], effects: Effects = no_effects, - debug_info: DebugInfo | None = None): + # We want all calls to pass a DebugInfo object, but for backwards + # compatibility we have to allow calls when the debug_info + # is missing. + debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch] + ): """ Args: constvars: list of variables introduced for constants. Array constants are @@ -133,14 +137,16 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var], eqns: list of equations. effects: set of effects. The effects on a jaxpr are a superset of the union of the effects for each equation. - debug_info: optional DebugInfo. + debug_info: debugging information. """ self._constvars = list(constvars) self._invars = list(invars) self._outvars = list(outvars) self._eqns = list(eqns) self._effects = effects - self._debug_info = debug_info and debug_info.resolve_result_paths() + # TODO(https://github.com/jax-ml/jax/issues/26480) + debug_info = debug_info or lu._missing_debug_info("core.Jaxpr") + self._debug_info = debug_info.resolve_result_paths() # TODO(necula): re-enable these safety checks # assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars) # assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 2c6c41ff59c6..f410a3d17ee5 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -686,7 +686,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable @lu.transformation2 def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int], - debug_info: core.DebugInfo | None, *args): + debug_info: core.DebugInfo, *args): _check_for_aliased_refs(f, nondiff_argnums, debug_info, args) out = f(*args) _check_for_returned_refs(f, out, 'primal') @@ -694,20 +694,14 @@ def _check_primal_refs(f: Callable, nondiff_argnums: Sequence[int], def _check_for_aliased_refs(f: Callable, nondiff_argnums: Sequence[int], - debug: core.DebugInfo | None, + debug: core.DebugInfo, args): leaves = tree_leaves(args) refs: dict[int, int] = {} for i, x in enumerate(leaves): if (isinstance((a := core.get_aval(x)), AbstractRef) and (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): - if debug is not None: - arg_names = debug.safe_arg_names(len(leaves)) - else: - # TODO(necula): drop this branch - arg_names = _non_static_arg_names(fun_signature(f), args, {}, nondiff_argnums, ()) - if arg_names is None: - arg_names = [f'flat index {j}' for j in range(len(leaves))] + arg_names = debug.safe_arg_names(len(leaves)) raise ValueError( "only one reference to a mutable array may be passed as an argument " f"to a function, but custom_vjp function {f} got the same mutable " @@ -763,8 +757,8 @@ def _check_for_tracers(x): def _flatten_fwd(f: Callable, store: lu.EqualStore, nondiff_argnums: Sequence[int], symbolic_zeros: bool, - debug_primal: core.DebugInfo | None, - debug_fwd: core.DebugInfo | None, + debug_primal: core.DebugInfo, + debug_fwd: core.DebugInfo, in_tree: PyTreeDef, maybe_out_type, *args): primal_name = debug_primal.func_name if debug_primal else str(f) fwd_name = debug_fwd.func_name if debug_fwd else "" @@ -1560,9 +1554,9 @@ def jvp(primals, tangents): # simpler, but it would be worth revisiting this. def optimize_remat_of_custom_vjp_fwd( fun: Callable[..., ReturnValue], - debug_fun: core.DebugInfo | None, + debug_fun: core.DebugInfo, fwd: Callable[..., tuple[ReturnValue, Any]], - debug_fwd: core.DebugInfo | None, + debug_fwd: core.DebugInfo, nondiff_argnums: Sequence[int] = (), symbolic_zeros: bool = False, ) -> Callable[..., tuple[ReturnValue, Any]]: diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ac3c03933795..928a02cc8d8e 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -86,7 +86,7 @@ def jvpfun(f: Callable, instantiate, transform_stack, primals, tangents): @lu.transformation_with_aux2 def linearize_subtrace(_f: Callable, _store: lu.Store, _tag: core.TraceTag, nzs_in: Sequence[bool], - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, *primals, **params): with core.take_current_trace() as parent_trace: tangent_trace = pe.DynamicJaxprTrace(debug_info) @@ -133,7 +133,7 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents): return out_primals, out_tangents def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr: - dbg = jaxpr.debug_info and jaxpr.debug_info._replace( + dbg = jaxpr.debug_info._replace( arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars)) return core.Jaxpr(constvars=(), invars=jaxpr.invars + jaxpr.constvars, @@ -768,7 +768,7 @@ def linearize_from_jvp(jvp: Callable, multiple_results: bool, nonzeros: Sequence[bool], user_facing_symbolic_zeros: bool, instantiate_input_zeros: bool, - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, primals, params): current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: @@ -1100,15 +1100,14 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_ new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars) new_debug_info = jaxpr.jaxpr.debug_info - if new_debug_info is not None: - new_arg_names = tuple(_perm(primals_in, tangents_in, - jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars)))) - new_result_paths = tuple(_perm(primals_out, tangents_out, - jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars)))) - new_debug_info = new_debug_info._replace( - arg_names=new_arg_names, - result_paths=new_result_paths, - ) + new_arg_names = tuple(_perm(primals_in, tangents_in, + jaxpr.jaxpr.debug_info.safe_arg_names(len(jaxpr.jaxpr.invars)))) + new_result_paths = tuple(_perm(primals_out, tangents_out, + jaxpr.jaxpr.debug_info.safe_result_paths(len(jaxpr.jaxpr.outvars)))) + new_debug_info = new_debug_info._replace( + arg_names=new_arg_names, + result_paths=new_result_paths, + ) new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns, jaxpr.jaxpr.effects, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ee3df03d5ed3..999ca37d2803 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -502,7 +502,8 @@ def _closed_call_param_updater(params, _, __): return dict(params, call_jaxpr=core.ClosedJaxpr(jaxpr, ())) call_param_updaters[core.closed_call_p] = _closed_call_param_updater -def abstract_eval_fun(fun: Callable, *avals, debug_info=None, **params): +def abstract_eval_fun(fun: Callable, *avals, + debug_info: core.DebugInfo, **params): _, avals_out, _, () = trace_to_jaxpr_dynamic( lu.wrap_init(fun, params, debug_info=debug_info), avals) assert all(isinstance(aval, AbstractValue) for aval in avals_out) @@ -582,7 +583,7 @@ def trace_to_subjaxpr_nounits( f: Callable, trace: JaxprTrace, instantiate: Sequence[bool] | bool, - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( @@ -595,7 +596,7 @@ def trace_to_subjaxpr_nounits( def trace_to_subjaxpr_nounits2( f: Callable, tag: TraceTag, - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert isinstance(tag, TraceTag) @@ -612,7 +613,7 @@ def trace_to_subjaxpr_nounits2( def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, instantiate: Sequence[bool] | bool, in_pvals: Sequence[PartialVal], - debug_info: core.DebugInfo | None): + debug_info: core.DebugInfo): in_knowns = [pval.is_known() for pval in in_pvals] in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] @@ -639,7 +640,7 @@ def _trace_to_subjaxpr_nounits(f: Callable, trace: JaxprTrace, def trace_to_subjaxpr_nounits_fwd( f: Callable, tag: TraceTag, - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals @@ -669,7 +670,7 @@ def trace_to_subjaxpr_nounits_fwd( def trace_to_subjaxpr_nounits_fwd2( f: Callable, tag: TraceTag, - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals @@ -752,13 +753,14 @@ def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom], def tracers_to_jaxpr( in_tracers: Sequence[JaxprTracer], out_tracers: Sequence[JaxprTracer], - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, ) -> tuple[Jaxpr, tuple[Any, ...], tuple[Any, ...]]: """Constructs Jaxpr given tracers for inputs and outputs. Params: in_tracers: the tracers that were created for the function inputs out_tracers: the tracers that were output by the function. + debug_info: the debug info for the function. Returns: a triple of a `Jaxpr`, a list of constant values corresponding to the `constvars` in the returned Jaxps, and a list of environment values. @@ -838,7 +840,7 @@ def type_substitute(aval: AbstractValue) -> AbstractValue: def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: """Moves the constvars to the start of invars.""" config.enable_checks.value and core.check_jaxpr(jaxpr) - dbg = jaxpr.debug_info and jaxpr.debug_info._replace( + dbg = jaxpr.debug_info._replace( arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) lifted_jaxpr = Jaxpr(constvars=(), invars=jaxpr.constvars + jaxpr.invars, @@ -854,7 +856,7 @@ def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr: return jaxpr.replace() # 'return jaxpr' would create cache reference cycle config.enable_checks.value and core.check_jaxpr(jaxpr) constvars, invars = split_list(jaxpr.invars, [n]) - dbg = jaxpr.debug_info and jaxpr.debug_info._replace( + dbg = jaxpr.debug_info._replace( arg_names=jaxpr.debug_info.arg_names[n:]) lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars, debug_info=dbg) @@ -868,7 +870,7 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr: env_vars, invars = split_list(jaxpr.invars, [num_env_vars]) converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars, invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, - effects=jaxpr.effects) + effects=jaxpr.effects, debug_info=jaxpr.debug_info) config.enable_checks.value and core.check_jaxpr(converted_jaxpr) return converted_jaxpr @@ -1363,7 +1365,7 @@ def prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: Sequence[bool]) -> Jaxpr: def _prune_jaxpr_outputs(jaxpr: Jaxpr, used_outputs: tuple[bool, ...]) -> Jaxpr: outvars = [v for v, b in zip(jaxpr.outvars, used_outputs) if b] - dbg = jaxpr.debug_info and core.DebugInfo( + dbg = core.DebugInfo( jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, jaxpr.debug_info.arg_names, jaxpr.debug_info.filter_result_paths(used_outputs)) @@ -1451,7 +1453,7 @@ def write(x: Atom, b: bool) -> None: eqns = new_eqns[::-1] jaxpr_effects = make_jaxpr_effects(jaxpr.constvars, invars, outvars, eqns) - dbg = jaxpr.debug_info and core.DebugInfo( + dbg = core.DebugInfo( jaxpr.debug_info.traced_for, jaxpr.debug_info.func_src_info, jaxpr.debug_info.filter_arg_names(used_inputs), jaxpr.debug_info.filter_result_paths(used_outputs)) @@ -1653,9 +1655,9 @@ class JaxprStackFrame: attrs_tracked: list[tuple[Any, str]] attrs_inits: list attrs_vars: list[Var] - debug_info: core.DebugInfo | None + debug_info: core.DebugInfo - def __init__(self, debug_info: core.DebugInfo | None): + def __init__(self, debug_info: core.DebugInfo): self.gensym = core.gensym() self.tracer_to_var = {} self.constid_to_tracer = {} @@ -1674,7 +1676,7 @@ def add_eqn(self, eqn: core.JaxprEqn): def to_jaxpr(self, trace: DynamicJaxprTrace, out_tracers: Sequence[Tracer], - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, ) -> tuple[Jaxpr, list[Any], list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]]: # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) @@ -1696,7 +1698,7 @@ def to_jaxpr(self, trace: DynamicJaxprTrace, return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked) def to_jaxpr2(self, out_tracers: Sequence[core.Tracer], - debug_info: core.DebugInfo | None): + debug_info: core.DebugInfo): # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) constvars, constvals = unzip2(self.constvar_to_val.items()) @@ -1843,7 +1845,7 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]: class DynamicJaxprTrace(core.Trace): __slots__ = ("frame",) - def __init__(self, debug_info: core.DebugInfo | None): + def __init__(self, debug_info: core.DebugInfo): self.frame = JaxprStackFrame(debug_info) def invalidate(self): @@ -2117,7 +2119,7 @@ def transpose_jaxpr_thunk(): return out_tracers def to_jaxpr(self, out_tracers: Sequence[Tracer], - debug_info: core.DebugInfo | None): + debug_info: core.DebugInfo): return self.frame.to_jaxpr(self, out_tracers, debug_info) @@ -2180,17 +2182,13 @@ def trace_to_jaxpr_dynamic( return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked def _check_no_returned_refs( - dbg: core.DebugInfo | None, + dbg: core.DebugInfo, out_tracers: Sequence[DynamicJaxprTracer] ) -> None: if not config.mutable_array_checks.value: return for i, t in enumerate(out_tracers): a = t.aval if isinstance(a, AbstractRef): - if dbg is None: - raise ValueError( - f"function returned a mutable array reference of type {a.str_short()}, " - "but mutable array references cannot be returned.") result_paths = dbg.resolve_result_paths().safe_result_paths(len(out_tracers)) loc = result_paths[i] and f' at output tree path {result_paths[i]}' frame = t._trace.frame @@ -2469,7 +2467,8 @@ def substitute(aval: AbstractValue) -> AbstractValue: return aval in_avals = [substitute(v.aval) for v in jaxpr.invars] - eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts)) + eval_padded = lu.wrap_init(partial(_eval_jaxpr_padded, jaxpr, consts), + debug_info=jaxpr.debug_info) padded_jaxpr, _, padded_consts, () = trace_to_jaxpr_dynamic(eval_padded, in_avals) return padded_jaxpr, padded_consts diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e4b528136320..c3866cfc7845 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -877,8 +877,8 @@ def lower_parallel_callable( replicated_args=replicated_args, arg_shardings=None, result_shardings=None, - arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)), - result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)), + arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)), + result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)), num_replicas=replicas.num_global_replicas, lowering_parameters=lowering_parameters) return PmapComputation(lowering_result.module, @@ -1968,8 +1968,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, result_shardings=out_mlir_shardings, in_layouts=in_layouts, out_layouts=out_layouts, - arg_names=jaxpr._debug_info and jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)), - result_names=jaxpr._debug_info and jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)), + arg_names=jaxpr._debug_info.safe_arg_names(len(jaxpr.invars)), + result_names=jaxpr._debug_info.safe_result_paths(len(jaxpr.outvars)), num_replicas=nreps, num_partitions=num_partitions, all_default_mem_kind=all_default_mem_kind, @@ -2123,7 +2123,7 @@ def _get_num_devices( class AllArgsInfo(NamedTuple): """Avals and debug_info for all arguments prior to DCE.""" in_avals: Sequence[core.ShapedArray] - debug_info: core.DebugInfo | None + debug_info: core.DebugInfo @lru_cache(maxsize=2048) @@ -3200,17 +3200,13 @@ def cc_shard_arg(x, sharding, layout): def check_arg_avals_for_call(ref_avals, arg_avals, - jaxpr_debug_info: core.DebugInfo | None = None): + jaxpr_debug_info: core.DebugInfo): if len(ref_avals) != len(arg_avals): raise TypeError( f"Computation compiled for {len(ref_avals)} inputs " f"but called with {len(arg_avals)}") - if jaxpr_debug_info is not None: - arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))] - else: - num_args = len(ref_avals) - arg_names = [f"{i + 1}/{num_args}" for i in range(num_args)] + arg_names = [f"'{name}'" for name in jaxpr_debug_info.safe_arg_names(len(ref_avals))] errors = [] for ref_aval, arg_aval, name in safe_zip(ref_avals, arg_avals, arg_names): @@ -3262,14 +3258,13 @@ def check_array_xla_sharding_layout_match( args_after_dce, in_xla_shardings: Sequence[JSharding], in_xla_layouts: Sequence[DeviceLocalLayout], - jaxpr_debug_info: core.DebugInfo | None, + jaxpr_debug_info: core.DebugInfo, kept_var_idx: set[int]) -> None: from jax._src.array import ArrayImpl # jaxpr_debug_info.arg_names are before DCE, so need to DCE them. arg_names = ( - [""] * len(args_after_dce) if jaxpr_debug_info is None - else [a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore - if i in kept_var_idx] + [a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore + if i in kept_var_idx] ) errors = [] num_errors = 5 diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 52c6fd97a1da..6918f74df6a0 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -73,7 +73,7 @@ class Ref(Generic[T]): pass def _trace_to_jaxpr_with_refs(f: Callable, state_tree: PyTreeDef, state_avals: Sequence[core.AbstractValue], - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, ) -> tuple[core.Jaxpr, list[Any], PyTreeDef]: f, out_tree_thunk = api_util.flatten_fun_nokwargs( lu.wrap_init(f, debug_info=debug_info), diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index e278e7ee9b00..acfcfd7ff3d3 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -336,7 +336,7 @@ def _custom_linear_solve_impl(*args, const_lengths, jaxprs): def _tangent_linear_map(func: Callable, params, params_dot, - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, *x): """Compute the tangent of a linear map. diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 565598e51485..b58f2ee2b212 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -161,7 +161,7 @@ def __init__(self, f: Callable, f_transformed: Callable, transforms, stores: tuple[Store | EqualStore | None, ...], params, in_type, - debug_info: DebugInfo | None): + debug_info: DebugInfo): self.f = f self.f_transformed = f_transformed self.transforms = transforms @@ -258,6 +258,7 @@ def fun_name(f): except: return str(f) + class DebugInfo(NamedTuple): """Debugging info about a func, its arguments, and results.""" traced_for: str # e.g. 'jit', 'scan', etc @@ -331,18 +332,17 @@ def _missing_debug_info(for_what: str) -> DebugInfo: return DebugInfo("missing_debug_info", "", (), ()) def wrap_init(f: Callable, params=None, *, - debug_info: DebugInfo | None = None) -> WrappedFun: + debug_info: DebugInfo) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" params_dict = {} if params is None else params params = () if params is None else tuple(sorted(params.items())) fun = WrappedFun(f, partial(f, **params_dict), (), (), params, None, debug_info) - if debug_info: - if debug_info.result_paths is None: - fun, result_paths_thunk = _get_result_paths_thunk(fun) - debug_info = debug_info._replace( - result_paths=HashableFunction(result_paths_thunk, closure=())) - fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores, - fun.params, fun.in_type, debug_info) + if debug_info.result_paths is None: + fun, result_paths_thunk = _get_result_paths_thunk(fun) + debug_info = debug_info._replace( + result_paths=HashableFunction(result_paths_thunk, closure=())) + fun = WrappedFun(fun.f, fun.f_transformed, fun.transforms, fun.stores, + fun.params, fun.in_type, debug_info) return fun diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 68cddd90f90a..3bb69c7a2942 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -749,22 +749,20 @@ def _infer_params( entry.pjit_params = p return entry.pjit_params, entry.pjit_params.consts + dynargs -def _infer_input_type(fun: Callable, dbg: core.DebugInfo | None, +def _infer_input_type(fun: Callable, dbg: core.DebugInfo, explicit_args) -> tuple[core.AbstractValue, ...]: avals = [] try: for i, x in enumerate(explicit_args): avals.append(core.shaped_abstractify(x)) except OverflowError: - arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg # type: ignore - else f"flattened argument number is {i}") # type: ignore + arg_path = f"argument path is {dbg.arg_names[i]}" # type: ignore raise OverflowError( "An overflow was encountered while parsing an argument to a jitted " f"computation, whose {arg_path}." ) from None except TypeError: - arg_description = (f"path {dbg.arg_names[i]}" if dbg # type: ignore - else f"flattened argument number {i}") # type: ignore + arg_description = f"path {dbg.arg_names[i]}" # type: ignore raise TypeError( f"Error interpreting argument to {fun} as an abstract array." f" The problematic value is of type {type(x)} and was passed to" # type: ignore @@ -1129,7 +1127,7 @@ def __repr__(self): return "pytree leaf" @util.cache(max_size=4096, trace_context_in_key=False) def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, in_layouts_treedef, in_layouts_leaves, - in_avals, in_tree, debug_info, + in_avals, in_tree, debug_info: core.DebugInfo, device_or_backend_set, kws): if not kws: in_tree, _ = treedef_children(in_tree) @@ -1154,11 +1152,11 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals) if not config.dynamic_shapes.value and not attrs_tracked: pjit_check_aval_sharding(in_shardings_flat, in_avals, - None if debug_info is None else debug_info.safe_arg_names(len(in_avals)), + debug_info.safe_arg_names(len(in_avals)), "pjit arguments", allow_uneven_sharding=False) check_aval_layout_compatibility( in_layouts_flat, in_avals, - None if debug_info is None else debug_info.arg_names, "jit arguments") + debug_info.safe_arg_names(len(in_avals)), "jit arguments") return in_shardings_flat, in_layouts_flat callsites: set[str] = set() @@ -1185,7 +1183,7 @@ def unpack(key): # have we seen this function before at all? fun_name = getattr(fun.f, '__qualname__', fun.f) - if debug_info is not None and debug_info.func_src_info: + if debug_info.func_src_info: # TODO(necula): clean up the extraction of the source info _, *rest = debug_info.func_src_info.split(' at ') src_info = " defined at " + ' '.join(rest) @@ -1257,7 +1255,7 @@ def unpack(key): # have we never seen these input types (eg shapes, dtypes) before? types_match = [k for k in trees_match if k[1] == in_type] if not types_match: - if len(in_type) < 5 and debug_info is not None: + if len(in_type) < 5: in_type_str = ':\n {}'.format(', '.join( f'{n}: {ty.str_short(short_dtypes=True)}' for n, ty in zip(debug_info.arg_names, in_type))) @@ -1269,10 +1267,7 @@ def unpack(key): num_mismatch = sum(map(op.ne, closest_ty, in_type)) p(f" closest seen input type signature has {num_mismatch} mismatches, including:") add_weak_type_hint = False - if debug_info: - arg_names = debug_info.safe_arg_names(len(in_type)) - else: - arg_names = (None,) * len(in_type) + arg_names = debug_info.safe_arg_names(len(in_type)) for name, ty1, ty2 in zip(arg_names, closest_ty, in_type): if ty1 != ty2: @@ -1338,7 +1333,7 @@ def _create_pjit_jaxpr( def _check_and_canonicalize_out_shardings( out_shardings_treedef, out_shardings_leaves, out_layouts_treedef, out_layouts_leaves, out_tree, out_avals, - debug_info: core.DebugInfo | None, + debug_info: core.DebugInfo, device_or_backend_set): orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves) if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)): @@ -1358,11 +1353,11 @@ def _check_and_canonicalize_out_shardings( if not config.dynamic_shapes.value: pjit_check_aval_sharding( out_shardings_flat, out_avals, - None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] + debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] "pjit outputs", allow_uneven_sharding=False) check_aval_layout_compatibility( out_layouts_flat, out_avals, - None if debug_info is None else debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] + debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] "jit outputs") return out_shardings_flat, out_layouts_flat diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 7a3bc9bc8106..0cf9a013a9e4 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -15,6 +15,8 @@ # Note: import as is required for names to be exported. # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 +from typing import Callable + from jax._src.linear_util import ( StoreException as StoreException, WrappedFun as WrappedFun, @@ -24,7 +26,14 @@ transformation_with_aux as transformation_with_aux, transformation2 as transformation2, transformation_with_aux2 as transformation_with_aux2, - wrap_init as wrap_init, # TODO(b/396086979): remove this once we pass debug_info everywhere. + wrap_init as _wrap_init, _missing_debug_info as _missing_debug_info, ) + +# Version of wrap_init that does not require a DebugInfo object. +# This usage is deprecated, use api_util.debug_info() to construct a proper +# DebugInfo object. +def wrap_init(f: Callable, params=None, *, debug_info=None) -> WrappedFun: + debug_info = debug_info or _missing_debug_info("linear_util.wrap_init") + return _wrap_init(f, params, debug_info=debug_info) diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index bca62a2de5cf..8a79f3867dc9 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -71,8 +71,7 @@ def _collect_jaxprs(jaxpr: core.Jaxpr, return acc -def _debug_info_to_string(dbg: core.DebugInfo | None) -> list[str]: - if dbg is None: return "None" +def _debug_info_to_string(dbg: core.DebugInfo) -> list[str]: # Strip the absolute path and the line number but check that it references # this file (to catch errors when the source info points in JAX internals) func_src_info = re.sub(r"^(\S+)( at .*.debug_info_test.py:\d+)?", "\\1", dbg.func_src_info) @@ -294,7 +293,6 @@ def test_debug_info_save_wrapped_fun_source_info(self): def wrapper(x, y): return x - api_util.save_wrapped_fun_sourceinfo(wrapper, None) # No effect dbg = api_util.debug_info("test", wrapper, (1, 2), {}) self.assertEqual("wrapper", dbg.func_name) diff --git a/tests/extend_test.py b/tests/extend_test.py index e37bea42c3e6..fcf9d3b54c6d 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -53,7 +53,8 @@ def test_symbols(self): self.assertIs(jex.linear_util.merge_linear_aux, linear_util.merge_linear_aux) self.assertIs(jex.linear_util.transformation, linear_util.transformation) self.assertIs(jex.linear_util.transformation_with_aux, linear_util.transformation_with_aux) - self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init) + # TODO(necula): revert this change once we deprecate the old wrap_init + # self.assertIs(jex.linear_util.wrap_init, linear_util.wrap_init) class RandomTest(jtu.JaxTestCase):