-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add simlayerkv #22
Add simlayerkv #22
Changes from all commits
1732252
2a62d45
86453c5
de389df
5df36a3
748ac6b
9276b3f
edfbb83
d448fdc
6ecb918
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please rename simlayerkv_press.py for consistency |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from dataclasses import dataclass | ||
|
||
import math | ||
import inspect | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
from transformers.models.llama.modeling_llama import repeat_kv, rotate_half | ||
from kvpress.presses.snapkv_press import BasePress | ||
|
||
|
||
@dataclass | ||
class SimLayerKVPress(BasePress): | ||
initial_tokens: int = 4 # Number of initial tokens to always keep in KV cache | ||
recent_tokens: int = 1024 # Number of recent tokens to always keep in KV cache | ||
w_last: int = 32 # Window size for analyzing last tokens in prefilling stage | ||
compression_ratio: float = 0.9 # Threshold for identifying lazy layers | ||
window_size: int = 1 # Size of window for computing attention weights (1 for testing) | ||
|
||
|
||
def compute_window_attention(self,module, hidden_states, keys,window_size): | ||
""" | ||
Compute the last window_size queries and associated attention weights for the first q_len - window_size keys. | ||
""" | ||
|
||
bsz, q_len, _ = hidden_states.shape | ||
|
||
# Get last window_size queries | ||
if hasattr(module, "q_proj"): | ||
query_states = module.q_proj(hidden_states[:, -window_size :]) | ||
elif hasattr(module, "qkv_proj"): | ||
qkv = module.qkv_proj(hidden_states[:, -window_size :]) | ||
query_states = qkv[..., : module.num_heads * module.head_dim] | ||
else: | ||
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.") | ||
|
||
query_states = query_states.view(bsz, window_size, module.num_heads, module.head_dim).transpose(1, 2) | ||
|
||
# Apply RoPE | ||
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters: | ||
position_ids = torch.arange(q_len - window_size, q_len).unsqueeze(0).to(query_states.device) | ||
cos, sin = module.rotary_emb(query_states, position_ids) | ||
else: | ||
cos, sin = module.rotary_emb(query_states, q_len) | ||
cos, sin = cos[-window_size :].unsqueeze(0), sin[-window_size :].unsqueeze(0) | ||
query_states = (query_states * cos) + (rotate_half(query_states) * sin) | ||
|
||
# Compute attention for first q_len - window_size tokens | ||
key_states = repeat_kv(keys, module.num_key_value_groups) | ||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) | ||
attention_mask = torch.ones_like(attn_weights) * float("-inf") | ||
attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1) | ||
attn_weights += attention_mask | ||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | ||
attn_weights = attn_weights[..., : -window_size] | ||
|
||
return attn_weights | ||
|
||
def score( | ||
self, | ||
module: nn.Module, | ||
hidden_states: torch.Tensor, | ||
keys: torch.Tensor, | ||
values: torch.Tensor, | ||
attentions: torch.Tensor, | ||
kwargs, | ||
) -> torch.Tensor: | ||
# Get basic dimensions | ||
bsz, num_heads, seq_len, _ = keys.shape | ||
device, dtype = keys.device, keys.dtype | ||
|
||
# Get attention weights | ||
if attentions is None: | ||
attn_weights = self.compute_window_attention(module, hidden_states, keys, self.window_size) | ||
else: | ||
attn_weights = attentions | ||
|
||
# Identify lazy layers | ||
initial_recent_attn = attn_weights[:, :, :, :self.initial_tokens].sum(dim=-1) + \ | ||
attn_weights[:, :, :, -self.recent_tokens:].sum(dim=-1) | ||
|
||
# get the average attention of the last w_last tokens | ||
avg_attn = initial_recent_attn[:, :, -self.w_last:].mean() | ||
|
||
# check if the layer is lazy | ||
is_lazy_layer = avg_attn > self.compression_ratio | ||
|
||
# Create scores tensor | ||
scores = torch.zeros(bsz, num_heads, seq_len, device=device, dtype=dtype) | ||
|
||
if is_lazy_layer: | ||
# if layer is lazy , only keep the initial and recent tokens and the rest is set to 0 | ||
scores[:, :, :self.initial_tokens] = 1.0 | ||
scores[:, :, -self.recent_tokens:] = 1.0 | ||
else: | ||
scores[:, :, :] = 1.0 | ||
|
||
return scores |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of creating a new notebook, could you add |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from torch import nn\n", | ||
"from transformers import pipeline\n", | ||
"\n", | ||
"from kvpress import SimLayerKVPress" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load pipeline\n", | ||
"device = \"cuda:0\"\n", | ||
"ckpt = \"Qwen/Qwen2.5-1.5B-Instruct\" \n", | ||
"pipe = pipeline(\n", | ||
" \"kv-press-text-generation\", \n", | ||
" model=ckpt, \n", | ||
" device=device, \n", | ||
" torch_dtype=\"auto\", \n", | ||
")\n", | ||
"\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Test data for both prefilling and decoding\n", | ||
"context = \"\"\"SimLayerKV is a method for efficient transformer inference that identifies and optimizes \n", | ||
"lazy attention layers. It works in two phases: prefilling and decoding. During prefilling, it analyzes \n", | ||
"the last w_last tokens to identify lazy layers. During decoding, it examines the attention patterns of \n", | ||
"the first generated token.\"\"\"\n", | ||
"\n", | ||
"question = \"\\nWhat are the two phases of SimLayerKV?\"\n", | ||
"\n", | ||
"# Tokenize\n", | ||
"tokens = pipe.tokenizer(context, return_tensors=\"pt\").to(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"# Test prefilling phase\n", | ||
"press = SimLayerKVPress(\n", | ||
" initial_tokens=4,\n", | ||
" recent_tokens=1024,\n", | ||
" w_last=32,\n", | ||
" window_size=32,\n", | ||
" compression_ratio=0.85 # according to Original implmentation for qwen model 0.85 compression or threshold \n", | ||
")\n", | ||
"\n", | ||
"print(\"Testing Prefilling Phase:\")\n", | ||
"print(\"-\" * 50)\n", | ||
"\n", | ||
"with torch.no_grad():\n", | ||
" outputs_without_press = pipe.model(**tokens, output_hidden_states=True)\n", | ||
"\n", | ||
"with torch.no_grad(), press(pipe.model):\n", | ||
" output_with_press = pipe.model(**tokens)\n", | ||
"\n", | ||
"print(f\"Original cache shape: {outputs_without_press.past_key_values[0][0].shape}\")\n", | ||
"print(f\"Compressed cache shape: {output_with_press.past_key_values[0][0].shape}\")\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"# Test decoding phase\n", | ||
"print(\"\\nTesting Decoding Phase:\")\n", | ||
"print(\"-\" * 50)\n", | ||
"\n", | ||
"# Generate with press\n", | ||
"output = pipe(\n", | ||
" context, \n", | ||
" question=question, \n", | ||
" press=press,\n", | ||
" max_new_tokens=150,\n", | ||
" \n", | ||
")\n", | ||
"print(\"Generated Answer:\")\n", | ||
"print(output[\"answer\"])\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,14 +13,15 @@ | |
ObservedAttentionPress, | ||
RandomPress, | ||
SnapKVPress, | ||
SimLayerKVPress, | ||
StreamingLLMPress, | ||
TOVAPress, | ||
) | ||
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401 | ||
|
||
|
||
def test_presses_run(unit_test_model): # noqa: F811 | ||
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress]: | ||
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress,SimLayerKVPress]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing space after last comma |
||
for compression_ratio in [0.2, 0.4, 0.6, 0.8]: | ||
press = cls(compression_ratio=compression_ratio) | ||
if cls == SnapKVPress: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please make sure you merge with main, on main this line exists already