Skip to content

Commit

Permalink
Merge pull request #1017 from AI-Hypercomputer:sixiang-inference
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694569324
  • Loading branch information
maxtext authors committed Nov 8, 2024
2 parents dd2726c + fe3126a commit 1ff8505
Showing 1 changed file with 92 additions and 34 deletions.
126 changes: 92 additions & 34 deletions MaxText/inference_mlperf/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
import jax
from jax import numpy as jnp
import numpy as np
import queue
import os
import functools
import threading
import traceback
import signal

from jetstream.engine import engine_api

Expand All @@ -35,9 +41,21 @@ class InputData:
true_length: int


class JetThread(threading.Thread):

def run(self):
try:
super().run()
except Exception as e: # pylint: disable=broad-exception-caught
print(f"Thread {self.name} encountered an error: {e}")
traceback.print_exc()
os.kill(os.getpid(), signal.SIGKILL)


class OfflineInference:

def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.Engine):
self.live = False
self.engine = engine
self.decode_state = None
if params is None:
Expand All @@ -55,6 +73,7 @@ def __init__(self, engine: engine_api.Engine, params, base_engine: engine_api.En

self._cached_pref = {}
self._cached_generate = None
self.detokenize_backlog = queue.Queue(10)

def init_decode_state(self):
if self.decode_state is None:
Expand All @@ -75,8 +94,17 @@ def warmup(self, max_length, warmup_samples):
for length in interesting_buckets:
if length > max_length:
break

log.info(f"Compiling prefill: {length}")
input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32"))
self._cached_pref[length] = (
jax.jit(self._prefill_insert, donate_argnums=(4,))
.lower(self.params, tokens=input_data, slot=0, true_length=length - 1, decode_state=self.decode_state)
.compile()
)
self.batch_inference(warmup_samples, desc="warmup")
self._cached_generate = (
jax.jit(self.engine.generate, donate_argnums=(1,)).lower(self.params, self.decode_state).compile()
)

def _prefill_insert(self, params, tokens, slot, true_length, decode_state):
"""return decodestate."""
Expand All @@ -99,7 +127,7 @@ def batch_inference_with_callback(
def prefill(slot, tokens, true_length):
nonlocal self
if self.dummy:
log.debug("dummy prefill")
log.info("dummy prefill")
return 123

prefill_fn = self._prefill_insert
Expand All @@ -109,7 +137,7 @@ def prefill(slot, tokens, true_length):
first_token, self.decode_state = prefill_fn(
self.params, tokens=tokens, slot=slot, true_length=true_length, decode_state=self.decode_state
)
return first_token.data[0][0].item()
return first_token

empty_slots = list(range(self.batch_size))
slot_to_id = {}
Expand All @@ -119,12 +147,10 @@ def prefill(slot, tokens, true_length):
dummy_length = 1

def decode():
log.debug("decode")
nonlocal self
nonlocal slot_to_id
nonlocal dummy_length
if self.dummy:
log.debug("Dummy generate")
log.info("Dummy generate")
res = engine_api.ResultTokens(
data=np.array([[123, 1, dummy_length]] * self.batch_size),
tokens_idx=(0, 0),
Expand All @@ -138,51 +164,80 @@ def decode():
gen_fn = self.engine.generate
if self._cached_generate is not None:
gen_fn = self._cached_generate
self.decode_state, result_tokens = gen_fn(self.params, self.decode_state)
result_tokens_l = []
for i in range(5):
self.decode_state, result_tokens = gen_fn(self.params, self.decode_state)
result_tokens_l.append(result_tokens)
for i in range(5):
result_tokens = result_tokens_l[i].convert_to_numpy()
self.detokenize_backlog.put((result_tokens, False, 0, 0), block=True)
# log.info(f"Decode put result {i} to queue")

result_tokens = result_tokens.convert_to_numpy()

newly_empty = []
for slot, id_ in slot_to_id.items():
token, is_valid, length = result_tokens.data[slot]
log.debug(f"slot is {slot}, length is {length}")
should_finish = False
if is_valid:
should_finish = emit_token(id_, token.item())
if should_finish or length >= self.max_decode_length:
newly_empty.append(slot)

# Add slots of those that are empty to empty
for slot in newly_empty:
del slot_to_id[slot]
empty_slots.append(slot)
def detokenize():
nonlocal self
nonlocal slot_to_id
nonlocal empty_slots
while self.live:
# log.info("Detokenize start")
newly_empty = []
result_tokens, is_first_token, row_id, _slot = self.detokenize_backlog.get(block=True)
# log.info("Detokenize get from queue")
if is_first_token:
first_token = result_tokens.data[0][0].item()
should_terminate = emit_first_token(row_id, first_token)
if not should_terminate:
slot_to_id[_slot] = row_id
else:
empty_slots.append(_slot)
continue
for slot, id_ in slot_to_id.items():
token, is_valid, length = result_tokens.data[slot]
log.debug(f"slot is {slot}, length is {length}")
should_finish = False
if is_valid:
should_finish = emit_token(id_, token.item())
if should_finish or length >= self.max_decode_length:
newly_empty.append(slot)
log.info(f"Detokenize free up {slot}, length {length}")
# Add slots of those that are empty to empty
for slot in newly_empty:
del slot_to_id[slot]
empty_slots.append(slot)
if newly_empty and self.detokenize_backlog.qsize() == 0 and len(slot_to_id.items()) == 0:
break

detokenize_thread = JetThread(
target=functools.partial(
detokenize,
),
name="detokenize",
)
self.live = True
detokenize_thread.start()
for row in data:
log.debug(f"empty_slots {len(empty_slots)}")
while not empty_slots:
# If slots are all full, decode until there are free slots
# to insert
num_decodes += 1
log.debug(f"decode-{desc}-{num_decodes}")
log.info(f"decode-{desc}-{num_decodes}")
decode()
# do one insert
num_tokens = len(row.tokens)
num_prefills[num_tokens] = 0 if num_tokens not in num_prefills else num_prefills[num_tokens] + 1
log.debug(
f"prefill-{desc}-{num_prefills} num_tokens {num_tokens} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}"
log.info(
f"prefill-{desc}-{num_prefills} num_prefills {sum(num_prefills.values())} num_tokens {num_tokens} true_length {row.true_length} num_empty_slots {len(empty_slots)} num_decodes {num_decodes}"
)
slot = empty_slots.pop()
first_token = prefill(slot, row.tokens, row.true_length)
should_terminate = emit_first_token(row.id, first_token)
if not should_terminate:
slot_to_id[slot] = row.id
else:
empty_slots.append(slot) # dont use the slot
self.detokenize_backlog.put((first_token, True, row.id, slot), block=True)

while slot_to_id:
log.debug(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}")
log.info(f"decode-{desc}-{num_decodes} num_filled_slots {len(slot_to_id)}")
num_decodes += 1
decode()

self.live = False
detokenize_thread.join()
log.info(f"summary-{desc}-prefills-{num_prefills}-decodes-{num_decodes} completed.")

def batch_inference(self, data: List[InputData], desc=""):
Expand All @@ -191,7 +246,10 @@ def batch_inference(self, data: List[InputData], desc=""):

def callback(id_, token):
nonlocal res
res[id_].append(token)
if token == self.tokenizer.eos_id:
log.info(f"res[{id_}] eos")
if not res[id_] or res[id_][-1] != self.tokenizer.eos_id:
res[id_].append(token)
return token == self.tokenizer.eos_id

self.batch_inference_with_callback(data, emit_first_token=callback, emit_token=callback, desc=desc)
Expand Down

0 comments on commit 1ff8505

Please sign in to comment.