-
Notifications
You must be signed in to change notification settings - Fork 316
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #897 from AI-Hypercomputer:offline_inf
PiperOrigin-RevId: 678440353
- Loading branch information
Showing
9 changed files
with
1,149 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
|
||
## Create TPU VM. | ||
Follow these [instructions](https://cloud.google.com/tpu/docs/v5e-inference#tpu-vm) to create TPU v5e-8 VM and ssh into the VM | ||
|
||
|
||
## Clone repo | ||
``` | ||
git clone https://github.com/mlcommons/inference.git | ||
``` | ||
|
||
## Install loadgen | ||
``` | ||
apt-get install python3-dev | ||
apt-get install build-essential -y | ||
cd loadgen/ && pip install . | ||
``` | ||
|
||
## Install eval dependencies | ||
``` | ||
pip install \ | ||
transformers==4.31.0 \ | ||
nltk==3.8.1 \ | ||
evaluate==0.4.0 \ | ||
absl-py==1.4.0 \ | ||
rouge-score==0.1.2 \ | ||
sentencepiece==0.1.99 \ | ||
accelerate==0.21.0 | ||
``` | ||
|
||
## Download data file | ||
``` | ||
cd / | ||
export DATA_DISK_DIR=/loadgen_run_data | ||
mkdir -p ${DATA_DISK_DIR} | ||
cd ${DATA_DISK_DIR} | ||
gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.calibration_1000.pkl . | ||
mv open_orca_gpt4_tokenized_llama.calibration_1000.pkl processed-calibration-data.pkl | ||
gsutil cp gs://cloud-tpu-inference-public/mlcommons/inference/language/llama2-70b/data/processed-openorca/open_orca_gpt4_tokenized_llama.sampled_24576.pkl . | ||
mv open_orca_gpt4_tokenized_llama.sampled_24576.pkl processed-data.pkl | ||
cd /inference_mlperf4.1 | ||
``` | ||
|
||
## Install Maxtext | ||
``` | ||
cd / | ||
git clone [email protected]:google/maxtext.git | ||
cd maxtext | ||
git checkout offline_inf | ||
cd maxtext/MaxText | ||
``` | ||
|
||
## Checkpoint generation | ||
|
||
Steps to get a quantized llama2-70B checkpoint for v5e-8 | ||
|
||
Note llama2-70B model takes about 140G of memory and will not fit into a v5e-8. It must be downloaded onto a large machine (such as v5p-8) and quantized to a smaller quantized checkpoint to be loaded onto a v5e-8 machine. | ||
|
||
* Obtain a llama2-70b checkpoint and convert it to a maxtext inference checkpoint. Please follow maxtext instructions specified here: https://github.com/google/maxtext/blob/main/getting_started/Run_Llama2.md | ||
|
||
* Convert the checkpoint into a quantized checkpoint | ||
|
||
To create an int8 DRQ checkpoint run the following step: | ||
|
||
1. Define paths to load maxtext checkpoint from and save quantized checkpoint to. | ||
|
||
``` | ||
export LOAD_PARAMS_PATH=gs://${USER}-bkt/llama2-70b-chat/param-only-decode-ckpt-maxtext/checkpoints/0/items | ||
export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-70b-chat | ||
``` | ||
|
||
2. Run the following maxtext script to generate and save an int8 quantized checkpoint | ||
|
||
``` | ||
# Set appropriate tokenizer path. For example, LLama2 models tokenizer.llama2. You can find | ||
# other tokenizers under maxtext/assets/ directory. | ||
export TOKENIZER_PATH=maxtext/assets/tokenizer.llama2 | ||
cd maxtext && \ | ||
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} | ||
``` | ||
|
||
Your checkpoint is generated at `$SAVE_QUANT_PARAMS_PATH`. This is used to set `load_parameters_path` param below in `MAXENGINE_ARGS` env variable. | ||
|
||
## HF login | ||
``` | ||
huggingface-cli login | ||
``` | ||
|
||
## Loadgen settings | ||
``` | ||
cd Google/code/llama2-70b/tpu_v5e_8_jetstream_maxtext/scripts/ | ||
export API_URL=0.0.0.0:9000 | ||
export DATA_DISK_DIR=/loadgen_run_data | ||
export DATASET_TYPE=full # for calibration run, DATASET_TYPE=calibration | ||
export MODEL_NAME=llama70b | ||
export TOTAL_SAMPLE_COUNT=24576 # for calibration run, TOTAL_SAMPLE_COUNT=1000 | ||
export LOG_INTERVAL=1000 | ||
export BATCH_SIZE_EXP=8 | ||
export USER_CONFIG=user.conf | ||
``` | ||
|
||
## Offline Setup | ||
``` | ||
cd / | ||
git clone [email protected]:google/maxtext.git | ||
cd maxtext | ||
git checkout offline_inf | ||
cd maxtext/MaxText | ||
# For v5e use | ||
export BATCH_AND_PREFILL_LEN=“256,80|512,40|1024,20” | ||
# For v6 use | ||
export BATCH_AND_PREFILL_LEN=“256,216|512,108|1024,54” | ||
# Set appropriate tokenizer path. For example, LLama2 models tokenizer.llama2. You can find | ||
# other tokenizers under maxtext/assets/ directory. | ||
export TOKENIZER_PATH=maxtext/assets/tokenizer.llama2 | ||
export MAXENGINE_ARGS="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} quantization=int8 quantize_kvcache=True load_parameters_path=${SAVE_QUANT_PARAMS_PATH} checkpoint_is_quantized=True compute_axis_order=0,1,2,3 ar_cache_axis_order=0,1,2,3" | ||
``` | ||
|
||
## Run offline performance | ||
|
||
``` | ||
bash ./llama_offline_performance_run.sh | ||
``` | ||
|
||
## Run offline accuracy | ||
``` | ||
bash ./llama_offline_accuracy_run.sh | ||
``` | ||
|
||
## Run offline audit | ||
``` | ||
bash ./llama_offline_audit_run.sh | ||
``` | ||
|
||
## Run server performance | ||
``` | ||
bash ./generate_server_performance_run.sh | ||
``` | ||
|
||
## Run server accuracy | ||
``` | ||
bash ./generate_server_accuracy_run.sh | ||
``` | ||
|
||
## Run server audit | ||
``` | ||
bash ./generate_server_audit_run.sh | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" Evaluation script based on MLPerf requirements""" | ||
|
||
import argparse | ||
from transformers import AutoTokenizer | ||
import nltk | ||
import evaluate | ||
import numpy as np | ||
import json | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--checkpoint-path", required=True, help="Path to Llama2-70b-hf-chat checkpoint") | ||
parser.add_argument("--mlperf-accuracy-file", required=True, help="path to mlperf_log_accuracy.json") | ||
parser.add_argument("--dataset-file", required=True, help="path to processed openorca validation set") | ||
parser.add_argument("--verbose", action="store_true", help="verbose messages") | ||
parser.add_argument("--dtype", default="int64", help="dtype of the accuracy log", choices=["int32", "int64", "float"]) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def get_groundtruth(processed_dataset_file): | ||
import pandas as pd | ||
|
||
data = pd.read_pickle(processed_dataset_file) | ||
ground_truths = data["output"] | ||
return ground_truths | ||
|
||
|
||
def postprocess_text(preds, targets): | ||
preds = [pred.strip() for pred in preds] | ||
targets = [target.strip() for target in targets] | ||
|
||
# rougeLSum expects newline after each sentence | ||
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] | ||
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] | ||
|
||
return preds, targets | ||
|
||
|
||
def main(): | ||
|
||
args = get_args() | ||
dataset_path = args.dataset_file | ||
checkpoint_path = args.checkpoint_path | ||
metric = evaluate.load("rouge") | ||
nltk.download("punkt") | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
checkpoint_path, | ||
model_max_length=2048, | ||
padding_side="left", | ||
use_fast=False, | ||
) | ||
|
||
targets = get_groundtruth(args.dataset_file) | ||
|
||
target_required = [] | ||
preds_token_ids = [] | ||
|
||
eval_dtype = np.int64 | ||
if args.dtype == "int32": | ||
eval_dtype = np.int32 | ||
elif args.dtype == "float": | ||
eval_dtype = np.float32 | ||
|
||
with open(args.mlperf_accuracy_file, "r") as f: | ||
results = json.load(f) | ||
|
||
seen = set() | ||
gen_tok_len = 0 | ||
for pred in results: | ||
qsl_idx = pred["qsl_idx"] | ||
if qsl_idx in seen: | ||
continue | ||
|
||
seen.add(qsl_idx) | ||
target = targets[qsl_idx] | ||
target_required.append(target) | ||
pred = np.frombuffer(bytes.fromhex(pred["data"]), eval_dtype) | ||
if pred[0] > 32000 or pred[0] < 0: | ||
pred = [1, *pred[1:]] | ||
gen_tok_len += len(pred) | ||
preds_token_ids.append(pred) | ||
|
||
preds_decoded_text = tokenizer.batch_decode(preds_token_ids, skip_special_tokens=True) | ||
|
||
preds, targets = postprocess_text(preds_decoded_text, target_required) | ||
|
||
result = metric.compute(predictions=preds, references=targets, use_stemmer=True, use_aggregator=False) | ||
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} | ||
prediction_lens = [len(pred) for pred in preds] | ||
gen_num = len(preds) | ||
|
||
result = { | ||
**result, | ||
"gen_len": np.sum(prediction_lens), | ||
"gen_num": gen_num, | ||
"gen_tok_len": gen_tok_len, | ||
"tokens_per_sample": round(gen_tok_len / gen_num, 1), | ||
} | ||
|
||
print("\nResults\n") | ||
print(result) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#!/usr/bin/env bash | ||
me=$(basename "$0") | ||
|
||
if [ -z "$BASEDIR"]; | ||
then | ||
BASEDIR=/home/$USER/inference_mlperf4.1 | ||
fi | ||
|
||
USER_CONFIG=$BASEDIR/language/llama2-70b/tpu/user.conf | ||
|
||
if [ -z "$DATA_DISK_DIR"]; | ||
then | ||
DATA_DISK_DIR=/home/$USER/loadgen_run_data | ||
fi | ||
|
||
DATASET_PATH=${DATA_DISK_DIR}/processed-data.pkl | ||
TOTAL_SAMPLE_COUNT=24576 | ||
LOG_INTERVAL=900 | ||
|
||
if [-z "$BATCH_AND_PREFILL_LEN"]; | ||
then | ||
BATCH_AND_PREFILL_LEN="256,80|512,40|1024,20" | ||
fi | ||
|
||
LOADGEN_RUN_TYPE=offline-accuracy | ||
MODEL_NAME=llama70b | ||
DATASET_TYPE=full | ||
|
||
LOADGEN_RUN_TIMESTAMP=$(TZ=America/Los_Angeles date +%Y%m%d%H%M%S%Z) | ||
OUTPUT_LOG_ID=${MODEL_NAME}-${DATASET_TYPE}-${LOADGEN_RUN_TYPE}-${LOADGEN_RUN_TIMESTAMP} | ||
OUTPUT_LOG_DIR=${DATA_DISK_DIR}/logs/${OUTPUT_LOG_ID} | ||
|
||
mkdir -p ${OUTPUT_LOG_DIR} && cp ${USER_CONFIG} ${OUTPUT_LOG_DIR} | ||
|
||
OUTPUT_ACCURACY_JSON_PATH=${OUTPUT_LOG_DIR}/mlperf_log_accuracy.json | ||
|
||
# LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" | ||
# makes subsequent runs faster | ||
export JAX_COMPILATION_CACHE_DIR="/tmp/jax_cache2" | ||
export LIBTPU_INIT_ARGS | ||
|
||
echo "LOADGEN_RUN_TYPE: ${LOADGEN_RUN_TYPE}" | ||
echo "LOADGEN_RUN_TIMESTAMP: ${LOADGEN_RUN_TIMESTAMP}" | ||
echo "DATASET_PATH: ${DATASET_PATH}" | ||
echo "TOTAL_SAMPLE_COUNT: ${TOTAL_SAMPLE_COUNT}" | ||
echo "BATCH_SIZE_EXP: ${BATCH_SIZE_EXP}" | ||
echo "OUTPUT_LOG_DIR: ${OUTPUT_LOG_DIR}" | ||
echo "USER_CONFIG: ${USER_CONFIG}" | ||
|
||
python -m offline_mode \ | ||
--mlperf_test_mode=accuracy \ | ||
--input_mode tokenized \ | ||
--output_mode tokenized \ | ||
--mlperf_conf $BASEDIR/mlperf.conf \ | ||
--user_conf ${USER_CONFIG} \ | ||
--audit_conf no_audit \ | ||
--total_sample_count ${TOTAL_SAMPLE_COUNT} \ | ||
--dataset_path ${DATASET_PATH} \ | ||
--prefill_lengths_and_batch_sizes ${BATCH_AND_PREFILL_LEN} \ | ||
--maxengine_args "${MAXENGINE_ARGS}" \ | ||
--output_log_dir ${OUTPUT_LOG_DIR} 2>&1 | tee ${OUTPUT_LOG_DIR}/offline_accuracy_log.log | ||
|
||
# Eval Run | ||
if [ -e ${OUTPUT_ACCURACY_JSON_PATH} ]; then | ||
python3 evaluate-accuracy.py \ | ||
--checkpoint-path meta-llama/Llama-2-70b-chat-hf \ | ||
--mlperf-accuracy-file ${OUTPUT_ACCURACY_JSON_PATH} \ | ||
--dataset-file ${DATASET_PATH} 2>&1 | tee ${OUTPUT_LOG_DIR}/evaluate_offline_accuracy_log.log | ||
fi | ||
|
Oops, something went wrong.