Skip to content

Commit

Permalink
Cleanup MaskFnAttentionBias.target_positions. (#895)
Browse files Browse the repository at this point in the history
  • Loading branch information
apghml authored Dec 17, 2024
1 parent 4415bd2 commit 3ae8f9f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
33 changes: 25 additions & 8 deletions axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,30 +473,47 @@ class MaskFnAttentionBias(BoolAttentionBias):
shape: tuple[int, ...] = struct.field(kw_only=True, pytree_node=False)
# The positions in the query sequence that the mask should be computed for.
# I.e., `self.value()[batch, num_heads, i]` is the mask specifying what the query token at
# `target_positions[batch, num_heads i]` may attend to.
# If None, set `target_positions[batch, num_heads, i] = i`.
# Shape: [batch].
# `target_positions[batch, i]` may attend to.
# If None, set `target_positions[batch, i] = i`.
# Shape: [batch] or [batch, target_len]`.
# This is typically used during decoding to specify the locations in the sequence being
# being decoded. E.g., if we are decoding position 5 and 7 of the first and second batch
# entry respectively, we would set `target_positions = jnp.asarray([5, 7])`.
# The motivation for supporting such shapes is for use cases where time_step in transformers
# is not necessarily contiguous. E.g., speculative decoding, non-contiguous prompts,
# various papers that need it.
target_positions: Optional[Tensor] = None

def _bool_value(self) -> Optional[Tensor]:
"""Return a tensor with the boolean values from `self.mask` before they have been converted
to biases.
Shape:
- If `target_positions` is None: [target_len, source_len]
- Else: [batch, target_len, source_len].
Shape: [batch, target_len, source_len].
Raises:
NotImplementedError. If `target_positions.ndim not in [1,2]`.
"""
target_positions, source_positions = jnp.indices(self.shape, sparse=True)
# Shape: [batch, target_len, source_len].
target_positions, source_positions = target_positions[None], source_positions[None]
if self.target_positions is not None:
target_positions = self.target_positions
if target_positions.ndim not in [1, 2]:
raise NotImplementedError(f"Shape of target_positions: {target_positions.shape}.")
if target_positions.ndim == 1:
# Shape: [batch, target_len].
# pylint: disable-next=unsubscriptable-object
target_positions = target_positions[:, None] + jnp.arange(self.shape[0])
while target_positions.ndim < 3:
target_positions = target_positions[..., None]
elif target_positions.ndim == 2:
shape_with_batch_dim = (1, *self.shape)
# Raise an exception if shapes aren't compatible. We don't use the output.
jnp.broadcast_shapes(
(target_positions.shape[0], 1, target_positions.shape[1]), shape_with_batch_dim
)
else:
raise NotImplementedError(f"Invalid value {target_positions.ndim=}.")
target_positions = target_positions[..., None] # Shape: [batch, target_len, 1].

return self.mask(target_positions, source_positions) # pylint: disable=not-callable

@classmethod
Expand Down
20 changes: 20 additions & 0 deletions axlearn/common/attention_bias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,26 @@ def test_mask_fn_attention_bias(self):
expected = attention_bias.bool_to_bias(expected)[:, None, :]
self.assertNestedEqual(bias.value(), expected)

def test_mask_fn_attention_bias_target_positions_ndim(self):
"""Tests mask_fn_attention_bias` when `target_positions.ndim == 2."""
bias = attention_bias.MaskFnAttentionBias(
mask=attention_bias.causal_mask,
shape=(5, 5),
target_positions=jnp.asarray([[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]]),
)
expected = jnp.asarray(
[
[
attention_bias.causal_mask(*jnp.indices([5, 5])),
],
[
attention_bias.causal_mask(*jnp.indices([5, 5]))[::-1, :],
],
],
dtype=bool,
)
self.assertNestedEqual(bias.bool_value(), expected)

def test_bool_tensor_attention_bias(self):
bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool))
self.assertNestedEqual(
Expand Down

0 comments on commit 3ae8f9f

Please sign in to comment.