-
Notifications
You must be signed in to change notification settings - Fork 73
/
Copy pathtrain.py
126 lines (93 loc) · 4.14 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn.functional as F
from torch.optim import Adam
from sklearn.metrics import roc_auc_score, average_precision_score
import scipy.sparse as sp
import numpy as np
import os
import time
from input_data import load_data
from preprocessing import *
import args
import model
# Train on CPU (hide GPU) due to memory constraints
os.environ['CUDA_VISIBLE_DEVICES'] = ""
adj, features = load_data(args.dataset)
# Store original adjacency matrix (without diagonal entries) for later
adj_orig = adj
adj_orig = adj_orig - sp.dia_matrix((adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
adj_orig.eliminate_zeros()
adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
adj = adj_train
# Some preprocessing
adj_norm = preprocess_graph(adj)
num_nodes = adj.shape[0]
features = sparse_to_tuple(features.tocoo())
num_features = features[2][1]
features_nonzero = features[1].shape[0]
# Create Model
pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
norm = adj.shape[0] * adj.shape[0] / float((adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
adj_label = adj_train + sp.eye(adj_train.shape[0])
adj_label = sparse_to_tuple(adj_label)
adj_norm = torch.sparse.FloatTensor(torch.LongTensor(adj_norm[0].T),
torch.FloatTensor(adj_norm[1]),
torch.Size(adj_norm[2]))
adj_label = torch.sparse.FloatTensor(torch.LongTensor(adj_label[0].T),
torch.FloatTensor(adj_label[1]),
torch.Size(adj_label[2]))
features = torch.sparse.FloatTensor(torch.LongTensor(features[0].T),
torch.FloatTensor(features[1]),
torch.Size(features[2]))
weight_mask = adj_label.to_dense().view(-1) == 1
weight_tensor = torch.ones(weight_mask.size(0))
weight_tensor[weight_mask] = pos_weight
# init model and optimizer
model = getattr(model,args.model)(adj_norm)
optimizer = Adam(model.parameters(), lr=args.learning_rate)
def get_scores(edges_pos, edges_neg, adj_rec):
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# Predict on test set of edges
preds = []
pos = []
for e in edges_pos:
# print(e)
# print(adj_rec[e[0], e[1]])
preds.append(sigmoid(adj_rec[e[0], e[1]].item()))
pos.append(adj_orig[e[0], e[1]])
preds_neg = []
neg = []
for e in edges_neg:
preds_neg.append(sigmoid(adj_rec[e[0], e[1]].data))
neg.append(adj_orig[e[0], e[1]])
preds_all = np.hstack([preds, preds_neg])
labels_all = np.hstack([np.ones(len(preds)), np.zeros(len(preds_neg))])
roc_score = roc_auc_score(labels_all, preds_all)
ap_score = average_precision_score(labels_all, preds_all)
return roc_score, ap_score
def get_acc(adj_rec, adj_label):
labels_all = adj_label.to_dense().view(-1).long()
preds_all = (adj_rec > 0.5).view(-1).long()
accuracy = (preds_all == labels_all).sum().float() / labels_all.size(0)
return accuracy
# train model
for epoch in range(args.num_epoch):
t = time.time()
A_pred = model(features)
optimizer.zero_grad()
loss = log_lik = norm*F.binary_cross_entropy(A_pred.view(-1), adj_label.to_dense().view(-1), weight = weight_tensor)
if args.model == 'VGAE':
kl_divergence = 0.5/ A_pred.size(0) * (1 + 2*model.logstd - model.mean**2 - torch.exp(model.logstd)**2).sum(1).mean()
loss -= kl_divergence
loss.backward()
optimizer.step()
train_acc = get_acc(A_pred,adj_label)
val_roc, val_ap = get_scores(val_edges, val_edges_false, A_pred)
print("Epoch:", '%04d' % (epoch + 1), "train_loss=", "{:.5f}".format(loss.item()),
"train_acc=", "{:.5f}".format(train_acc), "val_roc=", "{:.5f}".format(val_roc),
"val_ap=", "{:.5f}".format(val_ap),
"time=", "{:.5f}".format(time.time() - t))
test_roc, test_ap = get_scores(test_edges, test_edges_false, A_pred)
print("End of training!", "test_roc=", "{:.5f}".format(test_roc),
"test_ap=", "{:.5f}".format(test_ap))