-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomp9.py
318 lines (244 loc) · 11.1 KB
/
comp9.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
'''
gc, mol + gcn/gin
'''
import torch
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder
from torch_geometric.utils import degree
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from torch_geometric.loader import DataLoader
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from utils import load_dataset, init_layers
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
### GIN convolution along the graph structure
class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr = "add")
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim),
torch.nn.BatchNorm1d(emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(emb_dim, emb_dim))
self.eps = torch.nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr)
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
### GCN convolution along the graph structure
class GCNConv(MessagePassing):
def __init__(self, emb_dim):
super(GCNConv, self).__init__(aggr='add')
self.lin = torch.nn.Linear(emb_dim, emb_dim)
self.root_emb = torch.nn.Embedding(1, emb_dim)
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, x, edge_index, edge_attr):
x = self.lin(x)
edge_embedding = self.bond_encoder(edge_attr)
row, col = edge_index
#edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
deg = degree(row, x.size(0), dtype = x.dtype) + 1
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)
def message(self, x_j, edge_attr, norm):
return norm.view(-1, 1) * F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
### GNN to generate node embedding
class GNN_node(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layer, emb_dim, drop_ratio = 0.5, JK = "last", residual = False, gnn_type = 'gin'):
'''
emb_dim (int): node embedding dimensionality
num_layer (int): number of GNN message passing layers
'''
super(GNN_node, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
### add residual connection or not
self.residual = residual
if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
###List of GNNs
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for _ in range(num_layer):
if gnn_type == 'gin':
self.convs.append(GINConv(emb_dim))
elif gnn_type == 'gcn':
self.convs.append(GCNConv(emb_dim))
else:
raise ValueError('Undefined GNN type called {}'.format(gnn_type))
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, batched_data):
x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
### computing input node embedding
h_list = [self.atom_encoder(x)]
for layer in range(self.num_layer):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layer - 1:
#remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training = self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
### Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layer + 1):
node_representation += h_list[layer]
return node_representation
class GNN(torch.nn.Module):
def __init__(self, num_tasks, num_layer = 5, emb_dim = 300,
gnn_type = 'gin', residual = False, drop_ratio = 0.5, JK = "last", graph_pooling = "mean"):
'''
num_tasks (int): number of labels to be predicted
virtual_node (bool): whether to add virtual node or not
'''
super(GNN, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
self.graph_pooling = graph_pooling
if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
### GNN to generate node embeddings
self.gnn_node = GNN_node(num_layer, emb_dim, JK = JK, drop_ratio = drop_ratio, residual = residual, gnn_type = gnn_type)
### Pooling function to generate whole-graph embeddings
if self.graph_pooling == "sum":
self.pool = global_add_pool
elif self.graph_pooling == "mean":
self.pool = global_mean_pool
elif self.graph_pooling == "max":
self.pool = global_max_pool
elif self.graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, 1)))
elif self.graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps = 2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = torch.nn.Linear(2*self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, batched_data):
h_node = self.gnn_node(batched_data)
h_graph = self.pool(h_node, batched_data.batch)
return self.graph_pred_linear(h_graph)
def train(model, device, loader, optimizer, task_type):
model.train()
cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()
for step, batch in enumerate(tqdm(loader, desc="Iteration")):
batch = batch.to(device)
if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
pass
else:
pred = model(batch)
optimizer.zero_grad()
## ignore nan targets (unlabeled) when computing training loss.
is_labeled = batch.y == batch.y
if "classification" in task_type:
loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
else:
loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
loss.backward()
optimizer.step()
def eval(model, device, loader, evaluator):
model.eval()
y_true = []
y_pred = []
for step, batch in enumerate(tqdm(loader, desc="Iteration")):
batch = batch.to(device)
if batch.x.shape[0] == 1:
pass
else:
with torch.no_grad():
pred = model(batch)
y_true.append(batch.y.view(pred.shape).detach().cpu())
y_pred.append(pred.detach().cpu())
y_true = torch.cat(y_true, dim = 0).numpy()
y_pred = torch.cat(y_pred, dim = 0).numpy()
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)
def pipe(gName='molhiv', gnn='gcn', epochs=200, lr=1e-3, dropout=0.5, wd=0.0, init='virgo', num_g_samples=-1,
num_workers=1, batch_size=32, device='cuda'):
dataset = PygGraphPropPredDataset(name = f'ogbg-{gName}', root='/mnt/jiahanli/datasets')
split_idx = dataset.get_idx_split()
### automatic evaluator. takes dataset name as input
evaluator = Evaluator(f'ogbg-{gName}')
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=batch_size, shuffle=True, num_workers = num_workers)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=256, shuffle=False, num_workers = num_workers)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=256, shuffle=False, num_workers = num_workers)
g, _ = load_dataset(f'{gName}', num_samples=num_g_samples, rand_feat=False, drop_feat=False, device=device) # g here is only used for calculating C2
model = GNN(gnn_type = gnn, num_tasks = dataset.num_tasks, num_layer = 5, emb_dim = 300, drop_ratio=dropout)
init_layers(g, gnn, model.gnn_node.convs, init, num_classes=300, pyg=True)
del g
torch.cuda.empty_cache()
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
valid_curve = []
test_curve = []
train_curve = []
for epoch in range(epochs):
print("=====Epoch {}".format(epoch))
print('Training...')
train(model, device, train_loader, optimizer, dataset.task_type)
print('Evaluating...')
train_perf = eval(model, device, train_loader, evaluator)[dataset.eval_metric]
valid_perf = eval(model, device, valid_loader, evaluator)[dataset.eval_metric]
test_perf = eval(model, device, test_loader, evaluator)[dataset.eval_metric]
print(f"Train: {100 * train_perf:.2f}% | Validation: {100 * valid_perf:.2f}% | Test: {100 * test_perf:.2f}%")
train_curve.append(train_perf)
valid_curve.append(valid_perf)
test_curve.append(test_perf)
if 'classification' in dataset.task_type:
best_val_epoch = np.argmax(np.array(valid_curve))
else:
best_val_epoch = np.argmin(np.array(valid_curve))
print('------')
print(f'Final Train:{100 * train_curve[best_val_epoch]:.2f}% | '
f'Valid:{100 * valid_curve[best_val_epoch]:.2f}% | '
f'Test:{100 * test_curve[best_val_epoch]:.2f}%')
train_curve.append(train_curve[best_val_epoch])
valid_curve.append(valid_curve[best_val_epoch])
test_curve.append(test_curve[best_val_epoch])
return train_curve, valid_curve, test_curve
def run_test():
searchSpace = {
"gName": 'molhiv',
"epochs":1,
"num_workers":4,
"batch_size":64,
"num_g_samples":10,
"dropout":0.5,
"wd":5e-6
}
print(searchSpace)
pipe(**searchSpace)
if __name__ == "__main__":
run_test()
print(1)