Skip to content

Commit

Permalink
Merge pull request #897 from AI-Hypercomputer:offline_inf
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678440353
  • Loading branch information
maxtext authors committed Sep 24, 2024
2 parents f5f0e29 + c4a7d46 commit 5ae1b71
Show file tree
Hide file tree
Showing 9 changed files with 1,149 additions and 2 deletions.
155 changes: 155 additions & 0 deletions MaxText/inference_mlperf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@

## Create TPU VM.
Follow these [instructions](https://cloud.google.com/tpu/docs/v5e-inference#tpu-vm) to create TPU v5e-8 VM and ssh into the VM


## Clone repo
```
git clone https://github.com/mlcommons/inference.git
```

## Install loadgen
```
apt-get install python3-dev
apt-get install build-essential -y
cd loadgen/ && pip install .
```

## Install eval dependencies
```
pip install \
transformers==4.31.0 \
nltk==3.8.1 \
evaluate==0.4.0 \
absl-py==1.4.0 \
rouge-score==0.1.2 \
sentencepiece==0.1.99 \
accelerate==0.21.0
```

## Download data file
```
cd /
export DATA_DISK_DIR=/loadgen_run_data
mkdir -p ${DATA_DISK_DIR}
cd ${DATA_DISK_DIR}
gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl .
mv open_orca_gpt4_tokenized_llama.calibration_1000.pkl processed-calibration-data.pkl
gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl .
mv open_orca_gpt4_tokenized_llama.sampled_24576.pkl processed-data.pkl
cd /inference_mlperf4.1
```

## Install Maxtext
```
cd /
git clone [email protected]:google/maxtext.git
cd maxtext
git checkout offline_inf
cd maxtext/MaxText
```

## Checkpoint generation

Steps to get a quantized llama2-70B checkpoint for v5e-8

Note llama2-70B model takes about 140G of memory and will not fit into a v5e-8. It must be downloaded onto a large machine (such as v5p-8) and quantized to a smaller quantized checkpoint to be loaded onto a v5e-8 machine.

* Obtain a llama2-70b checkpoint and convert it to a maxtext inference checkpoint. Please follow maxtext instructions specified here: https://github.com/google/maxtext/blob/main/getting_started/Run_Llama2.md

* Convert the checkpoint into a quantized checkpoint

To create an int8 DRQ checkpoint run the following step:

1. Define paths to load maxtext checkpoint from and save quantized checkpoint to.

```
export LOAD_PARAMS_PATH=gs://${USER}-bkt/llama2-70b-chat/param-only-decode-ckpt-maxtext/checkpoints/0/items
export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-70b-chat
```

2. Run the following maxtext script to generate and save an int8 quantized checkpoint

```
# Set appropriate tokenizer path. For example, LLama2 models tokenizer.llama2. You can find
# other tokenizers under maxtext/assets/ directory.
export TOKENIZER_PATH=maxtext/assets/tokenizer.llama2
cd maxtext && \
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH}
```

Your checkpoint is generated at `$SAVE_QUANT_PARAMS_PATH`. This is used to set `load_parameters_path` param below in `MAXENGINE_ARGS` env variable.

## HF login
```
huggingface-cli login
```

## Loadgen settings
```
cd Google/code/llama2-70b/tpu_v5e_8_jetstream_maxtext/scripts/
export API_URL=0.0.0.0:9000
export DATA_DISK_DIR=/loadgen_run_data
export DATASET_TYPE=full # for calibration run, DATASET_TYPE=calibration
export MODEL_NAME=llama70b
export TOTAL_SAMPLE_COUNT=24576 # for calibration run, TOTAL_SAMPLE_COUNT=1000
export LOG_INTERVAL=1000
export BATCH_SIZE_EXP=8
export USER_CONFIG=user.conf
```

## Offline Setup
```
cd /
git clone [email protected]:google/maxtext.git
cd maxtext
git checkout offline_inf
cd maxtext/MaxText
# For v5e use
export BATCH_AND_PREFILL_LEN=“256,80|512,40|1024,20”
# For v6 use
export BATCH_AND_PREFILL_LEN=“256,216|512,108|1024,54”
# Set appropriate tokenizer path. For example, LLama2 models tokenizer.llama2. You can find
# other tokenizers under maxtext/assets/ directory.
export TOKENIZER_PATH=maxtext/assets/tokenizer.llama2
export MAXENGINE_ARGS="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True compute_axis_order=0,1,2,3 ar_cache_axis_order=0,1,2,3"
```

## Run offline performance

```
bash ./llama_offline_performance_run.sh
```

## Run offline accuracy
```
bash ./llama_offline_accuracy_run.sh
```

## Run offline audit
```
bash ./llama_offline_audit_run.sh
```

## Run server performance
```
bash ./generate_server_performance_run.sh
```

## Run server accuracy
```
bash ./generate_server_accuracy_run.sh
```

## Run server audit
```
bash ./generate_server_audit_run.sh
```

121 changes: 121 additions & 0 deletions MaxText/inference_mlperf/evaluate-accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Evaluation script based on MLPerf requirements"""

import argparse
from transformers import AutoTokenizer
import nltk
import evaluate
import numpy as np
import json


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint-path", required=True, help="Path to Llama2-70b-hf-chat checkpoint")
parser.add_argument("--mlperf-accuracy-file", required=True, help="path to mlperf_log_accuracy.json")
parser.add_argument("--dataset-file", required=True, help="path to processed openorca validation set")
parser.add_argument("--verbose", action="store_true", help="verbose messages")
parser.add_argument("--dtype", default="int64", help="dtype of the accuracy log", choices=["int32", "int64", "float"])
args = parser.parse_args()
return args


def get_groundtruth(processed_dataset_file):
import pandas as pd

data = pd.read_pickle(processed_dataset_file)
ground_truths = data["output"]
return ground_truths


def postprocess_text(preds, targets):
preds = [pred.strip() for pred in preds]
targets = [target.strip() for target in targets]

# rougeLSum expects newline after each sentence
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]

return preds, targets


def main():

args = get_args()
dataset_path = args.dataset_file
checkpoint_path = args.checkpoint_path
metric = evaluate.load("rouge")
nltk.download("punkt")

tokenizer = AutoTokenizer.from_pretrained(
checkpoint_path,
model_max_length=2048,
padding_side="left",
use_fast=False,
)

targets = get_groundtruth(args.dataset_file)

target_required = []
preds_token_ids = []

eval_dtype = np.int64
if args.dtype == "int32":
eval_dtype = np.int32
elif args.dtype == "float":
eval_dtype = np.float32

with open(args.mlperf_accuracy_file, "r") as f:
results = json.load(f)

seen = set()
gen_tok_len = 0
for pred in results:
qsl_idx = pred["qsl_idx"]
if qsl_idx in seen:
continue

seen.add(qsl_idx)
target = targets[qsl_idx]
target_required.append(target)
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype)
if pred[0] > 32000 or pred[0] < 0:
pred = [1, *pred[1:]]
gen_tok_len += len(pred)
preds_token_ids.append(pred)

preds_decoded_text = tokenizer.batch_decode(preds_token_ids, skip_special_tokens=True)

preds, targets = postprocess_text(preds_decoded_text, target_required)

result = metric.compute(predictions=preds, references=targets, use_stemmer=True, use_aggregator=False)
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()}
prediction_lens = [len(pred) for pred in preds]
gen_num = len(preds)

result = {
**result,
"gen_len": np.sum(prediction_lens),
"gen_num": gen_num,
"gen_tok_len": gen_tok_len,
"tokens_per_sample": round(gen_tok_len / gen_num, 1),
}

print("\nResults\n")
print(result)


if __name__ == "__main__":
main()
70 changes: 70 additions & 0 deletions MaxText/inference_mlperf/llama_offline_accuracy_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env bash
me=$(basename "$0")

if [ -z "$BASEDIR"];
then
BASEDIR=/home/$USER/inference_mlperf4.1
fi

USER_CONFIG=$BASEDIR/language/llama2-70b/tpu/user.conf

if [ -z "$DATA_DISK_DIR"];
then
DATA_DISK_DIR=/home/$USER/loadgen_run_data
fi

DATASET_PATH=${DATA_DISK_DIR}/processed-data.pkl
TOTAL_SAMPLE_COUNT=24576
LOG_INTERVAL=900

if [-z "$BATCH_AND_PREFILL_LEN"];
then
BATCH_AND_PREFILL_LEN="256,80|512,40|1024,20"
fi

LOADGEN_RUN_TYPE=offline-accuracy
MODEL_NAME=llama70b
DATASET_TYPE=full

LOADGEN_RUN_TIMESTAMP=$(TZ=America/Los_Angeles date +%Y%m%d%H%M%S%Z)
OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP}
OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID}

mkdir -p ${OUTPUT_LOG_DIR} && cp ${USER_CONFIG} ${OUTPUT_LOG_DIR}

OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json

# LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
# makes subsequent runs faster
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2"
export LIBTPU_INIT_ARGS

echo "LOADGEN_RUN_TYPE: ${LOADGEN_RUN_TYPE}"
echo "LOADGEN_RUN_TIMESTAMP: ${LOADGEN_RUN_TIMESTAMP}"
echo "DATASET_PATH: ${DATASET_PATH}"
echo "TOTAL_SAMPLE_COUNT: ${TOTAL_SAMPLE_COUNT}"
echo "BATCH_SIZE_EXP: ${BATCH_SIZE_EXP}"
echo "OUTPUT_LOG_DIR: ${OUTPUT_LOG_DIR}"
echo "USER_CONFIG: ${USER_CONFIG}"

python -m offline_mode \
--mlperf_test_mode=accuracy \
--input_mode tokenized \
--output_mode tokenized \
--mlperf_conf $BASEDIR/mlperf.conf \
--user_conf ${USER_CONFIG} \
--audit_conf no_audit \
--total_sample_count ${TOTAL_SAMPLE_COUNT} \
--dataset_path ${DATASET_PATH} \
--prefill_lengths_and_batch_sizes ${BATCH_AND_PREFILL_LEN} \
--maxengine_args "${MAXENGINE_ARGS}" \
--output_log_dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/offline_accuracy_log.log

# Eval Run
if [ -e ${OUTPUT_ACCURACY_JSON_PATH} ]; then
python3 evaluate-accuracy.py \
--checkpoint-path meta-llama/Llama-2-70b-chat-hf \
--mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \
--dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log
fi

Loading

0 comments on commit 5ae1b71

Please sign in to comment.