Skip to content

Commit

Permalink
[Excutorch][Llama] Decouple input sequence length from kv cache conte…
Browse files Browse the repository at this point in the history
…xt length

Pull Request resolved: #7927

Decouple max sequence length, for shape dynamism in torch.export, from sequence
length used for kv cache sizing.
ghstack-source-id: 263491763

Differential Revision: [D68448334](https://our.internmc.facebook.com/intern/diff/D68448334/)
  • Loading branch information
kimishpatel committed Jan 28, 2025
1 parent bdd3d9c commit 224d0d3
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 64 deletions.
19 changes: 17 additions & 2 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="maximum length sequence to evaluate",
)

parser.add_argument(
"--max_context_length",
type=int,
default=None,
help="maximum length of context for model to remember",
)

parser.add_argument("-2", "--fairseq2", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument(
Expand Down Expand Up @@ -579,6 +586,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
tokenizer_path=args.tokenizer_path,
verbose=args.verbose,
max_seq_len=args.max_seq_length,
max_context_len=args.max_context_length,
input_prune_map_path=args.input_prune_map,
output_prune_map_path=args.output_prune_map,
metadata_str=args.metadata,
Expand Down Expand Up @@ -637,6 +645,8 @@ def _validate_args(args):
"""
TODO: Combine all the backends under --backend args
"""
if args.max_context_length is None:
args.max_context_length = args.max_seq_length
if args.enable_dynamic_shape and (args.coreml or args.mps or args.qnn):
raise ValueError(
"Dynamic shape is not supported with coreml, MPS or qnn backends."
Expand Down Expand Up @@ -760,13 +770,13 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
atten = builder_exported_to_edge.model.layers[0].attention
if args.use_qnn_sha:
cache_shape = torch.Size(
(atten.max_batch_size, atten.max_seq_len, atten.head_dim)
(atten.max_batch_size, atten.max_context_len, atten.head_dim)
)
else:
cache_shape = torch.Size(
(
atten.max_batch_size,
atten.max_seq_len,
atten.max_context_len,
atten.n_kv_heads,
atten.head_dim,
)
Expand Down Expand Up @@ -861,6 +871,7 @@ def _load_llama_model_metadata(
use_sdpa_with_kv_cache: bool,
enable_dynamic_shape: bool,
max_seq_len: int,
max_context_len: int,
n_layers: int,
vocab_size: int,
metadata_str: Optional[str] = None,
Expand All @@ -870,6 +881,7 @@ def _load_llama_model_metadata(
"get_bos_id": 3 if is_fairseq2 else 1,
"get_eos_ids": [3] if is_fairseq2 else [2],
"get_max_seq_len": max_seq_len,
"get_max_context_len": max_context_len,
"get_n_layers": n_layers,
"get_vocab_size": vocab_size,
"use_kv_cache": use_kv_cache,
Expand Down Expand Up @@ -904,6 +916,7 @@ def _load_llama_model(
tokenizer_path: Optional[str] = None,
verbose: bool = False,
max_seq_len: int = 128,
max_context_len: int = 128,
input_prune_map_path: Optional[str] = None,
output_prune_map_path: Optional[str] = None,
metadata_str: Optional[str] = None,
Expand Down Expand Up @@ -948,6 +961,7 @@ def _load_llama_model(
generate_full_logits=generate_full_logits,
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
max_context_len=max_context_len,
enable_dynamic_shape=enable_dynamic_shape,
input_prune_map_path=input_prune_map_path,
output_prune_map_path=output_prune_map_path,
Expand Down Expand Up @@ -1006,6 +1020,7 @@ def _load_llama_model(
# pyre-fixme[6]: For 5th argument expected `ModelArgs` but got
# `Union[Tensor, Module]`.
model.max_seq_len,
model.max_context_len,
# pyre-fixme[6]: For 6th argument expected `int` but got `Union[Tensor,
# Module]`.
model.n_layers,
Expand Down
32 changes: 17 additions & 15 deletions examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class ModelArgs:
norm_eps: float = 1e-5
max_batch_size: int = 32
max_seq_len: int = 2048
max_context_len: int = 2048
moe: bool = False # True to enable the MoE (Mixture of Experts)
num_experts: int = 8 # Number of experts
num_activated_experts: int = 2 # Number of experts to activate
Expand Down Expand Up @@ -163,9 +164,9 @@ def __init__(self, params: ModelArgs):
freqs_cos, freqs_sin = self.precompute_freqs_cis(
self.params.head_dim,
(
self.params.max_seq_len # Normal llama2.
self.params.max_context_len # Normal llama2.
if self.params.ffn_dim_multiplier is None
else self.params.max_seq_len * 2 # Sharded checkpoint.
else self.params.max_context_len * 2 # Sharded checkpoint.
),
self.params.rope_freq_base,
)
Expand Down Expand Up @@ -205,7 +206,7 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
# when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos.
input_pos_item = input_pos[-1].item()
torch._check_is_size(input_pos_item)
torch._check(input_pos_item < self.params.max_seq_len)
torch._check(input_pos_item < self.params.max_context_len)
# pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor
freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len)
# pyre-ignore: Incompatible parameter type [6]
Expand All @@ -229,15 +230,15 @@ class KVCache(nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
max_context_length: int,
n_heads: int,
head_dim: int,
enable_dynamic_shape: bool,
dtype=torch.float32,
):
super().__init__()
self.max_seq_length = max_seq_length
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
self.max_context_length = max_context_length
cache_shape = (max_batch_size, n_heads, max_context_length, head_dim)

self.max_batch_size = max_batch_size
self.n_heads = n_heads
Expand All @@ -257,7 +258,7 @@ def update(
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_length)
torch._check(start_pos < self.max_context_length)
dim_to_slice = 2
seq_length = k_val.size(dim_to_slice)
# Replace the entry in the cache for this token
Expand Down Expand Up @@ -289,14 +290,14 @@ def __init__(
dim: int,
head_dim: int,
n_rep: int,
max_seq_len: int,
max_context_len: int,
enable_dynamic_shape: bool,
):
super().__init__()
self.dim = dim
self.head_dim = head_dim
self.n_rep = n_rep
self.max_seq_len = max_seq_len
self.max_context_len = max_context_len
self.enable_dynamic_shape = enable_dynamic_shape

def forward(
Expand All @@ -312,7 +313,7 @@ def forward(
if self.enable_dynamic_shape:
start_pos = input_pos[-1].item()
torch._check_is_size(start_pos)
torch._check(start_pos < self.max_seq_len)
torch._check(start_pos < self.max_context_len)
seq_length = q.size(2)
# pyre-ignore: Incompatible parameter type [6]
attn_mask = mask.narrow(0, start_pos, seq_length)
Expand Down Expand Up @@ -341,7 +342,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.head_dim
self.max_batch_size = args.max_batch_size
self.max_seq_len = args.max_seq_len
self.max_context_len = args.max_context_len
self.dim = args.dim
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
Expand All @@ -354,8 +355,8 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):

causal_mask = torch.tril(
torch.ones(
self.max_seq_len,
self.max_seq_len,
self.max_context_len,
self.max_context_len,
dtype=torch.bool,
device="cpu",
)
Expand All @@ -365,7 +366,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
if self.use_kv_cache:
self.kv_cache = KVCache(
args.max_batch_size,
args.max_seq_len,
args.max_context_len,
self.n_kv_heads,
self.head_dim,
args.enable_dynamic_shape,
Expand All @@ -374,7 +375,7 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
dim=self.n_local_heads * self.head_dim,
head_dim=self.head_dim,
n_rep=self.n_rep,
max_seq_len=self.max_seq_len,
max_context_len=self.max_context_len,
enable_dynamic_shape=args.enable_dynamic_shape,
)

Expand Down Expand Up @@ -528,6 +529,7 @@ def __init__(self, params: ModelArgs):
self.use_kv_cache = params.use_kv_cache
self.generate_full_logits = params.generate_full_logits
self.max_seq_len = params.max_seq_len
self.max_context_len = params.max_context_len
self.input_prune_map = params.input_prune_map
self.output_prune_map = params.output_prune_map

Expand Down
8 changes: 7 additions & 1 deletion examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ def __init__(self, **kwargs):
self.input_prune_map_path = kwargs.get("input_prune_map_path", None)
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
self.max_seq_len = kwargs.get("max_seq_len", 128)
self.max_context_len = kwargs.get("max_context_len", 128)
self.args = kwargs.get("args", None)

assert (
self.max_context_len >= self.max_seq_len
), f"max_context_len({self.max_context_len}) must be >= max_seq_len({self.max_seq_len})"

# The example is using a dummy small model with random weights for demo purpose only.
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
device = "cpu"
Expand Down Expand Up @@ -136,6 +141,7 @@ def __init__(self, **kwargs):

model_args: ModelArgs = ModelArgs(
max_seq_len=self.max_seq_len,
max_context_len=self.max_context_len,
max_batch_size=1,
use_kv_cache=self.use_kv_cache,
use_sdpa_with_kv_cache_op=self.use_sdpa_with_kv_cache_op,
Expand Down Expand Up @@ -219,7 +225,7 @@ def __init__(self, **kwargs):
window_size = int(attention_sink_params[1])
eviction_batch_size = int(attention_sink_params[2])

assert self.args.max_seq_length == sink_size + window_size
assert self.args.max_context_length == sink_size + window_size

self.model_ = enable_attention_sink(
module=self.model_,
Expand Down
16 changes: 8 additions & 8 deletions examples/models/llama/source_transformation/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ class KVCacheSHA(torch.nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
max_context_length: int,
n_heads: int,
head_dim: int,
dtype=torch.float32,
):
super().__init__()

# a buffer per head
cache_shape = (max_batch_size, max_seq_length, head_dim)
cache_shape = (max_batch_size, max_context_length, head_dim)
for i in range(n_heads):
self.register_buffer(
f"past_k_caches_{i}",
Expand Down Expand Up @@ -79,7 +79,7 @@ class SDPASHA(torch.nn.Module):
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
max_context_length: int,
n_heads: int,
n_rep: int,
head_dim: int,
Expand All @@ -90,7 +90,7 @@ def __init__(
self.n_rep = n_rep
self.dim = dim
self.kv_cache = KVCacheSHA(
max_batch_size, max_seq_length, n_heads // n_rep, head_dim
max_batch_size, max_context_length, n_heads // n_rep, head_dim
)
self.scale_factor = math.sqrt(head_dim)

Expand Down Expand Up @@ -134,11 +134,11 @@ def __init__(self, attention_mha: nn.Module):
self.n_rep = self.n_heads // self.n_kv_heads
self.dim = attention_mha.dim
self.max_batch_size = attention_mha.max_batch_size
self.max_seq_len = attention_mha.max_seq_len
self.max_context_len = attention_mha.max_context_len
self.head_dim = attention_mha.dim // self.n_heads
self.SDPA = SDPASHA(
self.max_batch_size,
self.max_seq_len,
self.max_context_len,
self.n_heads,
self.n_rep,
self.head_dim,
Expand Down Expand Up @@ -184,8 +184,8 @@ def __init__(self, attention_mha: nn.Module):

causal_mask = torch.tril(
torch.ones(
self.max_seq_len,
self.max_seq_len,
self.max_context_len,
self.max_context_len,
dtype=torch.bool,
device="cpu",
)
Expand Down
20 changes: 13 additions & 7 deletions examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(
self.apply_rotary_emb_to_k = hf_apply_rotary_emb_to_k
else:
self.apply_rotary_emb_to_k = apply_rotary_emb_to_k
self.max_seq_length = window_size + sink_size
assert self.max_seq_length == self.params.max_seq_len
self.max_context_length = window_size + sink_size
assert self.max_context_length == self.params.max_context_len
self.eviction_batch_size = eviction_batch_size
self.position_shift = 0

Expand All @@ -54,11 +54,14 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):

input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
if input_pos_item + self.position_shift + seq_len > self.max_context_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
input_pos_item
+ self.position_shift
- self.max_context_length
+ seq_len,
self.eviction_batch_size,
)
self.position_shift -= num_to_evict # pyre-ignore [8]
Expand Down Expand Up @@ -121,7 +124,7 @@ def __init__(
):
super().__init__(
max_batch_size=max_batch_size,
max_seq_length=window_size + sink_size,
max_context_length=window_size + sink_size,
n_heads=n_heads,
head_dim=head_dim,
enable_dynamic_shape=enable_dynamic_shape,
Expand All @@ -148,11 +151,14 @@ def evict_tokens(self, input_pos: torch.Tensor, seq_len: int) -> int:
"""
input_pos_item = input_pos.item()
torch._check_is_size(input_pos_item)
if input_pos_item + self.position_shift + seq_len > self.max_seq_length:
if input_pos_item + self.position_shift + seq_len > self.max_context_length:
# There are not enough spaces in the cache to store the new tokens.
# We need to evict some old tokens and shift some recent tokens.
num_to_evict = max(
input_pos_item + self.position_shift - self.max_seq_length + seq_len,
input_pos_item
+ self.position_shift
- self.max_context_length
+ seq_len,
self.eviction_batch_size,
)
num_to_keep = (
Expand Down
Loading

0 comments on commit 224d0d3

Please sign in to comment.