Skip to content

Commit

Permalink
Merge pull request #600 from google:patemotter_attn_calc_fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626415915
  • Loading branch information
maxtext authors committed Apr 19, 2024
2 parents 0e1c078 + 2806017 commit 1377756
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
8 changes: 4 additions & 4 deletions MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@ def prefill_benchmark(
config, engine, decode_state, params, tokens, true_length, iters, profile_name=profile_name
)
prefill_average_ms = 1000 * time_in_s / iters
total_prefill_tflops, _, _ = maxtext_utils.calculate_tflops_prefill(num_model_params, tokens.size, config)
tflops_per_sec_per_device = total_prefill_tflops / jax.device_count() / prefill_average_ms * 1000.0
prefill_tflops_per_device, _, _ = maxtext_utils.calculate_prefill_tflops_per_device(num_model_params, tokens.size, config)
tflops_per_sec_per_device = prefill_tflops_per_device / prefill_average_ms * 1000.0
print(
f"\tPrefill step average time: {prefill_average_ms:.3f}ms\n"
f"\tPrefill total TFLOPs: {total_prefill_tflops:.3f}\n"
f"\tPrefill total TFLOPs/device: {prefill_tflops_per_device:.3f}\n"
f"\tPrefill TFLOPs/sec/device: {tflops_per_sec_per_device:.3f}\n\n\n\n"
)
result_dict = {
"prefill_time_in_ms": prefill_average_ms,
"prefill_total_tflops": total_prefill_tflops,
"prefill_total_tflops_per_device": prefill_tflops_per_device,
"prefill_tflops_per_sec_per_device": tflops_per_sec_per_device,
}
return result_dict, decode_state
Expand Down
20 changes: 10 additions & 10 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,29 +118,29 @@ def calculate_tflops_training_per_device(num_model_parameters, config, log=True)


# https://arxiv.org/pdf/2204.02311.pdf Appendix B
def calculate_tflops_prefill(num_model_parameters, prefill_length, config, log=True):
def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, config, log=True):
"""Calculate training TFLOP"""
learnable_weight_tflops = 2 * num_model_parameters * prefill_length / 10**12
learnable_weight_tflops = 2 * num_model_parameters * prefill_length / jax.device_count() / 1e12
noncasual_attention_flops = (
4
* config.num_query_heads
* config.num_decoder_layers
* config.head_dim
* prefill_length**2
* config.per_device_batch_size
/ 10**12
/ jax.device_count()
/ 1e12
)
causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention
total_tflops = learnable_weight_tflops + causal_attention_tflops

if log:
print(
"Per prefill step: \n",
f"\tTotal TFLOPs: {total_tflops:.2f} \n",
f"\t\tLearnable weight TFLOPs: {learnable_weight_tflops} ",
f"({100 * learnable_weight_tflops/total_tflops:.2f})% of Total\n",
f"\t\tCausal attention TFLOPs: {causal_attention_tflops} ",
f"({100 * causal_attention_tflops/total_tflops:.2f})% of Total",
"Per prefill step per device: \n",
f"\tTotal TFLOPs: {total_tflops:.2f} \n",
f"\t\tLearnable weight TFLOPs: {learnable_weight_tflops:.2f} ",
f"({100 * learnable_weight_tflops/total_tflops:.2f})% of Total\n",
f"\t\tCausal attention TFLOPs: {causal_attention_tflops:.2f} ",
f"({100 * causal_attention_tflops/total_tflops:.2f})% of Total",
)
return total_tflops, learnable_weight_tflops, causal_attention_tflops

Expand Down

0 comments on commit 1377756

Please sign in to comment.