-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
455 lines (372 loc) · 19 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
import time
from collections import defaultdict
from multiprocessing import Pool
import random
import numpy as np
import torch
import copy
from tqdm import tqdm
from sys import stdout
# local import
from src import Evaluator, DotsAndBoxesGame, MCTS, AZNeuralNetwork, AZDualRes, AZFeedForward, \
AlphaBetaPlayer, NeuralNetworkPlayer, RandomPlayer, Checkpoint, functions
class Trainer:
"""
Executes the training loop, where each iteration consists of
1) self-play (using MCTS)
2) model learning (using generated data)
3) model comparison (using evaluator)
Attributes
----------
game_size : int
board size (width & height) of a Dots-and-Boxes game
mcts_parameters : dict
hyperparameters concerning the MCTS
model_parameters, optimizer_parameters, data_parameters : dict, dict, dict
hyperparameters concerning the neural network (architecture, optimizer, data)
evaluator_parameters : dict
hyperparameters concerning the evaluator
n_workers : int
number of threads utilized during self-play (each thread performs games of self-play) and evaluation
inference_device : torch.cuda.device
device on which model inference is performed during self-play and evaluation
training_device : torch.cuda.device
device on which model training is performed
model_name : str
name of the neural network that is to be trained (if existing, the checkpoint may be loaded from local files)
model : AZNeuralNetwork
the neural network that is updated with each iteration during model training
we use the AlphaZero way, i.e., maintaining a single neural network that is updated continually, rather than
waiting for an iteration to complete (AlphaGo Zero)
"""
def __init__(self, config: dict, n_workers: int, inference_device: str, training_device: str, checkpoint: Checkpoint):
self.checkpoint = checkpoint
self.n_workers = n_workers
functions.print_parameters(config)
self.config = config
self.game_size = config["game_size"]
self.n_iterations = config["n_iterations"]
self.mcts_parameters = config["mcts_parameters"]
self.model_parameters = config["model_parameters"]
self.optimizer_parameters = config["optimizer_parameters"]
self.data_parameters = config["data_parameters"]
self.evaluator_parameters = config["evaluator_parameters"]
# utilize gpu if possible
if "cuda" in [inference_device, training_device]:
assert torch.cuda.is_available()
self.inference_device = torch.device(inference_device)
self.training_device = torch.device(training_device)
print(f"\nModel inference device: {self.inference_device}")
print(f"Model training device: {self.training_device}")
# initialize models
AZModel = None
if self.model_parameters["name"] == "FeedForward":
AZModel = AZFeedForward
pass
elif self.model_parameters["name"] == "DualRes":
AZModel = AZDualRes
self.model = AZModel(
game_size=self.game_size,
inference_device=self.inference_device,
model_parameters=self.model_parameters,
).float()
print("Model has {0:,} trainable parameters".format(
sum(p.numel() for p in self.model.parameters() if p.requires_grad)
))
def loop(self):
"""
Perform iterations of self-play + model training + model evaluation.
The training data which is generated by self-play is saved and loaded from local files between iterations in
order to limit RAM allocation.
"""
# checkpoint: continue or start new training?
if self.checkpoint.is_new_training():
# new training
evaluation_results = defaultdict(list)
train_losses = []
iteration_losses = defaultdict(list)
starting_iteration = 1
self.checkpoint.save_config(self.config)
else:
# continue training
self.model.load_checkpoint(self.checkpoint.model)
evaluation_results = self.checkpoint.load_evaluation_results()
train_losses = self.checkpoint.load_train_losses()
iteration_losses = self.checkpoint.load_iteration_losses()
starting_iteration = len(train_losses) + 1
total_time = time.time()
for iteration in range(starting_iteration, self.n_iterations + 1):
print(f"\n#################### Iteration {iteration}/{self.n_iterations} #################### ")
# 1) perform games of self-play to obtain training data
print("------------ Self-Play using MCTS ------------")
train_examples_per_game = self.perform_self_plays(
game_size=self.game_size,
model=self.model,
mcts_parameters=self.mcts_parameters,
n_workers=self.n_workers,
device=self.inference_device
)
# training examples are the training examples from last iteration + (augmented) generated training examples
train_examples_per_game_augmented = self.checkpoint.load_train_examples() if iteration > 1 else []
train_examples_per_game_augmented.extend(
self.augment_data(train_examples_per_game)
)
train_examples_per_game = [] # free
# cut dataset to desired size and save data for next iteration
while len(train_examples_per_game_augmented) > self.data_parameters["game_buffer"]:
train_examples_per_game_augmented.pop(0)
self.checkpoint.save_train_examples(train_examples_per_game_augmented)
# 2) model learning
print("\n---------- Neural Network Training -----------")
train_loss, iteration_loss = self.perform_model_training(
model=self.model,
train_examples_per_game_augmented=train_examples_per_game_augmented,
data_parameters=self.data_parameters,
optimizer_parameters=self.optimizer_parameters,
device=self.training_device
)
train_examples_per_game_augmented = [] # free
self.model.to(self.inference_device)
torch.cuda.empty_cache() # free
train_losses.append(train_loss)
iteration_losses["p_loss"].append(iteration_loss["p_loss"])
iteration_losses["v_loss"].append(iteration_loss["v_loss"])
iteration_losses["loss"].append(iteration_loss["loss"])
# 3) evaluator: model comparison against non-neural network players
print("\n-------------- Model Comparison --------------")
neural_network_player = NeuralNetworkPlayer(
model=self.model,
name=f"UpdatedModel",
mcts_parameters=self.mcts_parameters,
device=self.inference_device
)
for opponent in [RandomPlayer(), AlphaBetaPlayer(depth=1), AlphaBetaPlayer(depth=2), AlphaBetaPlayer(depth=3)]:
results = Evaluator(
game_size=self.game_size,
player1=neural_network_player,
player2=opponent,
n_games=self.evaluator_parameters["n_games"],
n_workers=self.n_workers
).compare()
evaluation_results[opponent.name].append(results)
# save model and train/evaluation results
self.model.save_checkpoint(self.checkpoint.model)
self.checkpoint.save_evaluation_results(evaluation_results)
self.checkpoint.save_train_losses(train_losses)
self.checkpoint.save_iteration_losses(iteration_losses)
print("\nTotal time in training loop: {0:.2f}s".format(time.time() - total_time))
print("###########################################################")
@staticmethod
def perform_self_plays(game_size: int, model: AZNeuralNetwork, mcts_parameters: dict, n_workers: int, device: torch.device):
"""
Perform games of self-play using MCTS.
Parameters
----------
game_size : int
number of games of self-play to perform
model : AZNeuralNetwork
the neural network with which board positions are evaluated during self-play
mcts_parameters : dict
hyperparameters concerning the MCTS
n_workers : int
number of threads utilized during self-play (each thread performs games of self-play)
device : torch.device
device on which model inference is performed
Returns
-------
train_examples_per_game : [[(np.ndarray, np.ndarray, [float], float)]]
list (per game) of list of training examples (l, b, p, v) (from the current player's POV)
"""
# model inference
model.eval()
model.to(device)
n_games = mcts_parameters["n_games"]
train_examples_per_game = []
if n_workers > 1:
args = (game_size, model, mcts_parameters)
with Pool(processes=n_workers) as pool:
for train_examples in pool.istarmap(Trainer.perform_self_play, tqdm([args] * n_games, file=stdout, smoothing=0.0)):
train_examples_per_game.append(train_examples)
else:
for _ in tqdm(range(n_games), file=stdout):
train_examples = Trainer.perform_self_play(game_size, model, mcts_parameters)
train_examples_per_game.append(train_examples)
print("{0:,} games of Self-Play resulted in {1:,} new training examples (without augmentations).".format(
n_games, len([t for l in train_examples_per_game for t in l])))
return train_examples_per_game
@staticmethod
def perform_self_play(game_size: int, model: AZNeuralNetwork, mcts_parameters: dict):
"""
Perform a single game of self-play using MCTS. The data for the game is stored as (l, b, p, v) at each
time-step (i.e., each turn results in a training example), with l (lines vector) and b (boxes matrix)
representing the game state, and p (policy vector) and v (value scalar) being the parameters to be predicted.
Returns
-------
train_examples : [[np.ndarray, np.ndarray, [float], float]]
list of training examples (l, b, p, v) (from the current player's POV)
"""
game = DotsAndBoxesGame(game_size)
n_moves = 0
train_examples = []
# one self-play corresponds with one tree
mcts = MCTS(
model=model,
s=copy.deepcopy(game),
mcts_parameters=mcts_parameters
)
# when more than temperature_move_threshold moves were performed during self-play, the temperature parameter
# is set from 1 to 0. This ensures that a diverse set of positions are encountered, as then the first moves
# during MCTS are selected proportionally to their visit count
temperature_move_threshold = mcts_parameters["temperature_move_threshold"]
# iteration over time-steps t during the game. At each time-step, a MCTS is executed using the previous iteration
# of the neural network and a move is played by sampling the search probabilities
while game.is_running():
temp = 1 if n_moves < temperature_move_threshold else 0
n_moves += 1
# execute MCTS for next move
probs = mcts.play(temp=temp)
train_examples.append([
game.get_canonical_lines(),
game.get_canonical_boxes(),
probs,
game.current_player # correct v is determined later
])
# sample and play move from probability distribution
move = np.random.choice(
a=list(range(game.N_LINES)),
p=probs
)
game.execute_move(move)
# child node corresponding to the played action becomes the new root. The subtree below this child is
# retained along with all its statistics, while the remainder of the tree is discarded
mcts.root = mcts.root.get_child_by_move(move)
# determine correct value v for the activate player in each example
assert game.result is not None, "Game not yet finished. Unable to determine value v"
for i, (_, _, _, current_player) in enumerate(train_examples):
if current_player == game.result:
train_examples[i][-1] = 1
elif game.result == 0:
train_examples[i][-1] = 0
else:
train_examples[i][-1] = -1
return train_examples
@staticmethod
def perform_model_training(model: AZNeuralNetwork, train_examples_per_game_augmented: list,
data_parameters: dict, optimizer_parameters: dict, device: torch.device):
"""
Update the already existing neural network using the training data which was generated from self-play.
Loss Function: "The neural network is adjusted to minimize the error between the predicted value and the self-play
winner, and to maximize the similarity of the neural network move probabilities to the search probabilities": The
parameters are adjusted by gradient descent on a loss function that sums over the mean-squared error
and cross-entropy losses. The cross-entropy and MSE losses are weighted equally. L2 weight regularization is
used to prevent overfitting.
Optimization: The neural network parameters are optimized by stochastic gradient descent with momentum (without
learning rate annealingas opposed to the original paper).
Parameters
----------
model : AZNeuralNetwork
the neural network which is updated
train_examples_per_game_augmented : [[[np.ndarray, np.ndarray, [float], float]]]
list (per game) of list of training examples (l, b, p, v) (from the current player's POV)
data_parameters, optimizer_parameters : dict
training hyperparameters
device : torch.device
device on which model training is performed
"""
# prepare data
game_buffer = data_parameters["game_buffer"]
n_batches = data_parameters["n_batches"]
batch_size = data_parameters["batch_size"]
# sample specific number of batches
print("Encoding train examples for given model .. ")
train_examples = [t for t_list in train_examples_per_game_augmented for t in t_list]
train_examples = [(model.encode(lines, boxes), p, v) for lines, boxes, p, v in train_examples] # encode s=(l, b) for given model
for s, p, v in train_examples:
# for feature planes representation, batching s of shape [4, n, n] should result in batch of shape [batch_size, 4, n, n]
# if no dimension is added, resulting shape would be [4*batch_size, n, n] which would ignore one necessary dimension
s.shape = (1,) + s.shape
print(f"Batches are sampled from {len(train_examples):,} training examples (incl. augmentations) from the "
f"{len(train_examples_per_game_augmented):,}/{game_buffer:,} most recent games.")
print("Preparing batches .. ")
batches = []
for _ in tqdm(range(n_batches), file=stdout):
batch = random.sample(train_examples, batch_size)
x, p, v = [list(t) for t in zip(*batch)]
batches.append((np.vstack(x), np.vstack(p), v))
# loss Functions and optimizer
CrossEntropyLoss = torch.nn.CrossEntropyLoss()
MSELoss = torch.nn.MSELoss()
# optimizer = torch.optim.SGD(
# model.parameters(),
# lr=optimizer_parameters["learning_rate"],
# momentum=optimizer_parameters["momentum"],
# weight_decay=optimizer_parameters["weight_decay"],
# )
optimizer = torch.optim.Adam(
model.parameters(),
lr=optimizer_parameters["learning_rate"],
weight_decay=optimizer_parameters["weight_decay"]
)
print("Updating model .. ")
# model update
model.train()
model.to(device)
train_loss = defaultdict(list)
for i in tqdm(range(n_batches), file=stdout):
optimizer.zero_grad()
# to not run out of memory on gpu, move data to device sequentially
x, p_gt, v_gt = [torch.tensor(e, dtype=torch.float32, device=device) for e in batches[i]] # batch = (x, p, v)
p, v = model.forward(x)
# loss and model & optimizer update
p_loss = CrossEntropyLoss(p, p_gt)
v_loss = MSELoss(v, v_gt)
loss = p_loss + v_loss
loss.backward()
optimizer.step()
# logging
train_loss["p_loss"].append(p_loss.item())
train_loss["v_loss"].append(v_loss.item())
train_loss["loss"].append(loss.item())
# evaluate model on same data
print("Evaluating model .. ")
model.eval()
with torch.no_grad():
# calculate loss per training example
p_loss, v_loss = 0, 0
for i in tqdm(range(n_batches), file=stdout):
optimizer.zero_grad()
x, p_gt, v_gt = [torch.tensor(e, dtype=torch.float32, device=device) for e in batches[i]] # batch = (x, p, v)
p, v = model.forward(x)
p_loss += CrossEntropyLoss(p, p_gt)
v_loss += MSELoss(v, v_gt)
p_loss = p_loss / n_batches
v_loss = v_loss / n_batches
print("Policy Loss: {0:.5f} (avg.)".format(p_loss))
print("Value Loss: {0:.5f} (avg.)".format(v_loss))
print("Loss: {0:.5f} (avg.)".format(p_loss + v_loss))
return train_loss, {"p_loss": p_loss.item(), "v_loss": v_loss.item(), "loss": p_loss.item() + v_loss.item()}
@staticmethod
def augment_data(train_examples_per_game: list):
"""
Augments data
Parameters
----------
train_examples_per_game : [[[np.ndarray, np.ndarray, [float], float]]]
list (per game) of list of training examples (l, b, p, v) (from the current player's POV)
Returns
-------
train_examples_per_game_augmented : [[[np.ndarray, np.ndarray, [float], float]]]
augmented dataset, i.e., included rotations and reflections of each position
"""
data_augmented = []
for train_examples in train_examples_per_game:
train_examples_augmented = []
for lines, boxes, p, v in train_examples:
train_examples_augmented.extend(zip(
DotsAndBoxesGame.get_rotations_and_reflections_lines(lines),
DotsAndBoxesGame.get_rotations_and_reflections_boxes(boxes),
DotsAndBoxesGame.get_rotations_and_reflections_lines(np.asarray(p)),
[v] * 8
))
data_augmented.append(train_examples_augmented)
return data_augmented