-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[generation] Support cache-cropping methods #35591
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it solves the issue for #35168 with no major code changes from KVPress side, LGTM
My only concern is that it is slightly breaking, if users were somehow exploiting this right-side slicing feature before
@ArthurZucker, this draft PR makes a small modeling change to enable an advanced cache feature, as used e.g. in Nvidia's KVPress library -- see the PR header for full details. If you agree with these changes, I will propagate them to other models before requesting a final review :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Sorry but I don't think it makes sense to do this!
It's a very niche case, and is not transparent for people who just want to see how llama works!
@ArthurZucker makes sense. In fact, downstream libraries can implement their own custom attention implementation and crop the mask as they wish there. We would be able to unblock this cool use case in other libs if we merge the non-modeling change, i.e. make the EDIT: chatted on slack, moving forward with the non-model changes |
3a58fca
to
cf263b8
Compare
@@ -1037,7 +1037,7 @@ def _update_causal_mask( | |||
target_length = ( | |||
attention_mask.shape[-1] | |||
if isinstance(attention_mask, torch.Tensor) | |||
else past_seen_tokens + sequence_length + 1 | |||
else past_seen_tokens + sequence_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment applies to all models: the +1
not needed, we are discarding its corresponding data in the attn forward pass (e.g. here)
By removing it, we should expect a minimal improvement wrt memory and inference time. We also ensure that the causal mask has exactly the length of the (cached) input length, which is a requirement for the downstream use case this PR wants to enable.
Btw, I am working with Mllama now and remembered it has weird cache where cross attention layers have a very large length. Can we make sure mllama generation tests dont fail after this change? |
@zucchini-nlp correct, it does change. I'm adding a test to catch more models like this. EDIT: mllama is the only model that fails this test :) btw, |
That is too much for the Currently I am removing skip from many generation tests that checked For other models, in VLM side that's the only one and maybe ImageGPT is a bit peculiar. From LLM side all I can think is the Hybrid Cache models where sliding window length is fixed and can be lower/higher than static layers |
0a27d11
to
d17ee05
Compare
What does this PR do?
Adds the requirements to solve #35168
[EDITED after the discussion ending on this comment]
New dynamic cache compression methods may crop the cache differently at each layer. Doing so breaks our implicit assumption regarding the shape of the dynamic cache, where all layers have the same length. As we can see in the issue tagged above, we obtain an exception if we decide to do so.
We can enable it with small code changes in downstream libraries if we prepare the causal mask based on the maximum sequence length seen in all layers in the cache. A downstream library would then have to implement a custom attention forward pass, to left-crop the attention mask accordingly (e.g. using the length of the
key
)The following tests were run, with no regressions compared to
main
:RUN_SLOW=1 py.test tests/utils/test_cache_utils.py -vv
RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv