-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_reactivity.py
100 lines (81 loc) · 5.63 KB
/
train_reactivity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import logging
import os
from torch.utils.data import DataLoader, random_split
from rxntorch.containers.dataset import RxnGraphDataset as RxnGD
from rxntorch.models.reactivity_network import ReactivityNet as RxnNet, ReactivityTrainer as RxnTrainer
from rxntorch.utils import collate_fn
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--dataset_path", type=str, default='./data/', help="train dataset")
parser.add_argument("-c", "--train_dataset", required=True, type=str, help="train dataset")
parser.add_argument("-t", "--test_dataset", type=str, default=None, help="test set")
parser.add_argument("-op", "--output_path", type=str, default='./saved_models/', help="saved model path")
parser.add_argument("-o", "--output_name", required=True, type=str, help="e.g. rxntorch.model")
parser.add_argument("-ds", "--test_split", type=float, default=0.2, help="Ratio of samples to reserve for test data")
parser.add_argument("-b", "--batch_size", type=int, default=20, help="number of batch_size")
parser.add_argument("-tb", "--test_batch_size", type=int, default=None, help="batch size for evaluation")
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
parser.add_argument("-hs", "--hidden", type=int, default=300, help="hidden size of model layers")
parser.add_argument("-l", "--layers", type=int, default=3, help="number of layers")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of the optimizer")
parser.add_argument("-lrd", "--lr_decay", type=float, default=0.9,
help="Decay factor for reducing the learning rate")
parser.add_argument("-lrs", "--lr_steps", type=int, default=10000,
help="Number of steps between learning rate decay")
parser.add_argument("--adam_weight_decay", type=float, default=0.0, help="weight_decay of adam")
parser.add_argument("--adam_beta1", type=float, default=0.9, help="adam first beta value")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="adam second beta value")
parser.add_argument("-gc", "--grad_clip", type=float, default=None, help="value for gradient clipping")
parser.add_argument("-pw", "--pos_weight", type=float, default=None, help="Weights positive samples for imbalance")
parser.add_argument("-w", "--num_workers", type=int, default=4, help="dataloader worker size")
parser.add_argument("--with_cuda", type=bool, default=True, help="training with CUDA: true, or false")
parser.add_argument("--cuda_devices", type=int, nargs='*', default=None, help="CUDA device ids")
parser.add_argument("--log_freq", type=int, default=50, help="printing loss every n iter: setting n")
args = parser.parse_args()
if not os.path.exists(args.output_path):
os.mkdir(args.output_path)
outputfile = os.path.join(args.output_path, args.output_name)
logfile = '.'.join((args.output_name, "log"))
logpath = os.path.join(args.output_path, logfile)
logging.basicConfig(level=logging.INFO, style='{', format="{asctime:s}: {message:s}",
datefmt="%m/%d/%y %H:%M:%S", handlers=(
logging.FileHandler(logpath), logging.StreamHandler()))
logging.info("{:-^80}".format("Dataset"))
dataset = RxnGD(args.train_dataset, path=args.dataset_path)
sample = dataset[0]
afeats_size, bfeats_size, binary_size = (sample["atom_feats"].shape[-1], sample["bond_feats"].shape[-1],
sample["binary_feats"].shape[-1])
n_samples = len(dataset)
n_test = int(n_samples * args.test_split)
n_train = n_samples - n_test
logging.info("Splitting dataset into {:d} samples for training and {:d} samples for testing".format(
n_train, n_test))
train_set, test_set = random_split(dataset, (n_train, n_test))
logging.info("{:-^80}".format("Data loaders"))
logging.info("Batch size: {:d} Workers: {:d} Shuffle per epoch: {}".format(
args.batch_size, args.num_workers, True))
logging.info("Drop incomplete batches: {}".format(True))
train_dataloader = DataLoader(train_set, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True,
collate_fn=collate_fn, drop_last=True)
test_batch_size = args.test_batch_size if args.test_batch_size is not None else args.batch_size
test_dataloader = DataLoader(test_set, batch_size=test_batch_size, num_workers=args.num_workers, collate_fn=collate_fn)
logging.info("{:-^80}".format("Model"))
logging.info("Graph convolution layers: {} Hidden size: {}".format(
args.layers, args.hidden, args.batch_size, args.epochs))
net = RxnNet(depth=args.layers, afeats_size=afeats_size, bfeats_size=bfeats_size,
hidden_size=args.hidden, binary_size=binary_size)
logging.info("Total Parameters: {:,d}".format(sum([p.nelement() for p in net.parameters()])))
logging.info("{:-^80}".format("Trainer"))
logging.info("Optimizer: {} Beta1: {} Beta2: {}".format("Adam", args.adam_beta1, args.adam_beta2))
logging.info("Learning rate: {} Learning rate decay: {} Steps between updates: {}".format(
args.lr, args.lr_decay, args.lr_steps))
logging.info("Weight decay: {} Gradient clipping: {} Positive sample weighting: {}".format(
args.adam_weight_decay, args.grad_clip, args.pos_weight))
trainer = RxnTrainer(net, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq,
grad_clip=args.grad_clip, pos_weight=args.pos_weight, lr_decay=args.lr_decay,
lr_steps=args.lr_steps)
for epoch in range(args.epochs):
trainer.train_epoch(epoch, train_dataloader)
trainer.save(epoch, args.output_name, args.output_path)
trainer.test_epoch(epoch, test_dataloader)