Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[generation] Support cache-cropping methods #35591

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

gante
Copy link
Member

@gante gante commented Jan 9, 2025

What does this PR do?

Adds the requirements to solve #35168

[EDITED after the discussion ending on this comment]

New dynamic cache compression methods may crop the cache differently at each layer. Doing so breaks our implicit assumption regarding the shape of the dynamic cache, where all layers have the same length. As we can see in the issue tagged above, we obtain an exception if we decide to do so.

We can enable it with small code changes in downstream libraries if we prepare the causal mask based on the maximum sequence length seen in all layers in the cache. A downstream library would then have to implement a custom attention forward pass, to left-crop the attention mask accordingly (e.g. using the length of the key)


The following tests were run, with no regressions compared to main:

  • RUN_SLOW=1 py.test tests/utils/test_cache_utils.py -vv
  • RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it solves the issue for #35168 with no major code changes from KVPress side, LGTM

My only concern is that it is slightly breaking, if users were somehow exploiting this right-side slicing feature before

@gante gante changed the title [draft] support cache cropping Support cache-cropping methods Jan 10, 2025
@gante gante changed the title Support cache-cropping methods [generation] Support cache-cropping methods Jan 10, 2025
@gante
Copy link
Member Author

gante commented Jan 10, 2025

@ArthurZucker, this draft PR makes a small modeling change to enable an advanced cache feature, as used e.g. in Nvidia's KVPress library -- see the PR header for full details.

If you agree with these changes, I will propagate them to other models before requesting a final review :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Sorry but I don't think it makes sense to do this!
It's a very niche case, and is not transparent for people who just want to see how llama works!

@gante
Copy link
Member Author

gante commented Jan 13, 2025

@ArthurZucker makes sense. In fact, downstream libraries can implement their own custom attention implementation and crop the mask as they wish there.

We would be able to unblock this cool use case in other libs if we merge the non-modeling change, i.e. make the get_seq_length(layer=None) call return the longest sequence length for all layers, as opposed to the sequence length for layer 0 (current behavior). The two are equivalent except on this advanced use case. This change looks okay to you, correct?

EDIT: chatted on slack, moving forward with the non-model changes

@@ -1037,7 +1037,7 @@ def _update_causal_mask(
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
else past_seen_tokens + sequence_length
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment applies to all models: the +1 not needed, we are discarding its corresponding data in the attn forward pass (e.g. here)

By removing it, we should expect a minimal improvement wrt memory and inference time. We also ensure that the causal mask has exactly the length of the (cached) input length, which is a requirement for the downstream use case this PR wants to enable.

@gante gante marked this pull request as ready for review January 13, 2025 14:40
@gante gante requested a review from ArthurZucker January 13, 2025 14:40
@zucchini-nlp
Copy link
Member

Btw, I am working with Mllama now and remembered it has weird cache where cross attention layers have a very large length. Can we make sure mllama generation tests dont fail after this change?

@gante
Copy link
Member Author

gante commented Jan 16, 2025

@zucchini-nlp correct, it does change. I'm adding a test to catch more models like this. EDIT: mllama is the only model that fails this test :)

btw, RUN_SLOW=1 py.test tests/models/mllama/test_modeling_mllama.py results in 11 failures on both main and this branch 😬

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Jan 16, 2025

That is too much for the main branch, afaik mllama will not be able to run the slow tests because we are in EU and the repo is gated 😢

Currently I am removing skip from many generation tests that checked hasattr(config, "use_cache") otherwise all of them are skipped in multimodal models (PR will be submitted soon). In mlllama case it also skips all tests that need caching I guess. From code it seems like the get_seq_length() is used only when preparing cache_position if None and when preparing the attention mask. So I think the model might fail whenever there is no mask to infer length or when continuing generation from "initial prompt"

For other models, in VLM side that's the only one and maybe ImageGPT is a bit peculiar. From LLM side all I can think is the Hybrid Cache models where sliding window length is fixed and can be lower/higher than static layers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants