diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index adaf15b55..df3ece6bb 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -264,3 +264,8 @@ vertex_tensorboard_region: "" # If set to True, MaxText will perform extra checks using jax.checkify. Note that this will effect performance. max_checkify: False + +# Inference +inference_microbenchmark_prefill_lengths: "64,128,256,512,1024" +inference_microbenchmark_stages: "prefill,generate" +inference_microbenchmark_loop_iters: 10 diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index 8fe46b3e2..232129932 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -21,26 +21,13 @@ import sys from jetstream.engine import token_utils + import max_utils import maxengine import maxtext_utils import pyconfig -def summarize_pytree_data(params, name="Params"): - """Generate basic metrics of a given Pytree.""" - num_params, total_param_size, avg_param_size = max_utils.summarize_size_from_pytree(params) - num_params_in_billions = num_params / 1e9 - total_param_size_in_gb = total_param_size / 1e9 - print( - f"{name} stats: \n" - f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" - f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" - f"\tAvg size: {avg_param_size:.3f} bytes\n" - ) - return num_params, total_param_size, avg_param_size - - def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, profile_name=""): """Inner loop for benchmarking prefill step.""" max_utils.activate_profiler(config, profile_name) @@ -49,24 +36,29 @@ def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_le slot = int(i % (jax.device_count() * config.per_device_batch_size)) prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) decode_state = engine.insert(prefill_result, decode_state, slot=slot) + max_utils.delete_pytree(prefill_result) jax.block_until_ready(decode_state) end = datetime.datetime.now() max_utils.deactivate_profiler(config) return (end - start).total_seconds(), decode_state - def prefill_benchmark( config, engine, params, decode_state, tokens, true_length, iters=100, profile_name="", num_model_params=None ): """Handles init, warmup, running prefill benchmark, and printing results.""" if num_model_params is None: - num_model_params, _, _ = summarize_pytree_data(params, name="Params") + num_model_params, _, _ = max_utils.summarize_pytree_data(params, name="Params") prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) decode_state = engine.insert(prefill_result, decode_state, slot=0) - jax.block_until_ready(decode_state) + max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True) + max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache") + max_utils.summarize_pytree_data(prefill_result["next_pos"], name="Prefill Next pos", raw=True) + max_utils.summarize_pytree_data(prefill_result["generated_tokens"], name="Prefill Generated Tokens", raw=True) + max_utils.delete_pytree(prefill_result) prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) decode_state = engine.insert(prefill_result, decode_state, slot=0) + max_utils.delete_pytree(prefill_result) jax.block_until_ready(decode_state) print(f"Prefill results for length {tokens.size}:\n") @@ -106,9 +98,9 @@ def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name= def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_size=None, profile_name="", iters=100): """Handles init, warmup, running ar benchmark, and printing results.""" if cache_size is None: - _, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache") + _, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache") if model_size is None: - _, model_size, _ = summarize_pytree_data(params, name="Params") + _, model_size, _ = max_utils.summarize_pytree_data(params, name="Params") global_batch_size = jax.device_count() * config.per_device_batch_size # Warmup @@ -165,43 +157,52 @@ def write_results(results, filename=""): def print_results_for_analyze(results): - prefill_bucket_size_to_ms = {} - for k, v in results["Prefill"].items(): - prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3) - print("\nFor usage in analyze_sharegpt.py :") - print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}") - print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}") + if "Prefill" in results: + prefill_bucket_size_to_ms = {} + for k, v in results["Prefill"].items(): + prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3) + print("\nFor usage in analyze_sharegpt.py :") + print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}") + + if "AutoRegressive" in results: + print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}") def main(config): engine = maxengine.MaxEngine(config) params = engine.load_params() - prefill_lengths = [64, 128, 256, 512, 1024] - benchmark_loop_iters = 10 + prefill_lengths = [int(l) for l in config.inference_microbenchmark_prefill_lengths.split(",")] + stages_to_benchmark = config.inference_microbenchmark_stages.split(",") + benchmark_loop_iters = config.inference_microbenchmark_loop_iters + text = config.prompt metadata = engine.get_tokenizer() vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) decode_state = engine.init_decode_state() - _, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache") - num_model_params, model_size, _ = summarize_pytree_data(params, name="Model") - - benchmark_results = {"Prefill": {}} - benchmark_results["AutoRegressive"], decode_state = ar_benchmark( - config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size - ) - for prefill_length in prefill_lengths: - tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True, prefill_lengths=[prefill_length]) - benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark( - config, - engine, - params, - decode_state, - tokens, - true_length, - iters=benchmark_loop_iters, - num_model_params=num_model_params, - ) + _, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache") + num_model_params, model_size, _ = max_utils.summarize_pytree_data(params, name="Model") + + benchmark_results = {} + if "prefill" in stages_to_benchmark: + benchmark_results["Prefill"] = {} + for prefill_length in prefill_lengths: + tokens, true_length = token_utils.tokenize_and_pad( + text, vocab, is_bos=True, prefill_lengths=[prefill_length]) + benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark( + config, + engine, + params, + decode_state, + tokens, + true_length, + iters=benchmark_loop_iters, + num_model_params=num_model_params + ) + + if "generate" in stages_to_benchmark: + benchmark_results["AutoRegressive"], decode_state = ar_benchmark( + config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size) results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params) write_results(results, filename="") diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 2e218c48a..d4de84b41 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -652,3 +652,30 @@ def get_project(): max_logging.log("You must specify config.vertex_tensorboard_project or set 'gcloud config set project '") return None return project_outputs[-1] + + +def delete_pytree(p): + def delete_leaf(leaf): + if isinstance(leaf, jax.Array): + leaf.delete() + del leaf + + jax.tree_map(delete_leaf, p) + + +def summarize_pytree_data(params, name="Params", raw=False): + """Generate basic metrics of a given Pytree.""" + num_params, total_param_size, avg_param_size = summarize_size_from_pytree(params) + if not raw: + num_params_in_billions = num_params / 1e9 + total_param_size_in_gb = total_param_size / 1e9 + print(f"{name} stats: \n" + f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" + f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" + f"\tAvg size: {avg_param_size:.3f} bytes\n") + else: + print(f"{name} stats: \n" + f"\tTotal number of params: {num_params:.3f} \n" + f"\tTotal memory usage: {total_param_size:.3f} bytes \n" + f"\tAvg size: {avg_param_size:.3f} bytes\n") + return num_params, total_param_size, avg_param_size