Skip to content

Commit

Permalink
Merge pull request #953 from AI-Hypercomputer:msingh-cleanup
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684483781
  • Loading branch information
maxtext authors committed Oct 10, 2024
2 parents a4c0001 + d236f10 commit e64f681
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 111 deletions.
124 changes: 50 additions & 74 deletions MaxText/inference_mlperf/README.md
Original file line number Diff line number Diff line change
@@ -1,56 +1,46 @@
## Run offline performance benchmarks.

## Create TPU VM.

### Create TPU VM.
Follow these [instructions](https://cloud.google.com/tpu/docs/v5e-inference#tpu-vm) to create TPU v5e-8 VM and ssh into the VM


## Clone repo
```
git clone https://github.com/mlcommons/inference.git
```
### Setup a virtual env
sudo apt install python3.10-venv
python -m venv .env
source .env/bin/activate

## Install loadgen
### Install loadgen
```
apt-get install python3-dev
apt-get install build-essential -y
sudo apt-get install python3-dev
sudo apt-get install build-essential -y
git clone https://github.com/mlcommons/inference.git
cd inference/
cd loadgen/ && pip install .
```

## Install eval dependencies
```
pip install \
transformers==4.31.0 \
nltk==3.8.1 \
evaluate==0.4.0 \
absl-py==1.4.0 \
rouge-score==0.1.2 \
sentencepiece==0.1.99 \
accelerate==0.21.0
```

## Download data file
### Download datasets
```
cd /
export DATA_DISK_DIR=/loadgen_run_data
export DATA_DISK_DIR=~/loadgen_run_data
mkdir -p ${DATA_DISK_DIR}
cd ${DATA_DISK_DIR}
gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl .
mv open_orca_gpt4_tokenized_llama.calibration_1000.pkl processed-calibration-data.pkl
gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl .
mv open_orca_gpt4_tokenized_llama.sampled_24576.pkl processed-data.pkl
cd /inference_mlperf4.1
```

## Install Maxtext
### Install Maxtext
```
cd /
cd ~
git clone [email protected]:google/maxtext.git
cd maxtext
git checkout offline_inf
cd maxtext/MaxText
bash setup.sh
pip install -r MaxText/inference_mlperf/requirements.txt
```

## Checkpoint generation
### Generate quantized checkpoint

Steps to get a quantized llama2-70B checkpoint for v5e-8

Expand Down Expand Up @@ -82,74 +72,60 @@ python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=${TOKENIZER_PAT

Your checkpoint is generated at `$SAVE_QUANT_PARAMS_PATH`. This is used to set `load_parameters_path` param below in `MAXENGINE_ARGS` env variable.

## HF login
### HuggingFace login
```
huggingface-cli login
export HUGGING_FACE_TOKEN=<your_hugging_face_token>
huggingface-cli login --token $HUGGING_FACE_TOKEN
```

## Loadgen settings
### Offline Server - Test Run
```
cd Google/code/llama2-70b/tpu_v5e_8_jetstream_maxtext/scripts/
export API_URL=0.0.0.0:9000
export DATA_DISK_DIR=/loadgen_run_data
export DATASET_TYPE=full # for calibration run, DATASET_TYPE=calibration
cd ~/maxtext/MaxText/inference_mlperf
export TOKENIZER_PATH="/home/${USER}/maxtext/assets/tokenizer.llama2
export BATCH_AND_PREFILL_LEN="1024,20"
export MAXENGINE_ARGS="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True"
export MODEL_NAME=llama70b
export TOTAL_SAMPLE_COUNT=24576 # for calibration run, TOTAL_SAMPLE_COUNT=1000
export LOG_INTERVAL=1000
export BATCH_SIZE_EXP=8
export USER_CONFIG=user.conf
bash ./llama_offline_run.sh -p -t
```

## Offline Setup
```
cd /
git clone [email protected]:google/maxtext.git
cd maxtext
git checkout offline_inf
cd maxtext/MaxText
# For v5e use
export BATCH_AND_PREFILL_LEN=“256,80|512,40|1024,20”
# For v6 use
export BATCH_AND_PREFILL_LEN=“256,216|512,108|1024,54”
# Set appropriate tokenizer path. For example, LLama2 models tokenizer.llama2. You can find
# other tokenizers under maxtext/assets/ directory.
export TOKENIZER_PATH=maxtext/assets/tokenizer.llama2
### Offline Benchmarks

#### For v5e
```
export BATCH_AND_PREFILL_LEN="256,80|512,40|1024,20"
export MAXENGINE_ARGS="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True compute_axis_order=0,1,2,3 ar_cache_axis_order=0,1,2,3"
```

## Run offline performance

#### For v6
```
bash ./llama_offline_performance_run.sh
export BATCH_AND_PREFILL_LEN=“256,216|512,108|1024,54”
export MAXENGINE_ARGS="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3"
```

## Run offline accuracy
```
bash ./llama_offline_accuracy_run.sh
```
#### Run offline performance benchmark

## Run offline audit
```
bash ./llama_offline_audit_run.sh
bash ./llama_offline_run.sh -p
```

## Run server performance
#### Run offline accuracy benchmark
```
bash ./generate_server_performance_run.sh
bash ./llama_offline_run.sh -a
```

## Run server accuracy
```
bash ./generate_server_accuracy_run.sh
#### Run offline audit benchmark
```
bash ./llama_offline_run.sh -d
## Run server audit
```
bash ./generate_server_audit_run.sh

### Profiling

```
# Capture profile
bash ./llama_offline_run.sh -p -e
python -m jax.collect_profile 9999 2000 --log_dir /tmp/profiles --no_perfetto_link
# View profile
tensorboard --logdir /tmp/profiles
```
65 changes: 37 additions & 28 deletions MaxText/inference_mlperf/llama_offline_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,60 +6,64 @@
# enable profiling using -p option and capture using
# tensorboard --logdir /tmp/tensorboard/

dry_run=false
run_name="test_int8_kv_bs_216-108-54"
dry_run=false
skip_warmup=false
test_run=false
enable_profiler=false
performance=false
audit=false
accuracy=false


while getopts "ntspr:" opt
while getopts "ntsepdar:" opt
do
case "$opt" in
n ) dry_run=true ;;
t ) test_run=true ;;
s ) skip_warmup=true;;
p ) enable_profiler=true;;
t ) test_run=true ;;
s ) skip_warmup=true ;;
e ) enable_profiler=true ;;
p ) performance=true ;;
d ) audit=true ;;
a ) accuracy=true ;;
r ) run_name="$OPTARG" ;;
? ) helpFunction ;; # Print helpFunction in case parameter is non-existent
esac
done


if "$dry_run"; then
cmd=echo
else
cmd=''
fi

SKIP_WARMUP_OPTION=""
if "$skip_warmup"; then
SKIP_WARMUP_OPTION="--skip_warmup"
else
SKIP_WARMUP_OPTION=""
fi

PROFILER_OPTION=""
if "$enable_profiler"; then
PROFILER_OPTION="--enable_profile"
else
PROFILER_OPTION=""
fi

if [ -z "$TOKENIZER_PATH"];
then
if [ -z "$TOKENIZER_PATH" ]; then
TOKENIZER_PATH=/home/${USER}/maxtext/assets/tokenizer.llama2
fi

BATCH_STR=""
if [ -z "$BATCH_AND_PREFILL_LEN"];
if [ -z "$BATCH_AND_PREFILL_LEN" ];
then
BATCH_AND_PREFILL_LEN="256,216|512,108|1024,54"
fi

if [ -z "$TOK_OUTLEN_MULTIPLIER"];
if [ -z "$TOK_OUTLEN_MULTIPLIER" ];
then
TOK_OUTLEN_MULTIPLIER="3.0"
TOK_OUTLEN_MULTIPLIER="2.5"
fi

if [ -z "$MAXENGINE_ARGS"];
if [ -z "$MAXENGINE_ARGS" ];
then
CHECKPOINT="gs://msingh-bkt/checkpoints/quant_llama2-70b-chat/mlperf_070924/int8_"
BASE_CFG="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${CHECKPOINT}"
Expand All @@ -68,14 +72,14 @@ then
MAXENGINE_ARGS="${BASE_CFG} ${QUANT_CFG} ${LAYOUT_CFG}"
fi
export LOADGEN_RUN_TIMESTAMP=$(TZ=America/Los_Angeles date +%Y%m%d%H%M%S%Z)
export BASEDIR=/home/msingh/inference_mlperf4.1
export DATA_DISK_DIR=/home/msingh/loadgen_run_data
export BASEDIR=/home/${USER}/inference
export DATA_DISK_DIR=/home/${USER}/loadgen_run_data
export API_URL=0.0.0.0:9000
if "$test_run"; then
export DATASET_TYPE=test
export DATASET_PATH=${DATA_DISK_DIR}/processed-data.pkl
export TOTAL_SAMPLE_COUNT=100
export USER_CONFIG=user100.conf
export USER_CONFIG=user${TOTAL_SAMPLE_COUNT}.conf
else
export DATASET_TYPE=full
export DATASET_PATH=${DATA_DISK_DIR}/processed-data.pkl
Expand Down Expand Up @@ -150,15 +154,20 @@ run_loadgen_accuracy () {
fi
}

if "$performance"; then
echo
echo "Starting loadgen performance run"
run_loadgen_performance
fi

echo
echo "Starting loadgen performance run"
run_loadgen_performance

echo
echo "Starting loadgen audit"
run_loadgen_audit
if "$audit"; then
echo
echo "Starting loadgen audit"
run_loadgen_audit
fi

echo
echo "Starting loadgen accuracy"
run_loadgen_accuracy
if "$accuracy"; then
echo
echo "Starting loadgen accuracy"
run_loadgen_accuracy
fi
27 changes: 18 additions & 9 deletions MaxText/inference_mlperf/offline_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,22 @@ def make_response(id_, response_token_ids):

def _estimated_counts_by_bucket(dataset):
total_len = dataset.tok_input_length + dataset.tok_output_length
group1 = (total_len <= 512) & (dataset.tok_input_length <= 256)
group2 = (total_len <= 1024) & (dataset.tok_input_length <= 512)
query_batches = _init_query_batches()
prefix_lens = [l for l, b in list(query_batches.keys())]
prefix_lens.sort()

# with 5 percent extra
mult = FLAGS.total_sample_count / len(dataset) * 1.05
prev_len = 0
total_count = 0
estimates = {}
estimates["<256"] = math.ceil(len(dataset[group1]) * mult)
estimates["256-512"] = math.ceil(len(dataset[~group1 & group2]) * mult)
estimates[">512"] = math.ceil(len(dataset[~group1 & ~group2]) * mult)
for prefix_len in prefix_lens[:-1]:
target_len = 2 * prefix_len
condition = (total_len <= target_len) & (dataset.tok_input_length <= prefix_len)
count = len(dataset[condition])
estimates[f"{prev_len}-{prefix_len}"] = math.ceil((count - total_count) * mult)
total_count = count
estimates[f">{prefix_lens[-1]}"] = math.ceil((len(dataset) - total_count) * mult)
return estimates


Expand All @@ -412,13 +419,15 @@ def main(argv):

log.info("dataset path: %s", FLAGS.dataset_path)
dataset = pd.read_pickle(FLAGS.dataset_path)
rows = list(dataset.iterrows())
if FLAGS.total_sample_count < len(dataset):
dataset = dataset.sample(n=FLAGS.total_sample_count)
estimated_counts_by_bucket = _estimated_counts_by_bucket(dataset)
log.info(f"Estimated counts by bucket {estimated_counts_by_bucket}")
log.info(f"Dataset len {len(dataset)}, estimated counts by bucket {estimated_counts_by_bucket}")

rows = list(dataset.iterrows())
len_batch_str = FLAGS.prefill_lengths_and_batch_sizes
log.info(f"Prefill lengths and Batch sizes: {len_batch_str}")
log.info(f"Maxengine args: {FLAGS.maxengine_args}")
length_and_batch = [tuple(map(int, lb.split(","))) for lb in len_batch_str.split("|")]

log.info("Get warmup samples")
warmup_samples = get_warmup_samples(dataset)
Expand Down Expand Up @@ -484,7 +493,7 @@ def main(argv):
)
log.info("Starting Benchmark run")
lg.StartTestWithLogSettings(lgSUT, qsl, settings, log_settings, FLAGS.audit_conf)
log.info(f"query counts {[len(q) for q in sut._query_batches]}")
log.info(f"query counts {[len(sut._query_batches[q]) for q in sut._query_batches]}")
log.info("Run Completed!")
log.info("Destroying SUT...")
lg.DestroySUT(lgSUT)
Expand Down
9 changes: 9 additions & 0 deletions MaxText/inference_mlperf/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
transformers==4.31.0
nltk==3.8.1
evaluate==0.4.0
absl-py==1.4.0
rouge-score==0.1.2
sentencepiece==0.1.99
accelerate==0.21.0
orbax-checkpoint==0.5.20
aqtp==0.7.5

0 comments on commit e64f681

Please sign in to comment.