Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Cache a subclass of torch.Tensor #35792

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 85 additions & 26 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import importlib.metadata
import inspect
import json
import os
from dataclasses import dataclass
Expand All @@ -9,12 +10,7 @@
from packaging import version

from .configuration_utils import PretrainedConfig
from .utils import (
is_hqq_available,
is_optimum_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .utils import is_hqq_available, is_optimum_quanto_available, logging
from .utils.deprecation import deprecate_kwarg


Expand All @@ -24,13 +20,88 @@
logger = logging.get_logger(__name__)


class Cache(torch.nn.Module):
class Cache(torch.Tensor):
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""

def __init__(self):
super().__init__()
@staticmethod
def __new__(cls, *args, **kwargs):
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method

wrapper_kwargs = {}
init_signature = inspect.signature(cls.__init__)
init_arguments = list(init_signature.parameters.keys())
init_defaults = {
k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty
}

for argument in ["dtype", "device", "requires_grad"]:
if argument in init_arguments:
argument_index = init_arguments.index(argument)
if len(args) > argument_index:
wrapper_kwargs[argument] = args[argument_index]
elif argument in kwargs:
wrapper_kwargs[argument] = kwargs[argument]
elif argument in init_defaults:
wrapper_kwargs[argument] = init_defaults[argument]

self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs)
cls.__init__(self, *args, **kwargs)

return self

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
assert (
func.__name__ in cls.__dict__
), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}"
return getattr(cls, func.__name__)(*args, **kwargs)

def __repr__(self):
return f"{self.__class__.__name__}()"

def __bool__(self):
# in many places, past_key_values is checked for not being None using `if past_key_values:`
# I think `if past_key_values is not None:` should be used instead
return self is not None # True

def __setattr__(self, name, value):
# for the many places where `self.device` or `self.dtype` is set
if name in ["dtype", "device"]:
self.to(value)
else:
return super().__setattr__(name, value)

def to(self, *args, **kwargs):
# originals
wrapper_kwargs = {
"dtype": getattr(self, "dtype", None),
"device": getattr(self, "device", None),
"requires_grad": getattr(self, "requires_grad", False),
}

# overrides
for arg in list(args) + list(kwargs.values()):
if isinstance(arg, torch.device):
wrapper_kwargs["device"] = arg
elif isinstance(arg, torch.dtype):
wrapper_kwargs["dtype"] = arg

new_tensor_wrapper = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
new_tensor_wrapper.__dict__ = self.__dict__
self = new_tensor_wrapper
return self

def clone(self):
wrapper_kwargs = {
"dtype": getattr(self, "dtype", None),
"device": getattr(self, "device", None),
"requires_grad": getattr(self, "requires_grad", False),
}
new_tensor_wrapper = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
new_tensor_wrapper.__dict__ = copy.deepcopy(self.__dict__)
return new_tensor_wrapper

def update(
self,
Expand Down Expand Up @@ -665,7 +736,6 @@ class QuantizedCache(DynamicCache):
"""

def __init__(self, cache_config: QuantizedCacheConfig) -> None:
super().__init__()
self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = []

Expand All @@ -677,8 +747,6 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None:
self.compute_dtype = cache_config.compute_dtype
self.device = cache_config.device

super().__init__()

def update(
self,
key_states: torch.Tensor,
Expand Down Expand Up @@ -1110,13 +1178,12 @@ def __init__(
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if batch_size is not None:
logger.warning_once(
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)

self.dtype = dtype
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
self.max_batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len

Expand Down Expand Up @@ -1144,18 +1211,10 @@ def __init__(
layer_device = self.device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_key_cache)
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

Expand Down
11 changes: 6 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ def _expand_dict_for_generation(dict_to_expand):
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and not isinstance(dict_to_expand[key], Cache)
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
Expand Down Expand Up @@ -4519,13 +4520,13 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int =
"""
if data is None:
return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New cache format
elif isinstance(data, DynamicCache) or (
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
):
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
Expand Down Expand Up @@ -4632,13 +4633,13 @@ def _concat(data):
"""
if any(data is None for data in data):
return None
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
if isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
Expand Down
29 changes: 16 additions & 13 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@


if is_torch_available():
from transformers import (
PreTrainedModel,
StaticCache,
)
from transformers import PreTrainedModel, StaticCache
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3


Expand Down Expand Up @@ -68,21 +65,22 @@ def __init__(self, model: PreTrainedModel):
)

self.model = model
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)

self.static_cache = StaticCache(
config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.dtype,
device=self.model.generation_config.cache_config.device,
)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)

if self.is_causal:
causal_mask = torch.tril(
torch.ones(
self.static_cache.max_cache_len,
self.static_cache.max_cache_len,
dtype=torch.bool,
)
torch.ones(self.static_cache.max_cache_len, self.static_cache.max_cache_len, dtype=torch.bool)
)
self.register_buffer("mask", causal_mask, persistent=False)

Expand All @@ -108,15 +106,20 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
"""
_, seqlen = input_ids.shape

attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
position_ids = cache_position.unsqueeze(0)
past_key_values = self.static_cache

outs = self.model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=cache_position.unsqueeze(0),
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
past_key_values=self.static_cache,
use_cache=True,
)

return outs.logits

@staticmethod
Expand All @@ -143,7 +146,7 @@ def generate(
prompt_token_len = prompt_token_ids.shape[-1]
max_generation_length = prompt_token_len + max_new_tokens
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
if buffer_name.startswith("key_cache"):
max_cache_len = buffer.shape[2]
max_generation_length = min(max_generation_length, max_cache_len)
break
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ class HybridMambaAttentionDynamicCache(DynamicCache):

def __init__(self, config, batch_size, dtype=torch.float16, device=None):
super().__init__()
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
intermediate_size = config.mamba_expand * config.hidden_size
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ class ZambaHybridDynamicCache(DynamicCache):
"""

def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
self.intermediate_size = config.mamba_expand * config.hidden_size
Expand All @@ -138,9 +137,7 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.conv_states = []
self.ssm_states = []
self.transformer_layers = []
self._modules = {}
self._parameters = {}
self._buffers = {}

for i in range(config.num_hidden_layers):
self.conv_states += [
torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)
Expand Down
Loading
Loading