Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DynamicCache does not support variable lengths, except for FA2 #35168

Open
2 of 4 tasks
SimJeg opened this issue Dec 9, 2024 · 19 comments
Open
2 of 4 tasks

DynamicCache does not support variable lengths, except for FA2 #35168

SimJeg opened this issue Dec 9, 2024 · 19 comments

Comments

@SimJeg
Copy link

SimJeg commented Dec 9, 2024

System Info

  • transformers version: 4.47.0
  • Using GPU in script?: yes
  • GPU type: NVIDIA A100-SXM4-80GB

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The following code will reproduce the issue:

import torch
from transformers import AutoModelForCausalLM
from transformers.cache_utils import DynamicCache


device = "cuda:0"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
attn_implementation = "flash_attention_2"
cache = DynamicCache()
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype="auto", attn_implementation=attn_implementation).to(
    device
)


# Run first forward pass
inputs = model.dummy_inputs["input_ids"].to(device)
with torch.no_grad():
    model(inputs, past_key_values=cache)
    print("First forward pass")
    print(f"Layer 0: {cache.key_cache[0].shape}")
    print(f"Layer 1: {cache.key_cache[1].shape}")

# Update the cache size for the first layer
print("Removing 2 KV pairs for the first layer")
cache.key_cache[0] = cache.key_cache[0][:, :, :-2, :]
cache.value_cache[0] = cache.value_cache[0][:, :, :-2, :]

with torch.no_grad():
    model(inputs, past_key_values=cache)
    print("Second forward pass")
    print(f"Layer 0: {cache.key_cache[0].shape}")
    print(f"Layer 1: {cache.key_cache[1].shape}")

Expected behavior

In the kvpress repository, we implement various KV cache compression methods. For some of these methods, the KV cache compression is not the same from one layer to the other (see for instance NVIDIA/kvpress#28). We noticed it's not an issue when using attn_implementation="flash_attention_2" but raises an error for attn_implementation="sdpa" or attn_implementation="eager".

Here is the error from the script above:

--> [336](.../transformers/models/llama/modeling_llama.py:336)     attn_weights = attn_weights + causal_mask
    [338](.../transformers/models/llama/modeling_llama.py:338) # upcast attention to fp32
    [339](.../transformers/models/llama/modeling_llama.py:339) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

RuntimeError: The size of tensor a (10) must match the size of tensor b (9) at non-singleton dimension 3

The issue is that the causal_mask is not updated per-layer for sdpa and eager while it is for flash attention.

@SimJeg SimJeg added the bug label Dec 9, 2024
@Rocketknight1
Copy link
Member

cc @gante @zucchini-nlp

@zucchini-nlp
Copy link
Member

One thing to note is that we would need to know which indices from cache were discarded after the prefill stage for each layer, so we can correctly slice the attention mask when there are padding tokens.

Currently what we have is to crop the mask to the key length, but in the example code it falied because cache was shorter than key in subsequent uncropped layers. So probably we would need to obtain indices of remaining cache and slice attention with those indices, though not sure if we can save those indices as cache attribute per layer without causing compile errors for StaticCache.

causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]

Lets wait also for @gante here

@SimJeg
Copy link
Author

SimJeg commented Dec 11, 2024

I did not think about padding tokens. If we ignore them, the only thing to do would be to cut the causal mask to have the right length right ? Is it what FA2 is doing ? Could be useful to add a warning_once somewhere if that's case.

@zucchini-nlp
Copy link
Member

Yes, without padding it is quite easy, but we cannot assume users never pad their inputs

What FA2 is doing is not passing the mask (when no pad tokens in input) so we don't have to have actual qkv without pad tokens. Which is the case for the example code snippet If you had pad token in example code, FA2 would start passing attn mask, same was as SDPA/eager and prob causing errors. So FA2 is working only when the inputs has no pad tokens thus no zeros on the mask

So if we were to slice the causal mask differently on each layer, we'll need a way to handle padding tokens. For that the only wau I can think of is to hold a cache_indices in each cache layer to do causal_mask[:, :,, :, cache_indices]. Yet we need to make sure it isn't an obstacle for torch.compile

@SimJeg
Copy link
Author

SimJeg commented Dec 12, 2024

Would it be possible to do the easy fix for the no pad token case similar to FA2 ? It would solve the issue for us and we can open another issue for fixing the pad token case.

@zucchini-nlp
Copy link
Member

@SimJeg for no padding case, SDPA should also work if we call generate(). The reason it fails in example above is because it passed the inputs as is in the second call. If we were to actually generate, we will need to pass only one token which is the newly generated token from pre-fill phase. In that case, SDPA also ignores attn mask and lets the is_causal=True/False to handle everything internally

So calling like model(inputs[:, -1:], past_key_values=cache) second time work on sdpa. Though for eager attn we would need a PR

Feel free to propose an initial design by opening a PR and we can move the discussion there :)

@SimJeg
Copy link
Author

SimJeg commented Dec 13, 2024

For our use case, we do need to support new inputs with a sequence length greater than 1 :/

More context: kvpress benchmarks KV cache compression methods and introduces a transformers pipeline called KVPressTextGeneration. The inputs of the pipeline are a context to be compressed (e.g. a newspaper article), and optionally a question (e.g a question on the article). Hence before decoding, we need to run inference for the tokens of the question with the cache of the compressed context. And it appears that many KV cache compression methods do layer-wise compression. Supporting SDPA would be enough as we only have 1 press using eager attn.

@zucchini-nlp
Copy link
Member

Thanks for explaining, now it makes more sense where does this fit.

Would it be possible to do the easy fix for the no pad token case similar to FA2 ?

Coming back here, if you want to start from testing/evaluating for inputs of single length without pad tokens, the easy way would be to prepare your own 4D mask here so that it is the maximum possible length from all layers and pass it to model. The shape should be [bs, 1, input-len, max-key-len-from-all-cache-layers].

https://github.com/NVIDIA/kvpress/blob/2b350b05ab3e8dd92a7143ee9f042509a310daa9/kvpress/pipeline.py#L239-L245

For our use case, we do need to support new inputs with a sequence length greater than 1 :/

But since we want it to generalize to all input lengths, we'll need to somehow add logic to crop the correct indices (pad or no-pad id) in the mask. I am not very pro of adding all that logic in model code given it is used only for special cache classes. So what I am thinking now is to allow users (i.e. kvpress) to prepare and pass their own attn mask for each layer, which mean a mask of shape [num-layers, bs, heads, input-len, max-key-len]. If the lengths are different per layer, it can be padded on the last dimension. The model will take care to crop the mask to the same length as keys for that specific layer.

Would love to hear your opinion on this

@SimJeg
Copy link
Author

SimJeg commented Dec 20, 2024

@zucchini-nlp I'm wondering if the issue is not solved by this recent PR,see here:

def sdpa_attention_forward(...):
    causal_mask = attention_mask
    if attention_mask is not None:
        causal_mask = causal_mask[:, :, :, : key.shape[-2]]

Will check after holidays !

@zucchini-nlp
Copy link
Member

@SimJeg yeah, we've always had the slicing on the last dimension afaik. The issue here is that the query length dimensions are not matching due to the way causal mask is prepared internally

Let me know if it works for you now ootb, but I think preparing and passing a 4D mask outside of the forward call can solve most cases

@gante
Copy link
Member

gante commented Jan 8, 2025

Hi @SimJeg 👋

We've updated the attention layers on our models, and I can no longer reproduce your issue on main. Looking at the updated code on llama, we now slice the causal mask based on the key length, which factors in any custom KV cache slicing.

It seems then that this issue is sorted 🤗 @SimJeg can you please confirm on your end?

If the issue is indeed solved, I'm planning to add a model-agnostic version of your script as a test, to ensure we don't regress.

@SimJeg
Copy link
Author

SimJeg commented Jan 8, 2025

@gante yeah, thanks for the update ! I re-installed the repo from main but when I run the above script with attn_implementation = "sdpa" I still get the same error:

RuntimeError: The expanded size of the tensor (10) must match the existing size (9) at non-singleton dimension 3.  Target sizes: [3, 32, 5, 10].  Tensor sizes: [3, 1, 5, 9]

Hope my install is correct...

@gante
Copy link
Member

gante commented Jan 8, 2025

derp, I copy pasted your script and didn't change attn_implementation 👀

let me catch up with this thread, and see what can be done

@zucchini-nlp
Copy link
Member

Hey all!

Yes, the issue will remain with above script because we prepare causal attention based on the length of first layer cache. And when KVPress crops first layer cache to contain less tokens than the subsequent layers, we hit the error

So, what I proposed is for KVPress to pass a 4D causal mask into forward pass, since the repo anyway doesn't call generate directly. Didn't have time to actually test it, but should work for most cases. I can make a small example to test and share tomorrow here if it's oke

@zucchini-nlp
Copy link
Member

@SimJeg @gante

Here is a small example that will run, though the mask is not fully correct. In the first cropped layer, the mask is still attending to token it should not be attending because we prepare one mask for all layers. To fix that, we need to allow users to pass a mask of shape (num_layers, *4D mask) so we can prepare causal mask for each layer individually by padding to the right if the lengths of cache are different

import torch
from transformers import AutoModelForCausalLM
from transformers.cache_utils import DynamicCache, StaticCache


device = "cuda:0"
ckpt = "meta-llama/Meta-Llama-3.1-8B-Instruct"
attn_implementation = "eager"
cache = DynamicCache()
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype="auto", attn_implementation=attn_implementation, device_map=device)


# Run first forward pass
inputs = model.dummy_inputs["input_ids"].to(device)
attention_mask = torch.ones_like(inputs)
cache_position = torch.arange(inputs.shape[1]).to(device)

with torch.no_grad():
    model(inputs, attention_mask=attention_mask, cache_position=cache_position, past_key_values=cache)
    print("First forward pass")
    print(f"Layer 0: {cache.key_cache[0].shape}")
    print(f"Layer 1: {cache.key_cache[1].shape}")


# Update the cache size for the first layer
print("Removing 2 KV pairs for the first layer")
cache.key_cache[0] = cache.key_cache[0][:, :, :-2, :]
cache.value_cache[0] = cache.value_cache[0][:, :, :-2, :]


# Update attn mask to account for the INIT-PROMPT and QUESTION lengths assuming attention for the question is all ones
question_input_ids = inputs.clone()
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], question_input_ids.shape[1]))], dim=-1)
cache_position = torch.arange(
    cache_position[-1] + 1, cache_position[-1] + question_input_ids.shape[1] + 1, dtype=cache_position.dtype
).to(cache_position.device)


# Get the max past kv length from all layers
max_past_seen_tokens = 0
if cache is not None:
    max_past_seen_tokens = max([cache.get_seq_length(layer_idx) for layer_idx in range(len(cache))])

# Adapted from `LlamaModel._update_causal_mask` (TODO: StaticCache might not work as expected so need to check)
dtype, device = model.dtype, model.device
batch_size, sequence_length = question_input_ids.shape
max_target_length = (
    attention_mask.shape[-1]
    if isinstance(attention_mask, torch.Tensor)
    else max_past_seen_tokens + sequence_length + 1
)

# Adapted from `LlamaModel._prepare_4d_causal_attention_mask_with_cache_position`
min_dtype = torch.finfo(dtype).min    
causal_mask = torch.full(
    (sequence_length, max_target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
    causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(max_target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)

if attention_mask is not None:
    causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
    mask_length = attention_mask.shape[-1]
    padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
    padding_mask = padding_mask == 0
    causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
        padding_mask, min_dtype
    )
causal_mask_batch[layer_idx] = causal_mask

with torch.no_grad():
    model(question_input_ids, attention_mask=causal_mask, cache_position=cache_position, past_key_values=cache)
    print("Second forward pass")
    print(f"Layer 0: {cache.key_cache[0].shape}")
    print(f"Layer 1: {cache.key_cache[1].shape}")

@gante
Copy link
Member

gante commented Jan 9, 2025

(@SimJeg meanwhile, we're also exploring the possibility to solve the issue with 0 code changes on your end. We're assessing feasibility on our end)

@SimJeg
Copy link
Author

SimJeg commented Jan 9, 2025 via email

@gante
Copy link
Member

gante commented Jan 13, 2025

Hey @SimJeg 👋 (cc @zucchini-nlp)

We've chatted internally about what we can do to better support KVPress, and we'll merge the changes that allow a solution with little code on your end :)

TL;DR:

  1. [generation] Support cache-cropping methods #35591 will fix how we create the causal mask, using the length of the longest layer (as opposed to the length of layer 0)
  2. On your end, you will either have to monkey-patch eager_attention_forward or define a new attention forward pass. Both should be ~10 LOC.
  3. After [generation] Support cache-cropping methods #35591 is merged, I will share here a working example of 2.

@SimJeg
Copy link
Author

SimJeg commented Jan 13, 2025

That's great news @gante ! It's ok for us to raise an error for eager as we only have 1 press that requires this attention module. Having out of the box support for sdpa is great however for users without access to Ampere GPUs. I'll update our repo once the feature is released.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants