diff --git a/docs/aot.md b/docs/aot.md index f4e7e020ca1e..c177ff2bd7f9 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -4,9 +4,10 @@ -JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning -a function that is compiled and runs on accelerators or the CPU. As the JIT -acronym indicates, all compilation happens _just-in-time_ for execution. +JAX's `jax.jit` transformation returns a function that, when called, +compiles a computation and runs it on accelerators (or the CPU). As +the JIT acronym indicates, all compilation happens _just-in-time_ for +execution. Some situations call for _ahead-of-time_ (AOT) compilation instead. When you want to fully compile prior to execution time, or you want control over when @@ -18,10 +19,14 @@ function/callable output by {func}`jax.jit`, say `f = jax.jit(F)` for some input callable `F`. When it is invoked with arguments, say `f(x, y)` where `x` and `y` are arrays, JAX does the following in order: -1. **Stage out** a specialized version of the original Python callable `F` to an - internal representation. The specialization reflects a restriction of `F` to - input types inferred from properties of the arguments `x` and `y` (usually - their shape and element type). +1. **Stage out** a specialized version of the original Python callable + `F` to an internal representation. The specialization reflects a + restriction of `F` to input types inferred from properties of the + arguments `x` and `y` (usually their shape and element type). JAX + carries out this specialization by a process that we call + _tracing_. During tracing, JAX stages the specialization of `F` to + a jaxpr, which is a function in the [Jaxpr intermediate + language](https://jax.readthedocs.io/en/latest/jaxpr.html). 2. **Lower** this specialized, staged-out computation to the XLA compiler's input language, StableHLO. @@ -31,9 +36,8 @@ are arrays, JAX does the following in order: 4. **Execute** the compiled executable with the arrays `x` and `y` as arguments. -JAX's AOT API gives you direct control over steps #2, #3, and #4 (but [not -#1](#inspecting-staged-out-computations)), plus some other features along the -way. An example: +JAX's AOT API gives you direct control over each of these steps, plus +some other features along the way. An example: ```python >>> import jax @@ -41,7 +45,13 @@ way. An example: >>> def f(x, y): return 2 * x + y >>> x, y = 3, 4 ->>> lowered = jax.jit(f).lower(x, y) +>>> traced = jax.jit(f).trace(x, y) + +>>> # Print the specialized, staged-out representation (as Jaxpr IR) +>>> print(traced.jaxpr) +{ lambda ; a:i32[] b:i32[]. let c:i32[] = mul 2 a; d:i32[] = add c b in (d,) } + +>>> lowered = traced.lower() >>> # Print lowered HLO >>> print(lowered.as_text()) @@ -67,37 +77,36 @@ Array(10, dtype=int32, weak_type=True) ``` Note that the lowered objects can be used only in the same process -in which they were lowered. For exporting use cases, -see the {ref}`export` APIs. +in which they were lowered. For exporting use cases, see the {ref}`export` APIs. See the {mod}`jax.stages` documentation for more details on what functionality the lowering and compiled functions provide. All optional arguments to `jit`---such as `static_argnums`---are respected in -the corresponding lowering, compilation, and execution. +the corresponding tracing, lowering, compilation, and execution. -In the example above, we can replace the arguments to `lower` with any objects +In the example above, we can replace the arguments to `trace` with any objects that have `shape` and `dtype` attributes: ```python >>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32')) ->>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y) +>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x, y) Array(10, dtype=int32) ``` -More generally, `lower` only needs its arguments to structurally supply what JAX +More generally, `trace` only needs its arguments to structurally supply what JAX must know for specialization and lowering. For typical array arguments like the ones above, this means `shape` and `dtype` fields. For static arguments, by contrast, JAX needs actual array values (more on this [below](#lowering-with-static-arguments)). Invoking an AOT-compiled function with arguments that are incompatible with its -lowering raises an error: +tracing raises an error: ```python >>> x_1d = y_1d = jnp.arange(3) ->>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL +>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL ... Traceback (most recent call last): TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: @@ -105,7 +114,7 @@ Argument 'x' compiled with int32[] and called with int32[3] Argument 'y' compiled with int32[] and called with int32[3] >>> x_f = y_f = jnp.float32(72.) ->>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL +>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL ... Traceback (most recent call last): TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are: @@ -119,14 +128,14 @@ transformations](#aot-compiled-functions-cannot-be-transformed) such as `jax.jit`, {func}`jax.grad`, and {func}`jax.vmap`. -## Lowering with static arguments +## Tracing with static arguments -Lowering with static arguments underscores the interaction between options -passed to `jax.jit`, the arguments passed to `lower`, and the arguments needed +Tracing with static arguments underscores the interaction between options +passed to `jax.jit`, the arguments passed to `trace`, and the arguments needed to invoke the resulting compiled function. Continuing with our example above: ```python ->>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8) +>>> lowered_with_x = jax.jit(f, static_argnums=0).trace(7, 8).lower() >>> # Lowered HLO, specialized to the *value* of the first argument (7) >>> print(lowered_with_x.as_text()) @@ -140,33 +149,30 @@ module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : >>> lowered_with_x.compile()(5) Array(19, dtype=int32, weak_type=True) - ``` -The result of `lower` is not safe to serialize directly for use -in a different process. -See {ref}`export` for additional APIs for this purpose. - -Note that `lower` here takes two arguments as usual, but the subsequent compiled +Note that `trace` here takes two arguments as usual, but the subsequent compiled function accepts only the remaining non-static second argument. The static first argument (value 7) is taken as a constant at lowering time and built into the lowered computation, where it is possibly folded in with other constants. In this case, its multiplication by 2 is simplified, resulting in the constant 14. -Although the second argument to `lower` above can be replaced by a hollow +Although the second argument to `trace` above can be replaced by a hollow shape/dtype structure, it is necessary that the static first argument be a -concrete value. Otherwise, lowering would err: +concrete value. Otherwise, tracing errs: ```python ->>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) # doctest: +SKIP +>>> jax.jit(f, static_argnums=0).trace(i32_scalar, i32_scalar) # doctest: +SKIP Traceback (most recent call last): TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct' ->>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5) +>>> jax.jit(f, static_argnums=0).trace(10, i32_scalar).lower().compile()(5) Array(25, dtype=int32) - ``` +The results of `trace` and of `lower` are not safe to serialize directly for use +in a different process. See {ref}`export` for additional APIs for this purpose. + ## AOT-compiled functions cannot be transformed Compiled functions are specialized to a particular set of argument "types," such @@ -187,7 +193,7 @@ in transformations. Example: >>> z, zs = make_z(3, 2), make_z(4, 3, 2) >>> g_jit = jax.jit(g) ->>> g_aot = jax.jit(g).lower(z).compile() +>>> g_aot = jax.jit(g).trace(z).lower().compile() >>> jax.vmap(g_jit)(zs) Array([[ 1., 5., 9.], @@ -198,7 +204,6 @@ Array([[ 1., 5., 9.], >>> jax.vmap(g_aot)(zs) # doctest: +SKIP Traceback (most recent call last): TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type - ``` A similar error is raised when `g_aot` is involved in autodiff @@ -218,7 +223,7 @@ a text representation. Compiled functions do the same, and also offer cost and memory analyses from the compiler. All of these are provided via methods on the {class}`jax.stages.Lowered` and {class}`jax.stages.Compiled` objects (e.g., `lowered.as_text()` and `compiled.cost_analysis()` above). -You can obtain more debbugging information, e.g., source location, +You can obtain more debugging information, e.g., source location, by using the `debug_info` parameter to `lowered.as_text()`. These methods are meant as an aid for manual inspection and debugging, not as a @@ -238,14 +243,3 @@ platform, and runtime. This makes for two important caveats: remain the same on the following day. When in doubt, see the package API documentation for {mod}`jax.stages`. - - -## Inspecting staged-out computations - -Stage #1 in the list at the top of this note mentions specialization and -staging, prior to lowering. JAX's internal notion of a function specialized to -the types of its arguments is not always a reified data structure in memory. To -explicitly construct a view of JAX's specialization of a function in the -internal [Jaxpr intermediate -language](https://jax.readthedocs.io/en/latest/jaxpr.html), see -{func}`jax.make_jaxpr`.