Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Microbenchmark Profiling Memory Issues #597

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,8 @@ vertex_tensorboard_region: ""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in this PR, but I see a need for separate inference specific config files in future -- both base.yml and model specific config.

# If set to True, MaxText will perform extra checks using jax.checkify. Note that this will effect performance.
max_checkify: False

# Inference
inference_microbenchmark_prefill_lengths: "64,128,256,512,1024"
inference_microbenchmark_stages: "prefill,generate"
inference_microbenchmark_loop_iters: 10
93 changes: 47 additions & 46 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,13 @@
import sys

from jetstream.engine import token_utils

import max_utils
import maxengine
import maxtext_utils
import pyconfig


def summarize_pytree_data(params, name="Params"):
"""Generate basic metrics of a given Pytree."""
num_params, total_param_size, avg_param_size = max_utils.summarize_size_from_pytree(params)
num_params_in_billions = num_params / 1e9
total_param_size_in_gb = total_param_size / 1e9
print(
f"{name} stats: \n"
f"\tTotal number of params: {num_params_in_billions:.3f} billion \n"
f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n"
f"\tAvg size: {avg_param_size:.3f} bytes\n"
)
return num_params, total_param_size, avg_param_size


def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, profile_name=""):
"""Inner loop for benchmarking prefill step."""
max_utils.activate_profiler(config, profile_name)
Expand All @@ -49,24 +36,29 @@ def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_le
slot = int(i % (jax.device_count() * config.per_device_batch_size))
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=slot)
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)
end = datetime.datetime.now()
max_utils.deactivate_profiler(config)
return (end - start).total_seconds(), decode_state


def prefill_benchmark(
config, engine, params, decode_state, tokens, true_length, iters=100, profile_name="", num_model_params=None
):
"""Handles init, warmup, running prefill benchmark, and printing results."""
if num_model_params is None:
num_model_params, _, _ = summarize_pytree_data(params, name="Params")
num_model_params, _, _ = max_utils.summarize_pytree_data(params, name="Params")

prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=0)
jax.block_until_ready(decode_state)
max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True)
max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache")
max_utils.summarize_pytree_data(prefill_result["next_pos"], name="Prefill Next pos", raw=True)
max_utils.summarize_pytree_data(prefill_result["generated_tokens"], name="Prefill Generated Tokens", raw=True)
max_utils.delete_pytree(prefill_result)
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length)
decode_state = engine.insert(prefill_result, decode_state, slot=0)
max_utils.delete_pytree(prefill_result)
jax.block_until_ready(decode_state)

print(f"Prefill results for length {tokens.size}:\n")
Expand Down Expand Up @@ -106,9 +98,9 @@ def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name=
def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_size=None, profile_name="", iters=100):
"""Handles init, warmup, running ar benchmark, and printing results."""
if cache_size is None:
_, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache")
_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
if model_size is None:
_, model_size, _ = summarize_pytree_data(params, name="Params")
_, model_size, _ = max_utils.summarize_pytree_data(params, name="Params")
global_batch_size = jax.device_count() * config.per_device_batch_size

# Warmup
Expand Down Expand Up @@ -165,43 +157,52 @@ def write_results(results, filename=""):


def print_results_for_analyze(results):
prefill_bucket_size_to_ms = {}
for k, v in results["Prefill"].items():
prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3)
print("\nFor usage in analyze_sharegpt.py :")
print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}")
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}")
if "Prefill" in results:
prefill_bucket_size_to_ms = {}
for k, v in results["Prefill"].items():
prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3)
print("\nFor usage in analyze_sharegpt.py :")
print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}")

if "AutoRegressive" in results:
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}")


def main(config):
engine = maxengine.MaxEngine(config)
params = engine.load_params()
prefill_lengths = [64, 128, 256, 512, 1024]
benchmark_loop_iters = 10
prefill_lengths = [int(l) for l in config.inference_microbenchmark_prefill_lengths.split(",")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work if you pass in a command line param like --inference-microbenchmark-prefill-lengths="512,1024" or something similar?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, example commands:
to run a single prefill length:

  inference_microbenchmark_prefill_lengths=1024

to run a single stage:

  inference_microbenchmark_stages=generate

stages_to_benchmark = config.inference_microbenchmark_stages.split(",")
benchmark_loop_iters = config.inference_microbenchmark_loop_iters

text = config.prompt
metadata = engine.get_tokenizer()
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids)

decode_state = engine.init_decode_state()
_, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache")
num_model_params, model_size, _ = summarize_pytree_data(params, name="Model")

benchmark_results = {"Prefill": {}}
benchmark_results["AutoRegressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size
)
for prefill_length in prefill_lengths:
tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True, prefill_lengths=[prefill_length])
benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark(
config,
engine,
params,
decode_state,
tokens,
true_length,
iters=benchmark_loop_iters,
num_model_params=num_model_params,
)
_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache")
num_model_params, model_size, _ = max_utils.summarize_pytree_data(params, name="Model")

benchmark_results = {}
if "prefill" in stages_to_benchmark:
benchmark_results["Prefill"] = {}
for prefill_length in prefill_lengths:
tokens, true_length = token_utils.tokenize_and_pad(
text, vocab, is_bos=True, prefill_lengths=[prefill_length])
benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark(
config,
engine,
params,
decode_state,
tokens,
true_length,
iters=benchmark_loop_iters,
num_model_params=num_model_params
)

if "generate" in stages_to_benchmark:
benchmark_results["AutoRegressive"], decode_state = ar_benchmark(
config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size)
Comment on lines +204 to +205
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For running just generate benchmark, you still need to populate kv cache to produce proper perf numbers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will still need to initialize a decode_state for generate step calculation


results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params)
write_results(results, filename="")
Expand Down
27 changes: 27 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,3 +652,30 @@ def get_project():
max_logging.log("You must specify config.vertex_tensorboard_project or set 'gcloud config set project <project>'")
return None
return project_outputs[-1]


def delete_pytree(p):
def delete_leaf(leaf):
if isinstance(leaf, jax.Array):
leaf.delete()
del leaf

jax.tree_map(delete_leaf, p)


def summarize_pytree_data(params, name="Params", raw=False):
"""Generate basic metrics of a given Pytree."""
num_params, total_param_size, avg_param_size = summarize_size_from_pytree(params)
if not raw:
num_params_in_billions = num_params / 1e9
total_param_size_in_gb = total_param_size / 1e9
print(f"{name} stats: \n"
f"\tTotal number of params: {num_params_in_billions:.3f} billion \n"
f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n"
f"\tAvg size: {avg_param_size:.3f} bytes\n")
else:
print(f"{name} stats: \n"
f"\tTotal number of params: {num_params:.3f} \n"
f"\tTotal memory usage: {total_param_size:.3f} bytes \n"
f"\tAvg size: {avg_param_size:.3f} bytes\n")
return num_params, total_param_size, avg_param_size
Loading