Skip to content

Commit

Permalink
Merge pull request #26154 from jakevdp:pure-callback-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720763192
  • Loading branch information
Google-ML-Automation committed Jan 29, 2025
2 parents 809e113 + ba2858f commit bf22b53
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
41 changes: 40 additions & 1 deletion docs/external-callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ kernelspec:

<!--* freshness: { reviewed: '2024-05-16' } *-->

This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` and {func}`jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`.
This tutorial outlines how you can use various callback functions, which allow JAX runtimes to execute Python code on the host. Examples of JAX callbacks are `jax.pure_callback`, `jax.experimental.io_callback` and `jax.debug.callback`. You can use them even while running under JAX transformations, including {func}`~jax.jit`, {func}`~jax.vmap`, {func}`~jax.grad`.

## Why callbacks?

Expand Down Expand Up @@ -66,8 +66,11 @@ This works by passing the runtime value of `y` as a CPU {class}`jax.Array` back
In earlier versions of JAX, there was only one kind of callback available, implemented in {func}`jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:

- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effects.
See {ref}`external-callbacks-exploring-pure-callback`.
- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.
See {ref}`external-callbacks-exploring-io-callback`.
- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.
See {ref}`external-callbacks-exploring-debug-callback`.

(The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`).

Expand All @@ -85,6 +88,7 @@ From the user perspective, these three flavors of callback are mainly distinguis

³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases.

(external-callbacks-exploring-pure-callback)=
### Exploring `pure_callback`

{func}`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).
Expand Down Expand Up @@ -163,6 +167,41 @@ f2();
In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.
In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects.

#### `pure_callback` and exceptions

In the context of JAX transformations, Python runtime exceptions should be considered side-effects:
this means that intentionally raising an error within a `pure_callback` breaks the API contract,
and the behavior of the resulting program is undefined. In particular, the manner in which
such a program halts will generally depend on the backend, and the details of that behavior may
change in future releases.

Additionally, passing impure functions to `pure_callback` may result in unexpected behavior during
transformations like {func}`jax.jit` or {func}`jax.vmap`, because the transformation rules for
`pure_callback` are defined under the assumption that the callback function is pure. Here's one
simple example of an impure callback behaving unexpectedly under `vmap`:
```python
import jax
import jax.numpy as jnp

def raise_via_callback(x):
def _raise(x):
raise ValueError(f"value of x is {x}")
return jax.pure_callback(_raise, x, x)

def raise_if_negative(x):
return jax.lax.cond(x < 0, raise_via_callback, lambda x: x, x)

x_batch = jnp.arange(4)

[raise_if_negative(x) for x in x_batch] # does not raise

jax.vmap(raise_if_negative)(x_batch) # ValueError: value of x is 0
```
To avoid this and similar unexpected behavior, we recommend not attempting to use
`pure_callback` to raise runtime errors.


(external-callbacks-exploring-io-callback)=
### Exploring `io_callback`

In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.
Expand Down
9 changes: 8 additions & 1 deletion jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ def pure_callback(
`jit`-decorated function has no data dependence on its value. Pure callbacks
may also be reordered if data-dependence allows.
.. warning::
In the context of JAX transformations, Python exceptions should be
considered side-effects: this means that intentionally raising an error
within a `pure_callback` breaks the API contract, and the behavior of
the resulting program is undefined.
When `vmap`-ed the behavior will depend on the value of the ``vmap_method``.
* Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method``
Expand Down Expand Up @@ -440,7 +447,7 @@ def pure_callback(
(4,) (4,)
Array([1., 2., 3., 4.], dtype=float32)
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
.. _External Callbacks: https://jax.readthedocs.io/en/latest/external-callbacks.html
"""
if not isinstance(vectorized, DeprecatedArg) and not vectorized is None:
deprecations.warn(
Expand Down

0 comments on commit bf22b53

Please sign in to comment.