diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 9b937bd67055..7ae1c1dc6862 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -28,6 +28,7 @@ from jax._src import dispatch from jax._src import test_util as jtu from jax._src import util +from jax._src.lib import xla_extension_version from jax.experimental import io_callback from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -1033,6 +1034,32 @@ def f(x, y): jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error + @jtu.thread_unsafe_test() # Requires a lot of memory. + @jtu.skip_on_flag("jax_skip_slow_tests", True) + @jtu.run_on_devices("cpu") + def test_async_deadlock(self): + if xla_extension_version < 307: + self.skipTest("deadlock expected") + + # See https://github.com/jax-ml/jax/issues/24255 + eig = jax.jit(jnp.linalg.eig) + + def callback(x): + return jax.block_until_ready(eig(x)) + + def fun(x): + self.assertEqual(x.dtype, jnp.complex64) + out_type = ( + jax.ShapeDtypeStruct(x.shape[:-1], x.dtype), + jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + return jax.pure_callback(callback, out_type, x) + + result = 0.0 + for _ in range(10): + result += fun(jnp.ones((500, 500), jnp.complex64))[1] + jax.block_until_ready(result) # doesn't deadlock + class IOCallbackTest(jtu.JaxTestCase):