diff --git a/axlearn/common/attention_bias.py b/axlearn/common/attention_bias.py index d605242b..9830b077 100644 --- a/axlearn/common/attention_bias.py +++ b/axlearn/common/attention_bias.py @@ -687,7 +687,11 @@ def sliding_window_causal_mask(sliding_window_size: int) -> MaskFn: def mask(query_position: Tensor, key_position: Tensor): return query_position - key_position <= sliding_window_size - return and_masks(causal_mask, mask) + fun = and_masks(causal_mask, mask) + # Flash attention needs to recognize sliding window size in _to_splash_mask(). + # pylint: disable-next=protected-access + fun._sliding_window_size = sliding_window_size + return fun def make_causal_biases(seq_len: int) -> Tensor: diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py index de18a402..4de717b4 100644 --- a/axlearn/common/flash_attention/tpu_attention.py +++ b/axlearn/common/flash_attention/tpu_attention.py @@ -42,15 +42,16 @@ def tpu_flash_attention( - query: Tensor, # [batch_size, source_len, num_heads, head_dim] - key: Tensor, # [batch_size, target_len, num_heads, head_dim] - value: Tensor, # [batch_size, target_len, num_heads, head_dim] - bias: Tensor = None, # [batch_size, num_heads, source_len, target_len] - segment_ids: Tensor = None, # [batch_size, source_len] + query: Tensor, # [batch_size, target_len, num_heads, head_dim] + key: Tensor, # [batch_size, source_len, num_heads, head_dim] + value: Tensor, # [batch_size, source_len, num_heads, head_dim] + bias: Tensor = None, # [batch_size, num_heads, target_len, source_len] + segment_ids: Tensor = None, # [batch_size, target_len] *, mask: Optional[MaskFnAttentionBias] = None, softmax_scale: float = 1.0, block_size: int = 128, + interpret: bool = False, ): """Wraps JAX's TPU flash-attention, with reshapes and softmax-scaling outside kernel. @@ -63,20 +64,21 @@ def tpu_flash_attention( If provided, bias, segment_ids, and mask are applied on top of one another. Args: - query: The query tensor, of shape [batch_size, source_len, num_heads, head_dim]. - key: The key tensor, of shape [batch_size, target_len, num_heads, head_dim]. + query: The query tensor, of shape [batch_size, target_len, num_heads, head_dim]. + key: The key tensor, of shape [batch_size, source_len, num_heads, head_dim]. value: The value tensor, of shape [batch_size, source_len, num_heads, head_dim]. bias: The attention biases, can broadcast to shape - [batch_size, num_heads, source_len, target_len]. + [batch_size, num_heads, target_len, source_len]. segment_ids: The id of which segment each token belongs to. Attention is not computed between tokens in different segments. - Shape: [batch_size, source_len]. + Shape: [batch_size, target_len]. mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. softmax_scale: A scaling factor applied to the query. block_size: The block size to use for chunking data in the kernel. + interpret: If True, interpret the kernel using the pallas interpreter. CPU needs it. Returns: - The context tensor, of shape [batch_size, source_len, num_heads, head_dim]. + The context tensor, of shape [batch_size, target_len, num_heads, head_dim]. Raises: NotImplementedError: If no implementation with support for the arguments is found. @@ -111,13 +113,20 @@ def tpu_flash_attention( f"Source seq len {key.shape[1]} must be divisible by block size {block_size}." ) - mask = as_attention_bias(mask) + mask: Union[MaskFnAttentionBias | ZeroAttentionBias] = as_attention_bias(mask) # Switch num_heads and seq_len axes. query = jnp.einsum("btnh->bnth", query) key = jnp.einsum("bsnh->bnsh", key) value = jnp.einsum("bsnh->bnsh", value) try: + check_tpu_splash_attention( + query=query, + key=key, + mask=mask, + has_segment_ids=(segment_ids is not None), + has_bias=(bias is not None), + ) block_sizes = splash_attention_kernel.BlockSizes( block_q=block_size, block_kv=block_size, @@ -131,7 +140,13 @@ def tpu_flash_attention( use_fused_bwd_kernel=True, ) context = _tpu_splash_attention( - query, key, value, bias, segment_ids=segment_ids, mask=mask, block_sizes=block_sizes + query, + key, + value, + mask=mask, + segment_ids=segment_ids, + block_sizes=block_sizes, + interpret=interpret, ) logging.info("Using SplashAttention.") except SplashAttentionUnsupportedError as e: @@ -150,7 +165,14 @@ def tpu_flash_attention( block_q_dq=block_size, ) context = _legacy_tpu_flash_attention( - query, key, value, bias, segment_ids=segment_ids, mask=mask, block_sizes=block_sizes + query, + key, + value, + bias, + segment_ids=segment_ids, + mask=mask, + block_sizes=block_sizes, + interpret=interpret, ) logging.warning( "Falling back to legacy flash attention because SplashAttention is not supported.\n" @@ -167,36 +189,39 @@ def tpu_flash_attention( static_argnames=[ "mask", # Mask objects don't actually contain jax arrays, so they are static. "block_sizes", + "interpret", ], ) def _legacy_tpu_flash_attention( - query: Tensor, # [batch_size, num_heads, source_len, head_dim] - key: Tensor, # [batch_size, num_heads, target_len, head_dim] - value: Tensor, # [batch_size, num_heads, target_len, head_dim] - bias: Tensor = None, # [batch_size, num_heads, source_len, target_len] - segment_ids: Tensor = None, # [batch_size, source_len] + query: Tensor, # [batch_size, num_heads, target_len, head_dim] + key: Tensor, # [batch_size, num_heads, source_len, head_dim] + value: Tensor, # [batch_size, num_heads, source_len, head_dim] + bias: Tensor = None, # [batch_size, num_heads, target_len, source_len] + segment_ids: Tensor = None, # [batch_size, target_len] *, mask: MaskFnAttentionBias, block_sizes: Optional[LegacyBlockSizes] = None, -) -> Tensor: # [batch_size, num_heads, source_len, head_dim]. + interpret: bool = False, +) -> Tensor: # [batch_size, num_heads, target_len, head_dim]. """Wraps JAX's legacy TPU flash-attention. If provided, bias, segment_ids, and mask are applied on top of one another. Args: - query: The query tensor, of shape [batch_size, num_heads, source_len, head_dim]. - key: The key tensor, of shape [batch_size, num_heads, target_len, head_dim]. + query: The query tensor, of shape [batch_size, num_heads, target_len, head_dim]. + key: The key tensor, of shape [batch_size, num_heads, source_len, head_dim]. value: The value tensor, of shape [batch_size, num_heads, source_len, head_dim]. - bias: The attention biases, of shape [batch_size, num_heads, source_len, target_len]. + bias: The attention biases, of shape [batch_size, num_heads, target_len, source_len]. segment_ids: The id of which segment each token belongs to. Attention is not computed between tokens in different segments. - Shape: [batch_size, source_len]. + Shape: [batch_size, target_len]. mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. block_sizes: An object containing values that can be used to tune the performance such as the block size to chunk things into. + interpret: If True, interpret the kernel using the pallas interpreter. CPU needs it. Returns: - The context tensor, of shape [batch_size, num_heads, source_len, head_dim]. + The context tensor, of shape [batch_size, num_heads, target_len, head_dim]. Raises: NotImplementedError: If a custom (non-causal, non-full) mask is specified. @@ -216,6 +241,7 @@ def _legacy_tpu_flash_attention( softmax_scale=1.0, block_sizes=block_sizes, debug=False, + interpret=interpret, ) return context @@ -225,19 +251,103 @@ class SplashAttentionUnsupportedError(NotImplementedError): """An error indicating splash attention is not supported.""" +def check_tpu_splash_attention( + *, + query: Tensor, # [batch_size, num_heads, source_len, head_dim] + key: Tensor, # [batch_size, num_heads, target_len, head_dim] + mask: Union[MaskFnAttentionBias | ZeroAttentionBias], + has_segment_ids: bool = False, + has_bias: bool = False, +): + """Checks if splash attention is supported on TPU for the given arguments. + + Args: + query: The query tensor, of shape [batch_size, num_heads, target_len, head_dim]. + key: The key tensor, of shape [batch_size, num_heads, source_len, head_dim]. + mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. + has_segment_ids: Whether segment_ids is None or not. + has_bias: Whether attention involves a bias. + + Raises: + SplashAttentionUnsupportedError: If splash attention is not supported for the given + arguments. + """ + target_len = query.shape[2] + source_len = key.shape[2] + head_dim = query.shape[3] + + if has_bias: + return False # SplashAttention does not support specifying a bias. + with jax.ensure_compile_time_eval(): + if jnp.any( + jnp.asarray([target_len, source_len, head_dim]) % splash_attention_kernel.NUM_LANES != 0 + ): + raise SplashAttentionUnsupportedError( + "SplashAttention requires target_len, source_len, head_dim are divisible by" + f" {splash_attention_kernel.NUM_LANES}, got {target_len, source_len, head_dim}." + ) + if has_segment_ids: + raise SplashAttentionUnsupportedError( + "The public API for SplashAttention that we " + "currently use does not support segment ids." + ) + if mask.value() is not None: + assert isinstance(mask, MaskFnAttentionBias) + if target_len != source_len: + raise SplashAttentionUnsupportedError( + "Query and key/value must have same length when mask is used." + ) + if isinstance(mask.target_positions, jax.core.Tracer): + raise SplashAttentionUnsupportedError( + "Non-static value of `target_positions` is not supported.\n" + "Are you decoding using SplashAttention? That's not supported." + ) + + +def _to_splash_mask( + mask: Union[MaskFnAttentionBias | ZeroAttentionBias], + *, + mask_shape: tuple[int, int], + q_seq_shards: int = 1, +) -> splash_attention_mask.Mask: + """Converts a mask to a splash mask.""" + if mask.value() is None: + return splash_attention_mask.FullMask(mask_shape) + assert isinstance(mask, MaskFnAttentionBias) + if isinstance(mask, CausalAttentionBias): + return splash_attention_mask.CausalMask(shape=mask_shape, shard_count=q_seq_shards) + if hasattr(mask.mask, "_sliding_window_size"): + # TODO(dhwang2): introduce SlidingWindowAttentionBias instead of "_sliding_window_size". + # This is set in sliding_window_causal_mask(). + left_size = getattr(mask.mask, "_sliding_window_size") + return splash_attention_mask.LocalMask( + shape=mask_shape, window_size=(left_size, 0), offset=0, shard_count=q_seq_shards + ) + + with jax.ensure_compile_time_eval(): + # MaskFn always supports compile time eval. + mask_array = np.asarray(mask.bool_value()) + # Squeeze first two leading dimensions. + mask_array = mask_array.reshape(mask_array.shape[-2:]) + + # NumpyMask is backed by a dense [target_len, source_len] numpy array. + # May consume a large amount of host memory for long sequences at compile time. + return splash_attention_mask.NumpyMask(array=mask_array) + + @functools.partial( jax.jit, - static_argnames=["block_sizes"], + static_argnames=["block_sizes", "interpret"], ) def _tpu_splash_attention( query: Tensor, # [batch_size, num_heads, target_len, head_dim] key: Tensor, # [batch_size, num_heads, source_len, head_dim] value: Tensor, # [batch_size, num_heads, source_len, head_dim] - bias: Optional[Tensor] = None, # [batch_size, num_heads, target_len, source_len] - segment_ids: Optional[Tensor] = None, # [batch_size, target_len] *, - mask: Union[MaskFnAttentionBias, ZeroAttentionBias], + mask: Union[MaskFnAttentionBias | ZeroAttentionBias], + segment_ids: Optional[Tensor] = None, # [batch_size, target_len] block_sizes: Optional[splash_attention_kernel.BlockSizes] = None, + interpret: bool = False, ) -> Tensor: # [batch_size, num_heads, target_len, head_dim]. """Wraps JAX's sparse TPU flash-attention. @@ -245,13 +355,12 @@ def _tpu_splash_attention( query: The query tensor, of shape [batch_size, num_heads, target_len, head_dim]. key: The key tensor, of shape [batch_size, num_heads, source_len, head_dim]. value: The value tensor, of shape [batch_size, num_heads, source_len, head_dim]. - bias: The attention biases, of shape [batch_size, num_heads, target_len, source_len]. - segment_ids: The id of which segment each token belongs to. Attention is not computed - between tokens in different segments. - Shape: [batch_size, target_len]. mask: The mask to apply. This is more compute efficient compared to setting bias = -inf. + segment_ids: The id of which segment each token belongs to. Attention is not computed + between tokens in different segments, [batch_size, target_len]. block_sizes: An object containing values that can be used to tune the performance such as the block size to chunk things into. + interpret: If True, interpret the kernel using the pallas interpreter. CPU needs it. Returns: The context tensor, of shape [batch_size, num_heads, target_len, head_dim]. @@ -266,57 +375,19 @@ def _tpu_splash_attention( TypeError: If mask is not an instance of `MaskFnAttentionBias. """ - target_len = query.shape[2] - source_len = key.shape[2] + # TODO(dhwang2): splash attention can support segment_ids. Support it when needed. + del segment_ids num_heads = query.shape[1] - head_dim = query.shape[3] - - if bias is not None: - raise SplashAttentionUnsupportedError("SplashAttention does not support specifying a bias.") - with jax.ensure_compile_time_eval(): - if jnp.any( - jnp.asarray([target_len, source_len, head_dim]) % splash_attention_kernel.NUM_LANES != 0 - ): - raise SplashAttentionUnsupportedError( - "SplashAttention requires target_len, source_len, head_dim are divisible by" - f" {splash_attention_kernel.NUM_LANES}, got {target_len, source_len, head_dim}." - ) - if segment_ids is not None: - raise SplashAttentionUnsupportedError( - "The public API for SplashAttention that we " - "currently use does not support segment ids." - ) - if target_len != source_len and mask.value() is not None: - raise SplashAttentionUnsupportedError( - "Query and key/value must have same length when mask is used." - ) - if mask.value() is not None and not isinstance(mask, MaskFnAttentionBias): - raise TypeError(type(mask)) - if mask.value() is not None and isinstance(mask.target_positions, jax.core.Tracer): - raise SplashAttentionUnsupportedError( - "Non-static value of `target_positions` is not supported.\n" - "Are you decoding using SplashAttention? That's not supported." - ) - - mask_shape = (target_len, source_len) - if mask.value() is None: - mask = splash_attention_mask.FullMask(mask_shape) - else: - with jax.ensure_compile_time_eval(): - # MaskFn always supports compile time eval. - mask_array = np.asarray(mask.bool_value()) - # Squeeze first two leading dimensions. - mask_array = mask_array.reshape(mask_array.shape[-2:]) - - # NumpyMask is backed by a dense [target_len, source_len] numpy array. - # May consume a large amount of host memory for long sequences at compile time. - mask = splash_attention_mask.NumpyMask(array=mask_array) + mask_shape = (query.shape[2], key.shape[2]) + splash_mask = _to_splash_mask(mask, mask_shape=mask_shape) kernel = splash_attention_kernel.make_splash_mha( - mask=splash_attention_mask.MultiHeadMask(masks=[mask] * num_heads), + mask=splash_attention_mask.MultiHeadMask(masks=[splash_mask] * num_heads), block_sizes=block_sizes, + # TODO(dhwang2): support "seq" and "model" shard. head_shards=1, q_seq_shards=1, + interpret=interpret, ) kernel = jax.vmap(kernel) context = kernel(q=query, k=key, v=value) @@ -335,6 +406,7 @@ def _tpu_splash_attention( "softmax_scale", "block_sizes", "debug", + "interpret", ], ) def pallas_tpu_flash_attention( @@ -348,6 +420,7 @@ def pallas_tpu_flash_attention( softmax_scale: float = 1.0, block_sizes: Optional[LegacyBlockSizes] = None, debug: bool = False, + interpret: bool = False, ): batch_size, num_heads, q_seq_len, d_model = q.shape batch_size_k, num_heads_k, kv_seq_len, d_model_k = k.shape @@ -397,11 +470,11 @@ def pallas_tpu_flash_attention( batch_size, num_heads, q_seq_len, kv_seq_len, d_model ) return _flash_attention( - q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug + q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug, interpret ) -@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10)) +@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 11)) def _flash_attention( q, k, @@ -413,6 +486,7 @@ def _flash_attention( softmax_scale, block_sizes, debug, + interpret, ): return _flash_attention_impl( q, @@ -428,6 +502,7 @@ def _flash_attention( block_sizes.block_k_major, block_sizes.block_k, debug, + interpret, ) @@ -442,11 +517,12 @@ def _flash_attention_fwd( softmax_scale, block_sizes, debug, + interpret, ): if save_residuals: raise NotImplementedError("Higher-order AD not supported") o, l, m = _flash_attention( - q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug + q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug, interpret ) return o, (q, k, v, ab, segment_ids, o, l, m) @@ -457,6 +533,7 @@ def _flash_attention_bwd( softmax_scale: float, block_sizes: LegacyBlockSizes, debug: bool, + interpret: bool, residuals, do, ): @@ -491,6 +568,7 @@ def _flash_attention_bwd( causal=causal, mask_value=DEFAULT_MASK_VALUE, debug=debug, + interpret=interpret, ) dq, ds = _flash_attention_bwd_dq( @@ -510,6 +588,7 @@ def _flash_attention_bwd( causal=causal, mask_value=DEFAULT_MASK_VALUE, debug=debug, + interpret=interpret, ) return dq, dk, dv, ds, None @@ -531,6 +610,7 @@ def _flash_attention_impl( block_k_major, block_k, debug, + interpret, ): batch_size, num_heads, q_seq_len, head_dim = q.shape _, _, kv_seq_len, _ = k.shape @@ -693,6 +773,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index) ), out_shape=out_shape, debug=debug, + interpret=interpret, compiler_params=dict( mosaic=dict( dimension_semantics=( @@ -730,6 +811,7 @@ def _flash_attention_bwd_dkv( causal: bool = False, mask_value: float = DEFAULT_MASK_VALUE, debug: bool = False, + interpret: bool = False, ): batch_size, num_heads, q_seq_len, head_dim = q.shape _, _, kv_seq_len, _ = k.shape @@ -896,6 +978,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): ), out_shape=out_shapes, debug=debug, + interpret=interpret, compiler_params=dict( mosaic=dict( dimension_semantics=( @@ -930,6 +1013,7 @@ def _flash_attention_bwd_dq( causal: bool, mask_value: float, debug: bool, + interpret: bool, ): batch_size, num_heads, q_seq_len, head_dim = q.shape _, _, kv_seq_len, _ = k.shape @@ -1087,6 +1171,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index) ), out_shape=out_shapes, debug=debug, + interpret=interpret, compiler_params=dict( mosaic=dict( dimension_semantics=( diff --git a/axlearn/common/flash_attention/tpu_attention_benchmark.py b/axlearn/common/flash_attention/tpu_attention_benchmark.py index 379048bc..7e6400fa 100644 --- a/axlearn/common/flash_attention/tpu_attention_benchmark.py +++ b/axlearn/common/flash_attention/tpu_attention_benchmark.py @@ -3,34 +3,28 @@ """Benchmark TPU FlashAttention kernels. Sample outputs: (v5p) +CMD: python \ +/opt/venv/lib/python3.10/site-packages/axlearn/common/flash_attention/tpu_attention_benchmark.py \ +2>&1 | grep -E "Benchmarking|ref_|HBM usage" Benchmarking attention representative of 1.2b model layer on TPU v5. -ref_fwd:0.0008s, flash_fwd:0.0007s -ref_bwd:0.0027s, flash_bwd:0.0026s - - Benchmarking attention representative of 12.6b model layer on TPU v5. -ref_fwd:0.0012s, flash_fwd:0.0010s -ref_bwd:0.0037s, flash_bwd:0.0026s - - Benchmarking attention representative of 29.6b model layer on TPU v5. -ref_fwd:0.0017s, flash_fwd:0.0013s -ref_bwd:0.0053s, flash_bwd:0.0034s - - Benchmarking attention representative of 65.2b model layer on TPU v5. -ref_fwd:0.0021s, flash_fwd:0.0015s -ref_bwd:0.0067s, flash_bwd:0.0042s - - Benchmarking attention representative of 134b model layer on TPU v5. -ref_fwd:0.0024s, flash_fwd:0.0018s -ref_bwd:0.0080s, flash_bwd:0.0050s - - Benchmarking attention representative of 261.7b model layer on TPU v5. -ref_fwd:0.0027s, flash_fwd:0.0019s -ref_bwd:0.0092s, flash_bwd:0.0056s - - Benchmarking attention representative of 539.5b model layer on TPU v5. -ref_fwd:0.0034s, flash_fwd:0.0023s -ref_bwd:0.0126s, flash_bwd:0.0070s +ref_fwd:0.2291s, flash_fwd:0.0014s +ref_bwd:0.0217s, flash_bwd:0.0058s +Benchmarking attention representative of 12.6b model layer on TPU v5. +ref_fwd:0.5699s, flash_fwd:0.0032s +ref_bwd:0.0524s, flash_bwd:0.0152s +Benchmarking attention representative of 29.6b model layer on TPU v5. +ref_fwd:0.7957s, flash_fwd:0.0043s +ref_bwd:0.0731s, flash_bwd:0.0204s +Benchmarking attention representative of 65.2b model layer on TPU v5. +ref_fwd:1.0225s, flash_fwd:0.0055s +ref_bwd:0.0948s, flash_bwd:0.0262s +Benchmarking attention representative of 134b model layer on TPU v5. +ref_fwd:1.2485s, flash_fwd:0.0067s +ref_bwd:0.1159s, flash_bwd:0.0313s +Benchmarking attention representative of 261.7b model layer on TPU v5. +ref_fwd:1.5577s, flash_fwd:0.0072s +ref_bwd:0.1349s, flash_bwd:0.0373s """ import time from typing import Callable, Optional @@ -49,8 +43,8 @@ _BENCHMARK_CONFIGS = { "1.2b": dict( - num_heads=32, - per_head_dim=64, + num_heads=16, + per_head_dim=128, ), "12.6b": dict( num_heads=40, @@ -72,10 +66,11 @@ num_heads=110, per_head_dim=128, ), - "539.5b": dict( - num_heads=140, - per_head_dim=128, - ), + # OOM in mha_reference. + # "539.5b": dict( + # num_heads=140, + # per_head_dim=128, + # ), } @@ -167,7 +162,8 @@ def _benchmark( print(f"Benchmarking attention representative of {name} model layer on {device_kind}.") _benchmark( batch_size=2, - seq_len=2048, + seq_len=1024 * 8, block_size=4 * 128, + sliding_window_size=1024, **cfg, ) diff --git a/axlearn/common/flash_attention/tpu_attention_test.py b/axlearn/common/flash_attention/tpu_attention_test.py index 14c3dea7..f9a99c31 100644 --- a/axlearn/common/flash_attention/tpu_attention_test.py +++ b/axlearn/common/flash_attention/tpu_attention_test.py @@ -5,11 +5,12 @@ import unittest +import chex import jax import jax.numpy as jnp import numpy as np import pytest -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax.experimental import mesh_utils from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask from jax.experimental.shard_map import shard_map @@ -17,7 +18,9 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from axlearn.common.attention_bias import ( + CausalAttentionBias, MaskFnAttentionBias, + ZeroAttentionBias, causal_mask, sliding_window_causal_mask, ) @@ -26,10 +29,16 @@ from axlearn.common.test_utils import TestCase, is_supported_mesh_shape from axlearn.common.utils import Tensor +# Comment out to test on CPU manually. Technically, this test runs on the CPU, albeit very slowly. if jax.default_backend() != "tpu": pytest.skip(reason="Incompatible hardware", allow_module_level=True) +def setUpModule(): + # If on CPU, emulate 4 devices. + chex.set_n_cpu_devices(4) + + def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor: """A MaskFn that calls jax. @@ -73,21 +82,36 @@ def test_sliding_window_mask_equivalence(self, seq_len, sliding_window_size): for i in range(seq_len): self.assertNestedAllClose(ref_mask[i:, i:], test_mask[i:, i:]) + @parameterized.parameters( + [ZeroAttentionBias(), splash_attention_mask.FullMask((8, 8))], + [CausalAttentionBias(shape=(8, 8)), splash_attention_mask.CausalMask(shape=(8, 8))], + [ + MaskFnAttentionBias(sliding_window_causal_mask(4), shape=(8, 8)), + splash_attention_mask.LocalMask(shape=(8, 8), window_size=(4, 0), offset=0), + ], + ) + def test_to_splash_mask(self, mask, expected): + # pylint: disable-next=protected-access + splash_mask = tpu_attention._to_splash_mask(mask, mask_shape=(8, 8)) + self.assertEqual(splash_mask, expected) + @parameterized.product( batch_size=[4], - seq_len=[32768], + seq_len=[1024, 32768], + mask_fn=["zero", "causal", "sliding"], sliding_window_size=[1024], num_heads=[4], per_head_dim=[256], mesh=[(4, 1)], mesh_axis_names=[("data", "model")], ) - def test_sliding_window_mask( + def test_forward( self, batch_size, seq_len, num_heads, per_head_dim, + mask_fn, sliding_window_size, mesh, mesh_axis_names, @@ -121,12 +145,22 @@ def fn(q, k, v): ) softmax_scale = q.shape[-1] ** -0.5 - mask = MaskFnAttentionBias( - sliding_window_causal_mask(sliding_window_size), shape=(seq_len, seq_len) - ) + if mask_fn == "zero": + mask = ZeroAttentionBias() + elif mask_fn == "causal": + mask = CausalAttentionBias(shape=(seq_len, seq_len)) + elif mask_fn.startswith("sliding"): + mask = MaskFnAttentionBias( + sliding_window_causal_mask(sliding_window_size), shape=(seq_len, seq_len) + ) attn = lambda q, k, v: tpu_attention.tpu_flash_attention( - q, k, v, mask=mask, softmax_scale=softmax_scale + q, + k, + v, + mask=mask, + softmax_scale=softmax_scale, + interpret=(jax.default_backend() == "cpu"), ) partitioned_mha = shard_map( @@ -168,6 +202,9 @@ def test_forward_and_backward( attention_bias_type, with_segment_ids, ): + if jax.default_backend() == "cpu": + # TODO(dhwang2): this has been broken for a while on CPU. + pytest.skip(reason="Backward path is broken on CPU") # pylint: disable=protected-access causal = mask in [causal_mask, jax_fn_mask] @@ -224,7 +261,14 @@ def fn(q, k, v, bias, ids): ) with record_legacy_call: return tpu_attention.tpu_flash_attention( - q, k, v, bias, ids, mask=mask, softmax_scale=softmax_scale + q, + k, + v, + bias, + ids, + mask=mask, + softmax_scale=softmax_scale, + interpret=(jax.default_backend() == "cpu"), ) # Compare outputs. @@ -246,3 +290,7 @@ def fn(q, k, v, bias, ids): legacy_flash_wrapper.assert_called() else: legacy_flash_wrapper.assert_not_called() + + +if __name__ == "__main__": + absltest.main() diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 938a4708..3859f8b6 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -241,6 +241,8 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: ) elif backend == "tpu": + # TODO(dhwang2): splash attention supports GQA natively, so don't repeat it. + # https://github.com/jax-ml/jax/blob/7b9914d711593dca8725d46aa1dadb2194284519/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py#L934 key = _repeat_kv_heads(query.shape[2], key) value = _repeat_kv_heads(query.shape[2], value) # `mask` is supported.