From fe3126aa192e8f954a68f24eb5c5ec13d7e48ae6 Mon Sep 17 00:00:00 2001 From: Xiang Si Date: Thu, 7 Nov 2024 21:55:46 +0000 Subject: [PATCH] optimizations for offline mlperf inference --- MaxText/inference_mlperf/offline_inference.py | 126 +++++++++++++----- 1 file changed, 92 insertions(+), 34 deletions(-) diff --git a/MaxText/inference_mlperf/offline_inference.py b/MaxText/inference_mlperf/offline_inference.py index 9ae66cd36..b2c5ac69b 100644 --- a/MaxText/inference_mlperf/offline_inference.py +++ b/MaxText/inference_mlperf/offline_inference.py @@ -18,6 +18,12 @@ import jax from jax import numpy as jnp import numpy as np +import queue +import os +import functools +import threading +import traceback +import signal from jetstream.engine import engine_api @@ -35,9 +41,21 @@ class InputData: true_length: int +class JetThread(threading.Thread): + + def run(self): + try: + super().run() + except Exception as e: # pylint: disable=broad-exception-caught + print(f"Thread {self.name} encountered an error: {e}") + traceback.print_exc() + os.kill(os.getpid(), signal.SIGKILL) + + class OfflineInference: def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.Engine): + self.live = False self.engine = engine self.decode_state = None if params is None: @@ -55,6 +73,7 @@ def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.En self._cached_pref = {} self._cached_generate = None + self.detokenize_backlog = queue.Queue(10) def init_decode_state(self): if self.decode_state is None: @@ -75,8 +94,17 @@ def warmup(self, max_length, warmup_samples): for length in interesting_buckets: if length > max_length: break - + log.info(f"Compiling prefill: {length}") + input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32")) + self._cached_pref[length] = ( + jax.jit(self._prefill_insert, donate_argnums=(4,)) + .lower(self.params, tokens=input_data, slot=0, true_length=length - 1, decode_state=self.decode_state) + .compile() + ) self.batch_inference(warmup_samples, desc="warmup") + self._cached_generate = ( + jax.jit(self.engine.generate, donate_argnums=(1,)).lower(self.params, self.decode_state).compile() + ) def _prefill_insert(self, params, tokens, slot, true_length, decode_state): """return decodestate.""" @@ -99,7 +127,7 @@ def batch_inference_with_callback( def prefill(slot, tokens, true_length): nonlocal self if self.dummy: - log.debug("dummy prefill") + log.info("dummy prefill") return 123 prefill_fn = self._prefill_insert @@ -109,7 +137,7 @@ def prefill(slot, tokens, true_length): first_token, self.decode_state = prefill_fn( self.params, tokens=tokens, slot=slot, true_length=true_length, decode_state=self.decode_state ) - return first_token.data[0][0].item() + return first_token empty_slots = list(range(self.batch_size)) slot_to_id = {} @@ -119,12 +147,10 @@ def prefill(slot, tokens, true_length): dummy_length = 1 def decode(): - log.debug("decode") nonlocal self - nonlocal slot_to_id nonlocal dummy_length if self.dummy: - log.debug("Dummy generate") + log.info("Dummy generate") res = engine_api.ResultTokens( data=np.array([[123, 1, dummy_length]] * self.batch_size), tokens_idx=(0, 0), @@ -138,51 +164,80 @@ def decode(): gen_fn = self.engine.generate if self._cached_generate is not None: gen_fn = self._cached_generate - self.decode_state, result_tokens = gen_fn(self.params, self.decode_state) + result_tokens_l = [] + for i in range(5): + self.decode_state, result_tokens = gen_fn(self.params, self.decode_state) + result_tokens_l.append(result_tokens) + for i in range(5): + result_tokens = result_tokens_l[i].convert_to_numpy() + self.detokenize_backlog.put((result_tokens, False, 0, 0), block=True) + # log.info(f"Decode put result {i} to queue") - result_tokens = result_tokens.convert_to_numpy() - - newly_empty = [] - for slot, id_ in slot_to_id.items(): - token, is_valid, length = result_tokens.data[slot] - log.debug(f"slot is {slot}, length is {length}") - should_finish = False - if is_valid: - should_finish = emit_token(id_, token.item()) - if should_finish or length >= self.max_decode_length: - newly_empty.append(slot) - - # Add slots of those that are empty to empty - for slot in newly_empty: - del slot_to_id[slot] - empty_slots.append(slot) + def detokenize(): + nonlocal self + nonlocal slot_to_id + nonlocal empty_slots + while self.live: + # log.info("Detokenize start") + newly_empty = [] + result_tokens, is_first_token, row_id, _slot = self.detokenize_backlog.get(block=True) + # log.info("Detokenize get from queue") + if is_first_token: + first_token = result_tokens.data[0][0].item() + should_terminate = emit_first_token(row_id, first_token) + if not should_terminate: + slot_to_id[_slot] = row_id + else: + empty_slots.append(_slot) + continue + for slot, id_ in slot_to_id.items(): + token, is_valid, length = result_tokens.data[slot] + log.debug(f"slot is {slot}, length is {length}") + should_finish = False + if is_valid: + should_finish = emit_token(id_, token.item()) + if should_finish or length >= self.max_decode_length: + newly_empty.append(slot) + log.info(f"Detokenize free up {slot}, length {length}") + # Add slots of those that are empty to empty + for slot in newly_empty: + del slot_to_id[slot] + empty_slots.append(slot) + if newly_empty and self.detokenize_backlog.qsize() == 0 and len(slot_to_id.items()) == 0: + break + detokenize_thread = JetThread( + target=functools.partial( + detokenize, + ), + name="detokenize", + ) + self.live = True + detokenize_thread.start() for row in data: - log.debug(f"empty_slots {len(empty_slots)}") while not empty_slots: # If slots are all full, decode until there are free slots # to insert num_decodes += 1 - log.debug(f"decode-{desc}-{num_decodes}") + log.info(f"decode-{desc}-{num_decodes}") decode() # do one insert num_tokens = len(row.tokens) num_prefills[num_tokens] = 0 if num_tokens not in num_prefills else num_prefills[num_tokens] + 1 - log.debug( - f"prefill-{desc}-{num_prefills} num_tokens {num_tokens} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}" + log.info( + f"prefill-{desc}-{num_prefills} num_prefills {sum(num_prefills.values())} num_tokens {num_tokens} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}" ) slot = empty_slots.pop() first_token = prefill(slot, row.tokens, row.true_length) - should_terminate = emit_first_token(row.id, first_token) - if not should_terminate: - slot_to_id[slot] = row.id - else: - empty_slots.append(slot) # dont use the slot + self.detokenize_backlog.put((first_token, True, row.id, slot), block=True) while slot_to_id: - log.debug(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}") + log.info(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}") num_decodes += 1 decode() + + self.live = False + detokenize_thread.join() log.info(f"summary-{desc}-prefills-{num_prefills}-decodes-{num_decodes} completed.") def batch_inference(self, data: List[InputData], desc=""): @@ -191,7 +246,10 @@ def batch_inference(self, data: List[InputData], desc=""): def callback(id_, token): nonlocal res - res[id_].append(token) + if token == self.tokenizer.eos_id: + log.info(f"res[{id_}] eos") + if not res[id_] or res[id_][-1] != self.tokenizer.eos_id: + res[id_].append(token) return token == self.tokenizer.eos_id self.batch_inference_with_callback(data, emit_first_token=callback, emit_token=callback, desc=desc)