From d4b631edd063d7a121cab064c9aa70e6cc99a8f3 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 14:17:28 +0100 Subject: [PATCH 01/16] use tensor cache instead of module cache --- src/transformers/cache_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index e616adbe6798..7b8601caf99d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -24,7 +24,7 @@ logger = logging.get_logger(__name__) -class Cache(torch.nn.Module): +class Cache(torch.Tensor): """ Base, abstract class for all caches. The actual data structure is specific to each subclass. """ From a77a94b2097338fe190ee535beb55d995e82a7d2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 14:43:41 +0100 Subject: [PATCH 02/16] unproxy cache --- src/transformers/utils/fx.py | 52 ++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 45fa3d9ca68c..834554472b78 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -823,28 +823,28 @@ def cache_proxy_factory_fn(n: Node) -> HFCacheProxy: return cache_proxy_factory_fn -# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`. -ProxyableCache = HFProxyableClassMeta( - "ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache) -) -ProxyableDynamicCache = HFProxyableClassMeta( - "ProxyableDynamicCache", - (DynamicCache,), - {}, - proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache), -) -ProxyableSinkCache = HFProxyableClassMeta( - "ProxyableSinkCache", - (SinkCache,), - {}, - proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache), -) -ProxyableStaticCache = HFProxyableClassMeta( - "ProxyableStaticCache", - (StaticCache,), - {}, - proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache), -) +# # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`. +# ProxyableCache = HFProxyableClassMeta( +# "ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache) +# ) +# ProxyableDynamicCache = HFProxyableClassMeta( +# "ProxyableDynamicCache", +# (DynamicCache,), +# {}, +# proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache), +# ) +# ProxyableSinkCache = HFProxyableClassMeta( +# "ProxyableSinkCache", +# (SinkCache,), +# {}, +# proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache), +# ) +# ProxyableStaticCache = HFProxyableClassMeta( +# "ProxyableStaticCache", +# (StaticCache,), +# {}, +# proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache), +# ) def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): @@ -879,10 +879,10 @@ class HFTracer(Tracer): "tril", ] _CLASSES_TO_PATCH = { - Cache: ProxyableCache, - DynamicCache: ProxyableDynamicCache, - SinkCache: ProxyableSinkCache, - StaticCache: ProxyableStaticCache, + # Cache: ProxyableCache, + # DynamicCache: ProxyableDynamicCache, + # SinkCache: ProxyableSinkCache, + # StaticCache: ProxyableStaticCache, } supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) From 45bb39bb803cab7a33b69d9dcc98d0e3f0ac56e3 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 17:01:49 +0100 Subject: [PATCH 03/16] torch tensor subclassing --- src/transformers/cache_utils.py | 31 +++++++++++++++++++--------- src/transformers/generation/utils.py | 11 +++++----- src/transformers/utils/fx.py | 20 +++++++++--------- tests/test_modeling_common.py | 2 +- 4 files changed, 38 insertions(+), 26 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 7b8601caf99d..7538b791a9c1 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -29,8 +29,24 @@ class Cache(torch.Tensor): Base, abstract class for all caches. The actual data structure is specific to each subclass. """ - def __init__(self): - super().__init__() + def __new__(cls, *args, dtype=None, device=None, **kwargs): + # We use a tensor wrapper to allow for torch script tracing when using the cache as an input to nn.Module + # dtype and device don't need to be in the subclass's __init__ (unless they are used for something) + # But they can be passed as arguments when instantiating the cache (e.g. `DynamicCache(dtype=dtype)`) + # And will be accessible as `cache.dtype` and `cache.device` + self = torch.Tensor._make_wrapper_subclass(cls, (), dtype=dtype, device=device, requires_grad=False) + self.__init__(*args, **kwargs) + return self + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + assert ( + func.__name__ in cls.__dict__ + ), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}" + return getattr(cls, func.__name__)(*args, **kwargs) + + def __repr__(self): + return f"{self.__class__.__name__}()" def update( self, @@ -677,8 +693,6 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None: self.compute_dtype = cache_config.compute_dtype self.device = cache_config.device - super().__init__() - def update( self, key_states: torch.Tensor, @@ -1121,7 +1135,6 @@ def __init__( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) - self.dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None @@ -1145,8 +1158,8 @@ def __init__( # it is not needed anyway) # 2. `torch.export()` requires mutations to be registered as buffers. if not is_torchdynamo_compiling(): - self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device)) - self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device)) + setattr(self, f"key_cache_{idx}", new_layer_key_cache) + setattr(self, f"value_cache_{idx}", new_layer_value_cache) new_layer_key_cache = getattr(self, f"key_cache_{idx}") new_layer_value_cache = getattr(self, f"value_cache_{idx}") torch._dynamo.mark_static_address(new_layer_key_cache) @@ -1619,7 +1632,6 @@ def __init__( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) - self.dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) @@ -1804,7 +1816,7 @@ def __init__( f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) - self.dtype = dtype + self.max_batch_size = batch_size or max_batch_size self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size @@ -1934,7 +1946,6 @@ def __init__( self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.device = torch.device(device) if layer_device_map is None else layer_device_map[0] self.offload_device = torch.device(offload_device) - self.dtype = dtype if dtype is not None else torch.float32 # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 655a388cb70d..29cdf77f8981 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -731,6 +731,7 @@ def _expand_dict_for_generation(dict_to_expand): key != "cache_position" and dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor) + and not isinstance(dict_to_expand[key], Cache) ): dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) return dict_to_expand @@ -4552,13 +4553,13 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int = """ if data is None: return [None] * (full_batch_size // split_size) - if isinstance(data, torch.Tensor): - return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] # New cache format elif isinstance(data, DynamicCache) or ( isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) ): return data.batch_split(full_batch_size, split_size, num_hidden_layers) + if isinstance(data, torch.Tensor): + return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] elif isinstance(data, tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0], tuple): @@ -4665,13 +4666,13 @@ def _concat(data): """ if any(data is None for data in data): return None - if isinstance(data[0], torch.Tensor): - return torch.cat(data, dim=0) # New cache format - elif isinstance(data[0], DynamicCache): + if isinstance(data[0], DynamicCache): return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) elif isinstance(data[0], EncoderDecoderCache): return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) + elif isinstance(data[0], torch.Tensor): + return torch.cat(data, dim=0) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 834554472b78..885f145ff70e 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -811,16 +811,16 @@ def _proxies_to_metas(v): return v -def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]: - def cache_proxy_factory_fn(n: Node) -> HFCacheProxy: - global _CURRENT_TRACER - if not isinstance(_CURRENT_TRACER, HFTracer): - raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.") - cache_proxy = HFCacheProxy(n, _CURRENT_TRACER) - cache_proxy.install_orig_cache_cls(orig_cache_cls) - return cache_proxy - - return cache_proxy_factory_fn +# def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]: +# def cache_proxy_factory_fn(n: Node) -> HFCacheProxy: +# global _CURRENT_TRACER +# if not isinstance(_CURRENT_TRACER, HFTracer): +# raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.") +# cache_proxy = HFCacheProxy(n, _CURRENT_TRACER) +# cache_proxy.install_orig_cache_cls(orig_cache_cls) +# return cache_proxy + +# return cache_proxy_factory_fn # # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`. diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 965d75936933..b9776fd77c69 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2397,7 +2397,7 @@ def recursive_check(tuple_object, dict_object): elif tuple_object is None: return # model might return non-tensors objects (e.g. Cache class) - elif isinstance(tuple_object, torch.Tensor): + elif isinstance(tuple_object, torch.Tensor) and not isinstance(tuple_object, Cache): self.assertTrue( torch.allclose( set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 From 8606594ad437b5334ff67108639cd1678d485b89 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 17:08:37 +0100 Subject: [PATCH 04/16] fix boolean evaluation --- src/transformers/cache_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 7538b791a9c1..ca0aaafded32 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -45,6 +45,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}" return getattr(cls, func.__name__)(*args, **kwargs) + def __bool__(self): + # in many places, past_key_values is checked for not being None using `if past_key_values:` + return True + def __repr__(self): return f"{self.__class__.__name__}()" From 95c1686ee0b589cdd2569d6cc6d89356feda25d9 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 17:09:21 +0100 Subject: [PATCH 05/16] style --- src/transformers/utils/fx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 885f145ff70e..f88c04407350 100755 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -35,7 +35,7 @@ from torch.fx.proxy import ParameterProxy from .. import logging -from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache +from ..cache_utils import Cache from ..modeling_utils import PretrainedConfig, PreTrainedModel from ..models.auto import get_values from ..models.auto.modeling_auto import ( From d269417aab54c0cb06de4a64ac06697d7c597ba5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 17:21:49 +0100 Subject: [PATCH 06/16] fix zamba and jamba dynamic cache --- src/transformers/models/jamba/modeling_jamba.py | 1 - src/transformers/models/zamba/modeling_zamba.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index fd6b1bae31b1..549d5bea9527 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -214,7 +214,6 @@ class HybridMambaAttentionDynamicCache(DynamicCache): def __init__(self, config, batch_size, dtype=torch.float16, device=None): super().__init__() - self.dtype = dtype self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba intermediate_size = config.mamba_expand * config.hidden_size diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 761c799bdcdc..8db417e61d0f 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -128,7 +128,6 @@ class ZambaHybridDynamicCache(DynamicCache): """ def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.dtype = dtype self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba self.intermediate_size = config.mamba_expand * config.hidden_size @@ -138,9 +137,7 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.conv_states = [] self.ssm_states = [] self.transformer_layers = [] - self._modules = {} - self._parameters = {} - self._buffers = {} + for i in range(config.num_hidden_layers): self.conv_states += [ torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype) From b67b6eb9b2cb447cec181c2ae3c5aa0110194b33 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 18:47:30 +0100 Subject: [PATCH 07/16] make cache class exportable and executorch compatible --- src/transformers/cache_utils.py | 23 ++++----------- src/transformers/integrations/executorch.py | 32 ++++++++++++--------- tests/models/llama/test_modeling_llama.py | 1 + 3 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ca0aaafded32..0074cc44c5d5 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -9,12 +9,7 @@ from packaging import version from .configuration_utils import PretrainedConfig -from .utils import ( - is_hqq_available, - is_optimum_quanto_available, - is_torchdynamo_compiling, - logging, -) +from .utils import is_hqq_available, is_optimum_quanto_available, logging from .utils.deprecation import deprecate_kwarg @@ -1156,18 +1151,10 @@ def __init__( layer_device = device new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) - # Notes: - # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case - # it is not needed anyway) - # 2. `torch.export()` requires mutations to be registered as buffers. - if not is_torchdynamo_compiling(): - setattr(self, f"key_cache_{idx}", new_layer_key_cache) - setattr(self, f"value_cache_{idx}", new_layer_value_cache) - new_layer_key_cache = getattr(self, f"key_cache_{idx}") - new_layer_value_cache = getattr(self, f"value_cache_{idx}") - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) + # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, + # preventing compiled graph breaks when updating the cache. + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_key_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 258017f14180..41c5f2ff1cf1 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -16,10 +16,7 @@ if is_torch_available(): - from transformers import ( - PreTrainedModel, - StaticCache, - ) + from transformers import PreTrainedModel, StaticCache from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 @@ -68,20 +65,21 @@ def __init__(self, model: PreTrainedModel): ) self.model = model + self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) + self.static_cache = StaticCache( config=self.model.config, batch_size=self.model.generation_config.cache_config.batch_size, max_cache_len=self.model.generation_config.cache_config.max_cache_len, dtype=self.model.dtype, ) - self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) + for i in range(len(self.static_cache.key_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False) + self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False) + if self.is_causal: causal_mask = torch.tril( - torch.ones( - self.static_cache.max_cache_len, - self.static_cache.max_cache_len, - dtype=torch.bool, - ) + torch.ones(self.static_cache.max_cache_len, self.static_cache.max_cache_len, dtype=torch.bool) ) self.register_buffer("mask", causal_mask, persistent=False) @@ -107,15 +105,20 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. """ _, seqlen = input_ids.shape + attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None + position_ids = cache_position.unsqueeze(0) + past_key_values = self.static_cache + outs = self.model( input_ids=input_ids, attention_mask=attn_mask, - position_ids=cache_position.unsqueeze(0), + position_ids=position_ids, + past_key_values=past_key_values, cache_position=cache_position, - past_key_values=self.static_cache, use_cache=True, ) + return outs.logits @staticmethod @@ -142,7 +145,7 @@ def generate( prompt_token_len = prompt_token_ids.shape[-1] max_generation_length = prompt_token_len + max_new_tokens for buffer_name, buffer in exported_program.named_buffers(): - if buffer_name.startswith("static_cache.key_cache"): + if buffer_name.startswith("key_cache"): max_cache_len = buffer.shape[2] max_generation_length = min(max_generation_length, max_cache_len) break @@ -203,8 +206,9 @@ def convert_and_export_with_cache( # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. + torch_exportable_module = TorchExportableModuleWithStaticCache(model) exported_program = torch.export._trace._export( - TorchExportableModuleWithStaticCache(model), + torch_exportable_module, args=(example_input_ids,), kwargs={"cache_position": example_cache_position}, pre_dispatch=False, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 664616306d88..06248b1e88bb 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -747,6 +747,7 @@ def test_compile_static_cache(self): @slow @require_read_token def test_export_static_cache(self): + # this test only run with an accelerator but it doesn't need an accelerator ? if version.parse(torch.__version__) < version.parse("2.4.0"): self.skipTest(reason="This test requires torch >= 2.4 to run.") From 4950a9e3f0125f668591408d7186d43e649b7306 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 22 Jan 2025 13:49:01 +0100 Subject: [PATCH 08/16] extract wrapper kwargs from init signature to correctly instantate --- src/transformers/cache_utils.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0074cc44c5d5..6acf0c256daf 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1,5 +1,6 @@ import copy import importlib.metadata +import inspect import json import os from dataclasses import dataclass @@ -24,13 +25,26 @@ class Cache(torch.Tensor): Base, abstract class for all caches. The actual data structure is specific to each subclass. """ - def __new__(cls, *args, dtype=None, device=None, **kwargs): + @staticmethod + def __new__(cls, *args, **kwargs): # We use a tensor wrapper to allow for torch script tracing when using the cache as an input to nn.Module # dtype and device don't need to be in the subclass's __init__ (unless they are used for something) - # But they can be passed as arguments when instantiating the cache (e.g. `DynamicCache(dtype=dtype)`) - # And will be accessible as `cache.dtype` and `cache.device` - self = torch.Tensor._make_wrapper_subclass(cls, (), dtype=dtype, device=device, requires_grad=False) - self.__init__(*args, **kwargs) + + wrapper_kwargs = {} + init_signature = inspect.signature(cls.__init__) + init_arguments = list(init_signature.parameters.keys()) + + for argument in ["dtype", "device", "requires_grad"]: + if argument in init_arguments: + argument_index = init_arguments.index(argument) + if len(args) > argument_index: + wrapper_kwargs[argument] = args[argument_index] + elif argument in kwargs: + wrapper_kwargs[argument] = kwargs[argument] + + self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs) + cls.__init__(self, *args, **kwargs) + return self @classmethod @@ -42,11 +56,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs): def __bool__(self): # in many places, past_key_values is checked for not being None using `if past_key_values:` + # I think `if past_key_values is not None:` should be used instead return True def __repr__(self): return f"{self.__class__.__name__}()" + def to(self, *args, **kwargs): + # We override this method to prevent the cache from being moved to a different device + # It can be implemented in a way that moves all contained tensors to the new device/dtype + return self + def update( self, key_states: torch.Tensor, From 6e9799c817fc38dbb6d088a6fba4e4978aa12118 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 22 Jan 2025 15:42:43 +0100 Subject: [PATCH 09/16] add clone and to --- src/transformers/cache_utils.py | 44 ++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6acf0c256daf..eead49b33ba2 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -27,8 +27,7 @@ class Cache(torch.Tensor): @staticmethod def __new__(cls, *args, **kwargs): - # We use a tensor wrapper to allow for torch script tracing when using the cache as an input to nn.Module - # dtype and device don't need to be in the subclass's __init__ (unless they are used for something) + # We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method wrapper_kwargs = {} init_signature = inspect.signature(cls.__init__) @@ -54,19 +53,38 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}" return getattr(cls, func.__name__)(*args, **kwargs) + def __repr__(self): + return f"{self.__class__.__name__}()" + def __bool__(self): # in many places, past_key_values is checked for not being None using `if past_key_values:` # I think `if past_key_values is not None:` should be used instead - return True - - def __repr__(self): - return f"{self.__class__.__name__}()" + return self is not None # True def to(self, *args, **kwargs): - # We override this method to prevent the cache from being moved to a different device - # It can be implemented in a way that moves all contained tensors to the new device/dtype + def reccursive_to(elm): + if isinstance(elm, dict): + return {k: reccursive_to(v) for k, v in elm.items()} + elif isinstance(elm, (list, tuple, set)): + return type(elm)(reccursive_to(t) for t in elm) + elif isinstance(elm, torch.Tensor): + return elm.to(*args, **kwargs) + else: + return elm + + self.__dict__ = reccursive_to(self.__dict__) return self + def clone(self): + wrapper_kwargs = { + "dtype": getattr(self, "dtype", None), + "device": getattr(self, "device", None), + "requires_grad": getattr(self, "requires_grad", None), + } + new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs) + new_self.__dict__ = copy.deepcopy(self.__dict__) + return new_self + def update( self, key_states: torch.Tensor, @@ -267,7 +285,6 @@ def __init__( q_group_size: Optional[int] = 64, residual_length: Optional[int] = 128, compute_dtype: Optional[torch.dtype] = torch.float16, - device: Optional[str] = "cpu", ): self.backend = backend self.nbits = nbits @@ -276,7 +293,6 @@ def __init__( self.q_group_size = q_group_size self.residual_length = residual_length self.compute_dtype = compute_dtype - self.device = device def validate(self): """Validates if the arguments passed are correct""" @@ -339,10 +355,9 @@ class StaticCacheConfig(CacheConfig): cache_implementation = "static" - def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): + def __init__(self, batch_size: int, max_cache_len: int): self.batch_size = batch_size self.max_cache_len = max_cache_len - self.device = device def validate(self): """Validates if the arguments passed are correct""" @@ -710,7 +725,7 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None: self.axis_key = cache_config.axis_key self.axis_value = cache_config.axis_value self.compute_dtype = cache_config.compute_dtype - self.device = cache_config.device + self.to(cache_config.device) def update( self, @@ -1955,7 +1970,8 @@ def __init__( ) -> None: self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - self.device = torch.device(device) if layer_device_map is None else layer_device_map[0] + if layer_device_map is not None: + self.to(layer_device_map[0]) self.offload_device = torch.device(offload_device) # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads From da60604f2c9303f0edac5018ced27b34e978f95d Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 22 Jan 2025 15:43:14 +0100 Subject: [PATCH 10/16] fix test_cache_utils --- src/transformers/integrations/executorch.py | 3 +-- tests/utils/test_cache_utils.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 41c5f2ff1cf1..50001e2155ab 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -206,9 +206,8 @@ def convert_and_export_with_cache( # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release. - torch_exportable_module = TorchExportableModuleWithStaticCache(model) exported_program = torch.export._trace._export( - torch_exportable_module, + TorchExportableModuleWithStaticCache(model), args=(example_input_ids,), kwargs={"cache_position": example_cache_position}, pre_dispatch=False, diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 053d2cf6397a..aa760ff9ed94 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -214,11 +214,11 @@ def test_static_cache_exportability(self): # Check if the exported model is configured with the `StaticCache` correctly n_static_key_caches = n_static_value_caches = 0 for buffer_name, buffer in exported_program.named_buffers(): - if buffer_name.startswith("static_cache.key_cache"): + if buffer_name.startswith("key_cache"): self.assertTrue(buffer.shape[0] == batch_size) self.assertTrue(buffer.shape[2] == max_cache_len) n_static_key_caches = n_static_key_caches + 1 - if buffer_name.startswith("static_cache.value_cache"): + if buffer_name.startswith("value_cache"): self.assertTrue(buffer.shape[0] == batch_size) self.assertTrue(buffer.shape[2] == max_cache_len) n_static_value_caches = n_static_value_caches + 1 @@ -362,7 +362,7 @@ def test_sink_cache_iterative_prompts(self): input_ids = gen_out # We went well beyond the cache length - self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) + self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5) # And it still produces a coherent english decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) @@ -380,8 +380,8 @@ def test_sink_cache_iterative_prompts(self): [ ("eager", "static"), ("sdpa", "static"), - ("eager", "offloaded-static"), - ("sdpa", "offloaded-static"), + ("eager", "offloaded_static"), + ("sdpa", "offloaded_static"), ] ) def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_implementation): @@ -427,8 +427,8 @@ def test_static_cache_greedy_decoding_pad_left(self, attn_implementation, cache_ [ ("eager", "static"), ("sdpa", "static"), - ("eager", "offloaded-static"), - ("sdpa", "offloaded-static"), + ("eager", "offloaded_static"), + ("sdpa", "offloaded_static"), ] ) def test_static_cache_greedy_decoding_pad_right(self, attn_implementation, cache_implementation): @@ -519,7 +519,7 @@ def test_dynamic_cache_extra_left_padding(self): @parameterized.expand( [ "static", - "offloaded-static", + "offloaded_static", ] ) def test_static_cache_extra_left_padding(self, cache_implementation): @@ -642,4 +642,4 @@ def test_cache_copy(self): "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week", 'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the' ] # fmt: skip - self.assertTrue(responses == EXPECTED_DECODED_TEXT) + self.assertEqual(responses, EXPECTED_DECODED_TEXT) From 2bbbbbcf971068c2cc49d5b7883211dfda00d76f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 22 Jan 2025 17:15:12 +0100 Subject: [PATCH 11/16] add device and dtype setters --- src/transformers/cache_utils.py | 66 ++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 22 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index f5701d97523f..dfb7db1cb367 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -32,6 +32,9 @@ def __new__(cls, *args, **kwargs): wrapper_kwargs = {} init_signature = inspect.signature(cls.__init__) init_arguments = list(init_signature.parameters.keys()) + init_defaults = { + k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty + } for argument in ["dtype", "device", "requires_grad"]: if argument in init_arguments: @@ -40,6 +43,8 @@ def __new__(cls, *args, **kwargs): wrapper_kwargs[argument] = args[argument_index] elif argument in kwargs: wrapper_kwargs[argument] = kwargs[argument] + elif argument in init_defaults: + wrapper_kwargs[argument] = init_defaults[argument] self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs) cls.__init__(self, *args, **kwargs) @@ -61,29 +66,42 @@ def __bool__(self): # I think `if past_key_values is not None:` should be used instead return self is not None # True + def __setattr__(self, name, value): + # for the many places where `self.device` or `self.dtype` is set + if name in ["dtype", "device"]: + self.to(value) + else: + return super().__setattr__(name, value) + def to(self, *args, **kwargs): - def reccursive_to(elm): - if isinstance(elm, dict): - return {k: reccursive_to(v) for k, v in elm.items()} - elif isinstance(elm, (list, tuple, set)): - return type(elm)(reccursive_to(t) for t in elm) - elif isinstance(elm, torch.Tensor): - return elm.to(*args, **kwargs) - else: - return elm + # originals + wrapper_kwargs = { + "dtype": getattr(self, "dtype", None), + "device": getattr(self, "device", None), + "requires_grad": getattr(self, "requires_grad", False), + } - self.__dict__ = reccursive_to(self.__dict__) + # overrides + for arg in list(args) + list(kwargs.values()): + if isinstance(arg, torch.device): + wrapper_kwargs["device"] = arg + elif isinstance(arg, torch.dtype): + wrapper_kwargs["dtype"] = arg + + new_tensor_wrapper = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs) + new_tensor_wrapper.__dict__ = self.__dict__ + self = new_tensor_wrapper return self def clone(self): wrapper_kwargs = { "dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None), - "requires_grad": getattr(self, "requires_grad", None), + "requires_grad": getattr(self, "requires_grad", False), } - new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs) - new_self.__dict__ = copy.deepcopy(self.__dict__) - return new_self + new_tensor_wrapper = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs) + new_tensor_wrapper.__dict__ = copy.deepcopy(self.__dict__) + return new_tensor_wrapper def update( self, @@ -285,6 +303,7 @@ def __init__( q_group_size: Optional[int] = 64, residual_length: Optional[int] = 128, compute_dtype: Optional[torch.dtype] = torch.float16, + device: Optional[str] = "cpu", ): self.backend = backend self.nbits = nbits @@ -293,6 +312,7 @@ def __init__( self.q_group_size = q_group_size self.residual_length = residual_length self.compute_dtype = compute_dtype + self.device = device def validate(self): """Validates if the arguments passed are correct""" @@ -355,9 +375,10 @@ class StaticCacheConfig(CacheConfig): cache_implementation = "static" - def __init__(self, batch_size: int, max_cache_len: int): + def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): self.batch_size = batch_size self.max_cache_len = max_cache_len + self.device = device def validate(self): """Validates if the arguments passed are correct""" @@ -715,7 +736,6 @@ class QuantizedCache(DynamicCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: - super().__init__() self._quantized_key_cache: List[torch.Tensor] = [] self._quantized_value_cache: List[torch.Tensor] = [] @@ -725,7 +745,7 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None: self.axis_key = cache_config.axis_key self.axis_value = cache_config.axis_value self.compute_dtype = cache_config.compute_dtype - self.to(cache_config.device) + self.device = cache_config.device def update( self, @@ -1158,13 +1178,12 @@ def __init__( max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: - super().__init__() if batch_size is not None: logger.warning_once( f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) - + self.dtype = dtype self.max_batch_size = batch_size or max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len @@ -1173,6 +1192,8 @@ def __init__( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) + self.dtype = dtype + self.device = torch.device(device) if device is not None else torch.device("meta") self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None @@ -1657,7 +1678,7 @@ def __init__( config: PretrainedConfig, batch_size: int = None, max_cache_len: int = None, - device: Union[torch.device, str] = None, + device: Union[torch.device, str] = torch.device("meta"), dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, @@ -1681,11 +1702,11 @@ def __init__( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) + self.dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) - self.device = torch.device(device) if device is not None else torch.device("meta") layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC self.is_sliding = torch.tensor( [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool @@ -1888,7 +1909,6 @@ def __init__( self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel - self.device = torch.device(device) if device is not None else torch.device("meta") self.conv_states: List[torch.Tensor] = [] self.ssm_states: List[torch.Tensor] = [] @@ -2026,7 +2046,9 @@ def __init__( super(Cache, self).__init__() self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len + self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) self.offload_device = torch.device(offload_device) + self.dtype = dtype if dtype is not None else torch.float32 # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads From 485f959f85bcdf06309ce99503e6a43b1d747877 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 22 Jan 2025 17:17:17 +0100 Subject: [PATCH 12/16] revert --- src/transformers/cache_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index dfb7db1cb367..829df4ffcd01 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1678,7 +1678,7 @@ def __init__( config: PretrainedConfig, batch_size: int = None, max_cache_len: int = None, - device: Union[torch.device, str] = torch.device("meta"), + device: Union[torch.device, str] = None, dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, @@ -1707,6 +1707,7 @@ def __init__( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) + self.device = torch.device(device) if device is not None else torch.device("meta") layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC self.is_sliding = torch.tensor( [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool @@ -1904,11 +1905,12 @@ def __init__( f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) - + self.dtype = dtype self.max_batch_size = batch_size or max_batch_size self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel + self.device = torch.device(device) if device is not None else torch.device("meta") self.conv_states: List[torch.Tensor] = [] self.ssm_states: List[torch.Tensor] = [] From 2f4e0bc93e57892b782aa851b14654ce2737e024 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:18:28 +0100 Subject: [PATCH 13/16] Update src/transformers/cache_utils.py --- src/transformers/cache_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 829df4ffcd01..c54f52815662 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1183,7 +1183,6 @@ def __init__( f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) - self.dtype = dtype self.max_batch_size = batch_size or max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len From 338f5954b9cd85c0a00f915713928dc41f3157b9 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 22 Jan 2025 17:29:48 +0100 Subject: [PATCH 14/16] more reverts --- src/transformers/cache_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c54f52815662..b33a6869d4a7 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -736,6 +736,7 @@ class QuantizedCache(DynamicCache): """ def __init__(self, cache_config: QuantizedCacheConfig) -> None: + super().__init__() self._quantized_key_cache: List[torch.Tensor] = [] self._quantized_value_cache: List[torch.Tensor] = [] @@ -747,6 +748,8 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None: self.compute_dtype = cache_config.compute_dtype self.device = cache_config.device + super().__init__() + def update( self, key_states: torch.Tensor, @@ -1178,6 +1181,7 @@ def __init__( max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: + super().__init__() if batch_size is not None: logger.warning_once( f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " @@ -1217,6 +1221,8 @@ def __init__( self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) + print("init key_cache.device", self.key_cache[0].device) + def update( self, key_states: torch.Tensor, @@ -1261,6 +1267,9 @@ def update( # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place # operation, that avoids copies and uses less memory. try: + print("kout.device", k_out.device) + print("key_states.device", key_states.device) + print("cache_position.device", cache_position.device) k_out.index_copy_(2, cache_position, key_states) v_out.index_copy_(2, cache_position, value_states) except NotImplementedError: From 80b49d721b3d6f416381a219a52e5e4d785c4f15 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 22 Jan 2025 17:31:39 +0100 Subject: [PATCH 15/16] rebased --- src/transformers/cache_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b33a6869d4a7..88bc17506ed8 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1221,8 +1221,6 @@ def __init__( self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) - print("init key_cache.device", self.key_cache[0].device) - def update( self, key_states: torch.Tensor, @@ -1267,9 +1265,6 @@ def update( # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place # operation, that avoids copies and uses less memory. try: - print("kout.device", k_out.device) - print("key_states.device", key_states.device) - print("cache_position.device", cache_position.device) k_out.index_copy_(2, cache_position, key_states) v_out.index_copy_(2, cache_position, value_states) except NotImplementedError: From 5ccb79c16d35970ed8ff4f4a1fdad897d1279d3c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 23 Jan 2025 16:45:28 +0100 Subject: [PATCH 16/16] fixed dynamic cache --- src/transformers/cache_utils.py | 105 +++++++++++++++----------------- 1 file changed, 49 insertions(+), 56 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 88bc17506ed8..9fda187632a2 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -36,19 +36,28 @@ def __new__(cls, *args, **kwargs): k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty } - for argument in ["dtype", "device", "requires_grad"]: + for argument in ["dtype", "device"]: if argument in init_arguments: - argument_index = init_arguments.index(argument) - if len(args) > argument_index: - wrapper_kwargs[argument] = args[argument_index] - elif argument in kwargs: + arg_idx = init_arguments.index(argument) + if len(args) > arg_idx and args[arg_idx] is not None: + wrapper_kwargs[argument] = args[arg_idx] + elif kwargs.get(argument, None) is not None: wrapper_kwargs[argument] = kwargs[argument] - elif argument in init_defaults: + elif init_defaults[argument] is not None: wrapper_kwargs[argument] = init_defaults[argument] - self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs) - cls.__init__(self, *args, **kwargs) - + if "cache_config" in init_arguments: + cache_config_idx = init_arguments.index("cache_config") + if len(args) > cache_config_idx and args[cache_config_idx] is not None: + wrapper_kwargs["device"] = args[cache_config_idx].device + elif kwargs.get("cache_config", None) is not None: + wrapper_kwargs["device"] = kwargs["cache_config"].device + elif init_defaults["cache_config"] is not None: + wrapper_kwargs["device"] = init_defaults["cache_config"].device + + self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs, requires_grad=False) + # we create a dummy empty tensor for generic tensor flattening/unflattening + self._empty_tensor = torch.tensor([], **wrapper_kwargs, requires_grad=False) return self @classmethod @@ -66,42 +75,27 @@ def __bool__(self): # I think `if past_key_values is not None:` should be used instead return self is not None # True - def __setattr__(self, name, value): - # for the many places where `self.device` or `self.dtype` is set - if name in ["dtype", "device"]: - self.to(value) - else: - return super().__setattr__(name, value) - def to(self, *args, **kwargs): # originals - wrapper_kwargs = { - "dtype": getattr(self, "dtype", None), - "device": getattr(self, "device", None), - "requires_grad": getattr(self, "requires_grad", False), - } + wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)} # overrides for arg in list(args) + list(kwargs.values()): - if isinstance(arg, torch.device): + if isinstance(arg, (torch.device, str, int)): wrapper_kwargs["device"] = arg elif isinstance(arg, torch.dtype): wrapper_kwargs["dtype"] = arg - new_tensor_wrapper = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs) - new_tensor_wrapper.__dict__ = self.__dict__ - self = new_tensor_wrapper - return self + # new wrapper + new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs) + new_self.__dict__ = {k: v for k, v in self.__dict__.items() if k not in ["device", "dtype"]} + return new_self def clone(self): - wrapper_kwargs = { - "dtype": getattr(self, "dtype", None), - "device": getattr(self, "device", None), - "requires_grad": getattr(self, "requires_grad", False), - } - new_tensor_wrapper = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs) - new_tensor_wrapper.__dict__ = copy.deepcopy(self.__dict__) - return new_tensor_wrapper + wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)} + new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs, requires_grad=False) + new_self.__dict__ = copy.deepcopy(self.__dict__) + return new_self def update( self, @@ -375,7 +369,7 @@ class StaticCacheConfig(CacheConfig): cache_implementation = "static" - def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): + def __init__(self, batch_size: int, max_cache_len: int, device: Union[str, torch.device] = torch.device("cpu")): self.batch_size = batch_size self.max_cache_len = max_cache_len self.device = device @@ -432,6 +426,16 @@ class DynamicCache(Cache): ``` """ + def __tensor_flatten__(self): + return ["_empty_tensor"], {"_seen_tokens": self._seen_tokens} + + @staticmethod + def __tensor_unflatten__(inner_tensors, meta, _, __): + cache = DynamicCache() + cache._seen_tokens = meta["_seen_tokens"] + cache._empty_tensor = inner_tensors["_empty_tensor"] + return cache + @deprecate_kwarg("num_hidden_layers", version="4.47.0") def __init__(self, num_hidden_layers: Optional[int] = None) -> None: super().__init__() @@ -519,7 +523,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it or len(self.key_cache[layer_idx]) == 0 # the layer has no cache ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else torch.tensor(0) return layer_seq_length def get_max_cache_shape(self) -> Optional[int]: @@ -746,9 +750,6 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None: self.axis_key = cache_config.axis_key self.axis_value = cache_config.axis_value self.compute_dtype = cache_config.compute_dtype - self.device = cache_config.device - - super().__init__() def update( self, @@ -848,7 +849,7 @@ def __init__(self, cache_config: CacheConfig) -> None: raise ImportError( f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." ) - from optimum.quanto import MaxOptimizer, qint2, qint4 + from optimum.quanto import MaxOptimizer, qint2, qint4 # type: ignore if self.nbits not in [2, 4]: raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") @@ -867,7 +868,7 @@ def __init__(self, cache_config: CacheConfig) -> None: def _quantize(self, tensor, axis): # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore if is_optimum_quanto_available(): - from optimum.quanto import quantize_weight + from optimum.quanto import quantize_weight # type: ignore scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) @@ -1176,7 +1177,7 @@ def __init__( config: PretrainedConfig, batch_size: int = None, max_cache_len: int = None, - device: torch.device = None, + device: Union[torch.device, str] = torch.device("meta"), dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, @@ -1195,8 +1196,6 @@ def __init__( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) - self.dtype = dtype - self.device = torch.device(device) if device is not None else torch.device("meta") self.num_key_value_heads = ( config.num_attention_heads if getattr(config, "num_key_value_heads", None) is None @@ -1366,7 +1365,7 @@ def __init__( config: PretrainedConfig, batch_size: int = None, max_cache_len: int = None, - device: torch.device = None, + device: Union[torch.device, str] = torch.device("meta"), dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, @@ -1681,7 +1680,7 @@ def __init__( config: PretrainedConfig, batch_size: int = None, max_cache_len: int = None, - device: Union[torch.device, str] = None, + device: Union[torch.device, str] = torch.device("meta"), dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, @@ -1710,7 +1709,6 @@ def __init__( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) - self.device = torch.device(device) if device is not None else torch.device("meta") layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC self.is_sliding = torch.tensor( [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool @@ -1843,7 +1841,7 @@ def batch_size(self): return self.max_batch_size -class MambaCache: +class MambaCache(Cache): """ Cache for mamba model which does not have attention mechanism and key value states. @@ -1900,7 +1898,7 @@ def __init__( config: PretrainedConfig, batch_size: int = None, dtype: torch.dtype = torch.float16, - device: Optional[Union[torch.device, str]] = None, + device: Union[torch.device, str] = torch.device("meta"), max_batch_size: Optional[int] = None, ): if batch_size is not None: @@ -1908,12 +1906,10 @@ def __init__( f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) - self.dtype = dtype self.max_batch_size = batch_size or max_batch_size self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel - self.device = torch.device(device) if device is not None else torch.device("meta") self.conv_states: List[torch.Tensor] = [] self.ssm_states: List[torch.Tensor] = [] @@ -2043,17 +2039,14 @@ def __init__( config: PretrainedConfig, max_batch_size: int, max_cache_len: Optional[int], - device: Union[str, torch.device], - dtype: Optional[torch.dtype] = None, + device: Union[torch.device, str] = torch.device("meta"), + dtype: torch.dtype = torch.float32, offload_device: Union[str, torch.device] = torch.device("cpu"), layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: - super(Cache, self).__init__() self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len - self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) self.offload_device = torch.device(offload_device) - self.dtype = dtype if dtype is not None else torch.float32 # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads