-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add simlayerkv #22
Add simlayerkv #22
Conversation
StreamingLLMPress, | ||
TOVAPress, | ||
) | ||
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401 | ||
|
||
|
||
def test_presses_run(unit_test_model): # noqa: F811 | ||
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress]: | ||
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress,SimLayerKVPress]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing space after last comma
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of creating a new notebook, could you add SimLayerKVPress
in the imports of wikipedia_demo.ipynb
@@ -134,4 +134,4 @@ def score( | |||
# Add back the sink tokens. Use max score to make sure they are not pruned. | |||
scores = F.pad(scores, (self.n_sink, 0), value=scores.max().item()) | |||
|
|||
return scores | |||
return scores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please make sure you merge with main, on main this line exists already
Hi @dame-cell, thanks a lot for your work ! Could you please merge your branch with main ? I will read the paper and review your work. I already added a few minor comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please rename simlayerkv_press.py for consistency
After quickly reading the paper, I think your current implementation does not implement SimLayerKV. Here is my understanding of the paper. SimLayerKV dynamically identify what it calls "lazy layers". For lazy layers, it prunes the KV cache using the StreamingLLM strategy, i.e. keep only first and last tokens. For non-lazy layers, the full KV cache it kept. To identify lazy layers, a procedure similar to SnapKV is used: compute the average attention weights for the last 32 tokens. Then average these attention weights on the first tokens and the 1024 recent token. The layer is lazy if this average exceeds a threshold, which is model dependent. Your current implementation is the following: scores = torch.zeros(bsz, num_heads, seq_len, device=device, dtype=dtype)
if is_lazy_layer:
# if layer is lazy , only keep the initial and recent tokens and the rest is set to 0
scores[:, :, : self.initial_tokens] = 1.0
scores[:, :, -self.recent_tokens :] = 1.0
else:
scores[:, :, :] = 1.0
return scores What will happen with this:
This method would require to use a custom Cache (maybe similar to HybridCache in Gemma ?). The compression ratio is also ill-defined for this press but it's an issue with kvpress currently that maybe focuses to much on this (we are thinking about a refacto, see #21). SimLayerKV also looks similar to the DuoAttention paper which is doing this at the head level (would required specific attention kernel as proposed in #7). |
@SimJeg and also prefilling and decoding using simlayerkv is kinda diffferent than the rest of the press # [SimLayerKV]
if (torch.tensor(bos_probs).sum() != 0) & (torch.tensor(bos_weights).sum()==0):
# print(hidden_states.shape[1])
for layer_num in range(len(bos_probs)):
if threshold_stream< 2:
if bos_probs[layer_num] >threshold_stream:
# print(layer_num)
next_decoder_cache.key_cache[layer_num] = torch.cat([next_decoder_cache.key_cache[layer_num][:, :, 0:4],next_decoder_cache.key_cache[layer_num][:, :, -1024:]], dim=-2)
next_decoder_cache.value_cache[layer_num] = torch.cat([next_decoder_cache.value_cache[layer_num][:, :, 0:4],next_decoder_cache.value_cache[layer_num][:, :, -1024:]], dim=-2)
past_key_values.key_cache[layer_num] = torch.cat([past_key_values.key_cache[layer_num][:, :, 0:4],past_key_values.key_cache[layer_num][:, :, -1024:]], dim=-2)
past_key_values.value_cache[layer_num] = torch.cat([past_key_values.value_cache[layer_num][:, :, 0:4],past_key_values.value_cache[layer_num][:, :, -1024:]], dim=-2)
torch.cuda.empty_cache()
yes the original implementation had their own custom minicache
I would appreciate any advice or guidance you can offer on how to proceed, and I’ll keep up with developments on this front as well |
Hi @dame-cell, prior to the refacto I mentioned (see #21), I'd like to investigate the best way to integrate SimLayerKVPress. Let me investigate it based on your contributions. I might open another PR if the code is very different. |
@dame-cell I will close this PR as I believe the code I shared in #28 is closer to what is proposed in the original paper. Please tell me if you disagree |
Add SimlayerKV as proposed in this issue #19
Regarding the compression rate, I wanted to highlight how the original implementation chose thresholds dynamically based on the model type.
here is how the original implementation handled it
I’ve also added a notebook demonstrating how to generate outputs using SimLayerKV. All tests have passed successfully.
Let me know if there’s anything else you’d like refined or if additional details are needed!