-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
108 lines (86 loc) · 4.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from sklearn.metrics import f1_score
from utils import load_data, EarlyStopping
def score(logits, labels):
_, indices = torch.max(logits, dim=1)
prediction = indices.long().cpu().numpy()
labels = labels.cpu().numpy()
accuracy = (prediction == labels).sum() / len(prediction)
micro_f1 = f1_score(labels, prediction, average='micro')
macro_f1 = f1_score(labels, prediction, average='macro')
return accuracy, micro_f1, macro_f1
def evaluate(model, g, features, labels, mask, loss_func):
model.eval()
with torch.no_grad():
logits = model(g, features)
loss = loss_func(logits[mask], labels[mask])
accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])
return loss, accuracy, micro_f1, macro_f1
def main(args):
# If args['hetero'] is True, g would be a heterogeneous graph.
# Otherwise, it will be a list of homogeneous graphs.
g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
val_mask, test_mask = load_data(args['dataset'])
if hasattr(torch, 'BoolTensor'):
train_mask = train_mask.bool()
val_mask = val_mask.bool()
test_mask = test_mask.bool()
features = features.to(args['device'])
labels = labels.to(args['device'])
train_mask = train_mask.to(args['device'])
val_mask = val_mask.to(args['device'])
test_mask = test_mask.to(args['device'])
if args['hetero']:
from model_hetero import HAN
model = HAN(meta_paths=[['pa', 'ap'], ['pf', 'fp']],
in_size=features.shape[1],
hidden_size=args['hidden_units'],
out_size=num_classes,
num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device'])
g = g.to(args['device'])
else:
from model import HAN
model = HAN(num_meta_paths=len(g),
in_size=features.shape[1],
hidden_size=args['hidden_units'],
out_size=num_classes,
num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device'])
g = [graph.to(args['device']) for graph in g]
stopper = EarlyStopping(patience=args['patience'])
loss_fcn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],
weight_decay=args['weight_decay'])
for epoch in range(args['num_epochs']):
model.train()
logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn)
early_stop = stopper.step(val_loss.data.item(), val_acc, model)
print('Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '
'Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(
epoch + 1, loss.item(), train_micro_f1, train_macro_f1, val_loss.item(), val_micro_f1, val_macro_f1))
if early_stop:
break
stopper.load_checkpoint(model)
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn)
print('Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format(
test_loss.item(), test_micro_f1, test_macro_f1))
if __name__ == '__main__':
import argparse
from utils import setup
parser = argparse.ArgumentParser('HAN')
parser.add_argument('-s', '--seed', type=int, default=1,
help='Random seed')
parser.add_argument('-ld', '--log-dir', type=str, default='results',
help='Dir for saving training results')
parser.add_argument('--hetero', action='store_true',
help='Use metapath coalescing with DGL\'s own dataset')
args = parser.parse_args().__dict__
args = setup(args)
main(args)