From 224d0d38d568b10af39305c98b65a96854618568 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Tue, 28 Jan 2025 11:54:24 -0800 Subject: [PATCH] [Excutorch][Llama] Decouple input sequence length from kv cache context length Pull Request resolved: https://github.com/pytorch/executorch/pull/7927 Decouple max sequence length, for shape dynamism in torch.export, from sequence length used for kv cache sizing. ghstack-source-id: 263491763 Differential Revision: [D68448334](https://our.internmc.facebook.com/intern/diff/D68448334/) --- examples/models/llama/export_llama_lib.py | 19 +++++++++-- examples/models/llama/llama_transformer.py | 32 ++++++++++--------- examples/models/llama/model.py | 8 ++++- .../llama/source_transformation/attention.py | 16 +++++----- .../source_transformation/attention_sink.py | 20 ++++++++---- .../quantized_kv_cache.py | 24 +++++++------- .../llama/source_transformation/sdpa.py | 18 +++++------ .../test_attention_sink.py | 4 +-- .../test_quantized_kv_cache.py | 4 +-- .../test_sdpa_with_quantized_kv_cache.py | 6 ++-- .../models/llama/tests/test_simple_sdpa.py | 8 ++--- 11 files changed, 95 insertions(+), 64 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index c25dce6ffc..ae40230b25 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -335,6 +335,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="maximum length sequence to evaluate", ) + parser.add_argument( + "--max_context_length", + type=int, + default=None, + help="maximum length of context for model to remember", + ) + parser.add_argument("-2", "--fairseq2", action="store_true") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( @@ -579,6 +586,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: tokenizer_path=args.tokenizer_path, verbose=args.verbose, max_seq_len=args.max_seq_length, + max_context_len=args.max_context_length, input_prune_map_path=args.input_prune_map, output_prune_map_path=args.output_prune_map, metadata_str=args.metadata, @@ -637,6 +645,8 @@ def _validate_args(args): """ TODO: Combine all the backends under --backend args """ + if args.max_context_length is None: + args.max_context_length = args.max_seq_length if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn): raise ValueError( "Dynamic shape is not supported with coreml, MPS or qnn backends." @@ -760,13 +770,13 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 atten = builder_exported_to_edge.model.layers[0].attention if args.use_qnn_sha: cache_shape = torch.Size( - (atten.max_batch_size, atten.max_seq_len, atten.head_dim) + (atten.max_batch_size, atten.max_context_len, atten.head_dim) ) else: cache_shape = torch.Size( ( atten.max_batch_size, - atten.max_seq_len, + atten.max_context_len, atten.n_kv_heads, atten.head_dim, ) @@ -861,6 +871,7 @@ def _load_llama_model_metadata( use_sdpa_with_kv_cache: bool, enable_dynamic_shape: bool, max_seq_len: int, + max_context_len: int, n_layers: int, vocab_size: int, metadata_str: Optional[str] = None, @@ -870,6 +881,7 @@ def _load_llama_model_metadata( "get_bos_id": 3 if is_fairseq2 else 1, "get_eos_ids": [3] if is_fairseq2 else [2], "get_max_seq_len": max_seq_len, + "get_max_context_len": max_context_len, "get_n_layers": n_layers, "get_vocab_size": vocab_size, "use_kv_cache": use_kv_cache, @@ -904,6 +916,7 @@ def _load_llama_model( tokenizer_path: Optional[str] = None, verbose: bool = False, max_seq_len: int = 128, + max_context_len: int = 128, input_prune_map_path: Optional[str] = None, output_prune_map_path: Optional[str] = None, metadata_str: Optional[str] = None, @@ -948,6 +961,7 @@ def _load_llama_model( generate_full_logits=generate_full_logits, fairseq2=weight_type == WeightType.FAIRSEQ2, max_seq_len=max_seq_len, + max_context_len=max_context_len, enable_dynamic_shape=enable_dynamic_shape, input_prune_map_path=input_prune_map_path, output_prune_map_path=output_prune_map_path, @@ -1006,6 +1020,7 @@ def _load_llama_model( # pyre-fixme[6]: For 5th argument expected `ModelArgs` but got # `Union[Tensor, Module]`. model.max_seq_len, + model.max_context_len, # pyre-fixme[6]: For 6th argument expected `int` but got `Union[Tensor, # Module]`. model.n_layers, diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index d5661ae400..cc6b81edc1 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -91,6 +91,7 @@ class ModelArgs: norm_eps: float = 1e-5 max_batch_size: int = 32 max_seq_len: int = 2048 + max_context_len: int = 2048 moe: bool = False # True to enable the MoE (Mixture of Experts) num_experts: int = 8 # Number of experts num_activated_experts: int = 2 # Number of experts to activate @@ -163,9 +164,9 @@ def __init__(self, params: ModelArgs): freqs_cos, freqs_sin = self.precompute_freqs_cis( self.params.head_dim, ( - self.params.max_seq_len # Normal llama2. + self.params.max_context_len # Normal llama2. if self.params.ffn_dim_multiplier is None - else self.params.max_seq_len * 2 # Sharded checkpoint. + else self.params.max_context_len * 2 # Sharded checkpoint. ), self.params.rope_freq_base, ) @@ -205,7 +206,7 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. input_pos_item = input_pos[-1].item() torch._check_is_size(input_pos_item) - torch._check(input_pos_item < self.params.max_seq_len) + torch._check(input_pos_item < self.params.max_context_len) # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) # pyre-ignore: Incompatible parameter type [6] @@ -229,15 +230,15 @@ class KVCache(nn.Module): def __init__( self, max_batch_size: int, - max_seq_length: int, + max_context_length: int, n_heads: int, head_dim: int, enable_dynamic_shape: bool, dtype=torch.float32, ): super().__init__() - self.max_seq_length = max_seq_length - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.max_context_length = max_context_length + cache_shape = (max_batch_size, n_heads, max_context_length, head_dim) self.max_batch_size = max_batch_size self.n_heads = n_heads @@ -257,7 +258,7 @@ def update( if self.enable_dynamic_shape: start_pos = input_pos[0].item() torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_length) + torch._check(start_pos < self.max_context_length) dim_to_slice = 2 seq_length = k_val.size(dim_to_slice) # Replace the entry in the cache for this token @@ -289,14 +290,14 @@ def __init__( dim: int, head_dim: int, n_rep: int, - max_seq_len: int, + max_context_len: int, enable_dynamic_shape: bool, ): super().__init__() self.dim = dim self.head_dim = head_dim self.n_rep = n_rep - self.max_seq_len = max_seq_len + self.max_context_len = max_context_len self.enable_dynamic_shape = enable_dynamic_shape def forward( @@ -312,7 +313,7 @@ def forward( if self.enable_dynamic_shape: start_pos = input_pos[-1].item() torch._check_is_size(start_pos) - torch._check(start_pos < self.max_seq_len) + torch._check(start_pos < self.max_context_len) seq_length = q.size(2) # pyre-ignore: Incompatible parameter type [6] attn_mask = mask.narrow(0, start_pos, seq_length) @@ -341,7 +342,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.head_dim self.max_batch_size = args.max_batch_size - self.max_seq_len = args.max_seq_len + self.max_context_len = args.max_context_len self.dim = args.dim self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) @@ -354,8 +355,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): causal_mask = torch.tril( torch.ones( - self.max_seq_len, - self.max_seq_len, + self.max_context_len, + self.max_context_len, dtype=torch.bool, device="cpu", ) @@ -365,7 +366,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): if self.use_kv_cache: self.kv_cache = KVCache( args.max_batch_size, - args.max_seq_len, + args.max_context_len, self.n_kv_heads, self.head_dim, args.enable_dynamic_shape, @@ -374,7 +375,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope): dim=self.n_local_heads * self.head_dim, head_dim=self.head_dim, n_rep=self.n_rep, - max_seq_len=self.max_seq_len, + max_context_len=self.max_context_len, enable_dynamic_shape=args.enable_dynamic_shape, ) @@ -528,6 +529,7 @@ def __init__(self, params: ModelArgs): self.use_kv_cache = params.use_kv_cache self.generate_full_logits = params.generate_full_logits self.max_seq_len = params.max_seq_len + self.max_context_len = params.max_context_len self.input_prune_map = params.input_prune_map self.output_prune_map = params.output_prune_map diff --git a/examples/models/llama/model.py b/examples/models/llama/model.py index 9f7994916a..00f59df286 100644 --- a/examples/models/llama/model.py +++ b/examples/models/llama/model.py @@ -52,8 +52,13 @@ def __init__(self, **kwargs): self.input_prune_map_path = kwargs.get("input_prune_map_path", None) self.output_prune_map_path = kwargs.get("output_prune_map_path", None) self.max_seq_len = kwargs.get("max_seq_len", 128) + self.max_context_len = kwargs.get("max_context_len", 128) self.args = kwargs.get("args", None) + assert ( + self.max_context_len >= self.max_seq_len + ), f"max_context_len({self.max_context_len}) must be >= max_seq_len({self.max_seq_len})" + # The example is using a dummy small model with random weights for demo purpose only. # Follow the instruction in https://github.com/facebookresearch/llama to download the model. device = "cpu" @@ -136,6 +141,7 @@ def __init__(self, **kwargs): model_args: ModelArgs = ModelArgs( max_seq_len=self.max_seq_len, + max_context_len=self.max_context_len, max_batch_size=1, use_kv_cache=self.use_kv_cache, use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op, @@ -219,7 +225,7 @@ def __init__(self, **kwargs): window_size = int(attention_sink_params[1]) eviction_batch_size = int(attention_sink_params[2]) - assert self.args.max_seq_length == sink_size + window_size + assert self.args.max_context_length == sink_size + window_size self.model_ = enable_attention_sink( module=self.model_, diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py index 7dc9003f13..f1d40b7042 100644 --- a/examples/models/llama/source_transformation/attention.py +++ b/examples/models/llama/source_transformation/attention.py @@ -32,7 +32,7 @@ class KVCacheSHA(torch.nn.Module): def __init__( self, max_batch_size: int, - max_seq_length: int, + max_context_length: int, n_heads: int, head_dim: int, dtype=torch.float32, @@ -40,7 +40,7 @@ def __init__( super().__init__() # a buffer per head - cache_shape = (max_batch_size, max_seq_length, head_dim) + cache_shape = (max_batch_size, max_context_length, head_dim) for i in range(n_heads): self.register_buffer( f"past_k_caches_{i}", @@ -79,7 +79,7 @@ class SDPASHA(torch.nn.Module): def __init__( self, max_batch_size: int, - max_seq_length: int, + max_context_length: int, n_heads: int, n_rep: int, head_dim: int, @@ -90,7 +90,7 @@ def __init__( self.n_rep = n_rep self.dim = dim self.kv_cache = KVCacheSHA( - max_batch_size, max_seq_length, n_heads // n_rep, head_dim + max_batch_size, max_context_length, n_heads // n_rep, head_dim ) self.scale_factor = math.sqrt(head_dim) @@ -134,11 +134,11 @@ def __init__(self, attention_mha: nn.Module): self.n_rep = self.n_heads // self.n_kv_heads self.dim = attention_mha.dim self.max_batch_size = attention_mha.max_batch_size - self.max_seq_len = attention_mha.max_seq_len + self.max_context_len = attention_mha.max_context_len self.head_dim = attention_mha.dim // self.n_heads self.SDPA = SDPASHA( self.max_batch_size, - self.max_seq_len, + self.max_context_len, self.n_heads, self.n_rep, self.head_dim, @@ -184,8 +184,8 @@ def __init__(self, attention_mha: nn.Module): causal_mask = torch.tril( torch.ones( - self.max_seq_len, - self.max_seq_len, + self.max_context_len, + self.max_context_len, dtype=torch.bool, device="cpu", ) diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 5b3bfba9ad..d710773d00 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -44,8 +44,8 @@ def __init__( self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k else: self.apply_rotary_emb_to_k = apply_rotary_emb_to_k - self.max_seq_length = window_size + sink_size - assert self.max_seq_length == self.params.max_seq_len + self.max_context_length = window_size + sink_size + assert self.max_context_length == self.params.max_context_len self.eviction_batch_size = eviction_batch_size self.position_shift = 0 @@ -54,11 +54,14 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): input_pos_item = input_pos.item() torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_seq_length: + if input_pos_item + self.position_shift + seq_len > self.max_context_length: # There are not enough spaces in the cache to store the new tokens. # We need to evict some old tokens and shift some recent tokens. num_to_evict = max( - input_pos_item + self.position_shift - self.max_seq_length + seq_len, + input_pos_item + + self.position_shift + - self.max_context_length + + seq_len, self.eviction_batch_size, ) self.position_shift -= num_to_evict # pyre-ignore [8] @@ -121,7 +124,7 @@ def __init__( ): super().__init__( max_batch_size=max_batch_size, - max_seq_length=window_size + sink_size, + max_context_length=window_size + sink_size, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=enable_dynamic_shape, @@ -148,11 +151,14 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int: """ input_pos_item = input_pos.item() torch._check_is_size(input_pos_item) - if input_pos_item + self.position_shift + seq_len > self.max_seq_length: + if input_pos_item + self.position_shift + seq_len > self.max_context_length: # There are not enough spaces in the cache to store the new tokens. # We need to evict some old tokens and shift some recent tokens. num_to_evict = max( - input_pos_item + self.position_shift - self.max_seq_length + seq_len, + input_pos_item + + self.position_shift + - self.max_context_length + + seq_len, self.eviction_batch_size, ) num_to_keep = ( diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index 90ec9879e5..650546b6db 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -33,7 +33,7 @@ class QuantizedKVCache(nn.Module): def __init__( self, max_batch_size, - max_seq_length, + max_context_length, n_heads, head_dim, cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, @@ -52,8 +52,8 @@ def __init__( self.use_custom_update_cache_op = use_custom_update_cache_op self.quantized_cache_dtype = torch.int8 self.cache_fp_type = torch.float32 - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) - scale_shape = (max_batch_size, max_seq_length, n_heads, 1) + cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) + scale_shape = (max_batch_size, max_context_length, n_heads, 1) self.register_buffer( "k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) ) @@ -161,13 +161,15 @@ def from_float( cache_type: QuantizedCacheType, use_custom_update_cache_op: bool = False, ): - max_batch_size, n_heads, max_seq_length, head_dim = kv_cache.k_cache.shape + max_batch_size, n_heads, max_context_length, head_dim = kv_cache.k_cache.shape if isinstance(kv_cache, CustomKVCache): # If replacing custom kv cache, then the shape is [B, S, H, D] - max_batch_size, max_seq_length, n_heads, head_dim = kv_cache.k_cache.shape + max_batch_size, max_context_length, n_heads, head_dim = ( + kv_cache.k_cache.shape + ) return cls( max_batch_size, - max_seq_length, + max_context_length, n_heads, head_dim, cache_type, @@ -226,14 +228,14 @@ class CustomKVCache(nn.Module): def __init__( self, max_batch_size: int, - max_seq_length: int, + max_context_length: int, n_heads: int, head_dim: int, dtype=torch.float32, ): super().__init__() - self.max_seq_length = max_seq_length - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + self.max_context_length = max_context_length + cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) self.max_batch_size = max_batch_size self.n_heads = n_heads @@ -275,13 +277,13 @@ def replace_kv_cache_with_custom_kv_cache(module): if isinstance(child, KVCache): cache_shape = child.k_cache.shape cache_dtype = child.k_cache.dtype - max_batch_size, n_heads, max_seq_length, head_dim = cache_shape + max_batch_size, n_heads, max_context_length, head_dim = cache_shape setattr( module, name, CustomKVCache( max_batch_size, - max_seq_length, + max_context_length, n_heads, head_dim, dtype=cache_dtype, diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 6a54d6a119..f3c297dd40 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -268,14 +268,14 @@ class KVCacheCoreML(torch.nn.Module): def __init__( self, max_batch_size: int, - max_seq_length: int, + max_context_length: int, n_heads: int, head_dim: int, dtype=torch.float32, ): super().__init__() - self.max_seq_length = max_seq_length - cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.max_context_length = max_context_length + cache_shape = (max_batch_size, n_heads, max_context_length, head_dim) self.max_batch_size = max_batch_size self.n_heads = n_heads @@ -303,7 +303,7 @@ def replace_kv_cache_with_coreml_kv_cache(module: torch.nn.Module): name, KVCacheCoreML( child.max_batch_size, - child.max_seq_length, + child.max_context_length, child.n_heads, child.head_dim, child.k_cache.dtype, @@ -318,13 +318,13 @@ class KVCacheSimple(torch.nn.Module): def __init__( self, max_batch_size: int, - max_seq_length: int, + max_context_length: int, n_heads: int, head_dim: int, dtype=torch.float32, ): super().__init__() - cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) + cache_shape = (max_batch_size, max_context_length, n_heads, head_dim) self.register_buffer( "past_k_caches", torch.zeros(cache_shape, dtype=dtype, device="cpu"), @@ -358,7 +358,7 @@ def replace_kv_cache_with_simple_kv_cache(module: torch.nn.Module): name, KVCacheSimple( child.max_batch_size, - child.max_seq_length, + child.max_context_length, child.n_heads, child.head_dim, child.k_cache.dtype, @@ -373,9 +373,9 @@ def replace_causal_mask(module: torch.nn.Module): for buffer_fqn_name, buffer in module.named_buffers(): buffer_name = buffer_fqn_name.split(".")[-1] if buffer_name == "mask": - max_seq_len = buffer.shape[-1] + max_context_len = buffer.shape[-1] mask = torch.full( - (max_seq_len, max_seq_len), + (max_context_len, max_context_len), float("-inf"), device="cpu", ) diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py index 4dd522dff2..5ecf3d162e 100644 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -29,7 +29,7 @@ def _init_rope(self, params: ModelArgs, eviction_batch_size: int): def setUp(self): torch.manual_seed(42) self.params = ModelArgs( - use_kv_cache=True, enable_dynamic_shape=True, max_seq_len=256 + use_kv_cache=True, enable_dynamic_shape=True, max_context_len=256 ) self.rope_with_attention_sink = self._init_rope( params=self.params, eviction_batch_size=1 @@ -135,7 +135,7 @@ def _init_cache(self, sink_size, eviction_batch_size): self.params = ModelArgs( use_kv_cache=True, enable_dynamic_shape=True, - max_seq_len=self.window_size + sink_size, + max_context_len=self.window_size + sink_size, ) self.rope_with_attention_sink = RopeWithAttentionSink( params=self.params, diff --git a/examples/models/llama/source_transformation/test_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_quantized_kv_cache.py index 67ebbc7b3f..fac62e7366 100644 --- a/examples/models/llama/source_transformation/test_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_quantized_kv_cache.py @@ -20,7 +20,7 @@ class QuantizedKVCacheTest(unittest.TestCase): def _init_cache(self): self.kv_cache = KVCache( self.max_batch_size, - self.max_seq_len, + self.max_context_len, self.n_kv_heads, self.head_dim, self.enable_dynamic_shape, @@ -36,7 +36,7 @@ def _init_kv(self): def setUp(self): torch.manual_seed(42) self.max_batch_size = 1 - self.max_seq_len = 5 + self.max_context_len = 5 self.n_kv_heads = 8 self.head_dim = 17 self.enable_dynamic_shape = False diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index 0081c5072c..6a1cdac32e 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -23,7 +23,7 @@ class SDPAWithQuantizedKVCacheTest(unittest.TestCase): def _init_cache(self): self.kv_cache = KVCache( self.max_batch_size, - self.max_seq_len, + self.max_context_len, self.n_kv_heads, self.head_dim, self.enable_dynamic_shape, @@ -40,7 +40,7 @@ def _init_cache(self): # as a sequence of token positions self.custom_kv_cache = CustomKVCache( self.max_batch_size, - self.max_seq_len, + self.max_context_len, self.n_kv_heads, self.head_dim, dtype=self.dtype, @@ -57,7 +57,7 @@ def _init_kv(self): def setUp(self): torch.manual_seed(42) self.max_batch_size = 1 - self.max_seq_len = 5 + self.max_context_len = 5 self.n_kv_heads = 4 self.n_heads = 8 self.head_dim = 17 diff --git a/examples/models/llama/tests/test_simple_sdpa.py b/examples/models/llama/tests/test_simple_sdpa.py index 4088165c71..3ad9f634cc 100644 --- a/examples/models/llama/tests/test_simple_sdpa.py +++ b/examples/models/llama/tests/test_simple_sdpa.py @@ -15,7 +15,7 @@ class SDPATest(unittest.TestCase): def test_simple_sdpa(self): # Verify the correctness between the simple SDPA and the original SDPA module defined in llama_transformer.py max_batch_size = 1 - max_seq_length = 128 + max_context_length = 128 n_heads = 8 head_dim = 8 dim = 64 @@ -25,7 +25,7 @@ def test_simple_sdpa(self): n_local_heads = n_heads kv_cache = KVCache( max_batch_size=max_batch_size, - max_seq_length=max_seq_length, + max_context_length=max_context_length, n_heads=n_heads, head_dim=head_dim, enable_dynamic_shape=False, @@ -34,14 +34,14 @@ def test_simple_sdpa(self): dim=dim, head_dim=head_dim, n_rep=n_rep, - max_seq_len=max_seq_length, + max_context_len=max_context_length, enable_dynamic_shape=False, ) input_pos = torch.tensor([0]) query = torch.randn(1, 1, n_local_heads, head_dim) key = torch.randn(1, 1, n_local_heads, head_dim) value = torch.randn(1, 1, n_local_heads, head_dim) - mask = torch.randn(max_seq_length, max_seq_length) + mask = torch.randn(max_context_length, max_context_length) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2)