forked from tensorflow/minigo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
227 lines (184 loc) · 8.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Train a network.
Usage:
BOARD_SIZE=19 python train.py tfrecord1 tfrecord2 tfrecord3
"""
import logging
from absl import app, flags
import numpy as np
import tensorflow as tf
import bigtable_input
import dual_net
import preprocessing
import utils
# See www.moderndescartes.com/essays/shuffle_viz for discussion on sizing
flags.DEFINE_integer('shuffle_buffer_size', 2000,
'Size of buffer used to shuffle train examples.')
flags.DEFINE_integer('steps_to_train', None,
'Number of training steps to take. If not set, iterates '
'once over training data.')
flags.DEFINE_integer('window_size', 500000,
'Number of games to include in the window')
flags.DEFINE_float('filter_amount', 1.0,
'Fraction of positions to filter from golden chunks,'
'default, 1.0 (no filter)')
flags.DEFINE_string('export_path', None,
'Where to export the model after training.')
flags.DEFINE_bool('use_bt', False,
'Whether to use Bigtable as input. '
'(Only supported with --use_tpu, currently.)')
flags.DEFINE_bool('freeze', False,
'Whether to freeze the graph at the end of training.')
flags.register_multi_flags_validator(
['use_bt', 'use_tpu'],
lambda flags: flags['use_tpu'] if flags['use_bt'] else True,
'`use_bt` flag only valid with `use_tpu` as well')
@flags.multi_flags_validator(
['use_bt', 'cbt_project', 'cbt_instance', 'cbt_table'],
message='Cloud Bigtable configuration flags not correct')
def _bt_checker(flags_dict):
if not flags_dict['use_bt']:
return True
return (flags_dict['cbt_project']
and flags_dict['cbt_instance']
and flags_dict['cbt_table'])
# From dual_net.py
flags.declare_key_flag('work_dir')
flags.declare_key_flag('train_batch_size')
flags.declare_key_flag('num_tpu_cores')
flags.declare_key_flag('use_tpu')
FLAGS = flags.FLAGS
class EchoStepCounterHook(tf.train.StepCounterHook):
"""A hook that logs steps per second."""
def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
s_per_sec = elapsed_steps / elapsed_time
logging.info("{}: {:.3f} steps per second".format(global_step, s_per_sec))
super()._log_and_record(elapsed_steps, elapsed_time, global_step)
def compute_update_ratio(weight_tensors, before_weights, after_weights):
"""Compute the ratio of gradient norm to weight norm."""
deltas = [after - before for after,
before in zip(after_weights, before_weights)]
delta_norms = [np.linalg.norm(d.ravel()) for d in deltas]
weight_norms = [np.linalg.norm(w.ravel()) for w in before_weights]
ratios = [d / w for d, w in zip(delta_norms, weight_norms)]
all_summaries = [
tf.Summary.Value(tag='update_ratios/' +
tensor.name, simple_value=ratio)
for tensor, ratio in zip(weight_tensors, ratios)]
return tf.Summary(value=all_summaries)
class UpdateRatioSessionHook(tf.train.SessionRunHook):
"""A hook that computes ||grad|| / ||weights|| (using frobenius norm)."""
def __init__(self, output_dir, every_n_steps=1000):
self.output_dir = output_dir
self.every_n_steps = every_n_steps
self.before_weights = None
self.file_writer = None
self.weight_tensors = None
self.global_step = None
def begin(self):
# These calls only works because the SessionRunHook api guarantees this
# will get called within a graph context containing our model graph.
self.file_writer = tf.summary.FileWriterCache.get(self.output_dir)
self.weight_tensors = tf.trainable_variables()
self.global_step = tf.train.get_or_create_global_step()
def before_run(self, run_context):
global_step = run_context.session.run(self.global_step)
if global_step % self.every_n_steps == 0:
self.before_weights = run_context.session.run(self.weight_tensors)
def after_run(self, run_context, run_values):
global_step = run_context.session.run(self.global_step)
if self.before_weights is not None:
after_weights = run_context.session.run(self.weight_tensors)
weight_update_summaries = compute_update_ratio(
self.weight_tensors, self.before_weights, after_weights)
self.file_writer.add_summary(
weight_update_summaries, global_step)
self.before_weights = None
def train(*tf_records: "Records to train on"):
"""Train on examples."""
tf.logging.set_verbosity(tf.logging.INFO)
estimator = dual_net.get_estimator()
effective_batch_size = FLAGS.train_batch_size
if FLAGS.use_tpu:
effective_batch_size *= FLAGS.num_tpu_cores
if FLAGS.use_tpu:
if FLAGS.use_bt:
def _input_fn(params):
games = bigtable_input.GameQueue(
FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table)
games_nr = bigtable_input.GameQueue(
FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table + '-nr')
return preprocessing.get_tpu_bt_input_tensors(
games,
games_nr,
params['batch_size'],
number_of_games=FLAGS.window_size,
random_rotation=True)
else:
def _input_fn(params):
return preprocessing.get_tpu_input_tensors(
params['batch_size'],
tf_records,
random_rotation=True)
# Hooks are broken with TPUestimator at the moment.
hooks = []
else:
def _input_fn():
return preprocessing.get_input_tensors(
FLAGS.train_batch_size,
tf_records,
filter_amount=FLAGS.filter_amount,
shuffle_buffer_size=FLAGS.shuffle_buffer_size,
random_rotation=True)
hooks = [UpdateRatioSessionHook(FLAGS.work_dir),
EchoStepCounterHook(output_dir=FLAGS.work_dir)]
steps = FLAGS.steps_to_train
logging.info("Training, steps = %s, batch = %s -> %s examples",
steps or '?', effective_batch_size,
(steps * effective_batch_size) if steps else '?')
if FLAGS.use_bt:
games = bigtable_input.GameQueue(
FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table)
if not games.read_wait_cell():
games.require_fresh_games(20000)
latest_game = games.latest_game_number
index_from = max(latest_game, games.read_wait_cell())
print("== Last game before training:", latest_game, flush=True)
print("== Wait cell:", games.read_wait_cell(), flush=True)
try:
estimator.train(_input_fn, steps=steps, hooks=hooks)
if FLAGS.use_bt:
bigtable_input.set_fresh_watermark(games, index_from,
FLAGS.window_size)
except:
if FLAGS.use_bt:
games.require_fresh_games(0)
raise
def main(argv):
"""Train on examples and export the updated model weights."""
tf_records = argv[1:]
logging.info("Training on %s records: %s to %s",
len(tf_records), tf_records[0], tf_records[-1])
with utils.logged_timer("Training"):
train(*tf_records)
if FLAGS.export_path:
dual_net.export_model(FLAGS.export_path)
if FLAGS.freeze:
if FLAGS.use_tpu:
dual_net.freeze_graph_tpu(FLAGS.export_path)
else:
dual_net.freeze_graph(FLAGS.export_path)
if __name__ == "__main__":
app.run(main)