Skip to content

Commit

Permalink
make cache class exportable and executorch compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 20, 2025
1 parent d269417 commit b67b6eb
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 32 deletions.
23 changes: 5 additions & 18 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from packaging import version

from .configuration_utils import PretrainedConfig
from .utils import (
is_hqq_available,
is_optimum_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .utils import is_hqq_available, is_optimum_quanto_available, logging
from .utils.deprecation import deprecate_kwarg


Expand Down Expand Up @@ -1156,18 +1151,10 @@ def __init__(
layer_device = device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
setattr(self, f"key_cache_{idx}", new_layer_key_cache)
setattr(self, f"value_cache_{idx}", new_layer_value_cache)
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_key_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

Expand Down
32 changes: 18 additions & 14 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@


if is_torch_available():
from transformers import (
PreTrainedModel,
StaticCache,
)
from transformers import PreTrainedModel, StaticCache
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3


Expand Down Expand Up @@ -68,20 +65,21 @@ def __init__(self, model: PreTrainedModel):
)

self.model = model
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)

self.static_cache = StaticCache(
config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.dtype,
)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)

if self.is_causal:
causal_mask = torch.tril(
torch.ones(
self.static_cache.max_cache_len,
self.static_cache.max_cache_len,
dtype=torch.bool,
)
torch.ones(self.static_cache.max_cache_len, self.static_cache.max_cache_len, dtype=torch.bool)
)
self.register_buffer("mask", causal_mask, persistent=False)

Expand All @@ -107,15 +105,20 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
"""
_, seqlen = input_ids.shape

attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
position_ids = cache_position.unsqueeze(0)
past_key_values = self.static_cache

outs = self.model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=cache_position.unsqueeze(0),
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
past_key_values=self.static_cache,
use_cache=True,
)

return outs.logits

@staticmethod
Expand All @@ -142,7 +145,7 @@ def generate(
prompt_token_len = prompt_token_ids.shape[-1]
max_generation_length = prompt_token_len + max_new_tokens
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
if buffer_name.startswith("key_cache"):
max_cache_len = buffer.shape[2]
max_generation_length = min(max_generation_length, max_cache_len)
break
Expand Down Expand Up @@ -203,8 +206,9 @@ def convert_and_export_with_cache(

# Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
# export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
torch_exportable_module = TorchExportableModuleWithStaticCache(model)
exported_program = torch.export._trace._export(
TorchExportableModuleWithStaticCache(model),
torch_exportable_module,
args=(example_input_ids,),
kwargs={"cache_position": example_cache_position},
pre_dispatch=False,
Expand Down
1 change: 1 addition & 0 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def test_compile_static_cache(self):
@slow
@require_read_token
def test_export_static_cache(self):
# this test only run with an accelerator but it doesn't need an accelerator ?
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

Expand Down

0 comments on commit b67b6eb

Please sign in to comment.