-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtf_train.py
112 lines (108 loc) · 5.09 KB
/
tf_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from tf_net import LeelaZeroNet
import tensorflow as tf
from argparse import ArgumentParser
from pathlib import Path
from new_data_pipeline import ARRAY_SHAPES_WITHOUT_BATCH, make_callable
def get_schedule_function(
starting_lr, reduce_lr_every_n_epochs, reduce_lr_factor, min_learning_rate
):
def scheduler(epoch, lr):
num_reductions = int(epoch // reduce_lr_every_n_epochs)
reduction_factor = reduce_lr_factor ** num_reductions
return max(min_learning_rate, starting_lr / reduction_factor)
return scheduler
if __name__ == "__main__":
parser = ArgumentParser()
# These parameters control the net and the training process
parser.add_argument("--num_filters", type=int, default=128)
parser.add_argument("--num_residual_blocks", type=int, default=10)
parser.add_argument("--se_ratio", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--no_constrain_norms", action="store_true")
parser.add_argument("--max_grad_norm", type=float, default=5.6)
parser.add_argument("--mixed_precision", action="store_true")
parser.add_argument("--reduce_lr_every_n_epochs", type=int)
parser.add_argument("--reduce_lr_factor", type=int, default=3)
parser.add_argument("--min_learning_rate", type=float, default=5e-6)
parser.add_argument("--save_dir", type=Path)
parser.add_argument("--tensorboard_dir", type=Path)
# These parameters control the data pipeline
parser.add_argument("--dataset_path", type=Path, required=True)
parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--shuffle_buffer_size", type=int, default=2 ** 19)
parser.add_argument("--skip_factor", type=int, default=32)
parser.add_argument("--optimizer", type=str, choices=["adam", "lion"], default="adam")
# These parameters control the loss calculation. They should not be changed unless you
# know what you're doing, as the loss values you get will not be comparable with other
# people's unless they are kept at the defaults.
parser.add_argument("--policy_loss_weight", type=float, default=1.0)
parser.add_argument("--value_loss_weight", type=float, default=1.6)
parser.add_argument("--moves_left_loss_weight", type=float, default=0.5)
parser.add_argument("--q_ratio", type=float, default=0.2)
args = parser.parse_args()
if args.mixed_precision:
tf.keras.mixed_precision.set_global_policy("mixed_float16")
model = LeelaZeroNet(
num_filters=args.num_filters,
num_residual_blocks=args.num_residual_blocks,
se_ratio=args.se_ratio,
constrain_norms=not args.no_constrain_norms,
policy_loss_weight=args.policy_loss_weight,
value_loss_weight=args.value_loss_weight,
moves_left_loss_weight=args.moves_left_loss_weight,
q_ratio=args.q_ratio,
)
if args.optimizer == "lion":
try:
from lion_tf import Lion
except ImportError:
raise ImportError(
"Lion optimizer not installed. Please install it with "
"pip install git+https://github.com/Rocketknight1/lion-tf.git"
)
optimizer = Lion(args.learning_rate, global_clipnorm=args.max_grad_norm)
else:
optimizer = tf.keras.optimizers.Adam(args.learning_rate, global_clipnorm=args.max_grad_norm)
if args.mixed_precision:
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
callbacks = []
if args.reduce_lr_every_n_epochs is not None:
scheduler = get_schedule_function(
args.learning_rate,
args.reduce_lr_every_n_epochs,
args.reduce_lr_factor,
args.min_learning_rate,
)
callbacks.append(tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=1))
if args.save_dir is not None:
args.save_dir.mkdir(exist_ok=True, parents=True)
checkpoint_path = args.save_dir / "checkpoint"
callbacks.append(
tf.keras.callbacks.experimental.BackupAndRestore(checkpoint_path)
)
if args.tensorboard_dir is not None:
args.tensorboard_dir.mkdir(exist_ok=True, parents=True)
callbacks.append(
tf.keras.callbacks.TensorBoard(
log_dir=args.tensorboard_dir, update_freq="batch", histogram_freq=1
)
)
model.compile(optimizer=optimizer, jit_compile=True)
array_shapes = [
tuple([args.batch_size] + list(shape)) for shape in ARRAY_SHAPES_WITHOUT_BATCH
]
output_signature = tuple(
[tf.TensorSpec(shape=shape, dtype=tf.float32) for shape in array_shapes]
)
callable_gen = make_callable(
chunk_dir=args.dataset_path,
batch_size=args.batch_size,
skip_factor=args.skip_factor,
num_workers=args.num_workers,
shuffle_buffer_size=args.shuffle_buffer_size,
)
dataset = tf.data.Dataset.from_generator(
callable_gen, output_signature=output_signature
).prefetch(tf.data.AUTOTUNE)
model.fit(dataset, epochs=999, steps_per_epoch=8192, callbacks=callbacks)