diff --git a/axlearn/common/attention_bias.py b/axlearn/common/attention_bias.py index 9830b077a..589edc408 100644 --- a/axlearn/common/attention_bias.py +++ b/axlearn/common/attention_bias.py @@ -59,6 +59,28 @@ class BaseAttentionBias: # If None, do not cast the dtype. dtype: Optional[jnp.dtype] = struct.field(kw_only=True, default=None, pytree_node=False) + @final + def eval_shape(self) -> tuple[int, int, int, int]: + """Return the shape of the bias tensor. + + Note: this doesn't materialize the value. jax.eval_shape calls value(), but it only does so + using tracers. + + Returns + shape: [batch or 1, num_heads or 1, target_len, source_len]. + + Raises: + ValueError: If the bias has no value. + """ + if not self.has_value(): + raise ValueError("AttentionBias has no value.") + return jax.eval_shape(self.value).shape + + @final + def has_value(self) -> bool: + """Return whether to the bias has a value.""" + return jax.eval_shape(self.value) is not None + @final def value(self) -> Optional[Tensor]: """Return a tensor with the biases or None if there are no biases. @@ -116,9 +138,6 @@ def _broadcast_value(cls, value: OpT) -> OpT: return value[:, None, :, :] raise ValueError(f"Invalid attention_logit_biases shape: {value.shape}.") - def eval_shape(self): - return jax.eval_shape(self.value).shape - def partition_spec( self, mha_dim_to_partition_spec: dict[str, PartitionSpec] ) -> Union["BaseAttentionBias", PartitionSpec]: @@ -233,7 +252,7 @@ def _nonzero(self) -> Sequence[BaseAttentionBias]: Returned biases are not guaranteed to be nonzero, but are guaranteed to not return None. """ - filt = lambda b: b.value() is not None + filt = lambda b: b.has_value() return list(filter(filt, self.biases)) def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]": @@ -260,7 +279,7 @@ def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]": send_residual_to = remaining_biases else: send_residual_to = residuals - if bias_and_residual.residual.value() is not None: + if bias_and_residual.residual.has_value(): send_residual_to.append(bias_and_residual.residual) return BiasAndResidual( bias=cls.from_sequence(cls_biases), residual=CompositeAttentionBias(residuals) diff --git a/axlearn/common/attention_bias_test.py b/axlearn/common/attention_bias_test.py index 358c62911..873852853 100644 --- a/axlearn/common/attention_bias_test.py +++ b/axlearn/common/attention_bias_test.py @@ -21,6 +21,15 @@ class AttentionBiasTest(test_utils.TestCase): + @parameterized.parameters( + [attention_bias.ZeroAttentionBias(), False], + [attention_bias.CausalAttentionBias(shape=(5, 5)), True], + [attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)), True], + [attention_bias.TensorAttentionBias.from_tensor(jnp.ones((5, 5))), True], + ) + def test_has_bias(self, bias, expected): + self.assertEqual(bias.has_value(), expected) + def test_causal_attention_bias(self): bias = attention_bias.CausalAttentionBias(shape=(5, 5)) chex.assert_trees_all_close(bias.value(), attention_bias.make_causal_biases(5)[None, None]) @@ -45,19 +54,19 @@ def test_base_attention_bias_value(self): # pylint: disable=function-redefined class TestAttentionBias(attention_bias.BaseAttentionBias): - def _value(self) -> Optional[Tensor]: + def _value(self) -> Tensor: return jnp.ones((5, 7)) self.assertEqual(TestAttentionBias().value().shape, (1, 1, 5, 7)) class TestAttentionBias(attention_bias.BaseAttentionBias): - def _value(self) -> Optional[Tensor]: + def _value(self) -> Tensor: return jnp.ones((3, 5, 7)) self.assertEqual(TestAttentionBias().value().shape, (3, 1, 5, 7)) class TestAttentionBias(attention_bias.BaseAttentionBias): - def _value(self) -> Optional[Tensor]: + def _value(self) -> Tensor: return jnp.ones((2, 3, 5, 7)) self.assertEqual(TestAttentionBias().value().shape, (2, 3, 5, 7)) @@ -77,6 +86,56 @@ def test_base_attention_bias_and_residual(self): bias.bias_and_residual(int), attention_bias.BiasAndResidual(bias=None, residual=bias) ) + @parameterized.parameters( + [ + attention_bias.CompositeAttentionBias( + [attention_bias.ZeroAttentionBias(), attention_bias.ZeroAttentionBias()] + ), + False, + ], + [ + attention_bias.CompositeAttentionBias( + [ + attention_bias.CausalAttentionBias(shape=(5, 5)), + attention_bias.CausalAttentionBias(shape=(5, 5)), + ] + ), + True, + ], + [ + attention_bias.CompositeAttentionBias( + [ + attention_bias.CausalAttentionBias(shape=(5, 5)), + attention_bias.ZeroAttentionBias(), + ] + ), + True, + ], + [ + attention_bias.CompositeAttentionBias( + [ + attention_bias.ZeroAttentionBias(), + attention_bias.CausalAttentionBias(shape=(5, 5)), + ] + ), + True, + ], + ) + def test_composite_attention_has_bias(self, bias, expected): + self.assertEqual(bias.has_value(), expected) + + def test_bias_and_residual_has_bias(self): + bias = attention_bias.CompositeAttentionBias( + [ + attention_bias.CausalAttentionBias(shape=(5, 5)), + attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)), + ] + ) + bias_and_residual = bias.bias_and_residual(attention_bias.CausalAttentionBias) + self.assertTrue(bias_and_residual.has_value()) + bias_and_residual = bias.bias_and_residual(attention_bias.MaskFnAttentionBias) + self.assertTrue(bias_and_residual.has_value()) + def test_composite_attention_bias_zero(self): # Test handling of zero biases. bias = attention_bias.CompositeAttentionBias( @@ -191,7 +250,7 @@ def test_split_subsets( attention_bias.SegmentIdAttentionBias, attention_bias.MaskFnAttentionBias, ) - new_bias_list = [b if b.value() is not None else None for b in new_bias_list] + new_bias_list = [b if b.has_value() else None for b in new_bias_list] expected = [causal, segment_ids, mask, None] for b1, b2 in jax.util.safe_zip(new_bias_list, expected): self.assertIs(b1, b2) diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py index 4de717b4b..750dec74c 100644 --- a/axlearn/common/flash_attention/tpu_attention.py +++ b/axlearn/common/flash_attention/tpu_attention.py @@ -227,7 +227,7 @@ def _legacy_tpu_flash_attention( NotImplementedError: If a custom (non-causal, non-full) mask is specified. """ causal = isinstance(mask, CausalAttentionBias) - if not causal and mask.value() is not None: + if not causal and mask.has_value(): bias = apply_attention_logit_biases(mask.value(), bias) context = pallas_tpu_flash_attention( @@ -291,7 +291,7 @@ def check_tpu_splash_attention( "The public API for SplashAttention that we " "currently use does not support segment ids." ) - if mask.value() is not None: + if mask.has_value(): assert isinstance(mask, MaskFnAttentionBias) if target_len != source_len: raise SplashAttentionUnsupportedError( @@ -311,7 +311,7 @@ def _to_splash_mask( q_seq_shards: int = 1, ) -> splash_attention_mask.Mask: """Converts a mask to a splash mask.""" - if mask.value() is None: + if not mask.has_value(): return splash_attention_mask.FullMask(mask_shape) assert isinstance(mask, MaskFnAttentionBias) if isinstance(mask, CausalAttentionBias): diff --git a/axlearn/common/flash_attention/utils.py b/axlearn/common/flash_attention/utils.py index 4115da579..2b4546969 100644 --- a/axlearn/common/flash_attention/utils.py +++ b/axlearn/common/flash_attention/utils.py @@ -162,7 +162,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: """Return the segment ids Tensor from the sequence of segment ids attention biases or None if there are no segment ids. """ - if segment_ids is None or segment_ids.value() is None: + if not segment_ids.has_value(): return None if query.shape[1] != key.shape[1]: raise ValueError( @@ -220,8 +220,8 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: # - explicit_bias is not empty, or # - query/key/value is in float32. if ( - segment_ids.value() is not None - or explicit_bias.value() is not None + segment_ids.has_value() + or explicit_bias.has_value() or jnp.float32 in (query.dtype, key.dtype, value.dtype) or query.shape[1] != key.shape[1] or dropout_rate != 0.0 @@ -235,7 +235,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: segment_ids=get_segment_ids(segment_ids), prng_key=prng_key, softmax_scale=softmax_scale, - causal=causal.value() is not None, + causal=causal.has_value(), dropout_rate=dropout_rate, ) else: @@ -246,7 +246,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: value, bias=explicit_bias.value(), softmax_scale=softmax_scale, - causal=causal.value() is not None, + causal=causal.has_value(), dropout_rate=0.0, ) @@ -295,7 +295,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: bias=explicit_bias.value(), segment_ids=get_segment_ids(segment_ids), prng_key=prng_key, - causal=causal.value() is not None, + causal=causal.has_value(), softmax_scale=softmax_scale, dropout_rate=dropout_rate, )