diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 6076beba..a738b8b1 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -79,6 +79,7 @@ import pandas from eval_accuracy import eval_accuracy +from transformers import AutoTokenizer def str2bool(v: str) -> bool: @@ -156,16 +157,29 @@ def to_dict(self): } -def get_tokenizer(model_id: str, tokenizer_name: str) -> Any: +def get_tokenizer( + model_id: str, + tokenizer_name: str, + use_hf_tokenizer: bool, +) -> Any: """Return a tokenizer or a tokenizer placholder.""" if tokenizer_name == "test": + print("Using test tokenizer") return "test" + elif use_hf_tokenizer: + # Please accept agreement to access private/gated models in HF, and + # follow up instructions below to set up access token + # https://huggingface.co/docs/transformers.js/en/guides/private + print(f"Using HuggingFace tokenizer: {tokenizer_name}") + return AutoTokenizer.from_pretrained(tokenizer_name) elif model_id == "llama-3": # Llama 3 uses a tiktoken tokenizer. + print(f"Using llama-3 tokenizer: {tokenizer_name}") return llama3_tokenizer.Tokenizer(tokenizer_name) else: # Use JetStream tokenizer util. It's using the sentencepiece wrapper in # seqio library. + print(f"Using tokenizer: {tokenizer_name}") vocab = load_vocab(tokenizer_name) return vocab.tokenizer @@ -563,10 +577,11 @@ def main(args: argparse.Namespace): model_id = args.model tokenizer_id = args.tokenizer + use_hf_tokenizer = args.use_hf_tokenizer api_url = f"{args.server}:{args.port}" - tokenizer = get_tokenizer(model_id, tokenizer_id) + tokenizer = get_tokenizer(model_id, tokenizer_id, use_hf_tokenizer) if tokenizer == "test" or args.dataset == "test": input_requests = mock_requests( args.total_mock_requests @@ -716,6 +731,15 @@ def main(args: argparse.Namespace): " default value)" ), ) + parser.add_argument( + "--use-hf-tokenizer", + type=str2bool, + default=False, + help=( + "Whether to use tokenizer from HuggingFace. If so, set this flag" + " to True, and provide name of the tokenizer in the tokenizer flag." + ), + ) parser.add_argument( "--num-prompts", type=int,