Skip to content

Commit

Permalink
Merge pull request #26476 from froystig:aot-doc-traced
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725902103
  • Loading branch information
Google-ML-Automation committed Feb 12, 2025
2 parents 675be01 + af381a7 commit 914adaf
Showing 1 changed file with 44 additions and 47 deletions.
91 changes: 44 additions & 47 deletions docs/aot.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

<!--* freshness: { reviewed: '2024-06-12' } *-->

JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning
a function that is compiled and runs on accelerators or the CPU. As the JIT
acronym indicates, all compilation happens _just-in-time_ for execution.
JAX's `jax.jit` transformation returns a function that, when called,
compiles a computation and runs it on accelerators (or the CPU). As
the JIT acronym indicates, all compilation happens _just-in-time_ for
execution.

Some situations call for _ahead-of-time_ (AOT) compilation instead. When you
want to fully compile prior to execution time, or you want control over when
Expand All @@ -18,10 +19,14 @@ function/callable output by {func}`jax.jit`, say `f = jax.jit(F)` for some input
callable `F`. When it is invoked with arguments, say `f(x, y)` where `x` and `y`
are arrays, JAX does the following in order:

1. **Stage out** a specialized version of the original Python callable `F` to an
internal representation. The specialization reflects a restriction of `F` to
input types inferred from properties of the arguments `x` and `y` (usually
their shape and element type).
1. **Stage out** a specialized version of the original Python callable
`F` to an internal representation. The specialization reflects a
restriction of `F` to input types inferred from properties of the
arguments `x` and `y` (usually their shape and element type). JAX
carries out this specialization by a process that we call
_tracing_. During tracing, JAX stages the specialization of `F` to
a jaxpr, which is a function in the [Jaxpr intermediate
language](https://jax.readthedocs.io/en/latest/jaxpr.html).

2. **Lower** this specialized, staged-out computation to the XLA compiler's
input language, StableHLO.
Expand All @@ -31,17 +36,22 @@ are arrays, JAX does the following in order:

4. **Execute** the compiled executable with the arrays `x` and `y` as arguments.

JAX's AOT API gives you direct control over steps #2, #3, and #4 (but [not
#1](#inspecting-staged-out-computations)), plus some other features along the
way. An example:
JAX's AOT API gives you direct control over each of these steps, plus
some other features along the way. An example:

```python
>>> import jax

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> lowered = jax.jit(f).lower(x, y)
>>> traced = jax.jit(f).trace(x, y)

>>> # Print the specialized, staged-out representation (as Jaxpr IR)
>>> print(traced.jaxpr)
{ lambda ; a:i32[] b:i32[]. let c:i32[] = mul 2 a; d:i32[] = add c b in (d,) }

>>> lowered = traced.lower()

>>> # Print lowered HLO
>>> print(lowered.as_text())
Expand All @@ -67,45 +77,44 @@ Array(10, dtype=int32, weak_type=True)
```

Note that the lowered objects can be used only in the same process
in which they were lowered. For exporting use cases,
see the {ref}`export` APIs.
in which they were lowered. For exporting use cases, see the {ref}`export` APIs.

See the {mod}`jax.stages` documentation for more details on what functionality
the lowering and compiled functions provide.

All optional arguments to `jit`---such as `static_argnums`---are respected in
the corresponding lowering, compilation, and execution.
the corresponding tracing, lowering, compilation, and execution.

In the example above, we can replace the arguments to `lower` with any objects
In the example above, we can replace the arguments to `trace` with any objects
that have `shape` and `dtype` attributes:

```python
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x, y)
Array(10, dtype=int32)

```

More generally, `lower` only needs its arguments to structurally supply what JAX
More generally, `trace` only needs its arguments to structurally supply what JAX
must know for specialization and lowering. For typical array arguments like the
ones above, this means `shape` and `dtype` fields. For static arguments, by
contrast, JAX needs actual array values (more on this
[below](#lowering-with-static-arguments)).
[below](#tracing-with-static-arguments)).

Invoking an AOT-compiled function with arguments that are incompatible with its
lowering raises an error:
tracing raises an error:

```python
>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]

>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
Expand All @@ -119,14 +128,14 @@ transformations](#aot-compiled-functions-cannot-be-transformed) such as
`jax.jit`, {func}`jax.grad`, and {func}`jax.vmap`.


## Lowering with static arguments
## Tracing with static arguments

Lowering with static arguments underscores the interaction between options
passed to `jax.jit`, the arguments passed to `lower`, and the arguments needed
Tracing with static arguments underscores the interaction between options
passed to `jax.jit`, the arguments passed to `trace`, and the arguments needed
to invoke the resulting compiled function. Continuing with our example above:

```python
>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)
>>> lowered_with_x = jax.jit(f, static_argnums=0).trace(7, 8).lower()

>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
Expand All @@ -143,30 +152,29 @@ Array(19, dtype=int32, weak_type=True)

```

The result of `lower` is not safe to serialize directly for use
in a different process.
See {ref}`export` for additional APIs for this purpose.

Note that `lower` here takes two arguments as usual, but the subsequent compiled
Note that `trace` here takes two arguments as usual, but the subsequent compiled
function accepts only the remaining non-static second argument. The static first
argument (value 7) is taken as a constant at lowering time and built into the
lowered computation, where it is possibly folded in with other constants. In
this case, its multiplication by 2 is simplified, resulting in the constant 14.

Although the second argument to `lower` above can be replaced by a hollow
Although the second argument to `trace` above can be replaced by a hollow
shape/dtype structure, it is necessary that the static first argument be a
concrete value. Otherwise, lowering would err:
concrete value. Otherwise, tracing errs:

```python
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) # doctest: +SKIP
>>> jax.jit(f, static_argnums=0).trace(i32_scalar, i32_scalar) # doctest: +SKIP
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'

>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
>>> jax.jit(f, static_argnums=0).trace(10, i32_scalar).lower().compile()(5)
Array(25, dtype=int32)

```

The results of `trace` and of `lower` are not safe to serialize directly for use
in a different process. See {ref}`export` for additional APIs for this purpose.

## AOT-compiled functions cannot be transformed

Compiled functions are specialized to a particular set of argument "types," such
Expand All @@ -187,7 +195,7 @@ in transformations. Example:
>>> z, zs = make_z(3, 2), make_z(4, 3, 2)

>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()
>>> g_aot = jax.jit(g).trace(z).lower().compile()

>>> jax.vmap(g_jit)(zs)
Array([[ 1., 5., 9.],
Expand Down Expand Up @@ -218,7 +226,7 @@ a text representation. Compiled functions do the same, and also offer cost and
memory analyses from the compiler. All of these are provided via methods on the
{class}`jax.stages.Lowered` and {class}`jax.stages.Compiled` objects (e.g.,
`lowered.as_text()` and `compiled.cost_analysis()` above).
You can obtain more debbugging information, e.g., source location,
You can obtain more debugging information, e.g., source location,
by using the `debug_info` parameter to `lowered.as_text()`.

These methods are meant as an aid for manual inspection and debugging, not as a
Expand All @@ -238,14 +246,3 @@ platform, and runtime. This makes for two important caveats:
remain the same on the following day.

When in doubt, see the package API documentation for {mod}`jax.stages`.


## Inspecting staged-out computations

Stage #1 in the list at the top of this note mentions specialization and
staging, prior to lowering. JAX's internal notion of a function specialized to
the types of its arguments is not always a reified data structure in memory. To
explicitly construct a view of JAX's specialization of a function in the
internal [Jaxpr intermediate
language](https://jax.readthedocs.io/en/latest/jaxpr.html), see
{func}`jax.make_jaxpr`.

0 comments on commit 914adaf

Please sign in to comment.