forked from BUPT-GAMMA/GammaGL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgin.py
136 lines (113 loc) · 5.48 KB
/
gin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
@File : gin.py
@Time : 2022/03/07
@Author : Tianyu Zhao
"""
import os
os.environ['TL_BACKEND'] = 'tensorflow' # set your backend here, default 'tensorflow'
# sys.path.insert(0, os.path.abspath('../')) # adds path2gammagl to execute in command line.
import argparse
import tensorflow as tf
import tensorlayerx as tlx
from gammagl.datasets import TUDataset
from gammagl.loader import DataLoader
from gammagl.models import GINModel
def train(model, train_loader, optimizer, train_weights):
model.set_train()
correct = 0
all_loss = 0
for data in train_loader:
with tf.GradientTape() as tape:
out = model(data.x, data.edge_index, data.batch)
train_loss = tlx.losses.softmax_cross_entropy_with_logits(out, tlx.cast(data.y, dtype=tlx.int32))
all_loss += train_loss
pred = tlx.argmax(out, axis=1)
correct += tlx.reduce_sum(tlx.cast((pred == data.y), dtype=tlx.int64))
grad = tape.gradient(train_loss, train_weights)
optimizer.apply_gradients(zip(grad, train_weights))
print(all_loss)
print(correct / len(train_loader.dataset))
def evaluate(model, loader):
model.set_eval() # disable dropout
correct = 0
for data in loader:
out = model(data.x, data.edge_index, data.batch)
pred = tlx.argmax(out, axis=1)
correct += tlx.reduce_sum(tlx.cast((pred == data.y), dtype=tlx.int64))
return correct / len(loader.dataset)
def main(args):
# load mutag dataset
if str.lower(args.dataset) not in ['mutag']:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
dataset = TUDataset(args.dataset_path, args.dataset, use_node_attr=True)
train_dataset = dataset[:110]
val_dataset = dataset[110:150]
test_dataset = dataset[150:]
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=False)
val_loader = DataLoader(dataset=val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True)
# model = GINModel(name="GIN", input_dim=dataset.num_features, hidden=args.hidden_dim, out_dim=dataset.num_classes)
model = GINModel(input_dim=dataset.num_features,
hidden_dim=args.hidden_dim,
output_dim=dataset.num_classes,
final_dropout=args.final_dropout,
num_layers=args.num_layers,
num_mlp_layers=args.num_mlp_layers,
learn_eps=args.learn_eps,
graph_pooling_type=args.graph_pooling_type,
neighbor_pooling_type=args.neighbor_pooling_type,
name="GIN"
)
optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef)
# metrics = tlx.metrics.Accuracy()
train_weights = model.trainable_weights
'''
data = {
"x": x,
"edge_index": edge_index,
"train_mask": graph.train_mask,
"test_mask": graph.test_mask,
"num_nodes": graph.num_nodes,
}
'''
best_val_acc = 0
print('start training...')
for epoch in range(args.n_epoch):
train(model, train_loader, optimizer, train_weights)
train_acc = evaluate(model=model, loader=train_loader)
val_acc = evaluate(model=model, loader=val_loader)
print("Epoch [{:0>3d}] ".format(epoch + 1),
" train acc: {:.4f}".format(train_acc),
" val acc: {:.4f}".format(val_acc))
# save best model on evaluation set
if val_acc > best_val_acc:
best_val_acc = val_acc
model.save_weights(args.best_model_path + model.name + ".npz", format='npz_dict')
model.load_weights(args.best_model_path + model.name + ".npz", format='npz_dict')
test_acc = evaluate(model=model, loader=test_loader)
print("Test acc: {:.4f}".format(test_acc))
if __name__ == '__main__':
# parameters setting
parser = argparse.ArgumentParser()
# learning
parser.add_argument("--lr", type=float, default=0.2, help="learning rate")
parser.add_argument("--n_epoch", type=int, default=350, help="number of epoch")
parser.add_argument("--final_dropout", type=float, default=0.01, help="final layer dropout")
parser.add_argument("--l2_coef", type=float, default=5e-4, help="l2 loss coeficient")
# graph
parser.add_argument("--graph_pooling_type", type=str, default="sum", choices=["sum", "mean", "max"],
help="type of graph pooling: sum, mean or max")
parser.add_argument("--neighbor_pooling_type", type=str, default="sum", choices=["sum", "mean", "max"],
help="type of neighboring pooling: sum, mean or max")
parser.add_argument("--learn_eps", type=bool, default=False, help="learn the epsilon weighting")
# net
parser.add_argument("--num_layers", type=int, default=5, help="numbers of layers")
parser.add_argument("--num_mlp_layers", type=int, default=2, help="number of MLP layers. 1 means linear model")
parser.add_argument("--hidden_dim", type=int, default=64, help="dimension of hidden layers")
parser.add_argument("--self_loops", type=int, default=1, help="number of graph self-loop")
# dataset
parser.add_argument('--dataset', type=str, default='MUTAG', help='dataset')
parser.add_argument("--dataset_path", type=str, default=r'../', help="path to save dataset")
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
args = parser.parse_args()
main(args)