forked from jingyaogong/minimind
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path0-eval_pretrain.py
179 lines (147 loc) · 5.72 KB
/
0-eval_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import random
import time
import numpy as np
import torch
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import Transformer
from model.LMConfig import LMConfig
warnings.filterwarnings('ignore')
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def init_model(lm_config):
tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
model_from = 1 # 1从权重,2用transformers
if model_from == 1:
moe_path = '_moe' if lm_config.use_moe else ''
ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
model = Transformer(lm_config)
state_dict = torch.load(ckp, map_location=device)
# 处理不需要的前缀
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
for k, v in list(state_dict.items()):
if 'mask' in k:
del state_dict[k]
# 加载到模型中
model.load_state_dict(state_dict, strict=False)
else:
model = AutoModelForCausalLM.from_pretrained('minimind', trust_remote_code=True)
model = model.to(device)
print(f'模型参数: {count_parameters(model) / 1e6} 百万 = {count_parameters(model) / 1e9} B (Billion)')
return model, tokenizer
def setup_seed(seed):
random.seed(seed) # 设置 Python 的随机种子
np.random.seed(seed) # 设置 NumPy 的随机种子
torch.manual_seed(seed) # 设置 PyTorch 的随机种子
torch.cuda.manual_seed(seed) # 为当前 GPU 设置随机种子(如果有)
torch.cuda.manual_seed_all(seed) # 为所有 GPU 设置随机种子(如果有)
torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
torch.backends.cudnn.benchmark = False # 关闭 cuDNN 的自动调优,避免不确定性
if __name__ == "__main__":
# -----------------------------------------------------------------------------
out_dir = 'out'
start = ""
temperature = 0.7
top_k = 8
setup_seed(1337)
# device = 'cpu'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16'
max_seq_len = 512
lm_config = LMConfig()
lm_config.max_seq_len = max_seq_len
# 对话是否携带历史对话(当前模型太弱,增大历史上下文,基本导致胡言乱语)
contain_history_chat = False
# -----------------------------------------------------------------------------
model, tokenizer = init_model(lm_config)
model = model.eval()
# 推送到huggingface
# model.push_to_hub("minimind")
# tokenizer.push_to_hub("minimind")
# answer_way = int(input('输入0自动测试,输入1问题测试:'))
answer_way = 0
stream = True
prompt_datas = [
'椭圆和圆的区别',
'中国关于马克思主义基本原理',
'人类大脑的主要功能是',
'万有引力是',
'世界上人口最多的国家是',
'DNA的全称是',
'数学中π的值大约是',
'世界上最高的山峰是',
'太阳系中最大的行星是',
'二氧化碳的化学分子式是',
'地球上最大的动物是',
'地球自转一圈大约需要',
'杭州市的美食有',
'江苏省的最好的大学',
]
messages_origin = []
messages = messages_origin
qa_index = 0
while True:
start = time.time()
if not contain_history_chat:
messages = messages_origin.copy()
if answer_way == 1:
# run generation
prompt = input('用户:')
else:
if qa_index >= len(prompt_datas):
break
prompt = prompt_datas[qa_index]
print('问题:', prompt)
qa_index += 1
messages.append({"role": "user", "content": prompt})
# print(messages)
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)[-(max_seq_len - 1):]
x = tokenizer(prompt).data['input_ids']
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
answer = new_prompt
with torch.no_grad():
res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature,
top_k=top_k, stream=stream)
print('回答:', end='')
try:
y = next(res_y)
except StopIteration:
print("No answer")
continue
history_idx = 0
while y != None:
answer = tokenizer.decode(y[0].tolist())
if answer and answer[-1] == '�':
try:
y = next(res_y)
except:
break
continue
# print(answer)
if not len(answer):
try:
y = next(res_y)
except:
break
continue
print(answer[history_idx:], end='', flush=True)
try:
y = next(res_y)
except:
break
history_idx = len(answer)
if not stream:
break
print('\n')
if contain_history_chat:
assistant_answer = answer.replace(new_prompt, "")
messages.append({"role": "assistant", "content": assistant_answer})
end = time.time()
print(end - start,'s')