Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simlayerkv #22

Closed
wants to merge 10 commits into from
Closed

Add simlayerkv #22

wants to merge 10 commits into from

Conversation

dame-cell
Copy link

Add SimlayerKV as proposed in this issue #19

Regarding the compression rate, I wanted to highlight how the original implementation chose thresholds dynamically based on the model type.
here is how the original implementation handled it

 if 'llama3' in out_path:
            threshold = 0.9
        elif 'llama2' in out_path:
            threshold = 0.65
        elif 'mistral' in out_path:
            threshold = 0.8
        elif 'qwen' in out_path:
            threshold = 0.85
        

I’ve also added a notebook demonstrating how to generate outputs using SimLayerKV. All tests have passed successfully.
Let me know if there’s anything else you’d like refined or if additional details are needed!

StreamingLLMPress,
TOVAPress,
)
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401


def test_presses_run(unit_test_model): # noqa: F811
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress]:
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress,SimLayerKVPress]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing space after last comma

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of creating a new notebook, could you add SimLayerKVPress in the imports of wikipedia_demo.ipynb

@@ -134,4 +134,4 @@ def score(
# Add back the sink tokens. Use max score to make sure they are not pruned.
scores = F.pad(scores, (self.n_sink, 0), value=scores.max().item())

return scores
return scores
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make sure you merge with main, on main this line exists already

@SimJeg
Copy link
Collaborator

SimJeg commented Dec 4, 2024

Hi @dame-cell, thanks a lot for your work ! Could you please merge your branch with main ? I will read the paper and review your work. I already added a few minor comments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rename simlayerkv_press.py for consistency

@SimJeg
Copy link
Collaborator

SimJeg commented Dec 4, 2024

After quickly reading the paper, I think your current implementation does not implement SimLayerKV. Here is my understanding of the paper.

SimLayerKV dynamically identify what it calls "lazy layers". For lazy layers, it prunes the KV cache using the StreamingLLM strategy, i.e. keep only first and last tokens. For non-lazy layers, the full KV cache it kept.

To identify lazy layers, a procedure similar to SnapKV is used: compute the average attention weights for the last 32 tokens. Then average these attention weights on the first tokens and the 1024 recent token. The layer is lazy if this average exceeds a threshold, which is model dependent.

Your current implementation is the following:

        scores = torch.zeros(bsz, num_heads, seq_len, device=device, dtype=dtype)

        if is_lazy_layer:
            # if layer is lazy , only keep the initial and recent tokens and the rest is set to 0
            scores[:, :, : self.initial_tokens] = 1.0
            scores[:, :, -self.recent_tokens :] = 1.0
        else:
            scores[:, :, :] = 1.0

        return scores

What will happen with this:

  • If the layer is not lazy, scores are 1. Thus for a given compression ratio, the top-k values (torch.topk in the forward hook of the BasePress) will be somehow random (torch.ones(10).topk(3) returns 8, 6, 7 for instance)
  • if the layer is lazy, then depending on the lenght of the prompt initial and recent tokens will be kept indeed, but then random tokens will be pruned

This method would require to use a custom Cache (maybe similar to HybridCache in Gemma ?). The compression ratio is also ill-defined for this press but it's an issue with kvpress currently that maybe focuses to much on this (we are thinking about a refacto, see #21).

SimLayerKV also looks similar to the DuoAttention paper which is doing this at the head level (would required specific attention kernel as proposed in #7).

@dame-cell
Copy link
Author

dame-cell commented Dec 4, 2024

@SimJeg
yea i actually made this a draft to talk about this I was so unsure on how to handle this
for instance in the original repo they did it by coding the entire llama and qwen from scratch (not entirely)

and also prefilling and decoding using simlayerkv is kinda diffferent than the rest of the press
here is the way the original implementation of simalerkv

  # [SimLayerKV]
        if (torch.tensor(bos_probs).sum() != 0) & (torch.tensor(bos_weights).sum()==0):
            # print(hidden_states.shape[1])
            for layer_num in range(len(bos_probs)):
                if threshold_stream< 2:
                    if bos_probs[layer_num] >threshold_stream:
                        # print(layer_num)
                        next_decoder_cache.key_cache[layer_num] = torch.cat([next_decoder_cache.key_cache[layer_num][:, :, 0:4],next_decoder_cache.key_cache[layer_num][:, :, -1024:]], dim=-2)
                        next_decoder_cache.value_cache[layer_num] = torch.cat([next_decoder_cache.value_cache[layer_num][:, :, 0:4],next_decoder_cache.value_cache[layer_num][:, :, -1024:]], dim=-2)
                        past_key_values.key_cache[layer_num] = torch.cat([past_key_values.key_cache[layer_num][:, :, 0:4],past_key_values.key_cache[layer_num][:, :, -1024:]], dim=-2)
                        past_key_values.value_cache[layer_num] = torch.cat([past_key_values.value_cache[layer_num][:, :, 0:4],past_key_values.value_cache[layer_num][:, :, -1024:]], dim=-2)
                        torch.cuda.empty_cache()

This method would require to use a custom Cache (maybe similar to HybridCache in Gemma ?). The compression ratio is also ill-defined for this press but it's an issue with kvpress currently that maybe focuses to much on this (we are thinking about a refacto, see #21).

yes the original implementation had their own custom minicache
SimLayerKV_llama.py Line 478

SimLayerKV also looks similar to the DuoAttention paper which is doing this at the head level (would required specific attention kernel as proposed in #7).

I would appreciate any advice or guidance you can offer on how to proceed, and I’ll keep up with developments on this front as well

@SimJeg
Copy link
Collaborator

SimJeg commented Dec 9, 2024

Hi @dame-cell, prior to the refacto I mentioned (see #21), I'd like to investigate the best way to integrate SimLayerKVPress. Let me investigate it based on your contributions. I might open another PR if the code is very different.

@SimJeg SimJeg mentioned this pull request Dec 9, 2024
@SimJeg
Copy link
Collaborator

SimJeg commented Dec 9, 2024

@dame-cell I will close this PR as I believe the code I shared in #28 is closer to what is proposed in the original paper. Please tell me if you disagree

@SimJeg SimJeg closed this Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants