Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simlayerkv #22

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kvpress/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kvpress.presses.observed_attention_press import ObservedAttentionPress
from kvpress.presses.random_press import RandomPress
from kvpress.presses.snapkv_press import SnapKVPress
from kvpress.presses.simlayerkv import SimLayerKVPress
from kvpress.presses.streaming_llm_press import StreamingLLMPress
from kvpress.presses.tova_press import TOVAPress

Expand All @@ -20,6 +21,7 @@
"ObservedAttentionPress",
"RandomPress",
"SnapKVPress",
"SimLayerKVPress",
"StreamingLLMPress",
"TOVAPress",
"KVPressTextGenerationPipeline",
Expand Down
2 changes: 1 addition & 1 deletion kvpress/presses/expected_attention_press.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,4 @@ def score(
# Add back the sink tokens. Use max score to make sure they are not pruned.
scores = F.pad(scores, (self.n_sink, 0), value=scores.max().item())

return scores
return scores
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make sure you merge with main, on main this line exists already

104 changes: 104 additions & 0 deletions kvpress/presses/simlayerkv.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rename simlayerkv_press.py for consistency

Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


from dataclasses import dataclass

import math
import inspect

import torch
from torch import nn
import torch.nn.functional as F

from transformers.models.llama.modeling_llama import repeat_kv, rotate_half
from kvpress.presses.snapkv_press import BasePress


@dataclass
class SimLayerKVPress(BasePress):
initial_tokens: int = 4 # Number of initial tokens to always keep in KV cache
recent_tokens: int = 1024 # Number of recent tokens to always keep in KV cache
w_last: int = 32 # Window size for analyzing last tokens in prefilling stage
compression_ratio: float = 0.9 # Threshold for identifying lazy layers
window_size: int = 1 # Size of window for computing attention weights (1 for testing)


def compute_window_attention(self,module, hidden_states, keys,window_size):
"""
Compute the last window_size queries and associated attention weights for the first q_len - window_size keys.
"""

bsz, q_len, _ = hidden_states.shape

# Get last window_size queries
if hasattr(module, "q_proj"):
query_states = module.q_proj(hidden_states[:, -window_size :])
elif hasattr(module, "qkv_proj"):
qkv = module.qkv_proj(hidden_states[:, -window_size :])
query_states = qkv[..., : module.num_heads * module.head_dim]
else:
raise NotImplementedError(f"SnapKV not yet implemented for {module.__class__}.")

query_states = query_states.view(bsz, window_size, module.num_heads, module.head_dim).transpose(1, 2)

# Apply RoPE
if "position_ids" in inspect.signature(module.rotary_emb.forward).parameters:
position_ids = torch.arange(q_len - window_size, q_len).unsqueeze(0).to(query_states.device)
cos, sin = module.rotary_emb(query_states, position_ids)
else:
cos, sin = module.rotary_emb(query_states, q_len)
cos, sin = cos[-window_size :].unsqueeze(0), sin[-window_size :].unsqueeze(0)
query_states = (query_states * cos) + (rotate_half(query_states) * sin)

# Compute attention for first q_len - window_size tokens
key_states = repeat_kv(keys, module.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(module.head_dim)
attention_mask = torch.ones_like(attn_weights) * float("-inf")
attention_mask = torch.triu(attention_mask, diagonal=q_len - window_size + 1)
attn_weights += attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = attn_weights[..., : -window_size]

return attn_weights

def score(
self,
module: nn.Module,
hidden_states: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
attentions: torch.Tensor,
kwargs,
) -> torch.Tensor:
# Get basic dimensions
bsz, num_heads, seq_len, _ = keys.shape
device, dtype = keys.device, keys.dtype

# Get attention weights
if attentions is None:
attn_weights = self.compute_window_attention(module, hidden_states, keys, self.window_size)
else:
attn_weights = attentions

# Identify lazy layers
initial_recent_attn = attn_weights[:, :, :, :self.initial_tokens].sum(dim=-1) + \
attn_weights[:, :, :, -self.recent_tokens:].sum(dim=-1)

# get the average attention of the last w_last tokens
avg_attn = initial_recent_attn[:, :, -self.w_last:].mean()

# check if the layer is lazy
is_lazy_layer = avg_attn > self.compression_ratio

# Create scores tensor
scores = torch.zeros(bsz, num_heads, seq_len, device=device, dtype=dtype)

if is_lazy_layer:
# if layer is lazy , only keep the initial and recent tokens and the rest is set to 0
scores[:, :, :self.initial_tokens] = 1.0
scores[:, :, -self.recent_tokens:] = 1.0
else:
scores[:, :, :] = 1.0

return scores
121 changes: 121 additions & 0 deletions notebooks/simlayerkv.ipynb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of creating a new notebook, could you add SimLayerKVPress in the imports of wikipedia_demo.ipynb

Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"from transformers import pipeline\n",
"\n",
"from kvpress import SimLayerKVPress"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load pipeline\n",
"device = \"cuda:0\"\n",
"ckpt = \"Qwen/Qwen2.5-1.5B-Instruct\" \n",
"pipe = pipeline(\n",
" \"kv-press-text-generation\", \n",
" model=ckpt, \n",
" device=device, \n",
" torch_dtype=\"auto\", \n",
")\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test data for both prefilling and decoding\n",
"context = \"\"\"SimLayerKV is a method for efficient transformer inference that identifies and optimizes \n",
"lazy attention layers. It works in two phases: prefilling and decoding. During prefilling, it analyzes \n",
"the last w_last tokens to identify lazy layers. During decoding, it examines the attention patterns of \n",
"the first generated token.\"\"\"\n",
"\n",
"question = \"\\nWhat are the two phases of SimLayerKV?\"\n",
"\n",
"# Tokenize\n",
"tokens = pipe.tokenizer(context, return_tensors=\"pt\").to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# Test prefilling phase\n",
"press = SimLayerKVPress(\n",
" initial_tokens=4,\n",
" recent_tokens=1024,\n",
" w_last=32,\n",
" window_size=32,\n",
" compression_ratio=0.85 # according to Original implmentation for qwen model 0.85 compression or threshold \n",
")\n",
"\n",
"print(\"Testing Prefilling Phase:\")\n",
"print(\"-\" * 50)\n",
"\n",
"with torch.no_grad():\n",
" outputs_without_press = pipe.model(**tokens, output_hidden_states=True)\n",
"\n",
"with torch.no_grad(), press(pipe.model):\n",
" output_with_press = pipe.model(**tokens)\n",
"\n",
"print(f\"Original cache shape: {outputs_without_press.past_key_values[0][0].shape}\")\n",
"print(f\"Compressed cache shape: {output_with_press.past_key_values[0][0].shape}\")\n",
"\n",
"\n",
"\n",
"# Test decoding phase\n",
"print(\"\\nTesting Decoding Phase:\")\n",
"print(\"-\" * 50)\n",
"\n",
"# Generate with press\n",
"output = pipe(\n",
" context, \n",
" question=question, \n",
" press=press,\n",
" max_new_tokens=150,\n",
" \n",
")\n",
"print(\"Generated Answer:\")\n",
"print(output[\"answer\"])\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 2 additions & 1 deletion tests/presses/test_presses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
ObservedAttentionPress,
RandomPress,
SnapKVPress,
SimLayerKVPress,
StreamingLLMPress,
TOVAPress,
)
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401


def test_presses_run(unit_test_model): # noqa: F811
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress]:
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress,SimLayerKVPress]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing space after last comma

for compression_ratio in [0.2, 0.4, 0.6, 0.8]:
press = cls(compression_ratio=compression_ratio)
if cls == SnapKVPress:
Expand Down