Skip to content

Commit

Permalink
jit entire prefill and insert for packed sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
sixiang-google committed Jan 15, 2025
1 parent 5530f34 commit 1d38ffe
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 85 deletions.
98 changes: 50 additions & 48 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,31 +103,42 @@ def warmup(self, max_length, warmup_samples):
.lower(self.params, tokens=input_data, slot=0, true_length=length - 1, decode_state=self.decode_state)
.compile()
)
# input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32"))
# example_seq_len=16
# num_prompts = max_length//length
# self._cached_pref_batch[length] = (
# jax.jit(self._prefill_insert_batch, donate_argnums=(4,))
# .lower(
# self.params,
# tokens=input_data_batch,
# slots=jnp.arange(0, example_seq_len),
# num_prompts = 16,
# decoder_positions = jnp.arange(0, max_length),
# decoder_segment_ids = jnp.ones(max_length),
# start_pos=jnp.arange(0, max_length, max_length//example_seq_len),
# padded_lengths=jnp.arange(0, max_length, max_length//example_seq_len),
# true_lengths=jnp.arange(0, max_length, max_length//example_seq_len),
# decode_state=self.decode_state)
# .compile()
# )
self.batch_inference(warmup_samples, desc="warmup")
if length == 64 or length == 1024:
continue
log.info(f"Compiling batched prefill: {length}")
input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32"))
num_prompts = max_length // length
self._cached_pref_batch[length] = (
jax.jit(
self._prefill_insert_batch,
static_argnames=(
"num_prompts",
"padded_length",
),
donate_argnames=("decode_state",),
)
.lower(
self.params,
tokens=input_data_batch,
slots=jnp.arange(0, 8, dtype=int),
num_prompts=num_prompts,
decoder_positions=jnp.arange(0, max_length, dtype=int),
decoder_segment_ids=jnp.ones(max_length, dtype=int),
start_pos=jnp.arange(0, max_length, 128, dtype=int),
padded_length=length,
true_lengths=jnp.full(8, length, dtype=int),
decode_state=self.decode_state,
)
.compile()
)
self._cached_generate = (
jax.jit(self.engine.generate, donate_argnums=(1,)).lower(self.params, self.decode_state).compile()
)
self.batch_inference(warmup_samples, desc="warmup")

def _prefill_insert(self, params, tokens, slot, true_length, decode_state):
"""return decodestate."""
padded_len = tokens.shape[0]
prefill_result, first_token = self.engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = self.engine.insert(prefill_result, decode_state, slot)
return first_token, decode_state
Expand All @@ -141,12 +152,12 @@ def _prefill_insert_batch(
decoder_positions,
decoder_segment_ids,
start_pos,
padded_lengths,
padded_length,
true_lengths,
decode_state,
):
"""return decodestate."""
prefill_results, first_tokens = self.engine.prefill_concat(
cache, prefill_results, first_tokens = self.engine.prefill_concat(
params=params,
padded_tokens=tokens,
decoder_positions=decoder_positions,
Expand All @@ -155,20 +166,15 @@ def _prefill_insert_batch(
true_lengths=true_lengths,
num_prompts=num_prompts,
)
# decode_state = jax.lax.fori_loop(
# 0, num_prompts,
# lambda i, state: self.engine.insert(
# prefill_results[i],
# state,
# slot=slots[i],
# start_idx = start_pos[i],
# seq_len = padded_lengths[i]),
# decode_state
# )
for i in range(num_prompts):
decode_state = self.engine.insert_partial(
prefill_results[i], decode_state, slots[i], start_idx=start_pos[i].item(), seq_len=padded_lengths[i].item()
)
decode_state = self.engine.insert_partial(
prefill_results,
decode_state,
cache,
slots,
num_prompts=num_prompts,
start_indices=start_pos,
seq_len=padded_length,
)
return first_tokens, decode_state

def batch_inference_with_callback(
Expand All @@ -188,7 +194,7 @@ def prefill(prefill_bucket, prefill_len):
if self.dummy:
log.info("dummy prefill")
return 123
if not self.enable_batch_prefill or prefill_len * len(prefill_bucket) != 1024:
if not self.enable_batch_prefill or prefill_len in (64, 1024) or prefill_len * len(prefill_bucket) != 1024:
prefill_result = []
prefill_fn = self._prefill_insert
if (cached := self._cached_pref.get(prefill_len)) is not None:
Expand Down Expand Up @@ -219,7 +225,6 @@ def prefill(prefill_bucket, prefill_len):
tokens = jnp.concat([row.tokens for (slot, row) in prefill_bucket])

slots = [slot for (slot, row) in prefill_bucket]
padded_lengths = [row.tokens.shape[0] for (slot, row) in prefill_bucket]
true_lengths = [row.true_length for (slot, row) in prefill_bucket]
start_pos = np.cumsum([0] + [row.tokens.shape[0] for (slot, row) in prefill_bucket])[:-1]
start_pos = start_pos.tolist()
Expand All @@ -230,23 +235,20 @@ def pad_num_prompts_len_array(array_to_pad, pad_len):
array_to_pad.extend([0] * (pad_len - len(array_to_pad)))
return jnp.array(array_to_pad)

slots = pad_num_prompts_len_array(slots, 16)
padded_lengths = pad_num_prompts_len_array(padded_lengths, 16)
true_lengths = pad_num_prompts_len_array(true_lengths, 16)
start_pos = pad_num_prompts_len_array(start_pos, 16)

slots = pad_num_prompts_len_array(slots, 8)
true_lengths = pad_num_prompts_len_array(true_lengths, 8)
start_pos = pad_num_prompts_len_array(start_pos, 8)
# this lowered function has static input for num_prompts and padded_length
first_tokens, self.decode_state = prefill_fn(
self.params,
tokens=tokens,
slots=slots,
num_prompts=len(prefill_bucket),
decoder_positions=positions,
decoder_segment_ids=sequence_indicator,
start_pos=start_pos,
padded_lengths=padded_lengths,
true_lengths=true_lengths,
decode_state=self.decode_state,
)
) #pytype: disable=missing-parameter
prefill_result = [(first_tokens[idx], slot, row) for (idx, (slot, row)) in enumerate(prefill_bucket)]

return prefill_result
Expand Down Expand Up @@ -345,9 +347,9 @@ def detokenize():
total_num_prefills += 1
log.info(f"Total num prefill: {total_num_prefills}")
slot = empty_slots.pop()
# directly prefill prompts with 64 or less tokens
if num_tokens == 64:
first_token, slot, row = prefill([(slot, row)], 64)[0]
# directly prefill prompts with 64 or less tokens, and with 1024 tokens
if num_tokens in (64, 1024) or not self.enable_batch_prefill:
first_token, slot, row = prefill([(slot, row)], num_tokens)[0]
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)
continue
self.prefill_buckets[num_tokens].append((slot, row))
Expand Down
78 changes: 41 additions & 37 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy as cp
import functools
from typing import Any, Optional, Tuple, Callable
from collections import defaultdict

import flax
from flax import linen as nn
Expand Down Expand Up @@ -402,19 +403,20 @@ def process_packed_logits_and_caches(packed_flat_logits, idx):
)
return {
"logits": selected_logits,
"cache": cache,
"next_pos": next_pos,
"generated_tokens": generated_tokens,
"tokens": first_generated_token,
}, result

prefill_results = []
prefill_results = defaultdict(list)
first_tokens = []
for idx in range(num_prompts):
prefill_result, first_token = process_packed_logits_and_caches(flat_logits, idx)
prefill_results.append(prefill_result)
for k, v in prefill_result.items():
prefill_results[k].append(v)
first_tokens.append(first_token)
return prefill_results, first_tokens
prefill_results = {k: jnp.stack(v) for k, v in prefill_results.items()}
return cache, prefill_results, first_tokens

@functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(2,))
def generate(
Expand Down Expand Up @@ -577,32 +579,33 @@ def copy(path, partial_cache, full_cache, annotations):
@functools.partial(
jax.jit,
static_argnums=(0,),
static_argnames=("seq_len",),
static_argnames=(
"num_prompts",
"seq_len",
),
donate_argnums=(
1,
2,
),
)
def insert_partial(
self,
prefix: Prefix,
prefix: PackedPrefix,
decode_state: DecodeState,
slot: int,
cache: Any,
slots: jax.Array,
*,
start_idx: int,
start_indices: jax.Array,
num_prompts: int,
seq_len: int,
) -> DecodeState:
"""Insert a sequence of several prefixes into KV cache."""
"""Insert into KV cache"""
unboxed_prefix = max_utils.unbox_logicallypartioned(prefix)
cache_unboxed = max_utils.unbox_logicallypartioned(cache)
cache_unboxed = self._maybe_unstack_prefill_result_cache(cache_unboxed)
start_idx = 0
slot = slots[0]

unboxed_prefix["cache"] = self._maybe_unstack_prefill_result_cache(unboxed_prefix["cache"])

# jax.debug.print("Inserting cache slot {} start_idx {} seq_len {}", slot, start_idx, seq_len)
# example = unboxed_prefix["cache"]["decoder"]['layers_0']['self_attention']['AttentionOp_0']
# for key in example.keys():
# jax.debug.print("{} shape: {}", key, example[key].shape)
# jax.debug.print("-----------------------------")
# jax.debug.print(self.config.prefill_cache_axis_order)
def copy(path, partial_cache, full_cache, annotations):
path_key = path[-1].key
if path_key in [
Expand Down Expand Up @@ -645,11 +648,8 @@ def copy(path, partial_cache, full_cache, annotations):
"cached_prefill_value_scale",
]:
seqlen_index = self.config.prefill_cache_axis_order.split(",").index("1")
start_indices = jnp.zeros(4, dtype=int)
start_indices = jax.lax.dynamic_update_slice(
start_indices, jnp.array(start_idx, dtype=int, ndmin=1), (seqlen_index,)
)
# start_indices[seqlen_index] = start_idx
start_indices = [0, 0, 0, 0]
start_indices[seqlen_index] = start_idx
slice_size = list(partial_cache.shape)
slice_size[seqlen_index] = seq_len

Expand All @@ -661,21 +661,25 @@ def copy(path, partial_cache, full_cache, annotations):
else:
raise ValueError(f"We don't have a strategy for inserting {path_key}")

inserted_cache = jax.tree_util.tree_map_with_path(
copy,
unboxed_prefix["cache"],
decode_state["cache"],
self.kv_cache_annotations_named,
)
inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0)
inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state["next_pos"], unboxed_prefix["next_pos"], slot, 0)
inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim(
decode_state["generated_tokens"],
unboxed_prefix["generated_tokens"],
slot,
0,
)
inserted_tokens = jax.lax.dynamic_update_index_in_dim(decode_state["tokens"], unboxed_prefix["tokens"], slot, 0)
inserted_cache = decode_state["cache"]
inserted_logits = decode_state["logits"]
inserted_next_pos = decode_state["next_pos"]
inserted_generated_tokens = decode_state["generated_tokens"]
inserted_tokens = decode_state["tokens"]

for i in range(num_prompts):
start_idx = start_indices[i]
slot = slots[i]
inserted_cache = jax.tree_util.tree_map_with_path(copy, cache_unboxed, inserted_cache, self.kv_cache_annotations_named)
inserted_logits = jax.lax.dynamic_update_index_in_dim(inserted_logits, unboxed_prefix["logits"][i, ...], slot, 0)
inserted_next_pos = jax.lax.dynamic_update_index_in_dim(inserted_next_pos, unboxed_prefix["next_pos"][i, ...], slot, 0)
inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim(
inserted_generated_tokens,
unboxed_prefix["generated_tokens"][i, ...],
slot,
0,
)
inserted_tokens = jax.lax.dynamic_update_index_in_dim(inserted_tokens, unboxed_prefix["tokens"][i, ...], slot, 0)

inserted_logits = jax.lax.with_sharding_constraint(inserted_logits, self.replicated_sharding)
inserted_generated_tokens = jax.lax.with_sharding_constraint(inserted_generated_tokens, self.replicated_sharding)
Expand Down

0 comments on commit 1d38ffe

Please sign in to comment.