-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce BaseAttentionBias.has_value()
.
#920
Conversation
`BaseAttentionBias.value()` is also used to check whether a bias value exists. However, calling `value()` creates the actual bias on the CPU, which is quadratic O(T^2) — especially when debugging long sequence lengths (e.g., 32k). Hence, we introduce a lightweight `has_value()` method. Normally, unused tensors are pruned from the graph during XLA compilation, and using `BaseAttentionBias.value()` to check for bias presence relies on XLA pruning. But `BaseAttentionBias.value()` is still expensive on the CPU for unittests and debugging, so `has_value()` saves Python runtime. The new `has_value()` method checks whether `value()` actually exists by calling `jax.eval_shape`. Since `jax.eval_shape` invokes `value()` through a tracer, it doesn’t materialize the actual value (proposed by John Peebles).
Raises: | ||
ValueError: If the bias has no value. | ||
""" | ||
if not self.has_value(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose multiple calls to jax.eval_shape are automatically cached?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't find caching mechanism in jax impl. Would you mind asking how it's related to this PR? Can I merge it?
@api_boundary
def eval_shape(fun: Callable, *args, **kwargs):
"""Compute the shape/dtype of ``fun`` without any FLOPs.
This utility function is useful for performing shape inference. Its
input/output behavior is defined by::
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, in eval_shape
, we make two calls to jax.eval_shape(self.value)
(once in has_value
and once in return) -- so ideally we only trace once. I checked and it seems to be the case, so please feel free to ignore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. Good point. Thank you for check!
`BaseAttentionBias.value()` is also used to check whether a bias value exists. However, calling `value()` creates the actual bias on the CPU, which is quadratic O(T^2) — especially when debugging long sequence lengths (e.g., 32k). Hence, we introduce a lightweight `has_value()` method. Normally, unused tensors are pruned from the graph during XLA compilation, and using `BaseAttentionBias.value()` to check for bias presence relies on XLA pruning. But `BaseAttentionBias.value()` is still expensive on the CPU for unittests and debugging, so `has_value()` saves Python runtime. The new `has_value()` method checks whether `value()` actually exists by calling `jax.eval_shape`. Since `jax.eval_shape` invokes `value()` through a tracer, it doesn’t materialize the actual value (proposed by John Peebles).
BaseAttentionBias.value()
is also used to check whether a bias value exists. However, callingvalue()
creates the actual bias on the CPU, which is quadratic O(T^2) — especially when debugging long sequence lengths (e.g., 32k). Hence, we introduce a lightweighthas_value()
method.Normally, unused tensors are pruned from the graph during XLA compilation, and using
BaseAttentionBias.value()
to check for bias presence relies on XLA pruning. ButBaseAttentionBias.value()
is still expensive on the CPU for unittests and debugging, sohas_value()
saves Python runtime.The new
has_value()
method checks whethervalue()
actually exists by callingjax.eval_shape
. Sincejax.eval_shape
invokesvalue()
through a tracer, it doesn’t materialize the actual value (proposed by John Peebles).