diff --git a/examples/llm_serving/textgen.py b/examples/llm_serving/textgen.py index 2ef2fd2e7..3e55a9982 100644 --- a/examples/llm_serving/textgen.py +++ b/examples/llm_serving/textgen.py @@ -26,7 +26,7 @@ def main(args): # Load the model model = get_model(model_name=args.model, - path="~/opt_weights", + path=args.path, batch_size=args.n_prompts, **generate_params) @@ -53,7 +53,8 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--model", type=str, default="alpa/opt-1.3b") + parser.add_argument('--model', type=str, default='alpa/opt-1.3b') + parser.add_argument('--path', type=str, default='~/opt_weights') parser.add_argument('--do-sample', action='store_true') parser.add_argument('--num-beams', type=int, default=1) parser.add_argument('--num-return-sequences', type=int, default=1)