-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path24_game_train1.py
161 lines (134 loc) · 6.55 KB
/
24_game_train1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
Trains a GPT (generative pre-trained transformer language model) to play 24 game.
Some background info about 24 game:
Given 4 numbers ranging from [0, 9], use "+ - * /" 4 kinds of operators to try to get 24.
For example:
Given 4, 2, 5, 3
Some valid reasonings are (there are 34 possible reasonings):
[4, 2, 5, 3]: 5 - 2 = 3, 3 + 3 = 6, 6 * 4 = 24
[4, 2, 5, 3]: 4 + 3 = 7, 5 + 7 = 12, 12 * 2 = 24
......
[4, 2, 5, 3]: 5 + 3 = 8, 8 + 4 = 12, 12 * 2 = 24
I trained a GPT model to predicate reasoning according to a
prompt (in source code, I use 'problem' to represent a prompt).
The trained model takes a prompt "[3, 7, 5, 5]: " as input, we expect
the mode to output something like "7 + 5 = 12, 5 - 3 = 2, 2 * 12 = 24".
As we use a language model to solve game 24, we use
"[3, 7, 5, 5]: 7 + 5 = 12, 5 - 3 = 2, 2 * 12 = 24" and
"[3, 5, 5, 7]: 7 + 5 = 12, 5 - 3 = 2, 2 * 12 = 24" as different data samples.
We use 40% of all data samples (see 24_game_data_generator.py about how to generate all possible data samples) to train the model.
The model runs on all data samples and achieves 99.1% accuracy.
"""
import os
import sys
import json
import torch
from torch.utils.data.dataloader import DataLoader
from model import GPT
from trainer import Trainer
from utils import set_seed, ConfigNode
from dataset import DatasetOf24Game
# -----------------------------------------------------------------------------
def get_config():
C = ConfigNode()
# system
C.system = ConfigNode()
C.system.seed = 515
C.system.work_dir = './out/data1_9_v3_vf'
# model
C.model = GPT.get_default_config()
C.model.model_type = 'gpt-mini'
# trainer
C.trainer = Trainer.get_default_config()
C.trainer.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
C.trainer.batch_size = 512
C.trainer.max_iters = 5000
return C
if __name__ == '__main__':
# get default config and overrides from the command line, if any
config = get_config()
config.merge_from_args(sys.argv[1:])
# create the work directory if it doesn't already exist
os.makedirs(config.system.work_dir, exist_ok=True)
# log the args (if any)
with open(os.path.join(config.system.work_dir, 'args.txt'), 'w') as f:
f.write(' '.join(sys.argv))
# log the config itself
with open(os.path.join(config.system.work_dir, 'config.json'), 'w') as f:
f.write(json.dumps(config.to_dict(), indent=4))
# see random seeds for everywhere
set_seed(config.system.seed)
# construct train and test datasets
train_dataset = DatasetOf24Game(split='train')
test_dataset = DatasetOf24Game(split='test')
# construct the model
config.model.vocab_size = DatasetOf24Game.get_vocab_size()
config.model.block_size = DatasetOf24Game.get_block_size()
print(config)
model = GPT(config.model)
# print(model)
# construct the trainer object
trainer = Trainer(config.trainer, model, train_dataset)
# helper function for the evaluation of a model
def eval_split(trainer, split, max_batches=None):
dataset = {'train': train_dataset, 'test': test_dataset}[split]
results = []
num_right = 0
mistakes_printed_already = 0
loader = DataLoader(dataset, batch_size=500, num_workers=0, drop_last=False)
for b, (x, y) in enumerate(loader):
x = x.to(trainer.device)
y = y.to(trainer.device)
# isolate the first two digits of the input sequence alone
problem = x[:, :len(DatasetOf24Game.one_problem_sample)]
# let the model sample the rest of the sequence
problem_result_pred = model.generate(problem, len(DatasetOf24Game.one_result_sample), do_sample=False) # using greedy argmax, not sampling
# isolate the last digit of the sampled sequence
result_pred = problem_result_pred[:, -len(DatasetOf24Game.one_result_sample):]
# evaluate the correctness of the results in this batch
for i in range(x.size(0)):
r = torch.cat((x[i, :len(DatasetOf24Game.one_problem_sample)], result_pred[i]), 0)
r = r.tolist()
r = "".join([DatasetOf24Game.itoc[i] for i in r])
# r = r.rstrip('\n').replace(" ", "")
r = r.rstrip('\n')
num_right += int(r in DatasetOf24Game.all_data_set) # 是一种解法就算对
if not r in DatasetOf24Game.all_data_set and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense
mistakes_printed_already += 1
print("GPT claims that " + r)
if max_batches is not None and b+1 >= max_batches:
break
# print(f"{split} final score: {len(right_results)} out of {len(results)} are right predications, {len(right_results_in_test)} out of all {len(right_results)} right predications are not in training data.")
accuracy = num_right / len(dataset)
print(f"{split} accuracy: {num_right} / {len(dataset)} = {accuracy:.3f}")
return accuracy
# iteration callback
top_score = 0
last_save_iter = 0
def batch_end_callback(trainer):
global top_score, last_save_iter
if trainer.iter_num > 0 and trainer.iter_num % 200 == 0:
# evaluate both the train and test score
model.eval()
with torch.no_grad():
# train_score = eval_split(trainer, 'train', max_batches=5)
train_score = 0
test_score = eval_split(trainer, 'test', max_batches=None) # on 10 x 100 test data
score = train_score + test_score
# save the model if this is the best score we've seen so far
if score > top_score:
top_score = score
last_save_iter = trainer.iter_num
print(f"saving model with new top score of {score}")
ckpt_path = os.path.join(config.system.work_dir, "model.pt")
torch.save(model.state_dict(), ckpt_path)
if trainer.iter_num == 2000:
print(f"\nsaving model at 2000 iter")
ckpt_path = os.path.join(config.system.work_dir, "model_2k.pt")
torch.save(model.state_dict(), ckpt_path)
print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}, last save iter {last_save_iter}")
# revert model to training mode
model.train()
trainer.set_callback('on_batch_end', batch_end_callback)
# run the optimization
trainer.run()