From a1fe0d5eb416369a64b5606a10964d62b80beef8 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Mon, 6 Jan 2025 12:41:55 -0800 Subject: [PATCH 1/3] Support cross attention and dropout --- .../common/flash_attention/gpu_attention.py | 932 ++++++++---------- .../gpu_attention_benchmark.py | 138 +-- .../flash_attention/gpu_attention_test.py | 270 +++-- .../common/flash_attention/gpu_decoding.py | 4 +- axlearn/common/flash_attention/layer.py | 28 +- axlearn/common/flash_attention/layer_test.py | 60 +- axlearn/common/flash_attention/utils.py | 37 +- axlearn/common/layers.py | 39 +- 8 files changed, 807 insertions(+), 701 deletions(-) diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 43e62d8b..5f04e71b 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -12,19 +12,21 @@ """Implements FlashAttention for GPU in JAX with logit bias support. -This implementation follows the original closely: -https://github.com/HazyResearch/flash-attention/blob/9818f85fee29ac6b60c9214bce841f8109a18b1b/flash_attn/flash_attn_triton.py -https://github.com/google/jax/blob/jaxlib-v0.4.25/jax/experimental/pallas/ops/attention.py - -As well as the original paper: https://arxiv.org/abs/2205.14135 - -Due to the caveats mentioned in the above link, we make several simplifying assumptions: -* Sequence length is a multiple of block size (128). -* No dropout is applied. -* 4-d bias tensor is supported. -* Currently only tested on A100/H100. +This implementation is ported from +https://github.com/jax-ml/jax/blob/ed4e9823b19591f8a4c98b1f895c284775d6e0c7/jax/experimental/pallas/ops/gpu/attention.py +and follows the original papers closely: +FlashAttention: https://arxiv.org/abs/2205.14135 +FlashAttention2: https://arxiv.org/abs/2307.08691 + +Caveats of this implementation: +* Sequence length must be a multiple of block size (128). +* Only tested on A100/H100. + +Compared to the implementation in the JAX repo, we made the following enhancements: +* Support kv_seq_len != q_seq_len. +* Support 2D/4D bias. +* Support dropout. """ -# pylint: disable=wrong-import-position,missing-param-doc,differing-param-doc import functools from collections.abc import Sequence from typing import Any, Optional @@ -34,19 +36,28 @@ from jax import lax from jax._src.cudnn.fused_attention_stablehlo import MaskType, dot_product_attention from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu -from axlearn.common.attention_bias import NEG_INF +from axlearn.common.attention import NEG_INF +from axlearn.common.layers import get_dropout_mask Tensor = jax.Array +class NoPopDict(dict): + """A dict that doesn't delete after pop. + + Used to workaround https://github.com/jax-ml/jax/issues/25714. + """ + + def pop(self, *args, **kwargs): + return super().get(*args, **kwargs) + + def _segment_mask( q_segment_ids: Tensor, kv_segment_ids: Tensor, ): - """ - Build the segment mask for the given query and key bias ids. + """Build the segment mask for the given query and key bias ids. If mask[..., i, j] == True, query position i and key position j are in the same segment. @@ -59,20 +70,20 @@ def _segment_mask( def _mha_forward_kernel( - # Inputs. q_ref, k_ref, v_ref, b_ref, s_ref, + dropout_mask_ref, # Outputs. o_ref, # Residual outputs. *residual_refs, softmax_scale: float, causal: bool, + dropout_rate: float, block_q: int, - block_d: int, block_k: int, ): """Computes attention outputs for the given block. @@ -91,99 +102,100 @@ def _mha_forward_kernel( s_ref: Input segment_ids ref. o_ref: Output ref. *residual_refs: Residual output refs, e.g. softmax statistics. - softmax_scale: Softmax scale. - causal: Whether to apply causal mask. - block_q: Block size for query seq dim. - block_d: Block size for head dim. - block_k: Block size for key seq dim. + **kwargs: See `flash_attention`. """ - seq_len = q_ref.shape[0] + kv_seq_len = k_ref.shape[0] + block_d = q_ref.shape[-1] start_q = pl.program_id(0) + precision = ( + lax.Precision.HIGHEST + if jnp.float32 in (q_ref.dtype, k_ref.dtype, v_ref.dtype) + else lax.Precision.DEFAULT + ) - # acc is the buffer where we accumulate the output on sram. + # o is the buffer where we accumulate the output on sram. # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. - m_i = jnp.zeros(block_q, dtype=jnp.float32) + NEG_INF + m_i = jnp.full(block_q, NEG_INF, dtype=jnp.float32) l_i = jnp.zeros(block_q, dtype=jnp.float32) # acc is the buffer where we accumulate the output on sram. - acc = jnp.zeros((block_q, block_d), dtype=jnp.float32) + o = jnp.zeros((block_q, block_d), dtype=jnp.float32) # Load q: it will stay in L1 throughout. Indices form a matrix because we # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. # q tile has shape [block_q, block_d], block_d == head_dim. curr_q_slice = pl.dslice(start_q * block_q, block_q) - q = pl.load(q_ref, (curr_q_slice, pl.dslice(None))) - - # Effectively a segment id for padding mask. - if s_ref is not None: - q_segment_ids = pl.load(s_ref, (curr_q_slice,)) + q = q_ref[...] + q_segment_ids = None if s_ref is None else pl.load(s_ref, (curr_q_slice,)) # In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size - # Bc == block_k here), and fast over blocks of q (size Br == block_q here). - # Here we only loop over blocks of kv to process entire seq_len, the loop over + # (Bc == block_k here), and fast over blocks of q (size Br == block_q here). + # Here we only loop over blocks of kv to process entire kv_seq_len, the loop over # blocks of q is carried out by the grid. def body(start_k, carry): - acc, m_prev, l_prev = carry - # This is slow loop over kv, essentially a scan through. + o_prev, m_prev, l_prev = carry curr_k_slice = pl.dslice(start_k * block_k, block_k) - k = pl.load(k_ref, (curr_k_slice, pl.dslice(None))) - qk = pl.dot(q, k.T) # [block_q, block_k]. + + k = pl.load(k_ref, (curr_k_slice, slice(None))) + qk = pl.dot(q, k.T, precision=precision) # [block_q, block_k]. if softmax_scale != 1.0: qk *= softmax_scale # [block_q, block_k]. - if b_ref is not None: - b = pl.load( - b_ref, - (curr_q_slice, curr_k_slice), - ) - qk += b - - if s_ref is not None: - kv_segment_ids = pl.load(s_ref, (curr_k_slice,)) - mask = _segment_mask(q_segment_ids, kv_segment_ids) + qk += pl.load(b_ref, (slice(None), curr_k_slice)) + qk = jnp.maximum(qk, NEG_INF) + + if causal or s_ref is not None: + mask = None + if s_ref is not None: + kv_segment_ids = pl.load(s_ref, (curr_k_slice,)) + mask = _segment_mask(q_segment_ids, kv_segment_ids) + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + span_k = start_k * block_k + jnp.arange(block_k) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + # Apply mask to qk. qk = jnp.where(mask, qk, NEG_INF) - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - span_k = start_k * block_k + jnp.arange(block_k) - mask = span_q[:, None] >= span_k[None, :] - qk = jnp.where(mask, qk, NEG_INF) - - # Bring closer to XLA:GPU numerics. - # These casts are needed to avoid precision issues. - qk = qk.astype(jnp.float32) m_curr = qk.max(axis=-1) - m_curr = jnp.maximum(m_curr, m_prev) - l_prev *= jnp.exp(m_prev - m_curr) - p = jnp.exp(qk - m_curr[:, None]) - l_curr = jnp.sum(p, axis=1) + l_prev - l_rcp = 1.0 / l_curr - p = p * l_rcp[:, None] - acc_prev = (l_prev * l_rcp)[:, None] * acc - + m_next = jnp.maximum(m_prev, m_curr) + correction = jnp.exp(m_prev - m_next) + l_prev_corr = correction * l_prev + s_curr = jnp.exp( + qk - m_next[:, None] + ) # Use m_next instead of m_curr to avoid a correction on l_curr + l_curr = s_curr.sum(axis=-1) + l_next = l_prev_corr + l_curr + o_prev_corr = correction[:, None] * o_prev v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d))) - acc_curr = pl.dot(p.astype(v.dtype), v) - acc_next = acc_prev + acc_curr - return acc_next, m_curr, l_curr + if dropout_rate > 0: + dropout_mask = pl.load(dropout_mask_ref, (slice(None), curr_k_slice)) + s_curr = jnp.where(dropout_mask, 0, s_curr / (1 - dropout_rate)) + o_curr = pl.dot(s_curr.astype(v.dtype), v, precision=precision) + + o_next = o_prev_corr + o_curr + return o_next, m_next, l_next if causal: - upper_bound = lax.div(block_q * start_q, block_k) + 1 + upper_bound = jnp.minimum( + lax.div((start_q + 1) * block_q, block_k), pl.cdiv(kv_seq_len, block_k) + ) else: - upper_bound = pl.cdiv(seq_len, block_k) - acc, m_i, l_i = lax.fori_loop(0, upper_bound, body, (acc, m_i, l_i)) + upper_bound = pl.cdiv(kv_seq_len, block_k) + o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) - if residual_refs: - l_ref, m_ref = residual_refs - pl.store(l_ref, (curr_q_slice,), l_i) - pl.store(m_ref, (curr_q_slice,), m_i) + # We keep an unscaled version of o during the scan over kv_seq_len. Scaling it + # by the last l_i gives us the correct final output. See section 3.1.1 in the + # FlashAttention-2 paper: https://arxiv.org/pdf/2307.08691. + o /= l_i[:, None] + if residual_refs: + lse_ref = residual_refs[0] + lse_ref[...] = m_i + jnp.log(l_i) # Write output to dram. - acc = acc.astype(o_ref.dtype) - pl.store(o_ref, (curr_q_slice, pl.dslice(None)), acc) + o_ref[...] = o.astype(o_ref.dtype) -# TODO(kelvin-zou): may decide to deprecate the triton backend if we can fully move to -# more low-level CUDA kernels. -@functools.partial(jax.custom_vjp, nondiff_argnums=[5, 6, 7, 8, 9, 10, 11, 12, 13, 14]) +@functools.partial(jax.custom_vjp, nondiff_argnums=[6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) @functools.partial( jax.jit, static_argnames=[ @@ -191,12 +203,13 @@ def body(start_k, carry): "causal", "block_q", "block_k", - "backward_pass_impl", "num_warps", "num_stages", "grid", "interpret", "debug", + "dropout_rate", + "output_activations", ], ) def flash_attention( @@ -205,16 +218,19 @@ def flash_attention( value: Tensor, bias: Optional[Tensor] = None, segment_ids: Optional[Tensor] = None, + prng_key: Optional[Tensor] = None, softmax_scale: float = 1.0, causal: bool = False, + dropout_rate: float = 0.0, block_q: int = 128, block_k: int = 128, - backward_pass_impl: str = "triton", num_warps: Optional[int] = None, num_stages: Optional[int] = None, grid: Optional[Sequence[int]] = None, interpret: bool = False, debug: bool = False, + # output_activations has to be the last arg for custom vjp to work. + output_activations: bool = False, ): """Computes attention outputs following FlashAttention. @@ -226,579 +242,429 @@ def flash_attention( value: Value of shape [batch_size, source_length, num_heads, per_head_dim]. bias: Optional logit biases of shape [batch_size, num_heads, target_length, source_length]. segment_ids: Optional segment ids of shape [batch_size, target_length]. + prng_key: PRNG key used for dropout. Must be specified when dropout_rate > 0.0. softmax_scale: Optional scale to apply to softmax. Defaults to 1. causal: Whether to apply causal mask. + dropout_rate: Dropout rate. Default to 0.0 (no dropout). + output_activations: Whether to output activations for backward. Default to False. **kwargs: Pallas/triton kwargs. Returns: The attention outputs of shape [batch_size, target_length, num_heads, per_head_dim]. """ - del backward_pass_impl - # Configure the grid and triton kernel specs. - batch_size, seq_len, num_heads, head_dim = query.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = query.shape + kv_seq_len = key.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) + assert q_seq_len % block_q == 0 + assert kv_seq_len % block_k == 0 # Heuristics. grid_ = grid if grid_ is None: - grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) - # Bias. - bias_block_spec = None - if bias is not None: - assert bias.ndim == 4 - - def bias_index_map(_, j, k): - return (j if bias.shape[0] != 1 else 0, k if bias.shape[1] != 1 else 0, 0, 0) - - bias_block_spec = pl.BlockSpec( - index_map=bias_index_map, block_shape=(None, None, seq_len, seq_len) - ) - # Segment Ids - segment_ids_block_spec = None - if segment_ids is not None: - assert segment_ids.ndim == 2 - segment_ids_block_spec = pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0)), block_shape=(None, seq_len) - ) - - num_warps_ = num_warps - if num_warps_ is None: - num_warps_ = 4 if head_dim <= 64 else 8 - num_stages_ = num_stages - if num_stages_ is None: - num_stages_ = ( + grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) + if num_stages is None: + num_stages = ( 2 if bias is None and jnp.float32 not in (query.dtype, key.dtype, value.dtype) else 1 ) + if num_warps is None: + num_warps = 4 if head_dim <= 64 else 8 kernel = functools.partial( _mha_forward_kernel, softmax_scale=softmax_scale, causal=causal, + dropout_rate=dropout_rate, block_q=block_q, block_k=block_k, - block_d=head_dim, ) - out_shape = jax.ShapeDtypeStruct(shape=query.shape, dtype=query.dtype) - - return pl.pallas_call( - kernel, - grid=grid_, - in_specs=[ - pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), # query - pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), # key - pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), # value - bias_block_spec, # bias - segment_ids_block_spec, # segment_ids - ], - out_specs=pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim) - ), - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps_, num_stages=num_stages_), - out_shape=out_shape, - debug=debug, - interpret=interpret, - name="mha_forward", - )(query, key, value, bias, segment_ids) - - -def _mha_forward( - query: Tensor, - key: Tensor, - value: Tensor, - bias: Optional[Tensor], - segment_ids: Optional[Tensor], - softmax_scale: float, - causal: bool, - block_q: int, - block_k: int, - backward_pass_impl: str, - num_warps: Optional[int], - num_stages: int, - grid: Any, - interpret: bool, - debug: bool, -): - """Calls `_mha_forward_kernel`.""" - del backward_pass_impl - # Configure the grid and triton kernel specs. - batch_size, seq_len, num_heads, head_dim = query.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) - # Heuristics. - grid_ = grid - if grid_ is None: - grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) - - # Bias. - bias_block_spec = None + out_shape = jax.ShapeDtypeStruct(shape=query.shape, dtype=query.dtype) # out + in_specs = [ + pl.BlockSpec((None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), + pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), + ] if bias is not None: assert bias.ndim == 4 - - def bias_index_map(_, j, k): - return (j if bias.shape[0] != 1 else 0, k if bias.shape[1] != 1 else 0, 0, 0) - - bias_block_spec = pl.BlockSpec( - index_map=bias_index_map, block_shape=(None, None, seq_len, seq_len) + in_specs.append( + pl.BlockSpec( + index_map=lambda i, j, k: ( + j if bias.shape[0] != 1 else 0, + k if bias.shape[1] != 1 else 0, + i, + 0, + ), + block_shape=(None, None, block_q, kv_seq_len), + ) ) - - # Segment Ids. - segment_ids_block_spec = None - if segment_ids is not None: - assert segment_ids.ndim == 2 - segment_ids_block_spec = pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0)), block_shape=(None, seq_len) + else: + in_specs.append(None) + in_specs.append( + None if segment_ids is None else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) + ) + if dropout_rate > 0: + assert dropout_rate < 1 + assert prng_key is not None + # TODO(hanzhi-zhou): Switch to in-kernel RNG when pallas supports it. + dropout_mask = get_dropout_mask( + (batch_size, num_heads, q_seq_len, kv_seq_len), prng_key=prng_key, rate=dropout_rate ) - - num_warps_ = num_warps - if num_warps_ is None: - num_warps_ = 4 if head_dim <= 64 else 8 - num_stages_ = num_stages - if num_stages_ is None: - num_stages_ = ( - 2 if bias is None and jnp.float32 not in (query.dtype, key.dtype, value.dtype) else 1 + in_specs.append( + pl.BlockSpec((None, None, block_q, kv_seq_len), lambda i, j, k: (j, k, i, 0)) ) - kernel = functools.partial( - _mha_forward_kernel, - softmax_scale=softmax_scale, - causal=causal, - block_q=block_q, - block_k=block_k, - block_d=head_dim, - ) - out_shape = [ - jax.ShapeDtypeStruct(shape=query.shape, dtype=query.dtype), # out - jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), dtype=jnp.float32), # l - jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), dtype=jnp.float32), # m - ] - - out, l, m = pl.pallas_call( + else: + dropout_mask = None + in_specs.append(None) + out_specs = pl.BlockSpec((None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)) + if output_activations: + out_specs = [out_specs, pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i))] + out_shape = [ + out_shape, + jax.ShapeDtypeStruct( + shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 + ), # lse + ] + pallas_out = pl.pallas_call( kernel, grid=grid_, - in_specs=[ - pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), # query - pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), # key - pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), # value - bias_block_spec, # bias - segment_ids_block_spec, # segment_ids - ], - out_specs=[ - pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), - pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), - pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), - ], - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps_, num_stages=num_stages_), + in_specs=in_specs, + out_specs=out_specs, + compiler_params=NoPopDict(triton=NoPopDict(num_warps=num_warps, num_stages=num_stages)), out_shape=out_shape, debug=debug, interpret=interpret, name="mha_forward", - )(query, key, value, bias, segment_ids) - return out, (query, key, value, bias, segment_ids, out, l, m) + )(query, key, value, bias, segment_ids, dropout_mask) + if output_activations: + out, lse = pallas_out + return out, (query, key, value, bias, segment_ids, prng_key, out, lse) + return pallas_out -def _preprocess_backward_kernel( - out_ref, - dout_ref, - l_ref, - new_dout_ref, - delta_ref, - *, - block_q: int, -): - """Precomputes Di for the attention backwards pass. +def _mha_forward(*args: Any): + """Wraps flash_attention for custom vjp.""" + return flash_attention(*args[:-1], output_activations=True) - This optimization is described in https://arxiv.org/abs/2205.14135 Appendix B.4 observation 2. - """ - pid_m = pl.program_id(0) - - off_m = pl.ds(pid_m * block_q, block_q) - # Load. - o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32) - do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32) - denom = pl.load(l_ref, (off_m,)).astype(jnp.float32) - # Compute. - do = do / denom[:, None] - delta = jnp.sum(o * do, axis=1) - # Write-back. - pl.store(new_dout_ref, (off_m, slice(None)), do.astype(new_dout_ref.dtype)) - pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype)) - - -@jax.named_scope("preprocess_backward") -def _preprocess_backward( - out, - do, - l, - block_q: int, - debug: bool, - interpret: bool, -): - """Calls `_preprocess_backward_kernel`.""" - batch_size, seq_len, num_heads, head_dim = out.shape - out_shape = [ - jax.ShapeDtypeStruct(do.shape, do.dtype), - jax.ShapeDtypeStruct(l.shape, l.dtype), - ] - do_scaled, delta = pl.pallas_call( - functools.partial(_preprocess_backward_kernel, block_q=block_q), - grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), - in_specs=[ - pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), - pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), - pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), - ], - out_specs=[ - pl.BlockSpec( - index_map=(lambda _, j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), - pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), - ], - compiler_params=plgpu.TritonCompilerParams(num_warps=4, num_stages=3), - out_shape=out_shape, - debug=debug, - interpret=interpret, - name="mha_preprocess_backward", - )(out, do, l) - return do_scaled, delta - -def _mha_backward_kernel( +def _mha_backward_kernel_dkdv( # Inputs. q_ref, k_ref, v_ref, b_ref, s_ref, - out_ref, + dropout_mask_ref, do_scaled_ref, - l_ref, - m_ref, + lse_ref, delta_ref, # Outputs. - dq_ref, dk_ref, dv_ref, *, softmax_scale: float, causal: bool, + dropout_rate: float, block_q: int, - block_d: int, block_k: int, ): - """Computes the backward pass. - - This algorithm is described in https://arxiv.org/abs/2205.14135 Appendix B.4 Algorithm 4. - Jax reference implementation: - https://github.com/jax-ml/jax/blob/0995bc231c51e2ee66995be8ee2b31adf9236509/jax/experimental/pallas/ops/gpu/attention.py#L343 - - See also `_mha_forward_kernel` for the forward pass. - - The main difference between ours and jax reference implementation is that it supports 4-d bias, - and it supports float32 in the input dtype. - - Args: - q_ref: Input query ref. - k_ref: Input key ref. - v_ref: Input value ref. - b_ref: Input bias ref. - s_ref: Input segment_ids ref. - out_ref: Input forward output ref. - do_scaled_ref: Preprocessed dOut ref. See `_preprocess_backward_kernel`. - l_ref: Input l ref. - m_ref: Input m ref. - delta_ref: Input delta ref. See `_preprocess_backward_kernel`. - dq_ref: Output dQuery ref. - dk_ref: Output dKey ref. - dv_ref: Output dValue ref. - softmax_scale: Softmax scale. - bias_type: Type of bias matrix. - causal: Whether to apply causal mask. - block_q: Block size for query seq dim. - block_d: Block size for head dim. - block_k: Block size for key seq dim. + """Computes dK and dV. + 1. Load a block of K and V of size (block_k, head_dim) in SMEM. + 2. Iterate through Q in chunks of (block_q, head_dim) to accumulate + dK and dV. """ - del out_ref, l_ref # Not needed - seq_len = q_ref.shape[0] + q_seq_len = q_ref.shape[0] + block_d = q_ref.shape[-1] + precision = ( + lax.Precision.HIGHEST + if jnp.float32 in (q_ref.dtype, k_ref.dtype, v_ref.dtype) + else lax.Precision.DEFAULT + ) - # Parallelize over k/v's seq dimension. - # Load a block of K and V of size (block_k, block_d). - # Iterate through Q in chunks of (block_q, block_d) to accumulate dK and dV. start_k = pl.program_id(2) - slice_k = pl.ds(start_k * block_k, block_k) + curr_k_slice = pl.dslice(start_k * block_k, block_k) + dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) - k = pl.load(k_ref, (slice_k, slice(None))) - v = pl.load(v_ref, (slice_k, slice(None))) + + v = pl.load(v_ref, (curr_k_slice, slice(None))) + k = pl.load(k_ref, (curr_k_slice, slice(None))) span_k = start_k * block_k + jnp.arange(block_k) - kv_segment_ids = None if s_ref is None else pl.load(s_ref, (slice_k,)) + kv_segment_ids = None if s_ref is None else pl.load(s_ref, (curr_k_slice,)) - def inner_loop_dk_dv(start_q, carry): + def inner_loop_dkdv(start_q, carry): dv, dk = carry - slice_q = pl.ds(start_q * block_q, block_q) - q = pl.load(q_ref, (slice_q, slice(None))) - qk = pl.dot(q, k.T) - # These casts are needed to avoid precision issues. - qk = qk.astype(jnp.float32) + curr_q_slice = pl.dslice(start_q * block_q, block_q) + q = pl.load(q_ref, (curr_q_slice, slice(None))) + qk = pl.dot(q, k.T, precision=precision) # type: ignore if softmax_scale != 1.0: qk *= softmax_scale - if b_ref is not None: - # Load bias in transposed order, for hopefully better cache efficiency. - b = pl.load( - b_ref, - (slice_k, slice_q), - ) - b = b.astype(jnp.float32) - qk += b.T # Transpose back. - if s_ref is not None: - q_segment_ids = pl.load(s_ref, (slice_q,)) - mask = _segment_mask(q_segment_ids, kv_segment_ids) + qk += pl.load(b_ref, (curr_q_slice, curr_k_slice)) + qk = jnp.maximum(qk, NEG_INF) + + if causal or s_ref is not None: + mask = None + if s_ref is not None: + q_segment_ids = pl.load(s_ref, (curr_q_slice,)) + mask = _segment_mask(q_segment_ids, kv_segment_ids) + + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) qk = jnp.where(mask, qk, NEG_INF) - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - mask = span_q[:, None] >= span_k[None, :] - qk = jnp.where(mask, qk, NEG_INF) - m = pl.load(m_ref, (slice_q,)) - p = jnp.exp(qk - m[:, None]) - do = pl.load(do_scaled_ref, (slice_q, slice(None))) - dv = dv + pl.dot(p.astype(do.dtype).T, do) - di = pl.load(delta_ref, (slice_q,)) - dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] - dp = dp + pl.dot(do, v.T) + + lse = pl.load(lse_ref, (curr_q_slice,)) + di = pl.load(delta_ref, (curr_q_slice,)) + do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + + p = p_dropped = jnp.exp(qk - lse[:, None]) + dp = dp_dropped = pl.dot(do, v.T, precision=precision) # type: ignore + if dropout_rate > 0: + dropout_mask = pl.load(dropout_mask_ref, (curr_q_slice, curr_k_slice)) + p_dropped = jnp.where(dropout_mask, 0, p / (1 - dropout_rate)) + dp = jnp.where(dropout_mask, 0, dp_dropped / (1 - dropout_rate)) + dv = dv + pl.dot(p_dropped.astype(do.dtype).T, do, precision=precision) + dp = dp - di[:, None] ds = p * dp if softmax_scale != 1.0: ds = ds * softmax_scale - dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) + dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q, precision=precision) return dv, dk lower_bound = lax.div(start_k * block_k, block_q) if causal else 0 - dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop_dk_dv, (dv, dk)) - pl.store(dv_ref, (slice_k, slice(None)), dv.astype(dv_ref.dtype)) - pl.store(dk_ref, (slice_k, slice(None)), dk.astype(dk_ref.dtype)) - # Free up memory. - del dv, dk - - # Parallelize over q's seq dimension. - # 1. Load a block of Q of size (block_q, block_d). - # 2. Iterate through K and V in chunks of (block_k, block_d) to accumulate dQ. + dv, dk = lax.fori_loop(lower_bound, pl.cdiv(q_seq_len, block_q), inner_loop_dkdv, (dv, dk)) + pl.store(dv_ref, (curr_k_slice, slice(None)), dv.astype(dv_ref.dtype)) + pl.store(dk_ref, (curr_k_slice, slice(None)), dk.astype(dk_ref.dtype)) + + +# This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence +# length. Inspired by the triton tutorial: +# https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py +def _mha_backward_kernel_dq( + # Inputs. + q_ref, + k_ref, + v_ref, + b_ref, + s_ref, + dropout_mask_ref, + do_scaled_ref, + lse_ref, + delta_ref, + # Outputs. + dq_ref, + *, + softmax_scale: float, + causal: bool, + dropout_rate: float, + block_q: int, + block_k: int, +): + """Computes dQ. + 1. Load a block of Q of size (block_q, head_dim) in SMEM. + 2. Iterate through K and V in chunks of (block_k, head_dim) to + accumulate dQ. + """ + kv_seq_len = k_ref.shape[0] + block_d = q_ref.shape[-1] + precision = ( + lax.Precision.HIGHEST + if jnp.float32 in (q_ref.dtype, k_ref.dtype, v_ref.dtype) + else lax.Precision.DEFAULT + ) + start_q = pl.program_id(2) - slice_q = pl.ds(start_q * block_q, block_q) - q = pl.load(q_ref, (slice_q, slice(None))) - dq = jnp.zeros([block_q, block_d], dtype=jnp.float32) - q_segment_ids = None if s_ref is None else pl.load(s_ref, (slice_q,)) + curr_q_slice = pl.ds(start_q * block_q, block_q) span_q = start_q * block_q + jnp.arange(block_q) - m = pl.load(m_ref, (slice_q,)) - di = pl.load(delta_ref, (slice_q,)) - do = pl.load(do_scaled_ref, (slice_q, slice(None))) - - def inner_loop_dq(start_k, carry): - dq = carry - slice_k = pl.ds(start_k * block_k, block_k) - k = pl.load(k_ref, (slice_k, slice(None))) - v = pl.load(v_ref, (slice_k, slice(None))) - qk = pl.dot(q, k.T) + dq = jnp.zeros([block_q, block_d], dtype=jnp.float32) - # These casts are needed to avoid precision issues. - qk = qk.astype(jnp.float32) + q = pl.load(q_ref, (curr_q_slice, slice(None))) + q_segment_ids = None if s_ref is None else pl.load(s_ref, (curr_q_slice,)) + lse = pl.load(lse_ref, (curr_q_slice,)) + do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + di = pl.load(delta_ref, (curr_q_slice,)) + def inner_loop_dq(start_k, dq): + curr_k_slice = pl.dslice(start_k * block_k, block_k) + k = pl.load(k_ref, (curr_k_slice, slice(None))) + v = pl.load(v_ref, (curr_k_slice, slice(None))) + qk = pl.dot(q, k.T, precision=precision) if softmax_scale != 1.0: qk *= softmax_scale if b_ref is not None: - # Load bias in transposed order, for hopefully better cache efficiency. - b = pl.load( - b_ref, - (slice_k, slice_q), - ) - b = b.astype(jnp.float32) - qk += b.T # Transpose back. - if s_ref is not None: - kv_segment_ids = pl.load(s_ref, (slice_k,)) - mask = _segment_mask(q_segment_ids, kv_segment_ids) + qk += pl.load(b_ref, (curr_q_slice, curr_k_slice)) + qk = jnp.maximum(qk, NEG_INF) + + if causal or s_ref is not None: + mask = None + if s_ref is not None: + kv_segment_ids = pl.load(s_ref, (curr_k_slice,)) + mask = _segment_mask(q_segment_ids, kv_segment_ids) + + if causal: + span_k = start_k * block_k + jnp.arange(block_k) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) qk = jnp.where(mask, qk, NEG_INF) - if causal: - span_k = start_k * block_k + jnp.arange(block_k) - mask = span_q[:, None] >= span_k[None, :] - qk = jnp.where(mask, qk, NEG_INF) - p = jnp.exp(qk - m[:, None]) - dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] - dp = dp + pl.dot(do, v.T) + + p = jnp.exp(qk - lse[:, None]) + dp = dp_dropped = pl.dot(do, v.T, precision=precision) + if dropout_rate > 0: + dropout_mask = pl.load(dropout_mask_ref, (curr_q_slice, curr_k_slice)) + dp = jnp.where(dropout_mask, 0, dp_dropped / (1 - dropout_rate)) + dp = dp - di[:, None] ds = p * dp if softmax_scale != 1.0: ds = ds * softmax_scale - dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) + dq = dq + pl.dot(ds.astype(k.dtype), k, precision=precision) return dq if causal: - upper_bound = lax.div((start_q + 1) * block_q, block_k) + upper_bound = jnp.minimum( + pl.cdiv((start_q + 1) * block_q, block_k), pl.cdiv(kv_seq_len, block_k) + ) else: - upper_bound = pl.cdiv(seq_len, block_k) + upper_bound = pl.cdiv(kv_seq_len, block_k) dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) - pl.store(dq_ref, (slice_q, slice(None)), dq.astype(dq_ref.dtype)) + pl.store(dq_ref, (curr_q_slice, slice(None)), dq.astype(dq_ref.dtype)) def _mha_backward( softmax_scale: float, causal: bool, + dropout_rate: float, block_q: int, block_k: int, - backward_pass_impl: str, num_warps: Optional[int], num_stages: int, grid: Any, interpret: bool, debug: bool, + output_activations: bool, res, do, ): - """Calls `_mha_backward_kernel`.""" - del num_warps, num_stages, grid - q, k, v, b, s, out, l, m = res - - # NOTE: temporarily removed the "xla" branch, which seems unused. - if backward_pass_impl == "triton": - # We must shrink the block size for float32 inputs to avoid OOM during bwd pass. - if jnp.float32 in (q.dtype, k.dtype, v.dtype, jnp.bfloat16 if b is None else b.dtype): - block_q = block_k = 32 - - batch_size, seq_len, num_heads, head_dim = q.shape - # Backward heuristics, using the same block size for block q and block k. - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) - # Very tiny amount of time, not worth using pallas_call. - do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) - # We accumulate into dq so we need to initialize it to zeros. - out_shapes = [ - jax.ShapeDtypeStruct(q.shape, q.dtype), - jax.ShapeDtypeStruct(k.shape, k.dtype), - jax.ShapeDtypeStruct(v.shape, v.dtype), - ] - - # Bias. - bias_block_spec = None - if b is not None: - assert b.ndim == 4 - b = jnp.moveaxis(b, -1, -2) + """Calls Pallas kernels to compute dQ, dK and dV. - def bias_index_map(j, k, _): - return (j if b.shape[0] != 1 else 0, k if b.shape[1] != 1 else 0, 0, 0) + Note: separating dKdV and dQ loops into two kernels in flash backward improved performance by + 10~15% when head_dim >= 128. Note that technically fusing dKdVdQ into a single loop and use + atomic add for dQ is the fastest solution, but pallas atomics are extremely slow according + to empirical testing. + """ + del num_warps, grid, output_activations + q, k, v, bias, segment_ids, prng_key, out, lse = res + # We must shrink the block size for float32 inputs to avoid OOM during bwd pass. + if jnp.float32 in (q.dtype, k.dtype, v.dtype, jnp.bfloat16 if bias is None else bias.dtype): + block_q = block_k = 32 + + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) + # Compute delta (D) as in Algorithm 2 Line 4 of FlashAttention2. + delta = ( + (out.astype(jnp.float32) * do.astype(jnp.float32)) + .sum(axis=3) + .transpose((0, 2, 1)) + .astype(lse.dtype) + ) - bias_block_spec = pl.BlockSpec( - index_map=bias_index_map, block_shape=(None, None, seq_len, seq_len) + if dropout_rate > 0: + dropout_mask = get_dropout_mask( + (batch_size, num_heads, q_seq_len, kv_seq_len), prng_key=prng_key, rate=dropout_rate + ) + else: + dropout_mask = None + + in_specs = [ + pl.BlockSpec((None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), # q + pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), # k + pl.BlockSpec((None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), # v + ( + None + if bias is None + else pl.BlockSpec( + index_map=lambda i, j, _: ( + i if bias.shape[0] != 1 else 0, + j if bias.shape[1] != 1 else 0, + 0, + 0, + ), + block_shape=(None, None, q_seq_len, kv_seq_len), ) + ), + None if segment_ids is None else pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0)), + ( + None + if dropout_mask is None + else pl.BlockSpec((None, None, q_seq_len, kv_seq_len), lambda i, j, _: (i, j, 0, 0)) + ), + pl.BlockSpec((None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)), # do + pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), # lse + pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), # delta + ] - # Segment Ids. - segment_ids_block_spec = None - if s is not None: - assert s.ndim == 2 - segment_ids_block_spec = pl.BlockSpec( - index_map=(lambda j, k, _: (j, 0)), block_shape=(None, seq_len) - ) - grid = (batch_size, num_heads, pl.cdiv(seq_len, block_q)) - # Add some proof check against SRAM for float32 inputs or huge bias input. - num_warps = 8 - num_stages = 2 if b is None and jnp.float32 not in (q.dtype, k.dtype, v.dtype) else 1 - dq, dk, dv = pl.pallas_call( + num_warps = 8 + if num_stages is None: + num_stages = 2 if bias is None and jnp.float32 not in (q.dtype, k.dtype, v.dtype) else 1 + + def call_kernel(*, kernel, grid, out_shape, out_specs): + return pl.pallas_call( functools.partial( - _mha_backward_kernel, + kernel, softmax_scale=softmax_scale, causal=causal, + dropout_rate=dropout_rate, block_q=block_q, - block_d=head_dim, block_k=block_k, ), + out_shape=out_shape, + in_specs=in_specs, grid=grid, - out_shape=out_shapes, - in_specs=[ - pl.BlockSpec( - index_map=(lambda j, k, _: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), # query - pl.BlockSpec( - index_map=(lambda j, k, _: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), # key - pl.BlockSpec( - index_map=(lambda j, k, _: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), # value - bias_block_spec, # bias - segment_ids_block_spec, # segment_ids - pl.BlockSpec( - index_map=(lambda j, k, _: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), - pl.BlockSpec( - index_map=(lambda j, k, _: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), - pl.BlockSpec( - index_map=(lambda j, k, _: (j, k, 0)), block_shape=(None, None, seq_len) - ), - pl.BlockSpec( - index_map=(lambda j, k, _: (j, k, 0)), block_shape=(None, None, seq_len) - ), - pl.BlockSpec( - index_map=(lambda j, k, _: (j, k, 0)), block_shape=(None, None, seq_len) - ), - ], - out_specs=[ - pl.BlockSpec( - index_map=(lambda j, k, _: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), - pl.BlockSpec( - index_map=(lambda j, k, _: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), - pl.BlockSpec( - index_map=(lambda j, k, _: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), - ), - ], - name="mha_backward", + out_specs=out_specs, + name=kernel.__name__, debug=debug, interpret=interpret, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages), - )(q, k, v, b, s, out, do_scaled, l, m, delta) - else: - raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") - return dq.astype(q.dtype), dk, dv, None, None + compiler_params=NoPopDict(triton=NoPopDict(num_warps=num_warps, num_stages=num_stages)), + )(q, k, v, bias, segment_ids, dropout_mask, do, lse, delta) + + dk, dv = call_kernel( + kernel=_mha_backward_kernel_dkdv, + grid=(batch_size, num_heads, pl.cdiv(kv_seq_len, block_k)), + out_shape=[ + jax.ShapeDtypeStruct(k.shape, k.dtype), + jax.ShapeDtypeStruct(v.shape, v.dtype), + ], + out_specs=[ + pl.BlockSpec( + (None, kv_seq_len, None, head_dim), + lambda i, j, _: (i, 0, j, 0), # dk + ), + pl.BlockSpec( + (None, kv_seq_len, None, head_dim), + lambda i, j, _: (i, 0, j, 0), # dv + ), + ], + ) + + dq = call_kernel( + kernel=_mha_backward_kernel_dq, + grid=(batch_size, num_heads, pl.cdiv(q_seq_len, block_q)), + out_shape=jax.ShapeDtypeStruct(q.shape, q.dtype), + out_specs=pl.BlockSpec( + (None, q_seq_len, None, head_dim), + lambda i, j, _: (i, 0, j, 0), # dq + ), + ) + return dq, dk, dv, None, None, None flash_attention.defvjp(_mha_forward, _mha_backward) # Interface to cuDNN's dot product attention. -# TODO(kelvin-zou): Verify dropout rate functions. # TODO(kelvin-zou): Add support for segment IDs. def cudnn_dot_product_attention( query: Tensor, diff --git a/axlearn/common/flash_attention/gpu_attention_benchmark.py b/axlearn/common/flash_attention/gpu_attention_benchmark.py index b7c1d79a..dc442bb0 100644 --- a/axlearn/common/flash_attention/gpu_attention_benchmark.py +++ b/axlearn/common/flash_attention/gpu_attention_benchmark.py @@ -10,95 +10,95 @@ """FlashAttention kernel benchmarks. Tor run: python3 gpu_attention_benchmark.py > out.txt -Requires Jax >= 0.4.36. Sample numbers on H100 SXM5: +Requires Jax >= 0.4.36. Sample numbers on H100 SXM5 with Jax == 0.4.36: is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128, sw_sz=-1 jax axlearn jax-cudnn -bs=1,seq_len=1024 0.020608 0.018656 0.023680 -bs=1,seq_len=4096 0.037856 0.022784 0.056704 -bs=1,seq_len=8192 0.033792 0.032768 0.104448 -bs=1,seq_len=131072 0.227808 0.198816 1.486752 -bs=4,seq_len=1024 0.021440 0.022208 0.024032 -bs=4,seq_len=4096 0.069728 0.054624 0.059584 -bs=4,seq_len=8192 0.081952 0.076064 0.105920 -bs=4,seq_len=131072 0.823104 0.705056 1.488832 -bs=8,seq_len=1024 0.032544 0.030688 0.024608 -bs=8,seq_len=4096 0.089728 0.071648 0.063584 -bs=8,seq_len=8192 0.129184 0.114944 0.109856 -bs=8,seq_len=131072 1.616800 1.376288 1.503360 -bs=16,seq_len=1024 0.050976 0.048608 0.037504 -bs=16,seq_len=4096 0.136768 0.117312 0.104224 -bs=16,seq_len=8192 0.234688 0.200128 0.190944 -bs=16,seq_len=131072 3.211200 2.727040 2.779872 -bs=32,seq_len=1024 0.078656 0.072992 0.061440 -bs=32,seq_len=4096 0.236576 0.204512 0.190752 -bs=32,seq_len=8192 0.443488 0.372352 0.361216 -bs=32,seq_len=131072 6.392320 5.453344 5.495488 +bs=1,seq_len=1024 0.020832 0.017536 0.024128 +bs=1,seq_len=4096 0.037472 0.021248 0.058656 +bs=1,seq_len=8192 0.034016 0.032576 0.108576 +bs=1,seq_len=131072 0.229856 0.198944 1.558464 +bs=4,seq_len=1024 0.021632 0.023296 0.024352 +bs=4,seq_len=4096 0.068064 0.055168 0.061312 +bs=4,seq_len=8192 0.080352 0.075968 0.109696 +bs=4,seq_len=131072 0.824576 0.703360 1.560768 +bs=8,seq_len=1024 0.033536 0.030304 0.024448 +bs=8,seq_len=4096 0.089056 0.071712 0.062944 +bs=8,seq_len=8192 0.128960 0.114848 0.112736 +bs=8,seq_len=131072 1.620032 1.373088 1.566208 +bs=16,seq_len=1024 0.050368 0.048064 0.036608 +bs=16,seq_len=4096 0.134816 0.116320 0.104320 +bs=16,seq_len=8192 0.234880 0.200384 0.191936 +bs=16,seq_len=131072 3.219008 2.726912 2.784768 +bs=32,seq_len=1024 0.078112 0.070816 0.061568 +bs=32,seq_len=4096 0.235648 0.203296 0.191936 +bs=32,seq_len=8192 0.442080 0.371936 0.365152 +bs=32,seq_len=131072 6.404832 5.448480 5.541504 is_decode=True, use_bwd=False, num_heads=8, seq_len=32768, per_head_dim=128, sw_sz=-1 jax axlearn jax-cudnn -bs=1,num_kv_heads=1 0.049280 0.059296 0.378304 -bs=1,num_kv_heads=8 0.076352 0.070912 0.377344 -bs=8,num_kv_heads=1 0.111072 0.080480 0.377696 -bs=8,num_kv_heads=8 0.425536 0.368576 0.386880 +bs=1,num_kv_heads=1 0.027648 0.058464 0.398816 +bs=1,num_kv_heads=8 0.076096 0.070368 0.398912 +bs=8,num_kv_heads=1 0.101696 0.078560 0.399040 +bs=8,num_kv_heads=8 0.426656 0.367616 0.403360 is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128 jax axlearn jax-cudnn -bs=1,seq_len=131072,sw_sz=-1 0.228640 0.199040 1.476928 -bs=1,seq_len=131072,sw_sz=4096 0.232320 0.053824 4.441376 -bs=1,seq_len=131072,sw_sz=16384 0.233696 0.061120 4.420992 -bs=8,seq_len=131072,sw_sz=-1 1.621696 1.374080 1.496224 -bs=8,seq_len=131072,sw_sz=4096 1.626016 0.193792 4.463296 -bs=8,seq_len=131072,sw_sz=16384 1.628704 0.318176 4.451648 +bs=1,seq_len=131072,sw_sz=-1 0.230336 0.199968 1.559168 +bs=1,seq_len=131072,sw_sz=4096 0.235296 0.051296 4.414048 +bs=1,seq_len=131072,sw_sz=16384 0.235904 0.062976 4.385216 +bs=8,seq_len=131072,sw_sz=-1 1.619008 1.372768 1.570272 +bs=8,seq_len=131072,sw_sz=4096 1.635424 0.194720 4.390976 +bs=8,seq_len=131072,sw_sz=16384 1.632832 0.321280 4.361984 is_decode=False, use_bwd=False, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 jax axlearn jax-cudnn jax-pallas -bs=2 3.502944 0.915360 0.467744 0.845792 -bs=4 6.969376 1.753152 0.890496 1.617280 -bs=8 13.962816 3.415232 1.735232 3.150752 +bs=2 3.583424 0.894912 0.488480 0.852960 +bs=4 7.107168 1.712448 0.922592 1.629888 +bs=8 14.202400 3.341568 1.801920 3.184064 is_decode=False, use_bwd=False, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 jax axlearn jax-cudnn jax-pallas -num_heads=12 1.262560 0.393536 0.205952 0.362304 -num_heads=16 1.786816 0.498304 0.257664 0.459936 -num_heads=32 3.507488 2.591456 0.468672 2.443296 -num_heads=48 5.246336 1.338272 0.675968 1.231328 -num_heads=72 7.866848 1.961152 0.995712 1.805376 +num_heads=12 1.287712 0.383200 0.214400 0.365120 +num_heads=16 1.803232 0.485408 0.270496 0.463040 +num_heads=32 3.578208 0.896576 0.488544 2.468096 +num_heads=48 5.346112 1.305856 0.707872 1.241728 +num_heads=72 8.001568 1.915776 1.035200 1.820288 is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1 jax axlearn jax-cudnn jax-pallas -seq_len=128 0.030592 0.011584 0.013024 0.012960 -seq_len=256 0.051520 0.015648 0.016640 0.015744 -seq_len=512 0.118720 0.038976 0.028224 0.037152 -seq_len=1024 0.310880 0.096256 0.054784 0.090368 -seq_len=2048 0.931072 0.277312 0.150784 0.256928 -seq_len=4096 3.516672 2.595872 0.465568 2.448128 +seq_len=256 0.049184 0.015360 0.016352 0.015488 +seq_len=512 0.110400 0.038624 0.028480 0.037760 +seq_len=1024 0.302304 0.094560 0.056736 0.090464 +seq_len=2048 0.936832 0.269856 0.154304 0.258944 +seq_len=4096 3.584800 0.895776 0.487104 2.462560 +seq_len=8192 14.260608 3.268320 1.742048 3.104640 is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1 jax axlearn jax-cudnn jax-pallas -per_head_dim=16 3.220960 0.487808 0.332928 0.478720 -per_head_dim=32 3.277824 0.530240 0.334624 0.515040 -per_head_dim=64 3.345376 0.696480 0.338944 0.631296 -per_head_dim=128 3.515616 2.594208 0.465824 2.442784 +per_head_dim=16 3.262592 0.518912 0.356544 0.477120 +per_head_dim=32 3.323552 0.563520 0.358944 0.533344 +per_head_dim=64 3.411744 0.690464 0.360192 0.635296 +per_head_dim=128 3.585920 0.896032 0.488416 2.461696 is_decode=False, use_bwd=True, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 jax axlearn jax-cudnn jax-pallas -bs=2 10.780096 4.573344 2.080672 4.487104 -bs=4 21.426336 9.336192 3.988224 9.159904 -bs=8 42.808033 18.926559 7.975296 18.075487 +bs=2 10.878624 3.924992 2.123008 4.504256 +bs=4 21.626017 8.043040 4.071552 9.186080 +bs=8 43.269279 16.195999 8.124896 18.184799 is_decode=False, use_bwd=True, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1 jax axlearn jax-cudnn jax-pallas -num_heads=12 4.128352 1.738016 0.882976 1.696704 -num_heads=16 5.467808 2.307488 1.120608 2.247904 -num_heads=32 10.782432 4.559456 2.082592 4.488448 -num_heads=48 16.119776 6.958272 3.027808 6.858144 -num_heads=72 24.140833 10.706656 4.560288 10.279136 +num_heads=12 4.159424 1.519680 0.898816 1.711808 +num_heads=16 5.486912 2.001952 1.142144 2.256960 +num_heads=32 10.886848 3.928896 2.114496 4.502976 +num_heads=48 16.224319 6.085408 3.093696 6.888640 +num_heads=72 24.367489 9.190560 4.642720 10.323552 is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1 jax axlearn jax-cudnn jax-pallas -seq_len=128 0.058944 0.037824 0.039040 0.036384 -seq_len=256 0.100384 0.069024 0.052608 0.067872 -seq_len=512 0.317056 0.159904 0.111840 0.158912 -seq_len=1024 0.906400 0.431104 0.244160 0.421792 -seq_len=2048 2.861056 1.319648 0.655840 1.297728 -seq_len=4096 10.762560 4.576864 2.079904 4.489056 +seq_len=256 0.094496 0.060096 0.053184 0.065760 +seq_len=512 0.297440 0.139328 0.112736 0.161664 +seq_len=1024 0.886304 0.361536 0.246848 0.418720 +seq_len=2048 2.857952 1.118368 0.675168 1.294144 +seq_len=4096 10.880512 3.914048 2.119808 4.503936 +seq_len=8192 43.000095 14.913824 7.484128 16.730017 is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1 jax axlearn jax-cudnn jax-pallas -per_head_dim=16 10.084800 1.744640 1.263264 1.711296 -per_head_dim=32 10.204480 2.098816 1.291104 2.041184 -per_head_dim=64 10.374720 2.649888 1.335200 2.510304 -per_head_dim=128 10.779680 4.568096 2.079264 4.489792 +per_head_dim=16 10.150080 1.826656 1.288192 1.718688 +per_head_dim=32 10.277440 2.028608 1.316512 2.048864 +per_head_dim=64 10.463904 2.569408 1.364448 2.540512 +per_head_dim=128 10.875328 3.929568 2.124192 4.502912 """ # pylint: enable=line-too-long import itertools @@ -365,8 +365,8 @@ def bench_flash_attention_fwd_bwd(use_bwd: bool): libraries = ["jax", "axlearn", "jax-cudnn", "jax-pallas"] benchmark_sweep(libraries, common_kwargs, bs=[2, 4, 8]) benchmark_sweep(libraries, common_kwargs, num_heads=[12, 16, 32, 48, 72]) - # 128 to 4096. - benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(7, 13)]) + # 256 to 8192. + benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(8, 14)]) benchmark_sweep(libraries, common_kwargs, per_head_dim=[16, 32, 64, 128]) diff --git a/axlearn/common/flash_attention/gpu_attention_test.py b/axlearn/common/flash_attention/gpu_attention_test.py index 9a1e8561..181a6888 100644 --- a/axlearn/common/flash_attention/gpu_attention_test.py +++ b/axlearn/common/flash_attention/gpu_attention_test.py @@ -43,6 +43,8 @@ (2, 384, 8, 128), ], ) +@pytest.mark.parametrize("kv_seq_len", [-1, 512]) +@pytest.mark.parametrize("dropout_rate", [0, 0.1]) @pytest.mark.parametrize("block_size", [64, 128]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("softmax_scale", [1.0, 0.123]) @@ -54,6 +56,8 @@ def test_triton_fwd_only_against_ref( seq_len: int, num_heads: int, per_head_dim: int, + kv_seq_len: int, + dropout_rate: float, block_size: int, causal: bool, softmax_scale: float, @@ -61,15 +65,21 @@ def test_triton_fwd_only_against_ref( use_segment_ids: bool, input_dtype: jnp.dtype, ): - k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) + if kv_seq_len == -1: + kv_seq_len = seq_len + if kv_seq_len != seq_len and use_segment_ids: + pytest.skip() + k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5) q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) - k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) - v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) + k = jax.random.normal(k2, (batch_size, kv_seq_len, num_heads, per_head_dim), dtype=input_dtype) + v = jax.random.normal(k3, (batch_size, kv_seq_len, num_heads, per_head_dim), dtype=input_dtype) if attention_bias_type == "4d": - bias = jax.random.normal(k4, (batch_size, num_heads, seq_len, seq_len), dtype=input_dtype) + bias = jax.random.normal( + k4, (batch_size, num_heads, seq_len, kv_seq_len), dtype=input_dtype + ) elif attention_bias_type == "2d": - bias = jax.random.normal(k4, (1, 1, seq_len, seq_len), dtype=input_dtype) + bias = jax.random.normal(k4, (1, 1, seq_len, kv_seq_len), dtype=input_dtype) else: bias = None @@ -79,21 +89,34 @@ def test_triton_fwd_only_against_ref( jnp.concatenate([segment_left, segment_right], axis=-1) if use_segment_ids else None ) - @jax.jit - def impl(q, k, v, bias, segment_ids): - fn = functools.partial( - flash_attention, - block_q=block_size, - block_k=block_size, - causal=causal, - softmax_scale=softmax_scale, - ) - out, _ = jax.vjp(fn, q, k, v, bias, segment_ids) - return out - - o = impl(q, k, v, bias, segment_ids) - o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale) - chex.assert_trees_all_close(o, o_ref, atol=0.07) + o = flash_attention( + q, + k, + v, + bias, + segment_ids, + k5, + block_q=block_size, + block_k=block_size, + causal=causal, + softmax_scale=softmax_scale, + dropout_rate=dropout_rate, + ) + o_ref = mha_reference( + q, + k, + v, + bias, + segment_ids, + k5, + causal=causal, + softmax_scale=softmax_scale, + dropout_rate=dropout_rate, + ) + if input_dtype == jnp.float16: + chex.assert_trees_all_close(o, o_ref, atol=0.07) + elif input_dtype == jnp.float32: + chex.assert_trees_all_close(o, o_ref, atol=0.03) class FlashDecodingTest(TestCase): @@ -192,6 +215,8 @@ def test_decode_against_ref( (2, 8, 384, 128), ], ) +@pytest.mark.parametrize("kv_seq_len", [-1, 512]) +@pytest.mark.parametrize("dropout_rate", [0, 0.1]) @pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"]) @pytest.mark.parametrize("use_segment_ids", [True, False]) @pytest.mark.parametrize("block_size", [64, 128]) @@ -202,28 +227,29 @@ def test_triton_against_xla_ref( num_heads: int, seq_len: int, per_head_dim: int, + kv_seq_len: int, attention_bias_type: Literal["2d", "4d", None], use_segment_ids: bool, + dropout_rate: float, block_size: int, causal: bool, input_dtype: jnp.dtype, ): - q = jax.random.normal( - jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype - ) - k = jax.random.normal( - jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype - ) - v = jax.random.normal( - jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype - ) + if kv_seq_len == -1: + kv_seq_len = seq_len + if kv_seq_len != seq_len and use_segment_ids: + pytest.skip() + k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5) + q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) + k = jax.random.normal(k2, (batch_size, kv_seq_len, num_heads, per_head_dim), dtype=input_dtype) + v = jax.random.normal(k3, (batch_size, kv_seq_len, num_heads, per_head_dim), dtype=input_dtype) if attention_bias_type == "4d": bias = jax.random.normal( - jax.random.PRNGKey(3), (batch_size, num_heads, seq_len, seq_len), dtype=input_dtype + k4, (batch_size, num_heads, seq_len, kv_seq_len), dtype=input_dtype ) elif attention_bias_type == "2d": - bias = jax.random.normal(jax.random.PRNGKey(3), (1, 1, seq_len, seq_len), dtype=input_dtype) + bias = jax.random.normal(k4, (1, 1, seq_len, kv_seq_len), dtype=input_dtype) else: bias = None @@ -236,48 +262,66 @@ def test_triton_against_xla_ref( softmax_scale = q.shape[-1] ** -0.5 # Compare outputs. - jax_out = flash_attention( + call_flash = functools.partial( + flash_attention, + causal=causal, + softmax_scale=softmax_scale, + block_q=block_size, + block_k=block_size, + dropout_rate=dropout_rate, + ) + jax_out = call_flash( q, k, v, bias, segment_ids, - causal=causal, - softmax_scale=softmax_scale, - block_q=block_size, - block_k=block_size, + k5, ) jax_ref_out = mha_reference( - q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale + q, + k, + v, + bias, + segment_ids, + k5, + causal=causal, + softmax_scale=softmax_scale, + dropout_rate=dropout_rate, ) if input_dtype == jnp.float16: chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.005) elif input_dtype == jnp.float32: - chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.05) + chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.005) else: raise ValueError(f"Unsupported dtype: {input_dtype}") - def fn(q, k, v, bias, segment_ids): - return flash_attention( + def fn(q, k, v, bias, segment_ids, k5): + return call_flash( q, k, v, bias, segment_ids, - causal=causal, - softmax_scale=softmax_scale, - block_q=block_size, - block_k=block_size, + k5, ).sum() - def ref_fn(q, k, v, bias, segment_ids): + def ref_fn(q, k, v, bias, segment_ids, k5): return mha_reference( - q, k, v, bias, segment_ids, causal=causal, softmax_scale=softmax_scale + q, + k, + v, + bias, + segment_ids, + k5, + causal=causal, + softmax_scale=softmax_scale, + dropout_rate=dropout_rate, ).sum() # Compare gradients. - jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids) - jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids) + jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids, k5) + jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v, bias, segment_ids, k5) chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.05) @@ -302,15 +346,10 @@ def test_cudnn_against_triton_ref( causal: bool, dtype: jnp.dtype, ): - q = jax.random.normal( - jax.random.PRNGKey(0), (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype - ) - k = jax.random.normal( - jax.random.PRNGKey(1), (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype - ) - v = jax.random.normal( - jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype - ) + k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) + q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype) + k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype) + v = jax.random.normal(k3, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype) softmax_scale = q.shape[-1] ** -0.5 @@ -346,3 +385,120 @@ def ref_fn(q, k, v): chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.05, rtol=1e-5) else: raise ValueError(f"Unsupported dtype: {dtype}") + + +@pytest.mark.parametrize( + "batch_size,num_heads,seq_len,per_head_dim", + [ + (1, 1, 128, 128), + (2, 4, 128, 128), + (1, 2, 64, 64), + (2, 8, 64, 64), + ], +) +@pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16]) +@pytest.mark.parametrize("dropout_rate", [0.1, 0.25]) +def test_cudnn_dropout_against_xla_dropout( + batch_size: int, + num_heads: int, + seq_len: int, + per_head_dim: int, + causal: bool, + dtype: jnp.dtype, + dropout_rate: float, +): + """Tests that cudnn dropout works as expected. + + Since cuDNN uses a different kind of RNG than Jax, we retrieve the mask generated by cuDNN + by setting V to the identity matrix. However, this only works when seq_len == per_head_dim, + i.e. when the shape of output is the same as the shape of the dropout mask. + """ + qkv_shape = (batch_size, seq_len, num_heads, per_head_dim) + softmax_scale = 1.0 + cudnn_attn = functools.partial( + cudnn_dot_product_attention, + bias=None, + causal=causal, + softmax_scale=softmax_scale, + dropout_rate=dropout_rate, + ) + + dropout_mask = ( + cudnn_attn( + jnp.zeros(qkv_shape, dtype=dtype), + jnp.zeros(qkv_shape, dtype=dtype), + jnp.broadcast_to(jnp.eye(per_head_dim, dtype=dtype)[None, :, None], qkv_shape), + ) + == 0.0 + ).swapaxes(1, 2) + # Clear the compilation cache to reset cudnn RNG offset, so the next invocation will generate + # the same mask. + jax.clear_caches() + + k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) + q = jax.random.normal(k1, qkv_shape, dtype=dtype) + k = jax.random.normal(k2, qkv_shape, dtype=dtype) + v = jax.random.normal(k3, qkv_shape, dtype=dtype) + + ref_attn = functools.partial( + mha_reference, + bias=None, + causal=causal, + softmax_scale=softmax_scale, + dropout_mask=dropout_mask, + dropout_rate=dropout_rate, + ) + # Compare outputs. + jax_out = cudnn_attn(q, k, v) + jax_ref_out = ref_attn(q, k, v) + if dtype == jnp.bfloat16: + # We relax the atol to support bf16 in the unit test. + chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.25, rtol=1e-3) + elif dtype == jnp.float16: + chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.05, rtol=1e-3) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + def fn(q, k, v): + return cudnn_attn(q, k, v).mean() + + def ref_fn(q, k, v): + return ref_attn(q, k, v).mean() + + # Compare gradients. + jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v) + jax_ref_grads = jax.grad(ref_fn, argnums=(0, 1, 2))(q, k, v) + # The diff between grads are expected to be larger than the forward pass. + if dtype == jnp.bfloat16: + # We relax the rtol to support bf16 in the unit test. + chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.05, rtol=1e-2) + elif dtype == jnp.float16: + chex.assert_trees_all_close(jax_grads, jax_ref_grads, atol=0.05, rtol=1e-5) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def test_cudnn_dropout_determinism(): + """Tests that cuDNN dropout produces identical outputs across runs.""" + k1, k2, k3 = jax.random.split(jax.random.PRNGKey(3), 3) + q = jax.random.normal(k1, (1, 128, 2, 64), dtype=jnp.float16) + k = jax.random.normal(k2, (1, 128, 2, 64), dtype=jnp.float16) + v = jax.random.normal(k3, (1, 128, 2, 64), dtype=jnp.float16) + outputs = [] + grads = [] + + def fn(q, k, v): + return cudnn_dot_product_attention(q, k, v, dropout_rate=0.1).mean() + + for i in range(10): + outputs.append(cudnn_dot_product_attention(q, k, v, dropout_rate=0.1)) + grads.append(jax.grad(fn, argnums=(0, 1, 2))(q, k, v)) + + jax.clear_caches() + + for i in range(10): + chex.assert_trees_all_equal( + cudnn_dot_product_attention(q, k, v, dropout_rate=0.1), outputs[i] + ) + chex.assert_trees_all_equal(jax.grad(fn, argnums=(0, 1, 2))(q, k, v), grads[i]) diff --git a/axlearn/common/flash_attention/gpu_decoding.py b/axlearn/common/flash_attention/gpu_decoding.py index 90e795b3..3201bc2d 100644 --- a/axlearn/common/flash_attention/gpu_decoding.py +++ b/axlearn/common/flash_attention/gpu_decoding.py @@ -48,9 +48,9 @@ from jax import lax from jax._src.cudnn.fused_attention_stablehlo import check_compute_capability from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu from axlearn.common.attention import NEG_INF, MaskFn, Tensor +from axlearn.common.flash_attention.gpu_attention import NoPopDict # Note: split_k_seq_len must be a multiple of block_k. @@ -238,7 +238,7 @@ def _decode_attn_unbatched( pl.BlockSpec((None, None, block_h), lambda kv_h, q_h, k: (kv_h, k, q_h)), # l pl.BlockSpec((None, None, block_h), lambda kv_h, q_h, k: (kv_h, k, q_h)), # m ], - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages), + compiler_params=NoPopDict(triton=NoPopDict(num_warps=num_warps, num_stages=num_stages)), out_shape=[ jax.ShapeDtypeStruct( shape=(num_kvheads, k_splits, *q.shape[1:]), dtype=jnp.float32 diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index 2cb8da18..457233b6 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -11,7 +11,7 @@ from jax.interpreters.pxla import thread_resources from jax.sharding import PartitionSpec -from axlearn.common.attention import GroupedQueryAttention +from axlearn.common.attention import Dropout, GroupedQueryAttention from axlearn.common.attention_bias import BaseAttentionBias from axlearn.common.config import config_class from axlearn.common.flash_attention.utils import ( @@ -69,9 +69,14 @@ def __init__(self, cfg: Config, *, parent: Module): cfg = self.config if getattr(cfg, "atten_logit_cap", None) is not None: raise NotImplementedError("cfg.atten_logit_cap is not supported.") - # TODO(kelvinzou): enable dropout for flash attention. - if cfg.dropout.rate: - raise NotImplementedError("cfg.dropout.rate is not supported.") + # We're checking for an exact class match here. + # pylint: disable-next=unidiomatic-typecheck + if type(self.dropout) is not Dropout: + raise NotImplementedError( + f"Only {Dropout.__module__}.{Dropout.__qualname__} is supported for " + "FlashAttention. Got " + f"{type(self.dropout).__module__}.{type(self.dropout).__qualname__}" + ) if cfg.tpu_block_size % 128 != 0: raise ValueError("cfg.tpu_block_size must divide 128.") @@ -113,7 +118,7 @@ def _compute_attention( v_proj: Tensor, attention_logit_biases: BaseAttentionBias, ) -> tuple[Tensor, Tensor]: - cfg = self.config + cfg: FlashAttention.Config = self.config backend = self._backend() batch, target_len, num_heads, _ = q_proj.shape @@ -125,6 +130,7 @@ def _compute_attention( backend=backend, softmax_scale=1.0, block_size=cfg.tpu_block_size, + dropout_rate=cfg.dropout.rate, ) attention_logit_biases_spec = self._logit_biases_spec(attention_logit_biases) @@ -156,6 +162,8 @@ def _compute_attention( cfg.mha_dim_to_partition_spec["bsnh"], # Bias that can broadcast to [batch_size, num_heads, seq_len, seq_len]. attention_logit_biases_spec, + # PRNG Key. + PartitionSpec(None), ), # O [batch_size, seq_len, num_heads, per_head_dim]. out_specs=cfg.mha_dim_to_partition_spec["btnh"], @@ -165,7 +173,15 @@ def _compute_attention( ) outputs = with_sharding_constraint( - partitioned_mha(q_proj, k_proj, v_proj, attention_logit_biases), + partitioned_mha( + # Note: we use dropout layer's prng_key so the dropout result is identical to + # using self.dropout.forward because we will produce identical mask. + q_proj, + k_proj, + v_proj, + attention_logit_biases, + self.dropout.get_prng_key(), + ), cfg.output_dim_to_partition_spec["btnh"], ) diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index 1d0df493..a3330a82 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -24,7 +24,7 @@ from jax.experimental import mesh_utils from jax.sharding import Mesh -from axlearn.common.attention import GroupedQueryAttention +from axlearn.common.attention import Dropout, GroupedQueryAttention from axlearn.common.attention_bias import ( CompositeAttentionBias, SegmentIdAttentionBias, @@ -98,6 +98,7 @@ def _prepare_layers( sliding_window_size, inference=False, set_layer_bias_recursively=False, + dropout_rate=0.0, ): hidden_dim = num_heads * per_head_dim kwargs = dict( @@ -106,6 +107,7 @@ def _prepare_layers( value_dim=hidden_dim, num_heads=num_heads, dtype=jnp.bfloat16, + dropout=Dropout.default_config().set(rate=dropout_rate), ) ref_cfg = GroupedQueryAttention.default_config().set(**kwargs) @@ -294,6 +296,22 @@ class TestFlashAttention(TestCase): ), ] + def test_dropout_support(self): + """Tests that FlashAttention errors out when custom dropout is used.""" + + class OtherDropout(Dropout): + pass + + required_kwargs = dict(query_dim=128, key_dim=128, value_dim=128, num_heads=2, name="test") + FlashAttention.default_config().set( + dropout=Dropout.default_config(), **required_kwargs + ).instantiate(parent=None) + + with self.assertRaises(NotImplementedError): + FlashAttention.default_config().set( + dropout=OtherDropout.default_config(), **required_kwargs + ).instantiate(parent=None) + @parameterized.parameters( [kwargs for kwargs in _TEST_CONFIGS if math.prod(kwargs["mesh"]) == 1] ) @@ -373,6 +391,7 @@ def as_partition_spec(pytree: CompositeAttentionBias) -> PartitionSpec: use_bias=[False, True], use_segment_ids=[False, True], input_dtype=[jnp.bfloat16, jnp.float32], + dropout_rate=[0.0, 0.1], ) def test_forward( self, @@ -388,6 +407,7 @@ def test_forward( use_bias, use_segment_ids, input_dtype, + dropout_rate, ): if not is_supported_mesh_shape(mesh): pytest.skip(reason=f"Unsupported mesh {mesh}.") @@ -401,6 +421,8 @@ def test_forward( # Data=1 with bias matrix in all fp32 format would OOM the H100 SRAM. if use_bias and mesh[mesh_axis_names.index("data")] == 1 and input_dtype == jnp.float32: pytest.skip(reason="Unsupported large bias matrix in fp32 format.") + if dropout_rate > 0.0 and jax.default_backend() == "tpu": + pytest.skip("Dropout is implemented for GPU only.") with Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names): test_layer, ref_layer, params, hidden_dim = _prepare_layers( @@ -409,12 +431,8 @@ def test_forward( mesh_axis_names=mesh_axis_names, causal=causal, sliding_window_size=sliding_window_size, + dropout_rate=dropout_rate, ) - # pylint: disable-next=protected-access - if test_layer._backend() == "gpu" and query_len_multiplier != 1: - pytest.skip( - reason="GPU flash attention does not support different query and key lengths." - ) query_len = int(query_len_multiplier * seq_len) inputs = _fake_inputs( @@ -444,7 +462,10 @@ def test_forward( is_training=True, ) # TODO(markblee): Test probs. - self.assertNestedAllClose(ref_out.data, test_out.data, atol=0.05) + # Note: cannot compare results when dropout_rate > 0 and not using segment ids, because + # cudnn dropout will be used and it uses different PRNG than ours. + if dropout_rate == 0.0 or use_segment_ids: + self.assertNestedAllClose(ref_out.data, test_out.data, atol=0.05) jax.extend.backend.clear_backends() @parameterized.product( @@ -455,6 +476,7 @@ def test_forward( use_bias=[False, True], use_segment_ids=[False, True], set_layer_bias_recursively=[False, True], + dropout_rate=[0.0, 0.1], ) def test_backward( self, @@ -470,6 +492,7 @@ def test_backward( use_bias, use_segment_ids, set_layer_bias_recursively, + dropout_rate, ): if not is_supported_mesh_shape(mesh): pytest.skip(reason=f"Unsupported mesh {mesh}.") @@ -477,6 +500,12 @@ def test_backward( pytest.skip("Segment IDs are not supported for Q and K with different lengths.") if not causal and sliding_window_size is not None: pytest.skip(reason="Sliding window attention must be causal.") + if sliding_window_size is not None and query_len_multiplier > 1: + # When sliding window is enabled and q_len > kv_len, there might be be fully masked + # rows. + pytest.skip(reason="Sliding window attention does not make sense when q_len > kv_len.") + if dropout_rate > 0.0 and jax.default_backend() == "tpu": + pytest.skip("Dropout is implemented for GPU only.") if causal and use_bias: # TODO(c_lan): Investigate the numerical errors when both causal and bias are used. @@ -526,6 +555,7 @@ def forward(self, *, query, key, value, attention_logit_biases, segment_ids): dtype=jnp.bfloat16, causal=causal and (mask_fn is None), mask=mask_fn, + dropout=Dropout.default_config().set(rate=dropout_rate), ) ref_cfg = DummyModel.default_config().set( layer=GroupedQueryAttention.default_config().set(**kwargs), @@ -544,11 +574,6 @@ def forward(self, *, query, key, value, attention_logit_biases, segment_ids): set_bias_recursively(test_cfg, set_layer_bias_recursively) ref_layer = ref_cfg.set(name="ref").instantiate(parent=None) test_layer = test_cfg.set(name="test").instantiate(parent=None) - # pylint: disable-next=protected-access - if test_layer.layer._backend() == "gpu" and query_len_multiplier != 1: - pytest.skip( - reason="GPU flash attention does not support different query and key lengths." - ) # Use the same params for both. Only attention implementation differs. params = ref_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) @@ -581,13 +606,18 @@ def loss(params, inputs, layer): # pylint: disable-next=protected-access if set_layer_bias_recursively and test_layer.layer._backend() == "gpu": atol, rtol = 5e-4, 5e-2 - + # pylint: disable-next=protected-access + elif dropout_rate > 0.0 and test_layer.layer._backend() == "gpu": + atol, rtol = 2.5e-4, 1e-3 # Can be 1e-5 on x86_64/GPU/TPU, needed to be slightly higher on ARM. else: atol, rtol = 1e-4, 1e-3 - self.assertNestedAllClose(ref_value, test_value, atol=atol, rtol=rtol) - self.assertNestedAllClose(ref_grads, test_grads, atol=atol, rtol=rtol) + # Note: cannot compare results when dropout_rate > 0 and not using segment ids, because + # cudnn dropout will be used and it uses different PRNG than ours. + if dropout_rate == 0.0 or use_segment_ids: + self.assertNestedAllClose(ref_value, test_value, atol=atol, rtol=rtol) + self.assertNestedAllClose(ref_grads, test_grads, atol=atol, rtol=rtol) jax.extend.backend.clear_backends() @parameterized.product( diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 938a4708..fa057cf8 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -24,10 +24,11 @@ from axlearn.common.flash_attention.gpu_attention import flash_attention as gpu_flash_attention from axlearn.common.flash_attention.gpu_decoding import flash_decoding from axlearn.common.flash_attention.tpu_attention import tpu_flash_attention +from axlearn.common.layers import dropout from axlearn.common.utils import Tensor -@functools.partial(jax.jit, static_argnames=["causal", "softmax_scale"]) +@functools.partial(jax.jit, static_argnames=["causal", "softmax_scale", "dropout_rate"]) @jax.default_matmul_precision("bfloat16") def mha_reference( q: Tensor, @@ -35,9 +36,12 @@ def mha_reference( v: Tensor, bias: Optional[Tensor] = None, segment_ids: Optional[Tensor] = None, + prng_key: Optional[Tensor] = None, *, causal: bool = False, softmax_scale: float = 1.0, + dropout_rate: float = 0.0, + dropout_mask: Optional[Tensor] = None, ) -> Tensor: """Reference multi-headed attention implementation. @@ -48,9 +52,10 @@ def mha_reference( bias: bias tensor with a shape that can broadcast to [batch_size, num_heads, seq_len, seq_len], e.g. [1, 1, seq_len, seq_len]. segment_ids: segment ids tensor with shape [batch_size, seq_len]. + prng_key: prng key for dropout. causal: whether the attention is causal. softmax_scale: a scalar value applied to the logits before softmax. - + dropout_rate: dropout rate. Returns: A tensor with shape [batch_size, seq_len, num_heads, per_head_dim]. """ @@ -77,6 +82,9 @@ def mha_reference( logits = jnp.where(mask, NEG_INF, logits) probs = softmax_with_biases(logits, bias) + if dropout_rate > 0: + probs = dropout(probs, prng_key=prng_key, rate=dropout_rate, mask=dropout_mask) + context = jnp.einsum("bnts,bsnh->btnh", probs, v).astype(v.dtype) return context @@ -93,8 +101,8 @@ def _repeat_kv_heads(num_q_heads: int, key_or_value: Tensor) -> Tensor: return jnp.repeat(key_or_value, num_head_repeats, axis=-2) -# Accepts [query, key, value, attention_bias, segment_ids] tensors and returns the context Tensor. -MultiHeadAttentionImpl = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor] +# Accepts [query, key, value, attention_bias, prng_key] tensors and returns the context Tensor. +MultiHeadAttentionImpl = Callable[[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]], Tensor] def flash_attention_implementation( @@ -102,6 +110,7 @@ def flash_attention_implementation( *, softmax_scale: float, block_size: int = 128, + dropout_rate: Optional[float] = 0.0, ) -> MultiHeadAttentionImpl: """Returns a jitted "flash" multihead-attention implementation for the given backend. @@ -118,6 +127,8 @@ def flash_attention_implementation( Raises: NotImplementedError: If implementation for the backend is not available. """ + if dropout_rate is None: + dropout_rate = 0.0 # shard_map-decorated function needs to be jitted. @jax.jit @@ -126,6 +137,7 @@ def jit_attn( key: Tensor, value: Tensor, bias: BaseAttentionBias, + prng_key: Optional[Tensor] = None, *, backend: str = backend, ) -> Tensor: @@ -141,6 +153,8 @@ def jit_attn( # TODO(senyut): Support TPU decoding. backend = "xla" bias = TensorAttentionBias(bias.value()) + if dropout_rate != 0.0 and backend not in ("gpu", "xla", "cpu"): + raise NotImplementedError("Dropout is only implemented for GPU, CPU and XLA.") bias = CompositeAttentionBias([bias]) @@ -191,14 +205,6 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: softmax_scale=softmax_scale, ) - if query.shape[1] != key.shape[1]: - # TODO(xuan-zou): Generalize GPU Flash Attention for q_len != kv_len. - # Remove pytest.skip corresponding to q_len != kv_len in layer_test.py once fixed. - raise NotImplementedError( - f"Query length {query.shape[1]} must be equal to KV length " - f"{key.shape[1]} for correctly supported GPU flash attention usage." - ) - key = _repeat_kv_heads(query.shape[2], key) value = _repeat_kv_heads(query.shape[2], value) @@ -217,6 +223,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: segment_ids.value() is not None or explicit_bias.value() is not None or jnp.float32 in (query.dtype, key.dtype, value.dtype) + or query.shape[1] != key.shape[1] ): logging.warning("Flash attention falling back to Triton GPU kernel.") return gpu_flash_attention( @@ -225,8 +232,10 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: value, bias=explicit_bias.value(), segment_ids=get_segment_ids(segment_ids), + prng_key=prng_key, softmax_scale=softmax_scale, causal=causal.value() is not None, + dropout_rate=dropout_rate, ) else: explicit_bias += segment_ids @@ -237,7 +246,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: bias=explicit_bias.value(), softmax_scale=softmax_scale, causal=causal.value() is not None, - dropout_rate=0.0, + dropout_rate=dropout_rate, ) elif backend == "tpu": @@ -282,8 +291,10 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: value, bias=explicit_bias.value(), segment_ids=get_segment_ids(segment_ids), + prng_key=prng_key, causal=causal.value() is not None, softmax_scale=softmax_scale, + dropout_rate=dropout_rate, ) raise NotImplementedError(f"Backend ({backend}) does not have an implementation.") diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index d02ce5eb..5c95fee0 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -138,6 +138,35 @@ def _redirect(self, *args, redirection_target_method: str, **kwargs) -> Any: return getattr(shared_module.module, redirection_target_method)(*args, **kwargs) +def get_dropout_mask(shape: tuple[int, ...], *, prng_key: Tensor, rate: float): + """Returns a bool dropout mask for the specified tensor shape where True indicates dropout.""" + return jax.random.bernoulli(prng_key, rate, shape) + + +def dropout( + x: Tensor, *, rate: float, prng_key: Optional[Tensor] = None, mask: Optional[Tensor] = None +): + """Performs dropout on `x` according to dropout rate or mask. + + After dropout, `x` will be rescaled by 1 / (1 - rate). If `mask` is provided, use `mask`. + Otherwise, generate a dropout mask using `prng_key` and `rate`. + + Args: + x: Input tensor. + rate: Dropout rate. + prng_key: PRNG key used for mask generation. Required if `mask` is None. + mask: A boolean mask with the same shape as x. If provided, `prng_key` will be ignored. + Any values in `x` where `mask` is True will be dropped. + """ + if not 0 < rate < 1: + raise ValueError(f"Dropout rate must be between 0 and 1. Got {rate=}") + if mask is None: + if prng_key is None: + raise ValueError("prng_key must be provided when mask is not specified.") + mask = get_dropout_mask(x.shape, prng_key=prng_key, rate=rate) + return jnp.where(mask, 0, x) / (1 - rate) + + class Dropout(BaseLayer): """The dropout layer.""" @@ -149,12 +178,10 @@ def forward(self, x: Tensor) -> Tensor: cfg = self.config if not self.is_training or cfg.rate is None or cfg.rate == 0: return x - assert 0 < cfg.rate < 1 - samples = jax.random.uniform( - self.prng_key, shape=x.shape, dtype=x.dtype, minval=0.0, maxval=1.0 - ) - dropout = jnp.floor(1 - cfg.rate + samples) - return x * dropout / (1.0 - cfg.rate) + return dropout(x, prng_key=self.prng_key, rate=cfg.rate) + + def get_prng_key(self) -> Tensor: + return self.prng_key class DropToken(BaseLayer): From 2aa60af0e3e1f4ff7f929e03040ea00d7151d2db Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Mon, 6 Jan 2025 12:53:42 -0800 Subject: [PATCH 2/3] Fix comments --- axlearn/common/flash_attention/gpu_attention.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 5f04e71b..42110673 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -438,9 +438,6 @@ def inner_loop_dkdv(start_q, carry): pl.store(dk_ref, (curr_k_slice, slice(None)), dk.astype(dk_ref.dtype)) -# This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence -# length. Inspired by the triton tutorial: -# https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py def _mha_backward_kernel_dq( # Inputs. q_ref, @@ -549,9 +546,9 @@ def _mha_backward( """Calls Pallas kernels to compute dQ, dK and dV. Note: separating dKdV and dQ loops into two kernels in flash backward improved performance by - 10~15% when head_dim >= 128. Note that technically fusing dKdVdQ into a single loop and use - atomic add for dQ is the fastest solution, but pallas atomics are extremely slow according - to empirical testing. + 10~15% when head_dim >= 128. Technically, fusing dKdVdQ into a single loop and use atomic add + for dQ is the fastest solution, but pallas atomics are extremely slow according to empirical + testing. """ del num_warps, grid, output_activations q, k, v, bias, segment_ids, prng_key, out, lse = res From d0bbde15cdc74512fab2c7e24a154484a6dfad2a Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Wed, 8 Jan 2025 12:50:58 -0800 Subject: [PATCH 3/3] Disable cudnn dropout --- axlearn/common/flash_attention/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index fa057cf8..b9235371 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -224,6 +224,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: or explicit_bias.value() is not None or jnp.float32 in (query.dtype, key.dtype, value.dtype) or query.shape[1] != key.shape[1] + or dropout_rate != 0.0 ): logging.warning("Flash attention falling back to Triton GPU kernel.") return gpu_flash_attention( @@ -246,7 +247,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: bias=explicit_bias.value(), softmax_scale=softmax_scale, causal=causal.value() is not None, - dropout_rate=dropout_rate, + dropout_rate=0.0, ) elif backend == "tpu":