From 0a8bdd56c317f9cc092ceaa685d5109980baa72f Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 11 Dec 2024 11:20:28 -0800 Subject: [PATCH] Remove references to jax.core.raise_to_shaped As of JAX v0.4.36, `core.raise_to_shaped` is deprecated, and simply returns the input unchanged. PiperOrigin-RevId: 705174146 --- jax_triton/experimental/fusion/lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_triton/experimental/fusion/lowering.py b/jax_triton/experimental/fusion/lowering.py index c05edcaa..75903d07 100644 --- a/jax_triton/experimental/fusion/lowering.py +++ b/jax_triton/experimental/fusion/lowering.py @@ -369,7 +369,7 @@ def jit(f, *, fuse: bool = True, debug: bool = False): def wrapped(*args, **kwargs): flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) - in_avals = [core.raise_to_shaped(core.get_aval(a)) for a in flat_args] + in_avals = [core.get_aval(a) for a in flat_args] jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) jaxpr, consts = lower_jaxpr(jaxpr, consts, fuse=fuse, debug=debug) out_vals = core.eval_jaxpr(jaxpr, consts, *flat_args)