-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Microbenchmark Profiling Memory Issues #597
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,26 +21,13 @@ | |
import sys | ||
|
||
from jetstream.engine import token_utils | ||
|
||
import max_utils | ||
import maxengine | ||
import maxtext_utils | ||
import pyconfig | ||
|
||
|
||
def summarize_pytree_data(params, name="Params"): | ||
"""Generate basic metrics of a given Pytree.""" | ||
num_params, total_param_size, avg_param_size = max_utils.summarize_size_from_pytree(params) | ||
num_params_in_billions = num_params / 1e9 | ||
total_param_size_in_gb = total_param_size / 1e9 | ||
print( | ||
f"{name} stats: \n" | ||
f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" | ||
f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" | ||
f"\tAvg size: {avg_param_size:.3f} bytes\n" | ||
) | ||
return num_params, total_param_size, avg_param_size | ||
|
||
|
||
def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_length, iters, profile_name=""): | ||
"""Inner loop for benchmarking prefill step.""" | ||
max_utils.activate_profiler(config, profile_name) | ||
|
@@ -49,24 +36,29 @@ def prefill_benchmark_loop(config, engine, decode_state, params, tokens, true_le | |
slot = int(i % (jax.device_count() * config.per_device_batch_size)) | ||
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) | ||
decode_state = engine.insert(prefill_result, decode_state, slot=slot) | ||
max_utils.delete_pytree(prefill_result) | ||
jax.block_until_ready(decode_state) | ||
end = datetime.datetime.now() | ||
max_utils.deactivate_profiler(config) | ||
return (end - start).total_seconds(), decode_state | ||
|
||
|
||
def prefill_benchmark( | ||
config, engine, params, decode_state, tokens, true_length, iters=100, profile_name="", num_model_params=None | ||
): | ||
"""Handles init, warmup, running prefill benchmark, and printing results.""" | ||
if num_model_params is None: | ||
num_model_params, _, _ = summarize_pytree_data(params, name="Params") | ||
num_model_params, _, _ = max_utils.summarize_pytree_data(params, name="Params") | ||
|
||
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) | ||
decode_state = engine.insert(prefill_result, decode_state, slot=0) | ||
jax.block_until_ready(decode_state) | ||
max_utils.summarize_pytree_data(prefill_result["logits"], name="Prefill Logits", raw=True) | ||
max_utils.summarize_pytree_data(prefill_result["cache"], name="Prefill Cache") | ||
max_utils.summarize_pytree_data(prefill_result["next_pos"], name="Prefill Next pos", raw=True) | ||
max_utils.summarize_pytree_data(prefill_result["generated_tokens"], name="Prefill Generated Tokens", raw=True) | ||
max_utils.delete_pytree(prefill_result) | ||
prefill_result = engine.prefill(params=params, padded_tokens=tokens, true_length=true_length) | ||
decode_state = engine.insert(prefill_result, decode_state, slot=0) | ||
max_utils.delete_pytree(prefill_result) | ||
jax.block_until_ready(decode_state) | ||
|
||
print(f"Prefill results for length {tokens.size}:\n") | ||
|
@@ -106,9 +98,9 @@ def ar_benchmark_loop(config, engine, decode_state, params, iters, profile_name= | |
def ar_benchmark(config, engine, params, decode_state, cache_size=None, model_size=None, profile_name="", iters=100): | ||
"""Handles init, warmup, running ar benchmark, and printing results.""" | ||
if cache_size is None: | ||
_, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache") | ||
_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache") | ||
if model_size is None: | ||
_, model_size, _ = summarize_pytree_data(params, name="Params") | ||
_, model_size, _ = max_utils.summarize_pytree_data(params, name="Params") | ||
global_batch_size = jax.device_count() * config.per_device_batch_size | ||
|
||
# Warmup | ||
|
@@ -165,43 +157,52 @@ def write_results(results, filename=""): | |
|
||
|
||
def print_results_for_analyze(results): | ||
prefill_bucket_size_to_ms = {} | ||
for k, v in results["Prefill"].items(): | ||
prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3) | ||
print("\nFor usage in analyze_sharegpt.py :") | ||
print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}") | ||
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}") | ||
if "Prefill" in results: | ||
prefill_bucket_size_to_ms = {} | ||
for k, v in results["Prefill"].items(): | ||
prefill_bucket_size_to_ms[int(k)] = round(v["prefill_time_in_ms"], 3) | ||
print("\nFor usage in analyze_sharegpt.py :") | ||
print(f"PREFILL_BUCKET_SIZE_TO_MS = {prefill_bucket_size_to_ms}") | ||
|
||
if "AutoRegressive" in results: | ||
print(f"SYSTEM_TIME_PER_DECODE_TOKEN_MS = {results['AutoRegressive']['ar_step_in_ms_per_seq']}") | ||
|
||
|
||
def main(config): | ||
engine = maxengine.MaxEngine(config) | ||
params = engine.load_params() | ||
prefill_lengths = [64, 128, 256, 512, 1024] | ||
benchmark_loop_iters = 10 | ||
prefill_lengths = [int(l) for l in config.inference_microbenchmark_prefill_lengths.split(",")] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work if you pass in a command line param like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, example commands:
to run a single stage:
|
||
stages_to_benchmark = config.inference_microbenchmark_stages.split(",") | ||
benchmark_loop_iters = config.inference_microbenchmark_loop_iters | ||
|
||
text = config.prompt | ||
metadata = engine.get_tokenizer() | ||
vocab = token_utils.load_vocab(metadata.path, metadata.extra_ids) | ||
|
||
decode_state = engine.init_decode_state() | ||
_, cache_size, _ = summarize_pytree_data(decode_state["cache"], name="Cache") | ||
num_model_params, model_size, _ = summarize_pytree_data(params, name="Model") | ||
|
||
benchmark_results = {"Prefill": {}} | ||
benchmark_results["AutoRegressive"], decode_state = ar_benchmark( | ||
config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size | ||
) | ||
for prefill_length in prefill_lengths: | ||
tokens, true_length = token_utils.tokenize_and_pad(text, vocab, is_bos=True, prefill_lengths=[prefill_length]) | ||
benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark( | ||
config, | ||
engine, | ||
params, | ||
decode_state, | ||
tokens, | ||
true_length, | ||
iters=benchmark_loop_iters, | ||
num_model_params=num_model_params, | ||
) | ||
_, cache_size, _ = max_utils.summarize_pytree_data(decode_state["cache"], name="Cache") | ||
num_model_params, model_size, _ = max_utils.summarize_pytree_data(params, name="Model") | ||
|
||
benchmark_results = {} | ||
if "prefill" in stages_to_benchmark: | ||
benchmark_results["Prefill"] = {} | ||
for prefill_length in prefill_lengths: | ||
tokens, true_length = token_utils.tokenize_and_pad( | ||
text, vocab, is_bos=True, prefill_lengths=[prefill_length]) | ||
benchmark_results["Prefill"][prefill_length], decode_state = prefill_benchmark( | ||
config, | ||
engine, | ||
params, | ||
decode_state, | ||
tokens, | ||
true_length, | ||
iters=benchmark_loop_iters, | ||
num_model_params=num_model_params | ||
) | ||
|
||
if "generate" in stages_to_benchmark: | ||
benchmark_results["AutoRegressive"], decode_state = ar_benchmark( | ||
config, engine, params, decode_state, iters=benchmark_loop_iters, cache_size=cache_size, model_size=model_size) | ||
Comment on lines
+204
to
+205
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For running just generate benchmark, you still need to populate kv cache to produce proper perf numbers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You will still need to initialize a decode_state for generate step calculation |
||
|
||
results = collate_results(config, benchmark_results, model_size, cache_size, num_model_params) | ||
write_results(results, filename="") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in this PR, but I see a need for separate inference specific config files in future -- both base.yml and model specific config.