-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DynamicCache does not support variable lengths, except for FA2 #35168
Comments
One thing to note is that we would need to know which indices from cache were discarded after the prefill stage for each layer, so we can correctly slice the attention mask when there are padding tokens. Currently what we have is to crop the mask to the key length, but in the example code it falied because cache was shorter than key in subsequent uncropped layers. So probably we would need to obtain
Lets wait also for @gante here |
I did not think about padding tokens. If we ignore them, the only thing to do would be to cut the causal mask to have the right length right ? Is it what FA2 is doing ? Could be useful to add a warning_once somewhere if that's case. |
Yes, without padding it is quite easy, but we cannot assume users never pad their inputs What FA2 is doing is not passing the mask (when no pad tokens in input) so we don't have to have actual qkv without pad tokens. Which is the case for the example code snippet If you had pad token in example code, FA2 would start passing attn mask, same was as SDPA/eager and prob causing errors. So FA2 is working only when the inputs has no pad tokens thus no zeros on the mask So if we were to slice the causal mask differently on each layer, we'll need a way to handle padding tokens. For that the only wau I can think of is to hold a |
Would it be possible to do the easy fix for the no pad token case similar to FA2 ? It would solve the issue for us and we can open another issue for fixing the pad token case. |
@SimJeg for no padding case, SDPA should also work if we call So calling like Feel free to propose an initial design by opening a PR and we can move the discussion there :) |
For our use case, we do need to support new inputs with a sequence length greater than 1 :/ More context: kvpress benchmarks KV cache compression methods and introduces a transformers pipeline called |
Thanks for explaining, now it makes more sense where does this fit.
Coming back here, if you want to start from testing/evaluating for inputs of single length without pad tokens, the easy way would be to prepare your own 4D mask here so that it is the maximum possible length from all layers and pass it to model. The shape should be
But since we want it to generalize to all input lengths, we'll need to somehow add logic to crop the correct indices (pad or no-pad id) in the mask. I am not very pro of adding all that logic in model code given it is used only for special cache classes. So what I am thinking now is to allow users (i.e. kvpress) to prepare and pass their own attn mask for each layer, which mean a mask of shape Would love to hear your opinion on this |
@zucchini-nlp I'm wondering if the issue is not solved by this recent PR,see here: def sdpa_attention_forward(...):
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]] Will check after holidays ! |
@SimJeg yeah, we've always had the slicing on the last dimension afaik. The issue here is that the query length dimensions are not matching due to the way causal mask is prepared internally Let me know if it works for you now ootb, but I think preparing and passing a 4D mask outside of the |
Hi @SimJeg 👋 We've updated the attention layers on our models, and I can no longer reproduce your issue on It seems then that this issue is sorted 🤗 @SimJeg can you please confirm on your end? If the issue is indeed solved, I'm planning to add a model-agnostic version of your script as a test, to ensure we don't regress. |
@gante yeah, thanks for the update ! I re-installed the repo from main but when I run the above script with attn_implementation = "sdpa" I still get the same error: RuntimeError: The expanded size of the tensor (10) must match the existing size (9) at non-singleton dimension 3. Target sizes: [3, 32, 5, 10]. Tensor sizes: [3, 1, 5, 9] Hope my install is correct... |
derp, I copy pasted your script and didn't change let me catch up with this thread, and see what can be done |
Hey all! Yes, the issue will remain with above script because we prepare causal attention based on the length of first layer cache. And when KVPress crops first layer cache to contain less tokens than the subsequent layers, we hit the error So, what I proposed is for KVPress to pass a 4D causal mask into forward pass, since the repo anyway doesn't call generate directly. Didn't have time to actually test it, but should work for most cases. I can make a small example to test and share tomorrow here if it's oke |
Here is a small example that will run, though the mask is not fully correct. In the first cropped layer, the mask is still attending to token it should not be attending because we prepare one mask for all layers. To fix that, we need to allow users to pass a mask of shape import torch
from transformers import AutoModelForCausalLM
from transformers.cache_utils import DynamicCache, StaticCache
device = "cuda:0"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
attn_implementation = "eager"
cache = DynamicCache()
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype="auto", attn_implementation=attn_implementation, device_map=device)
# Run first forward pass
inputs = model.dummy_inputs["input_ids"].to(device)
attention_mask = torch.ones_like(inputs)
cache_position = torch.arange(inputs.shape[1]).to(device)
with torch.no_grad():
model(inputs, attention_mask=attention_mask, cache_position=cache_position, past_key_values=cache)
print("First forward pass")
print(f"Layer 0: {cache.key_cache[0].shape}")
print(f"Layer 1: {cache.key_cache[1].shape}")
# Update the cache size for the first layer
print("Removing 2 KV pairs for the first layer")
cache.key_cache[0] = cache.key_cache[0][:, :, :-2, :]
cache.value_cache[0] = cache.value_cache[0][:, :, :-2, :]
# Update attn mask to account for the INIT-PROMPT and QUESTION lengths assuming attention for the question is all ones
question_input_ids = inputs.clone()
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], question_input_ids.shape[1]))], dim=-1)
cache_position = torch.arange(
cache_position[-1] + 1, cache_position[-1] + question_input_ids.shape[1] + 1, dtype=cache_position.dtype
).to(cache_position.device)
# Get the max past kv length from all layers
max_past_seen_tokens = 0
if cache is not None:
max_past_seen_tokens = max([cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))])
# Adapted from `LlamaModel._update_causal_mask` (TODO: StaticCache might not work as expected so need to check)
dtype, device = model.dtype, model.device
batch_size, sequence_length = question_input_ids.shape
max_target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else max_past_seen_tokens + sequence_length + 1
)
# Adapted from `LlamaModel._prepare_4d_causal_attention_mask_with_cache_position`
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, max_target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(max_target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
causal_mask_batch[layer_idx] = causal_mask
with torch.no_grad():
model(question_input_ids, attention_mask=causal_mask, cache_position=cache_position, past_key_values=cache)
print("Second forward pass")
print(f"Layer 0: {cache.key_cache[0].shape}")
print(f"Layer 1: {cache.key_cache[1].shape}") |
(@SimJeg meanwhile, we're also exploring the possibility to solve the issue with 0 code changes on your end. We're assessing feasibility on our end) |
Hi, the feature is not a very high priority for kvpress so we're ok to wait
for a little is there is a 0 code change solution !
Le jeu. 9 janv. 2025, 19:38, Joao Gante ***@***.***> a écrit :
… ***@***.*** <https://github.com/SimJeg> meanwhile, we're also exploring the
possibility to solve the issue with 0 code changes on your end. We're
assessing feasibility on our end)
—
Reply to this email directly, view it on GitHub
<#35168 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ADE64VJ4S4LQXJUBA3DSO7D2J26ZXAVCNFSM6AAAAABTI26PI6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKOBRGAYDOOJUG4>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Hey @SimJeg 👋 (cc @zucchini-nlp) We've chatted internally about what we can do to better support KVPress, and we'll merge the changes that allow a solution with little code on your end :) TL;DR:
|
That's great news @gante ! It's ok for us to raise an error for eager as we only have 1 press that requires this attention module. Having out of the box support for sdpa is great however for users without access to Ampere GPUs. I'll update our repo once the feature is released. |
System Info
transformers
version: 4.47.0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The following code will reproduce the issue:
Expected behavior
In the kvpress repository, we implement various KV cache compression methods. For some of these methods, the KV cache compression is not the same from one layer to the other (see for instance NVIDIA/kvpress#28). We noticed it's not an issue when using
attn_implementation="flash_attention_2"
but raises an error forattn_implementation="sdpa"
orattn_implementation="eager"
.Here is the error from the script above:
The issue is that the
causal_mask
is not updated per-layer forsdpa
andeager
while it is for flash attention.The text was updated successfully, but these errors were encountered: