-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathlearn
executable file
·180 lines (143 loc) · 6.56 KB
/
learn
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# This file is part of the gunyanza.
# Copyright (C) 2014- Erik Bernhardsson
# Copyright (C) 2015- Tasuku SUENAGA <[email protected]>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals, print_function
USE_GPU = True
NOT_PICKLE_DURATION = 10
import numpy as np
from chainer import cuda, Function, FunctionSet, gradient_check, Variable, optimizers
import chainer.functions as F
import itertools
import os
import sys
import h5py
import pickle
from datetime import datetime
from sklearn.cross_validation import train_test_split
from network import GunyaNetwork
# HDF5から盤面の情報取ってくる
def load_data_from_hdf5(input_dir_path):
for fn in os.listdir(input_dir_path):
if not fn.endswith('.hdf5'):
continue
fn = os.path.join(input_dir_path, fn)
print('Read from {0}'.format(fn))
yield h5py.File(fn, 'r')
# try:
# yield h5py.File(fn, 'r')
# except:
# print('could not read', fn)
# HDF5から盤面の情報を取ってくる
def feed_data_from_hdf5(input_dir_path):
series = ['x', 'xr', 'xp']
data = [[] for s in series]
for f in load_data_from_hdf5(input_dir_path):
try:
for i, s in enumerate(series):
data[i].append(f[s].value)
except:
raise
print('failed reading from', f)
data = [np.vstack(d) for d in data]
return data
# HDF5から盤面の情報取ってきて、train setとtest setに分ける
def feed_train_and_test_data_from_hdf5(input_dir_path, test_size_ratio):
data = feed_data_from_hdf5(input_dir_path)
print('Splitting', len(data[0]), 'entries into train/test set')
data = train_test_split(*data, test_size=int(test_size_ratio * len(data[0])))
print(data[0].shape[0], 'train set', data[1].shape[0], 'test set')
return data
def forward_triplet(model, parent, observed, random, coefficient_a, coefficient_b, train, use_gpu):
parent_value = model.forward(parent, use_gpu, train)
observed_value = model.forward(observed, use_gpu, train)
random_value = model.forward(random, use_gpu, train)
# observed_value should be bigger than random_value
loss_a = coefficient_a * F.sum(F.exp(random_value - observed_value)) / random_value.data.size
# observed_value should be the same as parent_value
loss_b = coefficient_b * F.mean_squared_error(observed_value, parent_value)
loss = loss_a + loss_b
print('loss: {0} loss_a: {1} loss_b: {2}'.format(loss.data, loss_a.data, loss_b.data))
return loss
def train_epoch(model, optimizer, coefficient_a, coefficient_b, parent_data, observed_data, random_data, parent_test, observed_test, random_test, batch_size, use_gpu):
# train
print('train')
train_loss = 0.0
for i in xrange(0, len(parent_data[0]), batch_size):
optimizer.zero_grads()
loss = forward_triplet(model, parent_data[i:i+batch_size], observed_data[i:i+batch_size], random_data[i:i+batch_size], coefficient_a, coefficient_b, True, use_gpu)
loss.backward()
optimizer.update()
train_loss += cuda.to_cpu(loss.data)
# test
print('test')
test_loss = 0.0
for i in xrange(0, len(parent_test[0]), batch_size):
loss = forward_triplet(model, parent_test[i:i+batch_size], observed_test[i:i+batch_size], random_test[i:i+batch_size], coefficient_a, coefficient_b, False, use_gpu)
test_loss += cuda.to_cpu(loss.data)
loss = train_loss + test_loss
print('loss: {0} train_loss: {1} test_loss: {2}'.format(loss, train_loss, test_loss))
return train_loss, test_loss
def train(input_dir_path, output_dir_path, max_epoch, batch_size, test_size_ratio, use_gpu, not_pickle_duration):
if test_size_ratio >= 1.0 or test_size_ratio <= 0.0:
raise ValueError, 'test_size_ratio should be between 0.0 and 1.0'
if use_gpu:
print('Use CUDA')
cuda.init()
# x?: board info with mochigoma
# ?o: observed
# ?r: random
# ?p: parent
xo_train, xo_test, xr_train, xr_test, xp_train, xp_test = feed_train_and_test_data_from_hdf5(input_dir_path, test_size_ratio)
# print board sampling for debug
if False:
for board in [xo_train[0], xp_train[0]]:
for row in xrange(9):
print(''.join('%3d' % x for x in board[(row*9):((row+1)*9)]))
print('')
for piece_count in xrange(81, 95):
print(board[piece_count], end='')
print('')
model = GunyaNetwork()
if use_gpu:
model = model.to_gpu()
optimizer = optimizers.Adam()
optimizer.setup(model.collect_parameters())
# FIXME: reducing coefficients along learning progress
coefficient_a = 1.0
coefficient_b = 10.0
best_test_loss = float('inf')
not_pickle_count = not_pickle_duration
for epoch in xrange(1, max_epoch):
print('epoch {0} at {1}'.format(epoch, datetime.now().strftime('%Y/%m/%d %H:%M:%S')))
train_loss, test_loss = train_epoch(model, optimizer, coefficient_a, coefficient_b, xp_train, xo_train, xr_train, xp_test, xo_test, xr_test, batch_size, use_gpu)
if test_loss < best_test_loss:
best_test_loss = test_loss
model_pickle_path = os.path.join(output_dir_path, 'gunyanza.model.pickle')
if not_pickle_count > 0:
not_pickle_count -= 1
print('Best test loss. {0} times rest to save'.format(not_pickle_count))
else:
print('Best test loss. Dump the model to {0}'.format(model_pickle_path))
with open(model_pickle_path, 'w') as f:
pickle.dump(model, f)
not_pickle_count = not_pickle_duration
if __name__ == '__main__':
if len(sys.argv) != 6:
print('Usage: {0} input_directory output_directory max_epoch batch_size test_set_ratio'.format(sys.argv[0]))
else:
train(sys.argv[1], sys.argv[2], int(sys.argv[3]), int(sys.argv[4]), float(sys.argv[5]), USE_GPU, NOT_PICKLE_DURATION)