Skip to content

Commit

Permalink
Disable async dispatch within the body of a host callback.
Browse files Browse the repository at this point in the history
This is a follow up to #26160 and openxla/xla#21980. See those PRs for more discussion of the motivation for this change.

In this PR, we disable CPU asynchronous execution when running within the body of a host callback, because this can cause deadlocks.

PiperOrigin-RevId: 720918318
  • Loading branch information
dfm authored and Google-ML-Automation committed Jan 29, 2025
1 parent a459e7e commit 9d39ab3
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tests/python_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from jax._src import dispatch
from jax._src import test_util as jtu
from jax._src import util
from jax._src.lib import xla_extension_version
from jax.experimental import io_callback
from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
Expand Down Expand Up @@ -1033,6 +1034,32 @@ def f(x, y):

jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error

@jtu.thread_unsafe_test() # Requires a lot of memory.
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.run_on_devices("cpu")
def test_async_deadlock(self):
if xla_extension_version < 307:
self.skipTest("deadlock expected")

# See https://github.com/jax-ml/jax/issues/24255
eig = jax.jit(jnp.linalg.eig)

def callback(x):
return jax.block_until_ready(eig(x))

def fun(x):
self.assertEqual(x.dtype, jnp.complex64)
out_type = (
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
jax.ShapeDtypeStruct(x.shape, x.dtype),
)
return jax.pure_callback(callback, out_type, x)

result = 0.0
for _ in range(10):
result += fun(jnp.ones((500, 500), jnp.complex64))[1]
jax.block_until_ready(result) # doesn't deadlock


class IOCallbackTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 9d39ab3

Please sign in to comment.