Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce BaseAttentionBias.has_value(). #920

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,28 @@ class BaseAttentionBias:
# If None, do not cast the dtype.
dtype: Optional[jnp.dtype] = struct.field(kw_only=True, default=None, pytree_node=False)

@final
def eval_shape(self) -> tuple[int, int, int, int]:
"""Return the shape of the bias tensor.

Note: this doesn't materialize the value. jax.eval_shape calls value(), but it only does so
using tracers.

Returns
shape: [batch or 1, num_heads or 1, target_len, source_len].

Raises:
ValueError: If the bias has no value.
"""
if not self.has_value():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose multiple calls to jax.eval_shape are automatically cached?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find caching mechanism in jax impl. Would you mind asking how it's related to this PR? Can I merge it?

@api_boundary
def eval_shape(fun: Callable, *args, **kwargs):
  """Compute the shape/dtype of ``fun`` without any FLOPs.

  This utility function is useful for performing shape inference. Its
  input/output behavior is defined by::

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, in eval_shape, we make two calls to jax.eval_shape(self.value) (once in has_value and once in return) -- so ideally we only trace once. I checked and it seems to be the case, so please feel free to ignore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Good point. Thank you for check!

raise ValueError("AttentionBias has no value.")
return jax.eval_shape(self.value).shape

@final
def has_value(self) -> bool:
"""Return whether to the bias has a value."""
return jax.eval_shape(self.value) is not None

@final
def value(self) -> Optional[Tensor]:
"""Return a tensor with the biases or None if there are no biases.
Expand Down Expand Up @@ -116,9 +138,6 @@ def _broadcast_value(cls, value: OpT) -> OpT:
return value[:, None, :, :]
raise ValueError(f"Invalid attention_logit_biases shape: {value.shape}.")

def eval_shape(self):
return jax.eval_shape(self.value).shape

def partition_spec(
self, mha_dim_to_partition_spec: dict[str, PartitionSpec]
) -> Union["BaseAttentionBias", PartitionSpec]:
Expand Down Expand Up @@ -233,7 +252,7 @@ def _nonzero(self) -> Sequence[BaseAttentionBias]:

Returned biases are not guaranteed to be nonzero, but are guaranteed to not return None.
"""
filt = lambda b: b.value() is not None
filt = lambda b: b.has_value()
return list(filter(filt, self.biases))

def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]":
Expand All @@ -260,7 +279,7 @@ def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]":
send_residual_to = remaining_biases
else:
send_residual_to = residuals
if bias_and_residual.residual.value() is not None:
if bias_and_residual.residual.has_value():
send_residual_to.append(bias_and_residual.residual)
return BiasAndResidual(
bias=cls.from_sequence(cls_biases), residual=CompositeAttentionBias(residuals)
Expand Down
67 changes: 63 additions & 4 deletions axlearn/common/attention_bias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@


class AttentionBiasTest(test_utils.TestCase):
@parameterized.parameters(
[attention_bias.ZeroAttentionBias(), False],
[attention_bias.CausalAttentionBias(shape=(5, 5)), True],
[attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)), True],
[attention_bias.TensorAttentionBias.from_tensor(jnp.ones((5, 5))), True],
)
def test_has_bias(self, bias, expected):
self.assertEqual(bias.has_value(), expected)

def test_causal_attention_bias(self):
bias = attention_bias.CausalAttentionBias(shape=(5, 5))
chex.assert_trees_all_close(bias.value(), attention_bias.make_causal_biases(5)[None, None])
Expand All @@ -45,19 +54,19 @@ def test_base_attention_bias_value(self):
# pylint: disable=function-redefined

class TestAttentionBias(attention_bias.BaseAttentionBias):
def _value(self) -> Optional[Tensor]:
def _value(self) -> Tensor:
return jnp.ones((5, 7))

self.assertEqual(TestAttentionBias().value().shape, (1, 1, 5, 7))

class TestAttentionBias(attention_bias.BaseAttentionBias):
def _value(self) -> Optional[Tensor]:
def _value(self) -> Tensor:
return jnp.ones((3, 5, 7))

self.assertEqual(TestAttentionBias().value().shape, (3, 1, 5, 7))

class TestAttentionBias(attention_bias.BaseAttentionBias):
def _value(self) -> Optional[Tensor]:
def _value(self) -> Tensor:
return jnp.ones((2, 3, 5, 7))

self.assertEqual(TestAttentionBias().value().shape, (2, 3, 5, 7))
Expand All @@ -77,6 +86,56 @@ def test_base_attention_bias_and_residual(self):
bias.bias_and_residual(int), attention_bias.BiasAndResidual(bias=None, residual=bias)
)

@parameterized.parameters(
[
attention_bias.CompositeAttentionBias(
[attention_bias.ZeroAttentionBias(), attention_bias.ZeroAttentionBias()]
),
False,
],
[
attention_bias.CompositeAttentionBias(
[
attention_bias.CausalAttentionBias(shape=(5, 5)),
attention_bias.CausalAttentionBias(shape=(5, 5)),
]
),
True,
],
[
attention_bias.CompositeAttentionBias(
[
attention_bias.CausalAttentionBias(shape=(5, 5)),
attention_bias.ZeroAttentionBias(),
]
),
True,
],
[
attention_bias.CompositeAttentionBias(
[
attention_bias.ZeroAttentionBias(),
attention_bias.CausalAttentionBias(shape=(5, 5)),
]
),
True,
],
)
def test_composite_attention_has_bias(self, bias, expected):
self.assertEqual(bias.has_value(), expected)

def test_bias_and_residual_has_bias(self):
bias = attention_bias.CompositeAttentionBias(
[
attention_bias.CausalAttentionBias(shape=(5, 5)),
attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)),
]
)
bias_and_residual = bias.bias_and_residual(attention_bias.CausalAttentionBias)
self.assertTrue(bias_and_residual.has_value())
bias_and_residual = bias.bias_and_residual(attention_bias.MaskFnAttentionBias)
self.assertTrue(bias_and_residual.has_value())

def test_composite_attention_bias_zero(self):
# Test handling of zero biases.
bias = attention_bias.CompositeAttentionBias(
Expand Down Expand Up @@ -191,7 +250,7 @@ def test_split_subsets(
attention_bias.SegmentIdAttentionBias,
attention_bias.MaskFnAttentionBias,
)
new_bias_list = [b if b.value() is not None else None for b in new_bias_list]
new_bias_list = [b if b.has_value() else None for b in new_bias_list]
expected = [causal, segment_ids, mask, None]
for b1, b2 in jax.util.safe_zip(new_bias_list, expected):
self.assertIs(b1, b2)
Expand Down
6 changes: 3 additions & 3 deletions axlearn/common/flash_attention/tpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _legacy_tpu_flash_attention(
NotImplementedError: If a custom (non-causal, non-full) mask is specified.
"""
causal = isinstance(mask, CausalAttentionBias)
if not causal and mask.value() is not None:
if not causal and mask.has_value():
bias = apply_attention_logit_biases(mask.value(), bias)

context = pallas_tpu_flash_attention(
Expand Down Expand Up @@ -291,7 +291,7 @@ def check_tpu_splash_attention(
"The public API for SplashAttention that we "
"currently use does not support segment ids."
)
if mask.value() is not None:
if mask.has_value():
assert isinstance(mask, MaskFnAttentionBias)
if target_len != source_len:
raise SplashAttentionUnsupportedError(
Expand All @@ -311,7 +311,7 @@ def _to_splash_mask(
q_seq_shards: int = 1,
) -> splash_attention_mask.Mask:
"""Converts a mask to a splash mask."""
if mask.value() is None:
if not mask.has_value():
return splash_attention_mask.FullMask(mask_shape)
assert isinstance(mask, MaskFnAttentionBias)
if isinstance(mask, CausalAttentionBias):
Expand Down
12 changes: 6 additions & 6 deletions axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
"""Return the segment ids Tensor from the sequence of segment ids attention
biases or None if there are no segment ids.
"""
if segment_ids is None or segment_ids.value() is None:
if not segment_ids.has_value():
return None
if query.shape[1] != key.shape[1]:
raise ValueError(
Expand Down Expand Up @@ -220,8 +220,8 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
# - explicit_bias is not empty, or
# - query/key/value is in float32.
if (
segment_ids.value() is not None
or explicit_bias.value() is not None
segment_ids.has_value()
or explicit_bias.has_value()
or jnp.float32 in (query.dtype, key.dtype, value.dtype)
or query.shape[1] != key.shape[1]
or dropout_rate != 0.0
Expand All @@ -235,7 +235,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
segment_ids=get_segment_ids(segment_ids),
prng_key=prng_key,
softmax_scale=softmax_scale,
causal=causal.value() is not None,
causal=causal.has_value(),
dropout_rate=dropout_rate,
)
else:
Expand All @@ -246,7 +246,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
value,
bias=explicit_bias.value(),
softmax_scale=softmax_scale,
causal=causal.value() is not None,
causal=causal.has_value(),
dropout_rate=0.0,
)

Expand Down Expand Up @@ -295,7 +295,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
bias=explicit_bias.value(),
segment_ids=get_segment_ids(segment_ids),
prng_key=prng_key,
causal=causal.value() is not None,
causal=causal.has_value(),
softmax_scale=softmax_scale,
dropout_rate=dropout_rate,
)
Expand Down
Loading