-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathcli_demo.py
78 lines (65 loc) · 2.64 KB
/
cli_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import torch
import platform
from colorama import Fore, Style
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
from utils import (
load_pretrained,
prepare_infer_args
)
def init_model():
print("init model ...")
model_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)
model.generation_config = GenerationConfig.from_pretrained(
model_args.model_name_or_path
)
model.generation_config.max_new_tokens = generating_args.max_new_tokens
model.generation_config.temperature = generating_args.temperature
model.generation_config.top_k = generating_args.top_k
model.generation_config.top_p = generating_args.top_p
model.generation_config.repetition_penalty = generating_args.repetition_penalty
model.generation_config.do_sample = generating_args.do_sample
model.generation_config.num_beams = generating_args.num_beams
model.generation_config.length_penalty = generating_args.length_penalty
return model, tokenizer
def clear_screen():
if platform.system() == "Windows":
os.system("cls")
else:
os.system("clear")
print(Fore.YELLOW + Style.BRIGHT + "欢迎使用百川大模型,输入进行对话,clear 清空历史,CTRL+C 中断生成,steam 开关流式生成,exit 结束。")
return []
def main(stream=True):
model, tokenizer = init_model()
messages = clear_screen()
while True:
prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL)
if prompt.strip() == "exit":
break
if prompt.strip() == "clear":
messages = clear_screen()
continue
print(Fore.CYAN + Style.BRIGHT + "\nBaichuan:" + Style.NORMAL, end='')
if prompt.strip() == "stream":
stream = not stream
print(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='')
continue
messages.append({"role": "user", "content": prompt})
if stream:
position = 0
try:
for response in model.chat(tokenizer, messages, stream=True):
print(response[position:], end='', flush=True)
position = len(response)
except KeyboardInterrupt:
pass
print()
else:
response = model.chat(tokenizer, messages)
print(response)
messages.append({"role": "assistant", "content": response})
print(Style.RESET_ALL)
if __name__ == "__main__":
main()