Skip to content

Commit

Permalink
Introduce BaseAttentionBias.has_value(). (#920)
Browse files Browse the repository at this point in the history
`BaseAttentionBias.value()` is also used to check whether a bias value exists.
However, calling `value()` creates the actual bias on the CPU, which is
quadratic O(T^2) — especially when debugging long sequence lengths
(e.g., 32k). Hence, we introduce a lightweight `has_value()` method.

Normally, unused tensors are pruned from the graph during XLA compilation, and
using `BaseAttentionBias.value()` to check for bias presence relies on XLA
pruning. But `BaseAttentionBias.value()` is still expensive on the CPU for
unittests and debugging, so `has_value()` saves Python runtime.

The new `has_value()` method checks whether `value()` actually exists by
calling `jax.eval_shape`. Since `jax.eval_shape` invokes `value()` through a
tracer, it doesn’t materialize the actual value (proposed by John Peebles).
  • Loading branch information
ds-hwang authored Jan 13, 2025
1 parent 3f36108 commit feb8357
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 18 deletions.
29 changes: 24 additions & 5 deletions axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,28 @@ class BaseAttentionBias:
# If None, do not cast the dtype.
dtype: Optional[jnp.dtype] = struct.field(kw_only=True, default=None, pytree_node=False)

@final
def eval_shape(self) -> tuple[int, int, int, int]:
"""Return the shape of the bias tensor.
Note: this doesn't materialize the value. jax.eval_shape calls value(), but it only does so
using tracers.
Returns
shape: [batch or 1, num_heads or 1, target_len, source_len].
Raises:
ValueError: If the bias has no value.
"""
if not self.has_value():
raise ValueError("AttentionBias has no value.")
return jax.eval_shape(self.value).shape

@final
def has_value(self) -> bool:
"""Return whether to the bias has a value."""
return jax.eval_shape(self.value) is not None

@final
def value(self) -> Optional[Tensor]:
"""Return a tensor with the biases or None if there are no biases.
Expand Down Expand Up @@ -116,9 +138,6 @@ def _broadcast_value(cls, value: OpT) -> OpT:
return value[:, None, :, :]
raise ValueError(f"Invalid attention_logit_biases shape: {value.shape}.")

def eval_shape(self):
return jax.eval_shape(self.value).shape

def partition_spec(
self, mha_dim_to_partition_spec: dict[str, PartitionSpec]
) -> Union["BaseAttentionBias", PartitionSpec]:
Expand Down Expand Up @@ -233,7 +252,7 @@ def _nonzero(self) -> Sequence[BaseAttentionBias]:
Returned biases are not guaranteed to be nonzero, but are guaranteed to not return None.
"""
filt = lambda b: b.value() is not None
filt = lambda b: b.has_value()
return list(filter(filt, self.biases))

def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]":
Expand All @@ -260,7 +279,7 @@ def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]":
send_residual_to = remaining_biases
else:
send_residual_to = residuals
if bias_and_residual.residual.value() is not None:
if bias_and_residual.residual.has_value():
send_residual_to.append(bias_and_residual.residual)
return BiasAndResidual(
bias=cls.from_sequence(cls_biases), residual=CompositeAttentionBias(residuals)
Expand Down
67 changes: 63 additions & 4 deletions axlearn/common/attention_bias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@


class AttentionBiasTest(test_utils.TestCase):
@parameterized.parameters(
[attention_bias.ZeroAttentionBias(), False],
[attention_bias.CausalAttentionBias(shape=(5, 5)), True],
[attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)), True],
[attention_bias.TensorAttentionBias.from_tensor(jnp.ones((5, 5))), True],
)
def test_has_bias(self, bias, expected):
self.assertEqual(bias.has_value(), expected)

def test_causal_attention_bias(self):
bias = attention_bias.CausalAttentionBias(shape=(5, 5))
chex.assert_trees_all_close(bias.value(), attention_bias.make_causal_biases(5)[None, None])
Expand All @@ -45,19 +54,19 @@ def test_base_attention_bias_value(self):
# pylint: disable=function-redefined

class TestAttentionBias(attention_bias.BaseAttentionBias):
def _value(self) -> Optional[Tensor]:
def _value(self) -> Tensor:
return jnp.ones((5, 7))

self.assertEqual(TestAttentionBias().value().shape, (1, 1, 5, 7))

class TestAttentionBias(attention_bias.BaseAttentionBias):
def _value(self) -> Optional[Tensor]:
def _value(self) -> Tensor:
return jnp.ones((3, 5, 7))

self.assertEqual(TestAttentionBias().value().shape, (3, 1, 5, 7))

class TestAttentionBias(attention_bias.BaseAttentionBias):
def _value(self) -> Optional[Tensor]:
def _value(self) -> Tensor:
return jnp.ones((2, 3, 5, 7))

self.assertEqual(TestAttentionBias().value().shape, (2, 3, 5, 7))
Expand All @@ -77,6 +86,56 @@ def test_base_attention_bias_and_residual(self):
bias.bias_and_residual(int), attention_bias.BiasAndResidual(bias=None, residual=bias)
)

@parameterized.parameters(
[
attention_bias.CompositeAttentionBias(
[attention_bias.ZeroAttentionBias(), attention_bias.ZeroAttentionBias()]
),
False,
],
[
attention_bias.CompositeAttentionBias(
[
attention_bias.CausalAttentionBias(shape=(5, 5)),
attention_bias.CausalAttentionBias(shape=(5, 5)),
]
),
True,
],
[
attention_bias.CompositeAttentionBias(
[
attention_bias.CausalAttentionBias(shape=(5, 5)),
attention_bias.ZeroAttentionBias(),
]
),
True,
],
[
attention_bias.CompositeAttentionBias(
[
attention_bias.ZeroAttentionBias(),
attention_bias.CausalAttentionBias(shape=(5, 5)),
]
),
True,
],
)
def test_composite_attention_has_bias(self, bias, expected):
self.assertEqual(bias.has_value(), expected)

def test_bias_and_residual_has_bias(self):
bias = attention_bias.CompositeAttentionBias(
[
attention_bias.CausalAttentionBias(shape=(5, 5)),
attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)),
]
)
bias_and_residual = bias.bias_and_residual(attention_bias.CausalAttentionBias)
self.assertTrue(bias_and_residual.has_value())
bias_and_residual = bias.bias_and_residual(attention_bias.MaskFnAttentionBias)
self.assertTrue(bias_and_residual.has_value())

def test_composite_attention_bias_zero(self):
# Test handling of zero biases.
bias = attention_bias.CompositeAttentionBias(
Expand Down Expand Up @@ -191,7 +250,7 @@ def test_split_subsets(
attention_bias.SegmentIdAttentionBias,
attention_bias.MaskFnAttentionBias,
)
new_bias_list = [b if b.value() is not None else None for b in new_bias_list]
new_bias_list = [b if b.has_value() else None for b in new_bias_list]
expected = [causal, segment_ids, mask, None]
for b1, b2 in jax.util.safe_zip(new_bias_list, expected):
self.assertIs(b1, b2)
Expand Down
6 changes: 3 additions & 3 deletions axlearn/common/flash_attention/tpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _legacy_tpu_flash_attention(
NotImplementedError: If a custom (non-causal, non-full) mask is specified.
"""
causal = isinstance(mask, CausalAttentionBias)
if not causal and mask.value() is not None:
if not causal and mask.has_value():
bias = apply_attention_logit_biases(mask.value(), bias)

context = pallas_tpu_flash_attention(
Expand Down Expand Up @@ -291,7 +291,7 @@ def check_tpu_splash_attention(
"The public API for SplashAttention that we "
"currently use does not support segment ids."
)
if mask.value() is not None:
if mask.has_value():
assert isinstance(mask, MaskFnAttentionBias)
if target_len != source_len:
raise SplashAttentionUnsupportedError(
Expand All @@ -311,7 +311,7 @@ def _to_splash_mask(
q_seq_shards: int = 1,
) -> splash_attention_mask.Mask:
"""Converts a mask to a splash mask."""
if mask.value() is None:
if not mask.has_value():
return splash_attention_mask.FullMask(mask_shape)
assert isinstance(mask, MaskFnAttentionBias)
if isinstance(mask, CausalAttentionBias):
Expand Down
12 changes: 6 additions & 6 deletions axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
"""Return the segment ids Tensor from the sequence of segment ids attention
biases or None if there are no segment ids.
"""
if segment_ids is None or segment_ids.value() is None:
if not segment_ids.has_value():
return None
if query.shape[1] != key.shape[1]:
raise ValueError(
Expand Down Expand Up @@ -220,8 +220,8 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
# - explicit_bias is not empty, or
# - query/key/value is in float32.
if (
segment_ids.value() is not None
or explicit_bias.value() is not None
segment_ids.has_value()
or explicit_bias.has_value()
or jnp.float32 in (query.dtype, key.dtype, value.dtype)
or query.shape[1] != key.shape[1]
or dropout_rate != 0.0
Expand All @@ -235,7 +235,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
segment_ids=get_segment_ids(segment_ids),
prng_key=prng_key,
softmax_scale=softmax_scale,
causal=causal.value() is not None,
causal=causal.has_value(),
dropout_rate=dropout_rate,
)
else:
Expand All @@ -246,7 +246,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
value,
bias=explicit_bias.value(),
softmax_scale=softmax_scale,
causal=causal.value() is not None,
causal=causal.has_value(),
dropout_rate=0.0,
)

Expand Down Expand Up @@ -295,7 +295,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
bias=explicit_bias.value(),
segment_ids=get_segment_ids(segment_ids),
prng_key=prng_key,
causal=causal.value() is not None,
causal=causal.has_value(),
softmax_scale=softmax_scale,
dropout_rate=dropout_rate,
)
Expand Down

0 comments on commit feb8357

Please sign in to comment.