diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_multi_queries_paged_attention_kernel.py similarity index 100% rename from test/test_tpu_paged_attention_kernel.py rename to test/test_multi_queries_paged_attention_kernel.py diff --git a/test/test_ragged_paged_attention_kernel.py b/test/test_ragged_paged_attention_kernel.py new file mode 100644 index 000000000000..95f3d937da44 --- /dev/null +++ b/test/test_ragged_paged_attention_kernel.py @@ -0,0 +1,434 @@ +from typing import List, Optional, Tuple + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +from torch_xla.experimental.pallas_kernels.ragged_paged_attention_kernel import ragged_paged_attention, make_sequence_metadata, DEFAULT_MASK_VALUE +import jax.numpy as jnp +import numpy as np + +jax.config.parse_flags_with_absl() + +ATOL_FP32 = 2e-1 + + +# https://github.com/vllm-project/flash-attention/blob/98a4f8df6f5f50413e03f102dc319690300d4aaf/tests/test_vllm_flash_attn.py#L22 +def _ref_ragged_paged_attention( + queries: jax.Array, # [num_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] + v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] + kv_lens: jax.Array, # i32[num_tokens] + page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] + cu_q_lens: jax.Array, # i32[num_tokens + 1] + num_seqs: int, +): + num_kv_heads, _, page_size, head_dim = k_pages.shape + num_q_heads = queries.shape[1] + assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." + num_query_per_kv = num_q_heads // num_kv_heads + start_idx = 0 + outputs: List[jax.Array] = [] + for i in range(num_seqs): + cur_q_len = cu_q_lens[i + 1] - cu_q_lens[i] + q = queries[start_idx:start_idx + + cur_q_len] # [cur_q_len, num_q_heads, head_dim] + + cur_kv_len = kv_lens[i] + num_pages = (cur_kv_len + page_size - 1) // page_size + page_indices_to_use = page_indices[i, :num_pages] + k = k_pages[:, + page_indices_to_use, :, :] # [num_kv_heads, page_indices_to_use, page_size, head_dim] + k = jnp.permute_dims( + k, (1, 2, 0, + 3)) # [page_indices_to_use, page_size, num_kv_heads, head_dim] + k = jnp.reshape( + k, (-1, num_kv_heads, head_dim)) # [kv_len, num_kv_heads, head_dim] + k = k[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim] + + v = v_pages[:, page_indices_to_use, :, :] + v = jnp.permute_dims(v, (1, 2, 0, 3)) + v = jnp.reshape(v, (-1, num_kv_heads, head_dim)) + v = v[:cur_kv_len] # [cur_kv_lens, num_kv_heads, head_dim] + + if num_query_per_kv != 1: + k = jnp.repeat(k, num_query_per_kv, axis=1) + v = jnp.repeat(v, num_query_per_kv, axis=1) + + attn = jnp.einsum("qhd,khd->hqk", q, k) + attn = attn.astype('float32') + q_span = (cur_kv_len - cur_q_len) + jax.lax.broadcasted_iota( + jnp.int32, (cur_q_len, cur_kv_len), 0) + kv_span = jax.lax.broadcasted_iota(jnp.int32, (cur_q_len, cur_kv_len), 1) + # Use the same DEFAULT_MASK_VALUE as in the kernel instead of float("-inf") so that the kernel can match the ref implement better. + mask = jnp.where(q_span < kv_span, DEFAULT_MASK_VALUE, 0.) + with jax.numpy_rank_promotion("allow"): + attn = attn + mask + attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype) + out = jnp.einsum("hqk,khd->qhd", attn, + v) # [cur_q_len, num_q_heads, head_dim] + + outputs.append(out) + start_idx += cur_q_len + + return jnp.concatenate(outputs, axis=0) + + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class RaggedPagedAttentionKernelTest(jtu.JaxTestCase): + + def _verify_ragged_paged_attention( + self, + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_queries_per_block=128, + ): + num_seqs = len(seq_lens) + # Make sure the q_len is no longer than the kv_len. For example, + # seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because + # the 3rd sequence has q_len(506) > kv_len(463). + for i in range(num_seqs): + cur_q_len = seq_lens[i][0] + cur_kv_len = seq_lens[i][1] + assert cur_q_len <= cur_kv_len, f"cur_q_len must be less than or equal to cur_kv_len. Got {cur_q_len} and {cur_kv_len}" + + query_lens = [seq_len[0] for seq_len in seq_lens] + num_q_tokens = sum(query_lens) + kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens]) + num_q_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." + + prng_key = jax.random.key(0) + k1, k2, k3, k4 = jax.random.split(prng_key, 4) + queries = jax.random.normal( + k1, (num_q_tokens, num_q_heads, head_dim), dtype=dtype) + k_pages = jax.random.normal( + k2, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) + v_pages = jax.random.normal( + k3, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) + + # Create a kv_lens: i32[num_tokens] + kv_lens_with_paddings = [0] * num_q_tokens + for i in range(num_seqs): + kv_lens_with_paddings[i] = kv_lens[i] + kv_lens_np = jnp.array(kv_lens_with_paddings) + + # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] + max_kv_len = max([seq_len[1] for seq_len in seq_lens]) + max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size + # The reason why we need to pad max_num_pages_per_seq is that + # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 + max_num_pages_per_seq = self._get_closest_power_of_two( + max_num_pages_per_seq) + # The assert below mimics the reality that each page get a unique index. + # But for testing, the assert could be omitted. + # assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" + page_indices = jax.random.randint( + k4, (num_q_tokens, max_num_pages_per_seq), + 0, + num_pages, + dtype=jnp.int32) + + # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] + q_lens_with_paddings = [0] * num_q_tokens + for i in range(num_seqs): + q_lens_with_paddings[i] = query_lens[i] + cu_q_lens = jnp.cumsum(jnp.array([0] + q_lens_with_paddings)) + + err, actual_output = ragged_paged_attention( + queries, + k_pages, + v_pages, + kv_lens_np, + page_indices, + cu_q_lens, + num_seqs, + num_queries_per_block=num_queries_per_block, + ) + err.throw() # noop if there is not err. + actual_output = jax.block_until_ready(actual_output) + + expected_output = _ref_ragged_paged_attention( + queries, + k_pages, + v_pages, + kv_lens_np, + page_indices, + cu_q_lens, + num_seqs, + ) + + self.assertEqual(actual_output.shape, expected_output.shape) + self.assertEqual(actual_output.dtype, expected_output.dtype) + + print( + f'Output max diff: {jnp.max(jnp.abs(expected_output - actual_output))}') + print( + f'Output mean diff: {jnp.mean(jnp.abs(expected_output - actual_output))}' + ) + if dtype == jnp.float32: + atol = 2e-1 + rtol = 1e-2 + elif dtype == jnp.bfloat16: + atol = 6e-1 + rtol = 1e-1 + else: + self.fail(f'Unsupported dtype: {dtype}') + self.assertTrue( + jnp.allclose(actual_output, expected_output, atol=atol, rtol=rtol)) + + def _get_closest_power_of_two(self, x): + if x <= 0: + raise ValueError(f"x must be positive. Got {x}") + return 2**int(np.ceil(np.log2(x))) + + def test_paged_attention_basic(self,): + # Same setup as in the design doc. + # assuming q_blk_size=128, page_size=16, num_kv_pages_per_compute_block=16 + # Note one of the constraints of the kernel is that q.shape[0]%q_blk_size==0 as in _calculate_num_tiles. + seq_lens = [(192, 328), (128, 180), (64, 255)] # [(q_len, kv_len),...] + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + @parameterized.product( + seq_lens=[[(1, 1328), (5, 18), (506, 563)]], + num_heads=[(4, 4), (8, 2), (16, 2)], + head_dim=[128, 256], + dtype=(jnp.float32, jnp.bfloat16), + page_size=[16, 32], + num_pages=[32768, 2048], + ) + def test_paged_attention_varlen_comprehensive( + self, + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_dim: int, + dtype, + page_size: int, + num_pages: int, + ): + # assuming q_blk_size=128 + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + def test_paged_attention_mix_prefill_and_decode1(self,): + # assuming q_blk_size=128 + seq_lens = [ + (1, 1328), + (5, 18), + (1, 129), + (120, 229), + (1, 122), # first physical q block + (1, 64), + (32, 100), + (250, 463), + (1, 18), + (1, 17), + (99, 123) + ] # last 3 physical q blocks [(q_len, kv_len),...] + num_heads = (4, 4) + head_dim = 128 + dtype = jnp.float32 + page_size = 16 + num_pages = 32768 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + def test_paged_attention_mix_prefill_and_decode2(self,): + # assuming q_blk_size=128 + seq_lens = [(1, 127), (120, 1328), (1, 64), (1, 64), (1, 64), (1, 64), + (256, 256), (131, 463)] # [(q_len, kv_len),...] + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + def test_paged_attention_extreme_all_tokens_belong_to_one_sequence(self,): + # assuming q_blk_size=128 + seq_lens = [(512, 1328)] # [(q_len, kv_len),...] + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + def test_paged_attention_extreme_one_tokens_per_sequence_min(self,): + seq_lens = [] # [(q_len, kv_len),...] + num_seqs = 64 + num_queries_per_block = 16 + for i in range(num_seqs): + seq_lens.append((1, 256 + i)) + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 1024 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + num_queries_per_block=num_queries_per_block, + ) + + def test_paged_attention_q_len_should_be_no_longer_than_kv_len(self,): + # assuming q_blk_size=128 + seq_lens = [(1, 0), (511, 256)] # [(q_len, kv_len),...] + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + num_seqs = len(seq_lens) + query_lens = [seq_len[0] for seq_len in seq_lens] + num_q_tokens = sum(query_lens) + kv_lens = jnp.array([seq_len[1] for seq_len in seq_lens]) + num_q_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_q_heads % num_kv_heads == 0, "num_q_heads % num_kv_heads !=0." + + prng_key = jax.random.key(0) + k1, k2, k3, k4 = jax.random.split(prng_key, 4) + queries = jax.random.normal( + k1, (num_q_tokens, num_q_heads, head_dim), dtype=dtype) + k_pages = jax.random.normal( + k2, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) + v_pages = jax.random.normal( + k3, (num_kv_heads, num_pages, page_size, head_dim), dtype=dtype) + + # Create a kv_lens: i32[num_tokens] + kv_lens_with_paddings = [0] * num_q_tokens + for i in range(num_seqs): + kv_lens_with_paddings[i] = kv_lens[i] + kv_lens_np = jnp.array(kv_lens_with_paddings) + + # Create a page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] + max_kv_len = max([seq_len[1] for seq_len in seq_lens]) + max_num_pages_per_seq = (max_kv_len + page_size - 1) // page_size + # The reason why we need to pad max_num_pages_per_seq is that + # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0 + max_num_pages_per_seq = self._get_closest_power_of_two( + max_num_pages_per_seq) + # The assert below mimics the reality that each page get a unique index. + # But for testing, the assert could be omitted. + assert max_num_pages_per_seq * num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}" + page_indices = jax.random.randint( + k4, (num_q_tokens, max_num_pages_per_seq), + 0, + num_pages, + dtype=jnp.int32) + + # Create a cu_q_lens: jax.Array, # i32[num_tokens + 1] + q_lens_with_paddings = [0] * num_q_tokens + for i in range(num_seqs): + q_lens_with_paddings[i] = query_lens[i] + cu_q_lens = jnp.cumsum(jnp.array([0] + q_lens_with_paddings)) + + with self.assertRaisesRegex( + ValueError, "cur_q_len must be less or equal to cur_kv_len"): + err, _ = ragged_paged_attention( + queries, + k_pages, + v_pages, + kv_lens_np, + page_indices, + cu_q_lens, + num_seqs, + ) + err.throw() + + def test_paged_attention_extreme_one_tokens_per_sequence_large(self,): + # assuming q_blk_size=128 + seq_lens = [] # [(q_len, kv_len),...] + num_seqs = 512 + for i in range(num_seqs): + seq_lens.append((1, 128 + i)) + num_heads = (1, 1) + head_dim = 128 + page_size = 16 + dtype = jnp.float32 + num_pages = 65536 + + self._verify_ragged_paged_attention( + seq_lens, + num_heads, + head_dim, + page_size, + dtype, + num_pages, + ) + + def test_make_sequence_metadata(self,): + cu_q_lens = jnp.array([0, 192, 448, 512] + [512] * (512 - 4)) + num_q_tokens = 512 + num_queries_per_compute_block = 128 + start_group = jnp.array([0]) + num_seqs = 3 + metadata, num_logical_q_tiles = make_sequence_metadata( + cu_q_lens=cu_q_lens, + m=num_q_tokens, + tm=num_queries_per_compute_block, + start_sequence=start_group, + num_sequences=num_seqs) + seq_ids, physical_q_tile_ids = metadata + self.assertEqual(num_logical_q_tiles, 6) + self.assertTrue(jnp.array_equal(seq_ids, [0, 0, 1, 1, 1, 2])) + self.assertTrue(jnp.array_equal(physical_q_tile_ids, [0, 1, 1, 2, 3, 3])) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index e429a782f6b2..d11c8eecc2b0 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -37,7 +37,8 @@ run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py" python3 "$TEST_CDIR/test_pallas.py" -v python3 "$TEST_CDIR/test_pallas_spmd.py" XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py" -python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py" +python3 "$TEST_CDIR/test_multi_queries_paged_attention_kernel.py" +python3 "$TEST_CDIR/test_ragged_paged_attention_kernel.py" python3 "$TEST_CDIR/test_input_output_aliases.py" python3 "$TEST_CDIR/test_gmm.py" python3 "$TEST_CDIR/eager/test_eager_spmd.py" diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py new file mode 100644 index 000000000000..cfcc04436969 --- /dev/null +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_kernel.py @@ -0,0 +1,923 @@ +from collections.abc import Sequence +from collections import namedtuple +import functools +from typing import Any, Literal, Optional, cast + +import jax +from jax import lax +from jax.experimental import checkify +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu.paged_attention import quantization_utils +import jax.numpy as jnp +import numpy as np + +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + + +class MultiPageAsyncCopyDescriptor: + """Descriptor for async copy of multiple K/V pages from HBM.""" + + def __init__( + self, + pages_hbm_ref, # [num_kv_heads, total_num_pages, page_size, head_dim] + scales_pages_hbm_ref, + vmem_buffer, # [pages_per_compute_block, page_size, head_dim] + scales_vmem_buffer, + sem, + page_indices, + page_indices_start_offset, + num_pages_to_load, + kv_head_index, + ): + # Original k_pages has shape [num_kv_heads, total_num_pages, page_size, head_dim] + self._vmem_buffer = vmem_buffer + self._scales_vmem_buffer = scales_vmem_buffer + self._num_pages_to_load = num_pages_to_load + if kv_head_index is not None: + self._pages_hbm_ref = pages_hbm_ref.at[kv_head_index] + if scales_pages_hbm_ref is not None: + self._scales_pages_hbm_ref = scales_pages_hbm_ref.at[kv_head_index] + else: + self._scales_pages_hbm_ref = None + else: + self._pages_hbm_ref = pages_hbm_ref + self._scales_pages_hbm_ref = scales_pages_hbm_ref + self._sem = sem + self._page_indices = page_indices + self._page_indices_start_offset = page_indices_start_offset + self._async_copies = [ + self._make_async_copy(i) for i in range(self._num_pages_to_load) + ] + if (self._scales_pages_hbm_ref is not None and + self._scales_vmem_buffer is not None): + self._async_copies += [ + self._make_scales_async_copy(i) + for i in range(self._num_pages_to_load) + ] + + def _make_async_copy(self, i): + page_index = self._page_indices[self._page_indices_start_offset + i] + return pltpu.make_async_copy(self._pages_hbm_ref.at[page_index], + self._vmem_buffer.at[i], self._sem) + + def _make_scales_async_copy(self, i): + page_index = self._page_indices[self._page_indices_start_offset + i] + return pltpu.make_async_copy( + self._scales_pages_hbm_ref.at[page_index], # pytype: disable=attribute-error + self._scales_vmem_buffer.at[i], # pytype: disable=attribute-error + self._sem, + ) + + def start(self): + """Starts the async copies.""" + for async_copy in self._async_copies: + async_copy.start() + + def _maybe_dequantize(self, x, x_scale, dtype=jnp.bfloat16): + if x_scale is None: + return x.astype(dtype) + return quantization_utils.from_int8(x, x_scale, dtype=dtype) + + def wait_and_get_loaded(self) -> jax.Array: + """Wait async copies and gets the loaded buffer as a jax.Array.""" + # Return value shape is [pages_per_compute_block*page_size, head_dim] + for async_copy in self._async_copies: + async_copy.wait() + head_dim = self._vmem_buffer.shape[-1] + jax_array = self._vmem_buffer[...].astype(jnp.float32) + if self._scales_vmem_buffer is not None: + scales_jax_array = self._scales_vmem_buffer[...].astype(jnp.float32) + else: + scales_jax_array = None + jax_array = self._maybe_dequantize(jax_array, scales_jax_array) + return jax_array.reshape(-1, head_dim) + + +def _calculate_num_tiles(x: int, tx: int) -> int: + tiles, rem = divmod(x, tx) + if rem: + raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") + return tiles + + +# https://github.com/jax-ml/jax/blob/9fb29766a2130e74a85cba30420cf777d185ea5a/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L79 +def make_sequence_metadata( + *, + cu_q_lens: jnp.ndarray, + m: int, + tm: int, + start_sequence: jnp.ndarray, + num_sequences: int, +): + """Create the metadata needed for ragged paged attention computation. + + Args: + cu_q_lens: : A 1d, jnp.ndarray with shape [num_seqs+1] and jnp.int32 dtype. + The cumulative query lengths. + m: The number of query tokens. + tm: The m-dimension tile size being used. + start_sequence: The sequence in cu_q_lens to start computing from. This is useful for when num_seqs is sharded. + num_sequences: The number of sequences to compute on. + + Returns: + tuple of: + seq_ids: A 1d, jnp.ndarray with shape [m_tiles + num_seqs] and + jnp.int32 dtype. seq_ids[i] indicates which sequence the grid index (num_logical_tiles_q) will work on. + physical_q_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_seqs] and + jnp.int32. physical_q_tile_ids[i] indicates which query-dim physical tile the grid index (num_logical_tiles_q) will work on. + + num_logical_q_tiles: The number of query-dim logical tiles to execute. + """ + end_sequence = start_sequence + num_sequences - 1 + + # We need the offset of each sequence from input, starting at zero. This metadata is + # similar to row offsets in a CSR matrix. The following properties hold: + # + # sequence_offsets.shape = [num_sequences + 1] + # sequence_offsets[0] = 0 + # sequence_offsets[num_sequences] = m + # + # The row at which sequence 'i' starts is sequence_offsets[i]. + sequence_ends = cu_q_lens[1:] + sequence_offsets = cu_q_lens + + # Assign a sequence id to each grid index. The grid index refers to the logical q tile index. + # + # If a sequence starts somewhere other than the start of a tile or ends somewhere + # other than the end of a tile we need to compute that full tile. Calculate + # the number of tiles for each sequence by rounding their end up to the nearest + # 'tm' and their start down to the nearest 'tm'. + + # (1) Round the sequence_ends up to the nearest multiple of 'tm'. + # + # NOTE: This does not change sequence_offsets[num_sequences], which is m + # (because we enforce m is divisible by tm). + rounded_sequence_ends = ((sequence_ends + tm - 1) // tm * tm).astype( + jnp.int32) + + # (2) Round the sequence_starts down to the nearest multiple of 'tm'. + sequence_starts = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), sequence_ends[:-1]]) + rounded_sequence_starts = sequence_starts // tm * tm + + # (3) Calculate the number of rows in each sequence. + rounded_sequence_sizes = rounded_sequence_ends - rounded_sequence_starts + + # (4) Convert the sequence sizes from units of rows to unit of 'tm' sized tiles. + # + # An m-dimension tile is 'owned' by sequence 'i' if the first row of the tile + # belongs to sequence 'i'. In addition to owned tiles, each sequence can have 0 or 1 + # initial partial tiles if it's first row does not occur in the first row of a + # tile. The '0-th' sequence never has a partial tile because it always starts at + # the 0-th row. + # + # If no sequence has a partial tile, the total number of tiles is equal to + # 'm // tm'. If every sequence has a partial except the 0-th sequence, the total + # number of tiles is equal to 'm // tm + num_sequences - 1'. Thus we know that + # + # tiles_m <= sequence_tiles.sum() <= tiles_m + num_sequences - 1 + # + # Where tiles_m = m // tm. + # + # NOTE: All sequence sizes are divisible by 'tm' because of the rounding in steps + # (1) and (2) so this division is exact. + sequence_tiles = rounded_sequence_sizes // tm + + # Create the sequence ids for each grid index based on the tile counts for each + # sequence. + # + # NOTE: This repeat(...) will pad sequence_ids with the final sequence id if + # sequence_tiles.sum() < tiles_m + num_sequences - 1. The kernel grid will be sized + # such that we only execute the necessary number of tiles. + tiles_m = _calculate_num_tiles(m, tm) + sequence_ids = jnp.repeat( + jnp.arange(num_sequences, dtype=jnp.int32), + sequence_tiles[:num_sequences], + total_repeat_length=tiles_m + num_sequences - 1, + ) + + # Assign an m-dimension tile id to each grid index. + # + # NOTE: Output tiles can only be re-visited consecutively. The following + # procedure guarantees that m-dimension tile indices respect this. + + # (1) Calculate how many times each m-dimension tile will be visited. + # + # Each tile is guaranteed to be visited once by the sequence that owns the tile. + # The remaining possible visits occur when a sequence starts inside of a tile at + # a position other than the first row. We can calculate which m-dimension tile + # each sequence starts in by floor-dividing its offset with `tm` and then count + # tile visits with a histogram. + # + # To avoid double counting tile visits from the sequence that owns the tile, + # filter these out by assigning their tile id to `tile_m` (one beyond the max) + # such that they're ignored by the subsequent histogram. + # + partial_tile_mask = ((sequence_offsets[:-1] % tm) == 0) + + partial_tile_ids = jnp.where(partial_tile_mask, tiles_m, + sequence_offsets[:-1] // tm) + + tile_visits = ( + jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + + 1) + + # Create the m-dimension tile ids for each grid index based on the visit + # counts for each tile. + m_tile_ids = jnp.repeat( + jnp.arange(tiles_m, dtype=jnp.int32), + tile_visits.astype(jnp.int32), + total_repeat_length=tiles_m + num_sequences - 1, + ) + + # Account for sharding. + # + # Find the start of the sequences owned by our shard and shift the sequence_ids and + # m_tile_ids s.t. the metadata for our tiles are at the front of the arrays. + # + first_tile_in_shard = (sequence_ids < start_sequence).sum() + sequence_ids = jnp.roll(sequence_ids, shift=-first_tile_in_shard, axis=0) + m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0) + + # Calculate the number of tiles we need to compute for our shard. + # + # Remove tile visits that belong to a sequence not in our shard. + iota = jnp.arange(num_sequences, dtype=jnp.int32) + active_sequence_mask = jnp.logical_and(iota <= end_sequence, + iota >= start_sequence) + sequence_tiles = jnp.where(active_sequence_mask, + sequence_tiles[:num_sequences], 0) + num_tiles = sequence_tiles.sum() + return (sequence_ids, m_tile_ids + ), num_tiles # (seq_ids, physical_q_tile_ids), num_logical_q_tiles + + +def check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, + num_seqs, num_kv_pages_per_block): + num_q_heads, num_tokens, head_dim = q.shape + num_kv_heads, _, _, head_dim_k = k_pages.shape + _, pages_per_sequence = page_indices.shape + if k_pages.shape != v_pages.shape: + raise ValueError( + f"k_pages and v_pages must have the same shape. Got {k_pages.shape} and" + f" {v_pages.shape}" # pytype: disable=attribute-error + ) + if head_dim_k != head_dim: + raise ValueError("head_dim of Q must be the same as that of K/V. Got" + f" {head_dim} and {head_dim_k}.") + if kv_lens.shape[0] != num_tokens: + raise ValueError("kv_lens.shape[0] must be the same as num_tokens. Got" + f" {kv_lens.shape[0]} and {num_tokens}") + if page_indices.shape[0] != num_tokens: + raise ValueError("page_indices.shape[0] must be the same as num_tokens. Got" + f" {page_indices.shape[0]} and {num_tokens}") + if cu_q_lens.shape[0] != num_tokens + 1: + raise ValueError( + "cu_q_lens.shape[0] must be the same as num_tokens + 1. Got" + f" {cu_q_lens.shape[0]} and {num_tokens + 1}") + for i in range(num_seqs): + cur_q_len = cu_q_lens[i + 1] - cu_q_lens[i] + cur_kv_len = kv_lens[i] + checkify.check( + cur_q_len <= cur_kv_len, + "cur_q_len must be less or equal to cur_kv_len. Got {} and {}", + cur_q_len, cur_kv_len) + if num_seqs > num_tokens: + raise ValueError( + f"num_seqs must be less or equal to num_tokens. Got {num_seqs} and {num_tokens}" + ) + if kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or cu_q_lens.dtype != jnp.int32: + raise ValueError( + f"The dtype of `lengths` must be int32. Got {kv_lens.dtype=}, " + f"{page_indices.dtype=}, {cu_q_lens.dtype=}") + if num_kv_pages_per_block > pages_per_sequence: + raise ValueError( + f"{num_kv_pages_per_block=} should be smaller or equal to {pages_per_sequence=}" + ) + if pages_per_sequence % num_kv_pages_per_block != 0: + raise ValueError( + "pages_per_sequence must be divisible by num_kv_pages_per_block. Got" + f" {pages_per_sequence=} and {num_kv_pages_per_block=}.") + if num_q_heads % num_kv_heads != 0: + raise ValueError( + "Number of Q heads must be divisible by number of KV heads. Got" + f" {num_q_heads} and {num_kv_heads}.") + + +# https://github.com/jax-ml/jax/blob/e3b3b913f7bcec3767e1442ace08999413f8703d/jax/experimental/pallas/ops/tpu/megablox/gmm.py#L269C1-L283C64 +def _get_store_mask( + *, + logical_q_blk_idx: jnp.ndarray, + sequence_offsets: jnp.ndarray, + sequence_ids: jnp.ndarray, + physical_q_tile_ids: jnp.ndarray, + tq: int, + tk: int, +) -> jnp.ndarray: + """Mask for rows that belong to the current sequence in the current physical q tile.""" + sequence_id = sequence_ids[logical_q_blk_idx] + sequence_start = sequence_offsets[sequence_id] + sequence_end = sequence_offsets[sequence_id + 1] + physical_q_tile_id = physical_q_tile_ids[logical_q_blk_idx] * tq + iota = jax.lax.broadcasted_iota(jnp.int32, (tq, tk), 0) + physical_q_tile_id + return jnp.logical_and(iota >= sequence_start, iota < sequence_end) + + +def _flash_attention( + q_head_idx_per_kv, # scalar, ranges from 0 to num_query_heads_per_kv_head + sequence_metadata_ref, # (seq_ids, physical_q_tile_ids) + effective_kv_lens_ref, # [num_tokens] + effective_cu_q_lens_ref, # [num_tokens + 1] + # kernel inputs + q_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + k, # [kv_blk_size, head_dim] + v, # [kv_blk_size, head_dim] + # outputs + o_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + l_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + m_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + # scratch space + l_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + m_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + acc_scratch_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + *, + num_tokens: int, + num_seqs: int, + num_kv_pages_per_block: int, + num_queries_per_block: int, + mask_value: float, + page_size: int, + head_dim: int, + num_q_heads_per_kv_head: int, +): + assert q_ref.shape == (num_q_heads_per_kv_head, num_queries_per_block, + head_dim) + kv_blk_size = page_size * num_kv_pages_per_block + assert k.shape == (kv_blk_size, head_dim) + assert v.shape == (kv_blk_size, head_dim) + + kv_head_idx, logical_q_blk_idx, kv_blk_idx = ( + pl.program_id(0), + pl.program_id(1), + pl.program_id(2), + ) + seq_ids, physical_q_tile_ids = sequence_metadata_ref + + # If the q-dim physical tile is changed (meaning it is a new physical q-dim tile that has not visited before), initialize the acc_scratch_ref, m_scratch_ref, and l_scratch_ref to run the flash attention v2 algorithm. + prev_logical_q_blk_idx = jnp.where(logical_q_blk_idx > 0, + logical_q_blk_idx - 1, 0) + is_first_processed_logical_q_blk = logical_q_blk_idx == 0 + physical_q_blk_changed = ( + physical_q_tile_ids[logical_q_blk_idx] != + physical_q_tile_ids[prev_logical_q_blk_idx]) + first_time_seeing_physical_q_blk = jnp.logical_or( + is_first_processed_logical_q_blk, physical_q_blk_changed) + is_first_kv_blk = (kv_blk_idx == 0) + should_init_scratch_ref = jnp.logical_and(is_first_kv_blk, + first_time_seeing_physical_q_blk) + + @pl.when(should_init_scratch_ref) + def init_scratch_ref(): # pylint: disable=unused-variable + l_scratch_ref[q_head_idx_per_kv] = jnp.zeros( + l_scratch_ref[q_head_idx_per_kv].shape, jnp.float32) + m_scratch_ref[q_head_idx_per_kv] = jnp.full( + m_scratch_ref[q_head_idx_per_kv].shape, -jnp.inf, jnp.float32) + acc_scratch_ref[q_head_idx_per_kv] = jnp.zeros( + acc_scratch_ref[q_head_idx_per_kv].shape, jnp.float32) + + m_prev = m_scratch_ref[ + q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] + l_prev = l_scratch_ref[ + q_head_idx_per_kv] # [num_queries_per_block, MIN_BLOCK_SIZE] + + # Load the whole q_block that belongs to the current physical q_blk and compute the attention. When we write, we only write the part that belongs to the current sequence. + # Cannot just load only the part of q_block that belongs to the current sequence, because it results in dynamic shapes and then fails the JIT compilation. + # Note, q_ref.shape=[num_q_heads_per_kv_head, num_queries_per_block, head_dim] + q = q_ref[q_head_idx_per_kv, :, :].astype(jnp.float32) # [block_q, head_dim] + assert q.shape == (num_queries_per_block, head_dim) + s = jnp.einsum( + 'qd,td->qt', q, k, + preferred_element_type=jnp.float32) # [block_q, block_k] + assert s.shape == (num_queries_per_block, kv_blk_size) + + # Modify the mask accordingly: first form the mask. Then move the mask up/down to the right place. + cur_seq_idx = seq_ids[logical_q_blk_idx] + cur_seq_start = effective_cu_q_lens_ref[cur_seq_idx] + cur_seq_end = effective_cu_q_lens_ref[cur_seq_idx + 1] + physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] + q_index = physical_q_blk_idx * num_queries_per_block - cur_seq_start + kv_index = kv_blk_idx * kv_blk_size + effective_kv_len = effective_kv_lens_ref[cur_seq_idx] + effective_q_len = cur_seq_end - cur_seq_start + row_ids = (effective_kv_len - + effective_q_len) + q_index + jax.lax.broadcasted_iota( + jnp.int32, (num_queries_per_block, kv_blk_size), 0) + col_ids = kv_index + jax.lax.broadcasted_iota( + jnp.int32, (num_queries_per_block, kv_blk_size), 1) + causal_mask = jnp.where(row_ids < col_ids, mask_value, 0.) + assert causal_mask.shape == (num_queries_per_block, kv_blk_size) + + s = s + causal_mask # [block_q, block_k] + + m_curr = jnp.max(s, axis=1)[:, None] # Row max, shape [block_q, 1]. + m_next = jnp.maximum(m_prev, m_curr) # Shape [block_q, 128]. + + block_k_repeats, rem = divmod(kv_blk_size, MIN_BLOCK_SIZE) + if rem: + raise NotImplementedError( + f"{kv_blk_size=} should be a multiple of {MIN_BLOCK_SIZE}") + p = jnp.exp( + s - pltpu.repeat(m_next, block_k_repeats, 1)) # Shape [block_q, block_k] + + alpha = jnp.exp(m_prev - m_next) # Shape [block_q, 128] + + l_corr = alpha * l_prev # Shape [block_q, 128] + + l_next = jnp.sum(p, axis=1)[:, None] + l_corr # Shape [block_q, 128] + + head_dim_repeats, rem = divmod(head_dim, MIN_BLOCK_SIZE) + l_broadcast = lambda l: pltpu.repeat(l, head_dim_repeats, 1) + if rem: + if head_dim_repeats == 0: + l_broadcast = lambda l: l[:, :head_dim] + else: + raise NotImplementedError( + f"{head_dim=} should be a multiple of {MIN_BLOCK_SIZE} if larger") + + # Need to store these l_next and m_next which will relay to the output. + # But only update the part that belongs to the current sequence we are working on. + lm_mask = _get_store_mask( + logical_q_blk_idx=logical_q_blk_idx, + sequence_offsets=effective_cu_q_lens_ref, + sequence_ids=seq_ids, + physical_q_tile_ids=physical_q_tile_ids, + tq=num_queries_per_block, + tk=MIN_BLOCK_SIZE, + ) + # Either jax.lax.select or jnp.where works here. + pl.store( + l_scratch_ref, + (pl.ds(q_head_idx_per_kv, 1), slice(None), slice(None)), + l_next.reshape(1, *l_next.shape), # no-op here. + mask=lm_mask.reshape(1, *lm_mask.shape), + ) + pl.store( + m_scratch_ref, + (pl.ds(q_head_idx_per_kv, 1), slice(None), slice(None)), + m_next.reshape(1, *m_next.shape), + mask=lm_mask.reshape(1, *lm_mask.shape), + ) + + l_next_inv_safe = jnp.where(l_next == 0.0, 1.0, + 1.0 / l_next) # [block_q, 128] + temp = acc_scratch_ref[q_head_idx_per_kv] * l_broadcast( + l_corr * l_next_inv_safe) + o_curr = jax.lax.dot( + p.astype(v.dtype), v, + preferred_element_type=jnp.float32) # [block_q, 128] + temp += o_curr * l_broadcast(l_next_inv_safe) + acc_mask = _get_store_mask( + logical_q_blk_idx=logical_q_blk_idx, + sequence_offsets=effective_cu_q_lens_ref, + sequence_ids=seq_ids, + physical_q_tile_ids=physical_q_tile_ids, + tq=num_queries_per_block, + tk=head_dim, + ) + pl.store( + acc_scratch_ref, + (pl.ds(q_head_idx_per_kv, 1), slice(None), slice(None)), + temp.reshape(1, *temp.shape), + mask=acc_mask.reshape(1, *acc_mask.shape), + ) + + # Store the result from VMEM to HBM only when it is the last kv_block and the next q-dim logical tile belongs to a different q-dim physical tile. + is_last_kv_blk_idx = ( + kv_blk_idx == (pl.cdiv(effective_kv_len, kv_blk_size) - 1)) + num_logical_q_blks = pl.num_programs( + 1) # grid=(num_kv_heads, num_logical_q_tiles, num_kv_blks) + next_logical_q_blk_idx = jnp.where( + logical_q_blk_idx == num_logical_q_blks - 1, logical_q_blk_idx, + logical_q_blk_idx + 1) + is_last_logical_q_blk = (logical_q_blk_idx == num_logical_q_blks - 1) + physical_q_blk_will_change = ( + physical_q_tile_ids[logical_q_blk_idx] != + physical_q_tile_ids[next_logical_q_blk_idx]) + last_time_seeing_cur_physical_q_blk = jnp.logical_or( + is_last_logical_q_blk, physical_q_blk_will_change) + should_store_to_output = jnp.logical_and(is_last_kv_blk_idx, + last_time_seeing_cur_physical_q_blk) + + @pl.when(should_store_to_output) + def store_to_output(): # pylint: disable=unused-variable + o_ref[q_head_idx_per_kv] = acc_scratch_ref[q_head_idx_per_kv].astype( + o_ref.dtype) + l_ref[q_head_idx_per_kv] = l_scratch_ref[q_head_idx_per_kv].astype( + l_ref.dtype) + m_ref[q_head_idx_per_kv] = m_scratch_ref[q_head_idx_per_kv].astype( + m_ref.dtype) + + +def paged_flash_attention_kernel( + # prefetch refs + sequence_metadata_ref, # (seq_ids, physical_q_tile_ids) + effective_kv_lens_ref, # [num_tokens] + # 1d vector, results from page_indices.reshape(-1) where originally page_indices.shape=[num_tokens, pages_per_sequence] + page_indices_1d_ref, + effective_cu_q_lens_ref, # [num_tokens + 1] + buffer_index_ref, + step_ref, + # kernel inputs + # At caller, q.shape= [num_q_heads, num_tokens, head_dim] + q_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + k_pages_hbm_ref, # [num_kv_heads, total_num_pages, page_size, head_dim] + k_scales_pages_hbm_ref, + v_pages_hbm_ref, # [num_kv_heads, total_num_pages, page_size, head_dim] + v_scales_pages_hbm_ref, + # same shape as q_ref: [1, num_q_heads_per_kv_head, num_queries_per_block, head_dim], output + # outputs + o_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + l_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + m_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + # scratch space + k_vmem_buffer, # (2, num_kv_pages_per_block, num_kv_heads, head_dim) + k_scales_vmem_buffer, + v_vmem_buffer, # (2, num_kv_pages_per_block, num_kv_heads, head_dim) + v_scales_vmem_buffer, + sem, + l_scratch_ref, + m_scratch_ref, + acc_scratch_ref, + *, + # The following parameters are not passed to Mosaic and not in SMEM. They are static values. + pages_per_sequence: int, # Note [bs, pages_per_sequence] = page_indices.shape + num_tokens: int, + num_seqs: int, + num_kv_pages_per_block: int, + mask_value: float, +): + kv_head_idx, logical_q_blk_idx, kv_blk_idx = ( + pl.program_id(0), + pl.program_id(1), + pl.program_id(2), + ) + num_logical_q_blks = pl.num_programs(1) + num_q_heads_per_kv_head, num_queries_per_block, head_dim = q_ref.shape + num_kv_heads, total_num_pages, page_size, head_dim = k_pages_hbm_ref.shape + kv_blk_size = page_size * num_kv_pages_per_block + + seq_ids, physical_q_tile_ids = sequence_metadata_ref + cur_seq_idx = seq_ids[logical_q_blk_idx] + effective_kv_len_cur_seq = effective_kv_lens_ref[cur_seq_idx] + should_run = (kv_blk_idx * kv_blk_size < effective_kv_len_cur_seq) + + @pl.when(should_run) + def get_kv_and_run_flash_attention(): + # grid = (num_kv_heads, num_logical_q_tiles, num_kv_blks) + def compute_block_indices(kv_head_idx, logical_q_blk_idx, kv_blk_idx): + """Return next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx + + Note, k_pages has shape [num_kv_heads, total_num_pages, page_size, head_dim]. + To get the KV, it needs the kv_head_idx, then we need the sequence_idx + and the kv_blk_idx to get the offset. + """ + + def advance_kv_head_idx(): + next_kv_head_idx = kv_head_idx + 1 + return next_kv_head_idx, 0, 0 + + def advance_logical_q_blk_idx(): + next_logical_q_blk_idx = logical_q_blk_idx + 1 + return lax.cond( + next_logical_q_blk_idx < num_logical_q_blks, + lambda: (kv_head_idx, next_logical_q_blk_idx, 0), + advance_kv_head_idx, + ) + + cur_seq_idx = seq_ids[logical_q_blk_idx] + effective_kv_len_cur_seq = effective_kv_lens_ref[cur_seq_idx] + return lax.cond( + kv_blk_idx * kv_blk_size < effective_kv_len_cur_seq, + lambda: (kv_head_idx, logical_q_blk_idx, kv_blk_idx), + advance_logical_q_blk_idx, + ) + + def create_kv_async_copy_descriptors(seq_idx, kv_head_idx, kv_blk_idx, + buffer_index): + page_offset = seq_idx * pages_per_sequence + kv_blk_idx * num_kv_pages_per_block + pages_to_load = num_kv_pages_per_block + async_copy_k = MultiPageAsyncCopyDescriptor( + k_pages_hbm_ref, + k_scales_pages_hbm_ref, + k_vmem_buffer.at[buffer_index], + k_scales_vmem_buffer.at[buffer_index] + if k_scales_vmem_buffer is not None else None, + sem, + page_indices_1d_ref, # [batch_size*pages_per_sequence] + page_offset, + pages_to_load, + kv_head_idx, + ) + async_copy_v = MultiPageAsyncCopyDescriptor( + v_pages_hbm_ref, + v_scales_pages_hbm_ref, + v_vmem_buffer.at[buffer_index], + v_scales_vmem_buffer.at[buffer_index] + if v_scales_vmem_buffer is not None else None, + sem, + page_indices_1d_ref, + page_offset, + pages_to_load, + kv_head_idx, + ) + return async_copy_k, async_copy_v + + step = step_ref[0] + buffer_index = buffer_index_ref[0] + + @pl.when(step == 0) + def prefetch_first_block(): # pylint: disable=unused-variable + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + cur_seq_idx, kv_head_idx, kv_blk_idx, buffer_index) + async_copy_k.start() + async_copy_v.start() + + next_kv_head_idx, next_logical_q_blk_idx, next_kv_blk_idx = compute_block_indices( + kv_head_idx, logical_q_blk_idx, kv_blk_idx + 1) + + @pl.when(next_kv_head_idx < num_kv_heads) + def prefetch_next_block(): # pylint: disable=unused-variable + next_buffer_index = jnp.where(buffer_index == 0, 1, 0) + next_seq_idx = seq_ids[next_logical_q_blk_idx] + async_copy_next_k, async_copy_next_v = create_kv_async_copy_descriptors( + next_seq_idx, next_kv_head_idx, next_kv_blk_idx, next_buffer_index) + async_copy_next_k.start() + async_copy_next_v.start() + buffer_index_ref[0] = next_buffer_index + + async_copy_k, async_copy_v = create_kv_async_copy_descriptors( + cur_seq_idx, kv_head_idx, kv_blk_idx, buffer_index) + k = async_copy_k.wait_and_get_loaded( + ) # [pages_per_compute_block*page_size,head_dim] + v = async_copy_v.wait_and_get_loaded() + assert k.shape == (num_kv_pages_per_block * page_size, head_dim) + assert v.shape == (num_kv_pages_per_block * page_size, head_dim) + + for q_head_idx in range(num_q_heads_per_kv_head): + _flash_attention( + q_head_idx, + sequence_metadata_ref, + effective_kv_lens_ref, + effective_cu_q_lens_ref, + # kernel inputs + q_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + k, + v, + # outputs + o_ref, # [num_q_heads_per_kv_head, num_queries_per_block, head_dim] + l_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + m_ref, # [num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE] + # scratch space + l_scratch_ref, + m_scratch_ref, + acc_scratch_ref, + num_tokens=num_tokens, + num_seqs=num_seqs, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + mask_value=mask_value, + page_size=page_size, + head_dim=head_dim, + num_q_heads_per_kv_head=num_q_heads_per_kv_head, + ) + step_ref[0] = step + 1 + # end of get_kv_and_run_flash_attention + + +MIN_BLOCK_SIZE = 128 + + +@checkify.checkify +@functools.partial( + jax.jit, + static_argnames=[ + "num_kv_pages_per_block", + "num_queries_per_block", + "mask_value", + "num_seqs", + ], +) +def ragged_paged_attention( + q: jax.Array, # [num_tokens, num_q_heads, head_dim] + k_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] + v_pages: jax.Array, # [num_kv_heads, total_num_pages, page_size, head_dim] + kv_lens: jax.Array, # i32[num_tokens] + page_indices: jax.Array, # i32[num_tokens, pages_per_sequence] + cu_q_lens: jax.Array, # i32[num_tokens + 1] + num_seqs, # int + *, + mask_value: float = DEFAULT_MASK_VALUE, + num_kv_pages_per_block: int = 16, + num_queries_per_block: int = 128, +) -> jax.Array: + """Paged attention kernel with ragged input. + + Args: + q: A [num_tokens, num_q_heads, head_dim] jax.Array. + k_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. + v_pages: A [num_kv_heads, total_num_pages, page_size, head_dim] jax.Array. + kv_lens: A i32[num_tokens] jax.Array the effective kv length of each + sequence. For example, if we have three sequences, lengths could be + [16, 3, 1024, x, x, x, x, ...] where x is any value for padding. While + lengths's shape is [num_tokens], only the first num_seqs values are valid. + The rest should be ignored. + page_indices: A i32[num_tokens, pages_per_sequence] jax.Array. Each entry + should be in the range of [0, total_num_pages), indicating where to locate + the page in `k_pages` or `v_pages`. Similar to kv_lens, only the first + num_seqs values are valid. + cu_q_lens: A i32[num_tokens+1] jax.Array the cumulative sum of the effective + query lengths. Similar to kv_lens, only the first num_seqs+1 values are + valid. + num_seqs: the number of sequences. + mask_value: The value used for padding in attention. By default it is a very + negative floating point number. + num_kv_pages_per_block: how many kv pages to be processed in one flash + attention block in the pallas kernel. + num_queries_per_block: how many queries to be processes in one flash + attention block in the pallas kernel. + + The num_tokens, num_seqs, and pages_per_sequence are dynamic. If they are + very dynamic, then the overhead could be high due to the recompilation. + + Returns: + The output of attention([num_tokens, num_q_heads, head_dim]). + """ + # TODO: consider remove the k_scales_pages and v_scales_pages during cleaning up. + if isinstance(k_pages, quantization_utils.QuantizedTensor): + k_pages, k_scales_pages = k_pages.weight, k_pages.scales + assert isinstance(k_scales_pages, jax.Array) # For typing. + k_scales_pages = jnp.broadcast_to( + k_scales_pages, (*k_scales_pages.shape[:-1], k_pages.shape[-1])) + else: + k_scales_pages = None + if isinstance(v_pages, quantization_utils.QuantizedTensor): + v_pages, v_scales_pages = v_pages.weight, v_pages.scales + assert isinstance(v_scales_pages, jax.Array) # For typing. + v_scales_pages = jnp.broadcast_to( + v_scales_pages, (*v_scales_pages.shape[:-1], v_pages.shape[-1])) + else: + v_scales_pages = None + + num_tokens, num_q_heads, head_dim = q.shape + # Why the permute_dims is needed? Before permute, q.shape=[num_tokens, num_q_heads, head_dim]; then when we apply the GridSpec, the 2nd last dimension is num_q_heads which is hard to be a multiple of 8. + # If permute_dims turns out to be expensive, try jnp.swapaxes. The compiler + # may optimize the copies away. + # Or consider unsqueeze a dimension at the 2nd last dimension and squeeze it + # out later so that num_q_heads doesn't have to be the 2nd last dimension and hence doesn't subject to the multiple of 8 constraint. + q = jnp.permute_dims(q, (1, 0, 2)) # [num_q_heads, num_tokens, head_dim] + num_kv_heads, total_num_pages, page_size, head_dim = k_pages.shape + check_kernel_input(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, + num_seqs, num_kv_pages_per_block) + num_q_heads_per_kv_head = num_q_heads // num_kv_heads + + sequence_metadata, num_logical_q_tiles = make_sequence_metadata( + cu_q_lens=cu_q_lens, + m=num_tokens, + tm=num_queries_per_block, + start_sequence=jnp.array([0]), + num_sequences=num_seqs, + ) + + pages_per_sequence = page_indices.shape[1] + num_kv_blks = pages_per_sequence // num_kv_pages_per_block + grid = (num_kv_heads, num_logical_q_tiles, num_kv_blks) + + # out_shape + o_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) + l = jax.ShapeDtypeStruct((num_q_heads, num_tokens, MIN_BLOCK_SIZE), + dtype=jnp.float32) + m = jax.ShapeDtypeStruct((num_q_heads, num_tokens, MIN_BLOCK_SIZE), + dtype=jnp.float32) + out_shape = (o_shape, l, m) + + # in-spec. Note currently q.shape=[num_q_heads, num_tokens, head_dim] + # Within the kernel, q.shape should be [num_q_heads_per_kv_head, q_block_size, head_dim] + def qo_index_map(kv_head_idx, logical_q_blk_idx, kv_blk_idx, + sequence_metadata, *_): + seq_ids, physical_q_tile_ids = sequence_metadata + del seq_ids + physical_q_blk_idx = physical_q_tile_ids[logical_q_blk_idx] + return (kv_head_idx, physical_q_blk_idx, 0) + + q_block_spec = pl.BlockSpec( + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), + qo_index_map, + ) + in_specs = [ + q_block_spec, + # Below 4 correspond to the 4 input: k_pages, k_scales_pages, q_pages, q_scales_pages. + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + None, + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + None, + ] + + # out_spec + # o_specs should be the same as q_block_spec + o_specs = q_block_spec + # lm_index_map is same as qo_index_map + lm_index_map = qo_index_map + out_specs = [ + o_specs, + pl.BlockSpec( + (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), + lm_index_map), # l + pl.BlockSpec( + (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), + lm_index_map), # m + ] + + # scratch space. Note k_pages.shape=[num_kv_heads, total_num_pages, page_size, head_dim] + l_scratch = pltpu.VMEM( + (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), + jnp.float32) + m_scratch = pltpu.VMEM( + (num_q_heads_per_kv_head, num_queries_per_block, MIN_BLOCK_SIZE), + jnp.float32) + acc_scratch = pltpu.VMEM( + (num_q_heads_per_kv_head, num_queries_per_block, head_dim), jnp.float32) + scratch_shapes = [ + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + num_kv_pages_per_block, + page_size, + head_dim, + ), + k_pages.dtype, + ), # k_pages buffer, k_pages.shape=[num_kv_heads, total_num_pages, page_size, head_dim] + None, # k_scales_pages=None + pltpu.VMEM( + ( + 2, # For double buffering during DMA copies. + num_kv_pages_per_block, + page_size, + head_dim, + ), + v_pages.dtype, + ), # v_pages buffer + None, # v_scales_pages=None + pltpu.SemaphoreType.DMA, + l_scratch, + m_scratch, + acc_scratch, + ] + + kernel = pl.pallas_call( + functools.partial( + paged_flash_attention_kernel, + pages_per_sequence=pages_per_sequence, + num_tokens=num_tokens, + num_seqs=num_seqs, + num_kv_pages_per_block=num_kv_pages_per_block, + mask_value=mask_value, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=6, + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + scratch_shapes=scratch_shapes, + ), + compiler_params=pltpu.TPUCompilerParams( + # due to compute_block_indices, we loop kv_head, q_blk, kv_blk, the order matters. + dimension_semantics=( + "arbitrary", + "arbitrary", + "arbitrary", + )), + out_shape=out_shape, + ) + # TODO: need to slice the page_indices later to avoid the SMEM OOM. + page_indices_1d = page_indices.reshape(-1) + buffer_index = jnp.zeros((1,), jnp.int32) + step = jnp.zeros((1,), jnp.int32) + + outputs = kernel( + # prefetch + sequence_metadata, + kv_lens, + page_indices_1d, + cu_q_lens, + buffer_index, + step, + # kernel inputs + q, + k_pages, + k_scales_pages, + v_pages, + v_scales_pages, + ) + ret = outputs[0] + return jnp.permute_dims(ret, (1, 0, 2)).astype(q.dtype)