Skip to content

Commit

Permalink
Merge pull request #1056 from AI-Hypercomputer:patemotter-offline-ben…
Browse files Browse the repository at this point in the history
…chmark

PiperOrigin-RevId: 699195652
  • Loading branch information
maxtext authors committed Nov 22, 2024
2 parents 261a8be + 8308e0f commit 1411510
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 19 deletions.
13 changes: 11 additions & 2 deletions MaxText/inference_mlperf/llama_offline_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,18 @@ then
LAYOUT_CFG="compute_axis_order=0,1,2,3 ar_cache_axis_order=0,1,2,3"
MAXENGINE_ARGS="${BASE_CFG} ${QUANT_CFG} ${LAYOUT_CFG}"
fi

if [ -z "$BASEDIR" ];
then
BASEDIR=/home/${USER}/inference
fi

if [ -z "$DATA_DISK_DIR" ];
then
DATA_DISK_DIR=/home/${USER}/loadgen_run_data
fi

export LOADGEN_RUN_TIMESTAMP=$(TZ=America/Los_Angeles date +%Y%m%d%H%M%S%Z)
export BASEDIR=/home/${USER}/inference
export DATA_DISK_DIR=/home/${USER}/loadgen_run_data
export API_URL=0.0.0.0:9000
if "$test_run"; then
export DATASET_TYPE=test
Expand Down
1 change: 1 addition & 0 deletions MaxText/inference_mlperf/offline_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def main(argv):
settings.use_token_latencies = True

os.makedirs(FLAGS.output_log_dir, exist_ok=True)
log.info(f"Logging to {FLAGS.output_log_dir}")
log_output_settings = lg.LogOutputSettings()
log_output_settings.outdir = FLAGS.output_log_dir
log_output_settings.copy_summary_to_stdout = True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/usr/bin/env bash

# Run command:
# bash benchmarks_llama2-70b-trillium_2x4.sh
# bash benchmarks_llama2-70b-trillium_2x4.sh [-b benchmark_type]
# benchmark_type can be: performance, audit, accuracy, or all (default)

run_name="trillium_llama2-70b"
dry_run=false
Expand All @@ -10,21 +11,43 @@ enable_xla_flags=false
single_bucket=false
token_multiplier=3.0
test_mode=false

while getopts "nptsxr:m:" opt
benchmark_type="all"

helpFunction()
{
echo ""
echo "Usage: $0 [-n] [-p] [-t] [-s] [-x] [-r run_name] [-m token_multiplier] [-b benchmark_type]"
echo -e "\t-n Dry run mode"
echo -e "\t-p Enable profiler"
echo -e "\t-t Test mode"
echo -e "\t-s Single bucket mode"
echo -e "\t-x Enable XLA flags"
echo -e "\t-r Specify run name"
echo -e "\t-m Specify token multiplier"
echo -e "\t-b Specify benchmark type (performance|audit|accuracy|all)"
exit 1
}

while getopts "nptsxr:m:b:" opt
do
case "$opt" in
n ) dry_run=true ;;
p ) enable_profiler=true ;;
t ) test_mode=true;;
t ) test_mode=true ;;
s ) single_bucket=true ;;
x ) enable_xla_flags=true ;;
r ) run_name="$OPTARG" ;;
m ) token_multiplier="$OPTARG" ;;
b ) benchmark_type="$OPTARG" ;;
? ) helpFunction ;; # Print helpFunction in case parameter is non-existent
esac
done

# Validate benchmark type
case "$benchmark_type" in
performance|audit|accuracy|all) ;;
*) echo "Invalid benchmark type. Must be: performance, audit, accuracy, or all"; exit 1 ;;
esac

if "$dry_run"; then
cmd=echo
Expand All @@ -41,8 +64,6 @@ if "$test_mode"; then
RUN_OPTIONS="${RUN_OPTIONS} -t "
fi



if "$single_bucket"; then
export BATCH_AND_PREFILL_LEN="1024,54"
else
Expand All @@ -58,21 +79,42 @@ fi

export TOK_OUTLEN_MULTIPLIER=${token_multiplier}

CHECKPOINT="gs://${USER}-bkt/checkpoints/quant_llama2-70b-chat/prod/int8_"
TOKENIZER_PATH="/home/${USER}/maxtext/assets/tokenizer.llama2"
if [[ -z ${CHECKPOINT} ]] ; then
export CHECKPOINT="gs://inference-benchmarks/models/llama2-70b-chat/quant/int8_"
fi

if [[ -z ${TOKENIZER_PATH} ]] ; then
export TOKENIZER_PATH="/home/${USER}/maxtext/assets/tokenizer.llama2"
fi

BASE_CFG="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${CHECKPOINT}"
QUANT_CFG="quantization=int8 quantize_kvcache=True checkpoint_is_quantized=True"
LAYOUT_CFG="compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3"
export MAXENGINE_ARGS="${BASE_CFG} ${QUANT_CFG} ${LAYOUT_CFG}"

RUN_DESC=int8_kv_${batch_and_prefill_str}_${token_multiplier}_flags_${enable_xla_flags}

$cmd cd ..
# Run mlperf perfromance benchmarks
$cmd bash llama_offline_run.sh -r benchmarks_performance_${RUN_DESC} ${RUN_OPTIONS}

# Run mlperf audit
# bash llama_offline_run.sh -r benchmarks_audit_${RUN_DESC} -d

# Run mlperf accuracy run
# bash llama_offline_run.sh -r benchmarks_accuracy_${RUN_DESC} -a
$cmd cd ..

run_benchmark() {
local type=$1
case "$type" in
"performance")
$cmd bash llama_offline_run.sh -r benchmarks_performance_${RUN_DESC} ${RUN_OPTIONS}
;;
"audit")
$cmd bash llama_offline_run.sh -r benchmarks_audit_${RUN_DESC} -d
;;
"accuracy")
$cmd bash llama_offline_run.sh -r benchmarks_accuracy_${RUN_DESC} -a
;;
esac
}

if [ "$benchmark_type" = "all" ]; then
run_benchmark "performance"
run_benchmark "audit"
run_benchmark "accuracy"
else
run_benchmark "$benchmark_type"
fi

0 comments on commit 1411510

Please sign in to comment.