Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce BaseAttentionBias.has_value(). #920

Merged
merged 1 commit into from
Jan 13, 2025
Merged

Conversation

ds-hwang
Copy link
Contributor

BaseAttentionBias.value() is also used to check whether a bias value exists. However, calling value() creates the actual bias on the CPU, which is quadratic O(T^2) — especially when debugging long sequence lengths (e.g., 32k). Hence, we introduce a lightweight has_value() method.

Normally, unused tensors are pruned from the graph during XLA compilation, and using BaseAttentionBias.value() to check for bias presence relies on XLA pruning. But BaseAttentionBias.value() is still expensive on the CPU for unittests and debugging, so has_value() saves Python runtime.

The new has_value() method checks whether value() actually exists by calling jax.eval_shape. Since jax.eval_shape invokes value() through a tracer, it doesn’t materialize the actual value (proposed by John Peebles).

`BaseAttentionBias.value()` is also used to check whether a bias value exists.
However, calling `value()` creates the actual bias on the CPU, which is
quadratic O(T^2) — especially when debugging long sequence lengths
(e.g., 32k). Hence, we introduce a lightweight `has_value()` method.

Normally, unused tensors are pruned from the graph during XLA compilation, and
using `BaseAttentionBias.value()` to check for bias presence relies on XLA
pruning. But `BaseAttentionBias.value()` is still expensive on the CPU for
unittests and debugging, so `has_value()` saves Python runtime.

The new `has_value()` method checks whether `value()` actually exists by
calling `jax.eval_shape`. Since `jax.eval_shape` invokes `value()` through a
tracer, it doesn’t materialize the actual value (proposed by John Peebles).
@ds-hwang ds-hwang requested review from ruomingp, markblee and a team as code owners January 13, 2025 20:23
@ds-hwang
Copy link
Contributor Author

@ruomingp, @markblee Could you take a look? From 972

@ds-hwang ds-hwang enabled auto-merge January 13, 2025 20:25
Raises:
ValueError: If the bias has no value.
"""
if not self.has_value():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose multiple calls to jax.eval_shape are automatically cached?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find caching mechanism in jax impl. Would you mind asking how it's related to this PR? Can I merge it?

@api_boundary
def eval_shape(fun: Callable, *args, **kwargs):
  """Compute the shape/dtype of ``fun`` without any FLOPs.

  This utility function is useful for performing shape inference. Its
  input/output behavior is defined by::

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, in eval_shape, we make two calls to jax.eval_shape(self.value) (once in has_value and once in return) -- so ideally we only trace once. I checked and it seems to be the case, so please feel free to ignore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Good point. Thank you for check!

@ds-hwang ds-hwang added this pull request to the merge queue Jan 13, 2025
github-merge-queue bot pushed a commit that referenced this pull request Jan 13, 2025
`BaseAttentionBias.value()` is also used to check whether a bias value exists.
However, calling `value()` creates the actual bias on the CPU, which is
quadratic O(T^2) — especially when debugging long sequence lengths
(e.g., 32k). Hence, we introduce a lightweight `has_value()` method.

Normally, unused tensors are pruned from the graph during XLA compilation, and
using `BaseAttentionBias.value()` to check for bias presence relies on XLA
pruning. But `BaseAttentionBias.value()` is still expensive on the CPU for
unittests and debugging, so `has_value()` saves Python runtime.

The new `has_value()` method checks whether `value()` actually exists by
calling `jax.eval_shape`. Since `jax.eval_shape` invokes `value()` through a
tracer, it doesn’t materialize the actual value (proposed by John Peebles).
@ds-hwang ds-hwang removed this pull request from the merge queue due to a manual request Jan 13, 2025
@ds-hwang ds-hwang added this pull request to the merge queue Jan 13, 2025
Merged via the queue into apple:main with commit feb8357 Jan 13, 2025
6 checks passed
@ds-hwang ds-hwang deleted the bias branch January 13, 2025 22:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants