Skip to content

Commit

Permalink
Remove references to jax.core.raise_to_shaped
Browse files Browse the repository at this point in the history
As of JAX v0.4.36, `core.raise_to_shaped` is deprecated, and simply returns the input unchanged.

PiperOrigin-RevId: 705174146
  • Loading branch information
Jake VanderPlas authored and Google-ML-Automation committed Dec 11, 2024
1 parent b00ec0c commit 0a8bdd5
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax_triton/experimental/fusion/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def jit(f, *, fuse: bool = True, debug: bool = False):
def wrapped(*args, **kwargs):
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(a)) for a in flat_args]
in_avals = [core.get_aval(a) for a in flat_args]
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
jaxpr, consts = lower_jaxpr(jaxpr, consts, fuse=fuse, debug=debug)
out_vals = core.eval_jaxpr(jaxpr, consts, *flat_args)
Expand Down

0 comments on commit 0a8bdd5

Please sign in to comment.