Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Cache a subclass of torch.Tensor #35792

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 69 additions & 34 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import importlib.metadata
import inspect
import json
import os
from dataclasses import dataclass
Expand All @@ -9,12 +10,7 @@
from packaging import version

from .configuration_utils import PretrainedConfig
from .utils import (
is_hqq_available,
is_optimum_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .utils import is_hqq_available, is_optimum_quanto_available, logging
from .utils.deprecation import deprecate_kwarg


Expand All @@ -24,13 +20,70 @@
logger = logging.get_logger(__name__)


class Cache(torch.nn.Module):
class Cache(torch.Tensor):
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""

def __init__(self):
super().__init__()
@staticmethod
def __new__(cls, *args, **kwargs):
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method

wrapper_kwargs = {}
init_signature = inspect.signature(cls.__init__)
init_arguments = list(init_signature.parameters.keys())

for argument in ["dtype", "device", "requires_grad"]:
if argument in init_arguments:
argument_index = init_arguments.index(argument)
if len(args) > argument_index:
wrapper_kwargs[argument] = args[argument_index]
elif argument in kwargs:
wrapper_kwargs[argument] = kwargs[argument]

self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs)
cls.__init__(self, *args, **kwargs)

return self

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
assert (
func.__name__ in cls.__dict__
), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}"
return getattr(cls, func.__name__)(*args, **kwargs)

def __repr__(self):
return f"{self.__class__.__name__}()"

def __bool__(self):
# in many places, past_key_values is checked for not being None using `if past_key_values:`
# I think `if past_key_values is not None:` should be used instead
return self is not None # True

def to(self, *args, **kwargs):
def reccursive_to(elm):
if isinstance(elm, dict):
return {k: reccursive_to(v) for k, v in elm.items()}
elif isinstance(elm, (list, tuple, set)):
return type(elm)(reccursive_to(t) for t in elm)
elif isinstance(elm, torch.Tensor):
return elm.to(*args, **kwargs)
else:
return elm

self.__dict__ = reccursive_to(self.__dict__)
return self

def clone(self):
wrapper_kwargs = {
"dtype": getattr(self, "dtype", None),
"device": getattr(self, "device", None),
"requires_grad": getattr(self, "requires_grad", None),
}
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
new_self.__dict__ = copy.deepcopy(self.__dict__)
return new_self

def update(
self,
Expand Down Expand Up @@ -232,7 +285,6 @@ def __init__(
q_group_size: Optional[int] = 64,
residual_length: Optional[int] = 128,
compute_dtype: Optional[torch.dtype] = torch.float16,
device: Optional[str] = "cpu",
):
self.backend = backend
self.nbits = nbits
Expand All @@ -241,7 +293,6 @@ def __init__(
self.q_group_size = q_group_size
self.residual_length = residual_length
self.compute_dtype = compute_dtype
self.device = device

def validate(self):
"""Validates if the arguments passed are correct"""
Expand Down Expand Up @@ -304,10 +355,9 @@ class StaticCacheConfig(CacheConfig):

cache_implementation = "static"

def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
def __init__(self, batch_size: int, max_cache_len: int):
self.batch_size = batch_size
self.max_cache_len = max_cache_len
self.device = device

def validate(self):
"""Validates if the arguments passed are correct"""
Expand Down Expand Up @@ -675,9 +725,7 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None:
self.axis_key = cache_config.axis_key
self.axis_value = cache_config.axis_value
self.compute_dtype = cache_config.compute_dtype
self.device = cache_config.device

super().__init__()
self.to(cache_config.device)

def update(
self,
Expand Down Expand Up @@ -1125,8 +1173,6 @@ def __init__(
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)

self.dtype = dtype
self.device = torch.device(device) if device is not None else torch.device("meta")
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
Expand All @@ -1144,18 +1190,10 @@ def __init__(
layer_device = self.device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_key_cache)
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

Expand Down Expand Up @@ -1643,7 +1681,6 @@ def __init__(
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)

self.dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
Expand Down Expand Up @@ -1846,7 +1883,7 @@ def __init__(
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.dtype = dtype

self.max_batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
Expand Down Expand Up @@ -1989,9 +2026,7 @@ def __init__(
super(Cache, self).__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32

# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
Expand Down
11 changes: 6 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ def _expand_dict_for_generation(dict_to_expand):
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and not isinstance(dict_to_expand[key], Cache)
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
Expand Down Expand Up @@ -4519,13 +4520,13 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int =
"""
if data is None:
return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New cache format
elif isinstance(data, DynamicCache) or (
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
):
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
Expand Down Expand Up @@ -4632,13 +4633,13 @@ def _concat(data):
"""
if any(data is None for data in data):
return None
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
if isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
Expand Down
29 changes: 16 additions & 13 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@


if is_torch_available():
from transformers import (
PreTrainedModel,
StaticCache,
)
from transformers import PreTrainedModel, StaticCache
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3


Expand Down Expand Up @@ -68,21 +65,22 @@ def __init__(self, model: PreTrainedModel):
)

self.model = model
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)

self.static_cache = StaticCache(
config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.dtype,
device=self.model.generation_config.cache_config.device,
)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)

if self.is_causal:
causal_mask = torch.tril(
torch.ones(
self.static_cache.max_cache_len,
self.static_cache.max_cache_len,
dtype=torch.bool,
)
torch.ones(self.static_cache.max_cache_len, self.static_cache.max_cache_len, dtype=torch.bool)
)
self.register_buffer("mask", causal_mask, persistent=False)

Expand All @@ -108,15 +106,20 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
"""
_, seqlen = input_ids.shape

attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
position_ids = cache_position.unsqueeze(0)
past_key_values = self.static_cache

outs = self.model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=cache_position.unsqueeze(0),
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
past_key_values=self.static_cache,
use_cache=True,
)

return outs.logits

@staticmethod
Expand All @@ -143,7 +146,7 @@ def generate(
prompt_token_len = prompt_token_ids.shape[-1]
max_generation_length = prompt_token_len + max_new_tokens
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
if buffer_name.startswith("key_cache"):
max_cache_len = buffer.shape[2]
max_generation_length = min(max_generation_length, max_cache_len)
break
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ class HybridMambaAttentionDynamicCache(DynamicCache):

def __init__(self, config, batch_size, dtype=torch.float16, device=None):
super().__init__()
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
intermediate_size = config.mamba_expand * config.hidden_size
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ class ZambaHybridDynamicCache(DynamicCache):
"""

def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
self.intermediate_size = config.mamba_expand * config.hidden_size
Expand All @@ -138,9 +137,7 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.conv_states = []
self.ssm_states = []
self.transformer_layers = []
self._modules = {}
self._parameters = {}
self._buffers = {}

for i in range(config.num_hidden_layers):
self.conv_states += [
torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)
Expand Down
Loading
Loading