From 4b0733fac2425a2b0c804c8dd45f443ef6dc1b64 Mon Sep 17 00:00:00 2001
From: Dongseong Hwang <dhwang2@apple.com>
Date: Mon, 6 Jan 2025 14:29:53 -0800
Subject: [PATCH] Introduce `BaseAttentionBias.has_value()`.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

`BaseAttentionBias.value()` is also used to check whether a bias value exists.
However, calling `value()` creates the actual bias on the CPU, which is
quadratic O(T^2) — especially when debugging long sequence lengths
(e.g., 32k). Hence, we introduce a lightweight `has_value()` method.

Normally, unused tensors are pruned from the graph during XLA compilation, and
using `BaseAttentionBias.value()` to check for bias presence relies on XLA
pruning. But `BaseAttentionBias.value()` is still expensive on the CPU for
unittests and debugging, so `has_value()` saves Python runtime.

The new `has_value()` method checks whether `value()` actually exists by
calling `jax.eval_shape`. Since `jax.eval_shape` invokes `value()` through a
tracer, it doesn’t materialize the actual value (proposed by John Peebles).
---
 axlearn/common/attention_bias.py              | 29 ++++++--
 axlearn/common/attention_bias_test.py         | 67 +++++++++++++++++--
 .../common/flash_attention/tpu_attention.py   |  6 +-
 axlearn/common/flash_attention/utils.py       | 12 ++--
 4 files changed, 96 insertions(+), 18 deletions(-)

diff --git a/axlearn/common/attention_bias.py b/axlearn/common/attention_bias.py
index 9830b077a..589edc408 100644
--- a/axlearn/common/attention_bias.py
+++ b/axlearn/common/attention_bias.py
@@ -59,6 +59,28 @@ class BaseAttentionBias:
     # If None, do not cast the dtype.
     dtype: Optional[jnp.dtype] = struct.field(kw_only=True, default=None, pytree_node=False)
 
+    @final
+    def eval_shape(self) -> tuple[int, int, int, int]:
+        """Return the shape of the bias tensor.
+
+        Note: this doesn't materialize the value. jax.eval_shape calls value(), but it only does so
+        using tracers.
+
+        Returns
+            shape: [batch or 1, num_heads or 1, target_len, source_len].
+
+        Raises:
+            ValueError: If the bias has no value.
+        """
+        if not self.has_value():
+            raise ValueError("AttentionBias has no value.")
+        return jax.eval_shape(self.value).shape
+
+    @final
+    def has_value(self) -> bool:
+        """Return whether to the bias has a value."""
+        return jax.eval_shape(self.value) is not None
+
     @final
     def value(self) -> Optional[Tensor]:
         """Return a tensor with the biases or None if there are no biases.
@@ -116,9 +138,6 @@ def _broadcast_value(cls, value: OpT) -> OpT:
             return value[:, None, :, :]
         raise ValueError(f"Invalid attention_logit_biases shape: {value.shape}.")
 
-    def eval_shape(self):
-        return jax.eval_shape(self.value).shape
-
     def partition_spec(
         self, mha_dim_to_partition_spec: dict[str, PartitionSpec]
     ) -> Union["BaseAttentionBias", PartitionSpec]:
@@ -233,7 +252,7 @@ def _nonzero(self) -> Sequence[BaseAttentionBias]:
 
         Returned biases are not guaranteed to be nonzero, but are guaranteed to not return None.
         """
-        filt = lambda b: b.value() is not None
+        filt = lambda b: b.has_value()
         return list(filter(filt, self.biases))
 
     def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]":
@@ -260,7 +279,7 @@ def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]":
                 send_residual_to = remaining_biases
             else:
                 send_residual_to = residuals
-            if bias_and_residual.residual.value() is not None:
+            if bias_and_residual.residual.has_value():
                 send_residual_to.append(bias_and_residual.residual)
         return BiasAndResidual(
             bias=cls.from_sequence(cls_biases), residual=CompositeAttentionBias(residuals)
diff --git a/axlearn/common/attention_bias_test.py b/axlearn/common/attention_bias_test.py
index 358c62911..873852853 100644
--- a/axlearn/common/attention_bias_test.py
+++ b/axlearn/common/attention_bias_test.py
@@ -21,6 +21,15 @@
 
 
 class AttentionBiasTest(test_utils.TestCase):
+    @parameterized.parameters(
+        [attention_bias.ZeroAttentionBias(), False],
+        [attention_bias.CausalAttentionBias(shape=(5, 5)), True],
+        [attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)), True],
+        [attention_bias.TensorAttentionBias.from_tensor(jnp.ones((5, 5))), True],
+    )
+    def test_has_bias(self, bias, expected):
+        self.assertEqual(bias.has_value(), expected)
+
     def test_causal_attention_bias(self):
         bias = attention_bias.CausalAttentionBias(shape=(5, 5))
         chex.assert_trees_all_close(bias.value(), attention_bias.make_causal_biases(5)[None, None])
@@ -45,19 +54,19 @@ def test_base_attention_bias_value(self):
         # pylint: disable=function-redefined
 
         class TestAttentionBias(attention_bias.BaseAttentionBias):
-            def _value(self) -> Optional[Tensor]:
+            def _value(self) -> Tensor:
                 return jnp.ones((5, 7))
 
         self.assertEqual(TestAttentionBias().value().shape, (1, 1, 5, 7))
 
         class TestAttentionBias(attention_bias.BaseAttentionBias):
-            def _value(self) -> Optional[Tensor]:
+            def _value(self) -> Tensor:
                 return jnp.ones((3, 5, 7))
 
         self.assertEqual(TestAttentionBias().value().shape, (3, 1, 5, 7))
 
         class TestAttentionBias(attention_bias.BaseAttentionBias):
-            def _value(self) -> Optional[Tensor]:
+            def _value(self) -> Tensor:
                 return jnp.ones((2, 3, 5, 7))
 
         self.assertEqual(TestAttentionBias().value().shape, (2, 3, 5, 7))
@@ -77,6 +86,56 @@ def test_base_attention_bias_and_residual(self):
             bias.bias_and_residual(int), attention_bias.BiasAndResidual(bias=None, residual=bias)
         )
 
+    @parameterized.parameters(
+        [
+            attention_bias.CompositeAttentionBias(
+                [attention_bias.ZeroAttentionBias(), attention_bias.ZeroAttentionBias()]
+            ),
+            False,
+        ],
+        [
+            attention_bias.CompositeAttentionBias(
+                [
+                    attention_bias.CausalAttentionBias(shape=(5, 5)),
+                    attention_bias.CausalAttentionBias(shape=(5, 5)),
+                ]
+            ),
+            True,
+        ],
+        [
+            attention_bias.CompositeAttentionBias(
+                [
+                    attention_bias.CausalAttentionBias(shape=(5, 5)),
+                    attention_bias.ZeroAttentionBias(),
+                ]
+            ),
+            True,
+        ],
+        [
+            attention_bias.CompositeAttentionBias(
+                [
+                    attention_bias.ZeroAttentionBias(),
+                    attention_bias.CausalAttentionBias(shape=(5, 5)),
+                ]
+            ),
+            True,
+        ],
+    )
+    def test_composite_attention_has_bias(self, bias, expected):
+        self.assertEqual(bias.has_value(), expected)
+
+    def test_bias_and_residual_has_bias(self):
+        bias = attention_bias.CompositeAttentionBias(
+            [
+                attention_bias.CausalAttentionBias(shape=(5, 5)),
+                attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)),
+            ]
+        )
+        bias_and_residual = bias.bias_and_residual(attention_bias.CausalAttentionBias)
+        self.assertTrue(bias_and_residual.has_value())
+        bias_and_residual = bias.bias_and_residual(attention_bias.MaskFnAttentionBias)
+        self.assertTrue(bias_and_residual.has_value())
+
     def test_composite_attention_bias_zero(self):
         # Test handling of zero biases.
         bias = attention_bias.CompositeAttentionBias(
@@ -191,7 +250,7 @@ def test_split_subsets(
             attention_bias.SegmentIdAttentionBias,
             attention_bias.MaskFnAttentionBias,
         )
-        new_bias_list = [b if b.value() is not None else None for b in new_bias_list]
+        new_bias_list = [b if b.has_value() else None for b in new_bias_list]
         expected = [causal, segment_ids, mask, None]
         for b1, b2 in jax.util.safe_zip(new_bias_list, expected):
             self.assertIs(b1, b2)
diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py
index 4de717b4b..750dec74c 100644
--- a/axlearn/common/flash_attention/tpu_attention.py
+++ b/axlearn/common/flash_attention/tpu_attention.py
@@ -227,7 +227,7 @@ def _legacy_tpu_flash_attention(
         NotImplementedError: If a custom (non-causal, non-full) mask is specified.
     """
     causal = isinstance(mask, CausalAttentionBias)
-    if not causal and mask.value() is not None:
+    if not causal and mask.has_value():
         bias = apply_attention_logit_biases(mask.value(), bias)
 
     context = pallas_tpu_flash_attention(
@@ -291,7 +291,7 @@ def check_tpu_splash_attention(
             "The public API for SplashAttention that we "
             "currently use does not support segment ids."
         )
-    if mask.value() is not None:
+    if mask.has_value():
         assert isinstance(mask, MaskFnAttentionBias)
         if target_len != source_len:
             raise SplashAttentionUnsupportedError(
@@ -311,7 +311,7 @@ def _to_splash_mask(
     q_seq_shards: int = 1,
 ) -> splash_attention_mask.Mask:
     """Converts a mask to a splash mask."""
-    if mask.value() is None:
+    if not mask.has_value():
         return splash_attention_mask.FullMask(mask_shape)
     assert isinstance(mask, MaskFnAttentionBias)
     if isinstance(mask, CausalAttentionBias):
diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py
index 4115da579..2b4546969 100644
--- a/axlearn/common/flash_attention/utils.py
+++ b/axlearn/common/flash_attention/utils.py
@@ -162,7 +162,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
             """Return the segment ids Tensor from the sequence of segment ids attention
             biases or None if there are no segment ids.
             """
-            if segment_ids is None or segment_ids.value() is None:
+            if not segment_ids.has_value():
                 return None
             if query.shape[1] != key.shape[1]:
                 raise ValueError(
@@ -220,8 +220,8 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
             # - explicit_bias is not empty, or
             # - query/key/value is in float32.
             if (
-                segment_ids.value() is not None
-                or explicit_bias.value() is not None
+                segment_ids.has_value()
+                or explicit_bias.has_value()
                 or jnp.float32 in (query.dtype, key.dtype, value.dtype)
                 or query.shape[1] != key.shape[1]
                 or dropout_rate != 0.0
@@ -235,7 +235,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
                     segment_ids=get_segment_ids(segment_ids),
                     prng_key=prng_key,
                     softmax_scale=softmax_scale,
-                    causal=causal.value() is not None,
+                    causal=causal.has_value(),
                     dropout_rate=dropout_rate,
                 )
             else:
@@ -246,7 +246,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
                     value,
                     bias=explicit_bias.value(),
                     softmax_scale=softmax_scale,
-                    causal=causal.value() is not None,
+                    causal=causal.has_value(),
                     dropout_rate=0.0,
                 )
 
@@ -295,7 +295,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
                 bias=explicit_bias.value(),
                 segment_ids=get_segment_ids(segment_ids),
                 prng_key=prng_key,
-                causal=causal.value() is not None,
+                causal=causal.has_value(),
                 softmax_scale=softmax_scale,
                 dropout_rate=dropout_rate,
             )