diff --git a/MaxText/inference_mlperf/offline_inference.py b/MaxText/inference_mlperf/offline_inference.py index f5645985a..82699c188 100644 --- a/MaxText/inference_mlperf/offline_inference.py +++ b/MaxText/inference_mlperf/offline_inference.py @@ -103,31 +103,42 @@ def warmup(self, max_length, warmup_samples): .lower(self.params, tokens=input_data, slot=0, true_length=length - 1, decode_state=self.decode_state) .compile() ) - # input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32")) - # example_seq_len=16 - # num_prompts = max_length//length - # self._cached_pref_batch[length] = ( - # jax.jit(self._prefill_insert_batch, donate_argnums=(4,)) - # .lower( - # self.params, - # tokens=input_data_batch, - # slots=jnp.arange(0, example_seq_len), - # num_prompts = 16, - # decoder_positions = jnp.arange(0, max_length), - # decoder_segment_ids = jnp.ones(max_length), - # start_pos=jnp.arange(0, max_length, max_length//example_seq_len), - # padded_lengths=jnp.arange(0, max_length, max_length//example_seq_len), - # true_lengths=jnp.arange(0, max_length, max_length//example_seq_len), - # decode_state=self.decode_state) - # .compile() - # ) - self.batch_inference(warmup_samples, desc="warmup") + if length == 64 or length == 1024: + continue + log.info(f"Compiling batched prefill: {length}") + input_data_batch = jax.ShapeDtypeStruct((max_length,), jnp.dtype("int32")) + num_prompts = max_length // length + self._cached_pref_batch[length] = ( + jax.jit( + self._prefill_insert_batch, + static_argnames=( + "num_prompts", + "padded_length", + ), + donate_argnames=("decode_state",), + ) + .lower( + self.params, + tokens=input_data_batch, + slots=jnp.arange(0, 8, dtype=int), + num_prompts=num_prompts, + decoder_positions=jnp.arange(0, max_length, dtype=int), + decoder_segment_ids=jnp.ones(max_length, dtype=int), + start_pos=jnp.arange(0, max_length, 128, dtype=int), + padded_length=length, + true_lengths=jnp.full(8, length, dtype=int), + decode_state=self.decode_state, + ) + .compile() + ) self._cached_generate = ( jax.jit(self.engine.generate, donate_argnums=(1,)).lower(self.params, self.decode_state).compile() ) + self.batch_inference(warmup_samples, desc="warmup") def _prefill_insert(self, params, tokens, slot, true_length, decode_state): """return decodestate.""" + padded_len = tokens.shape[0] prefill_result, first_token = self.engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) decode_state = self.engine.insert(prefill_result, decode_state, slot) return first_token, decode_state @@ -141,12 +152,12 @@ def _prefill_insert_batch( decoder_positions, decoder_segment_ids, start_pos, - padded_lengths, + padded_length, true_lengths, decode_state, ): """return decodestate.""" - prefill_results, first_tokens = self.engine.prefill_concat( + cache, prefill_results, first_tokens = self.engine.prefill_concat( params=params, padded_tokens=tokens, decoder_positions=decoder_positions, @@ -155,20 +166,15 @@ def _prefill_insert_batch( true_lengths=true_lengths, num_prompts=num_prompts, ) - # decode_state = jax.lax.fori_loop( - # 0, num_prompts, - # lambda i, state: self.engine.insert( - # prefill_results[i], - # state, - # slot=slots[i], - # start_idx = start_pos[i], - # seq_len = padded_lengths[i]), - # decode_state - # ) - for i in range(num_prompts): - decode_state = self.engine.insert_partial( - prefill_results[i], decode_state, slots[i], start_idx=start_pos[i].item(), seq_len=padded_lengths[i].item() - ) + decode_state = self.engine.insert_partial( + prefill_results, + decode_state, + cache, + slots, + num_prompts=num_prompts, + start_indices=start_pos, + seq_len=padded_length, + ) return first_tokens, decode_state def batch_inference_with_callback( @@ -188,7 +194,7 @@ def prefill(prefill_bucket, prefill_len): if self.dummy: log.info("dummy prefill") return 123 - if not self.enable_batch_prefill or prefill_len * len(prefill_bucket) != 1024: + if not self.enable_batch_prefill or prefill_len in (64, 1024) or prefill_len * len(prefill_bucket) != 1024: prefill_result = [] prefill_fn = self._prefill_insert if (cached := self._cached_pref.get(prefill_len)) is not None: @@ -219,7 +225,6 @@ def prefill(prefill_bucket, prefill_len): tokens = jnp.concat([row.tokens for (slot, row) in prefill_bucket]) slots = [slot for (slot, row) in prefill_bucket] - padded_lengths = [row.tokens.shape[0] for (slot, row) in prefill_bucket] true_lengths = [row.true_length for (slot, row) in prefill_bucket] start_pos = np.cumsum([0] + [row.tokens.shape[0] for (slot, row) in prefill_bucket])[:-1] start_pos = start_pos.tolist() @@ -230,23 +235,20 @@ def pad_num_prompts_len_array(array_to_pad, pad_len): array_to_pad.extend([0] * (pad_len - len(array_to_pad))) return jnp.array(array_to_pad) - slots = pad_num_prompts_len_array(slots, 16) - padded_lengths = pad_num_prompts_len_array(padded_lengths, 16) - true_lengths = pad_num_prompts_len_array(true_lengths, 16) - start_pos = pad_num_prompts_len_array(start_pos, 16) - + slots = pad_num_prompts_len_array(slots, 8) + true_lengths = pad_num_prompts_len_array(true_lengths, 8) + start_pos = pad_num_prompts_len_array(start_pos, 8) + # this lowered function has static input for num_prompts and padded_length first_tokens, self.decode_state = prefill_fn( self.params, tokens=tokens, slots=slots, - num_prompts=len(prefill_bucket), decoder_positions=positions, decoder_segment_ids=sequence_indicator, start_pos=start_pos, - padded_lengths=padded_lengths, true_lengths=true_lengths, decode_state=self.decode_state, - ) + ) #pytype: disable=missing-parameter prefill_result = [(first_tokens[idx], slot, row) for (idx, (slot, row)) in enumerate(prefill_bucket)] return prefill_result @@ -345,9 +347,9 @@ def detokenize(): total_num_prefills += 1 log.info(f"Total num prefill: {total_num_prefills}") slot = empty_slots.pop() - # directly prefill prompts with 64 or less tokens - if num_tokens == 64: - first_token, slot, row = prefill([(slot, row)], 64)[0] + # directly prefill prompts with 64 or less tokens, and with 1024 tokens + if num_tokens in (64, 1024) or not self.enable_batch_prefill: + first_token, slot, row = prefill([(slot, row)], num_tokens)[0] self.detokenize_backlog.put((first_token, True, row.id, slot), block=True) continue self.prefill_buckets[num_tokens].append((slot, row)) diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index 988244d2d..82bd6ce78 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -16,6 +16,7 @@ import copy as cp import functools from typing import Any, Optional, Tuple, Callable +from collections import defaultdict import flax from flax import linen as nn @@ -402,19 +403,20 @@ def process_packed_logits_and_caches(packed_flat_logits, idx): ) return { "logits": selected_logits, - "cache": cache, "next_pos": next_pos, "generated_tokens": generated_tokens, "tokens": first_generated_token, }, result - prefill_results = [] + prefill_results = defaultdict(list) first_tokens = [] for idx in range(num_prompts): prefill_result, first_token = process_packed_logits_and_caches(flat_logits, idx) - prefill_results.append(prefill_result) + for k, v in prefill_result.items(): + prefill_results[k].append(v) first_tokens.append(first_token) - return prefill_results, first_tokens + prefill_results = {k: jnp.stack(v) for k, v in prefill_results.items()} + return cache, prefill_results, first_tokens @functools.partial(jax.jit, static_argnums=(0,), donate_argnums=(2,)) def generate( @@ -577,7 +579,10 @@ def copy(path, partial_cache, full_cache, annotations): @functools.partial( jax.jit, static_argnums=(0,), - static_argnames=("seq_len",), + static_argnames=( + "num_prompts", + "seq_len", + ), donate_argnums=( 1, 2, @@ -585,24 +590,22 @@ def copy(path, partial_cache, full_cache, annotations): ) def insert_partial( self, - prefix: Prefix, + prefix: PackedPrefix, decode_state: DecodeState, - slot: int, + cache: Any, + slots: jax.Array, *, - start_idx: int, + start_indices: jax.Array, + num_prompts: int, seq_len: int, ) -> DecodeState: - """Insert a sequence of several prefixes into KV cache.""" + """Insert into KV cache""" unboxed_prefix = max_utils.unbox_logicallypartioned(prefix) + cache_unboxed = max_utils.unbox_logicallypartioned(cache) + cache_unboxed = self._maybe_unstack_prefill_result_cache(cache_unboxed) + start_idx = 0 + slot = slots[0] - unboxed_prefix["cache"] = self._maybe_unstack_prefill_result_cache(unboxed_prefix["cache"]) - - # jax.debug.print("Inserting cache slot {} start_idx {} seq_len {}", slot, start_idx, seq_len) - # example = unboxed_prefix["cache"]["decoder"]['layers_0']['self_attention']['AttentionOp_0'] - # for key in example.keys(): - # jax.debug.print("{} shape: {}", key, example[key].shape) - # jax.debug.print("-----------------------------") - # jax.debug.print(self.config.prefill_cache_axis_order) def copy(path, partial_cache, full_cache, annotations): path_key = path[-1].key if path_key in [ @@ -645,11 +648,8 @@ def copy(path, partial_cache, full_cache, annotations): "cached_prefill_value_scale", ]: seqlen_index = self.config.prefill_cache_axis_order.split(",").index("1") - start_indices = jnp.zeros(4, dtype=int) - start_indices = jax.lax.dynamic_update_slice( - start_indices, jnp.array(start_idx, dtype=int, ndmin=1), (seqlen_index,) - ) - # start_indices[seqlen_index] = start_idx + start_indices = [0, 0, 0, 0] + start_indices[seqlen_index] = start_idx slice_size = list(partial_cache.shape) slice_size[seqlen_index] = seq_len @@ -661,21 +661,25 @@ def copy(path, partial_cache, full_cache, annotations): else: raise ValueError(f"We don't have a strategy for inserting {path_key}") - inserted_cache = jax.tree_util.tree_map_with_path( - copy, - unboxed_prefix["cache"], - decode_state["cache"], - self.kv_cache_annotations_named, - ) - inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0) - inserted_next_pos = jax.lax.dynamic_update_index_in_dim(decode_state["next_pos"], unboxed_prefix["next_pos"], slot, 0) - inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim( - decode_state["generated_tokens"], - unboxed_prefix["generated_tokens"], - slot, - 0, - ) - inserted_tokens = jax.lax.dynamic_update_index_in_dim(decode_state["tokens"], unboxed_prefix["tokens"], slot, 0) + inserted_cache = decode_state["cache"] + inserted_logits = decode_state["logits"] + inserted_next_pos = decode_state["next_pos"] + inserted_generated_tokens = decode_state["generated_tokens"] + inserted_tokens = decode_state["tokens"] + + for i in range(num_prompts): + start_idx = start_indices[i] + slot = slots[i] + inserted_cache = jax.tree_util.tree_map_with_path(copy, cache_unboxed, inserted_cache, self.kv_cache_annotations_named) + inserted_logits = jax.lax.dynamic_update_index_in_dim(inserted_logits, unboxed_prefix["logits"][i, ...], slot, 0) + inserted_next_pos = jax.lax.dynamic_update_index_in_dim(inserted_next_pos, unboxed_prefix["next_pos"][i, ...], slot, 0) + inserted_generated_tokens = jax.lax.dynamic_update_index_in_dim( + inserted_generated_tokens, + unboxed_prefix["generated_tokens"][i, ...], + slot, + 0, + ) + inserted_tokens = jax.lax.dynamic_update_index_in_dim(inserted_tokens, unboxed_prefix["tokens"][i, ...], slot, 0) inserted_logits = jax.lax.with_sharding_constraint(inserted_logits, self.replicated_sharding) inserted_generated_tokens = jax.lax.with_sharding_constraint(inserted_generated_tokens, self.replicated_sharding)