forked from JiaruiFeng/N2GNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_utils.py
247 lines (210 loc) · 10.5 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""
Utils file for training.
"""
import argparse
import os
import shutil
import time
import torch
import data_utils
import yaml
from torch_geometric.data import Data, Dataset
from sklearn.model_selection import StratifiedKFold
from typing import Callable, Tuple
def args_setup():
r"""Setup argparser.
"""
parser = argparse.ArgumentParser(f'arguments for training and testing')
# common args
parser.add_argument('--save_dir', type=str, default='./save', help='Base directory for saving information.')
parser.add_argument('--config_file', type=str, default=None,
help='Additional configuration file for different dataset and models.')
parser.add_argument('--seed', type=int, default=234, help='Random seed for reproducibility.')
#training args
parser.add_argument('--drop_prob', type=float, default=0.0,
help='Probability of zeroing an activation in dropout models.')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size per GPU.')
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
parser.add_argument('--min_lr', type=float, default=1e-6, help='Minimum learning rate.')
parser.add_argument('--l2_wd', type=float, default=0., help='L2 weight decay.')
parser.add_argument('--num_epochs', type=int, default=500, help='Number of epochs.')
parser.add_argument('--test_eval_interval', type=int, default=10,
help='Interval between validation on test dataset.')
parser.add_argument('--factor', type=float, default=0.5,
help='Factor in the ReduceLROnPlateau learning rate scheduler.')
parser.add_argument('--patience', type=int, default=20,
help='Patience in the ReduceLROnPlateau learning rate scheduler.')
parser.add_argument("--offline", action="store_true", help="If true, save the wandb log offline. "
"Mainly use for debug.")
# data args
parser.add_argument('--policy', default="dense_ego", choices=("dense_ego",
"dense_noego",
"sparse_ego",
"sparse_noego"),
help="Policy of data generation in N2GNN. If dense, keep tuple that don't have any aggregation."
"if ego, further restrict all tuple mast have distance less than or equal to num_hops.")
parser.add_argument('--message_pool', default="plain", choices=("plain", "hierarchical"),
help="message pooling way in N2GNN, if set to plain, pooling all edges together. If set to"
"hierarchical, compute index during preprocessing for hierarchical pooling, must be used"
"with corresponding gnn convolutional layer.")
parser.add_argument('--reprocess', action="store_true", help='Whether to reprocess the dataset')
# model args
parser.add_argument('--gnn_name', type=str, default="GINEM", choices=("GINEC", "GINEM"),
help='Name of base gnn encoder.')
parser.add_argument('--model_name', type=str, default="N2GNN+",
choices=("N2GNN+", "N2GNN"), help='Name of GNN model.')
parser.add_argument('--tuple_size', type=int, default=5, help="Length of tuple in tuple aggregation.")
parser.add_argument('--num_hops', type=int, default=3, help="Number of hop in ego-net selection.")
parser.add_argument("--hidden_channels", type=int, default=96, help="Hidden size of the model.")
parser.add_argument("--inner_channels", type=int, default=32,
help="Inner channel size when doing tuple aggregation. Mainly used for reduce memory cost "
"during the aggregation and gradients saving.")
parser.add_argument('--wo_node_feature', action='store_true',
help='If true, remove node feature from model.')
parser.add_argument('--wo_edge_feature', action='store_true',
help='If true, remove edge feature from model.')
parser.add_argument("--edge_dim", type=int, default=0, help="Number of edge type.")
parser.add_argument("--num_layers", type=int, default=6, help="Number of layer for GNN.")
parser.add_argument("--JK", type=str, default="last",
choices=("sum", "max", "mean", "attention", "last", "concat"), help="Jumping knowledge method.")
parser.add_argument("--residual", action="store_true", help="If ture, use residual connection between each layer.")
parser.add_argument("--eps", type=float, default=0., help="Initial epsilon in GIN.")
parser.add_argument("--train_eps", action="store_true", help="If true, the epsilon is trainable.")
parser.add_argument("--pooling_method", type=str, default="mean", choices=("mean", "sum", "attention"),
help="Pooling method in graph level tasks.")
parser.add_argument('--norm_type', type=str, default="Batch",
choices=("Batch", "Layer", "Instance", "GraphSize", "Pair", "None"),
help="Normalization method in model.")
parser.add_argument('--add_rd', action="store_true", help="If true, additionally add resistance distance into model.")
return parser
def get_exp_name(args: argparse.ArgumentParser, add_task=True) -> str:
"""Get experiment name.
Args:
args (ArgumentParser): Arguments dict from argparser.
"""
arg_list = []
if "task" in args and add_task:
arg_list = [str(args.task)]
arg_list.extend([args.dataset_name,
args.gnn_name,
args.model_name,
str(args.num_layers),
str(args.hidden_channels)
])
arg_list.extend([args.policy, str(args.num_hops)])
if args.residual:
arg_list.append("residual")
if args.add_rd:
arg_list.append("rd")
exp_name = "_".join(arg_list)
return exp_name + f"-{time.strftime('%Y%m%d%H%M%S')}"
def update_args(args: argparse.ArgumentParser, add_task=True) -> argparse.ArgumentParser:
r"""Update argparser given config file.
Args:
args (ArgumentParser): Arguments dict from argparser.
"""
if args.config_file is not None:
with open(args.config_file) as f:
cfg = yaml.safe_load(f)
for key, value in cfg.items():
if isinstance(value, list):
for v in value:
getattr(args, key, []).append(v)
else:
setattr(args, key, value)
args.exp_name = get_exp_name(args, add_task)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
if args.message_pool == "plain":
assert args.gnn_name in ["GINEC", "GINEM"]
elif args.message_pool == "hierarchical":
assert args.gnn_name in ["GINECH", "GINEMH"]
return args
def data_setup(args: argparse.ArgumentParser) -> Tuple[str, Callable, list]:
r"""Setup data for experiment.
Args:
args (ArgumentParser): Arguments dict from argparser.
"""
update_style = args.model_name[:5]
path_arg_list = [f"data/{args.dataset_name}"]
path_arg_list.extend([update_style, str(args.num_hops), args.policy, args.message_pool])
sparse = False
ego_net = True
hierarchical = False
add_rd = False
if args.policy == "sparse_ego":
sparse = True
elif args.policy == "sparse_noego":
sparse = True
ego_net = False
elif args.policy == "dense_noego":
ego_net = False
if args.message_pool == "hierarchical":
hierarchical = True
if args.add_rd:
add_rd = True
pre_transform = data_utils.get_data_transform(update_style,
args.num_hops,
sparse,
ego_net,
hierarchical,
add_rd)
follow_batch = []
path = "_".join(path_arg_list)
if os.path.exists(path + "/processed") and args.reprocess:
shutil.rmtree(path + "/processed")
return path, pre_transform, follow_batch
class PostTransform(object):
r"""Post transformation of dataset.
Args:
wo_node_feature (bool): If true, remove path encoding from model.
wo_edge_feature (bool): If true, remove edge feature from model.
task (int): Specify the task in dataset if it has multiple targets.
"""
def __init__(self,
wo_node_feature: bool,
wo_edge_feature: bool,
task: int = None):
self.wo_node_feature = wo_node_feature
self.wo_edge_feature = wo_edge_feature
self.task = task
def __call__(self,
data: Data) -> Data:
if "x" not in data:
data.x = torch.zeros([data.num_nodes, 1]).long()
if self.wo_edge_feature:
data.edge_attr = None
if self.wo_node_feature:
data.x = torch.zeros_like(data.x)
if self.task is not None:
data.y = data.y[:, self.task]
return data
def k_fold(dataset: Dataset,
folds: int,
seed: int) -> Tuple[list, list, list]:
r"""Dataset split for K-fold cross-validation.
Args:
dataset (Dataset): The dataset to be split.
folds (int): Number of folds.
seed (int): Random seed.
"""
skf = StratifiedKFold(folds, shuffle=True, random_state=seed)
test_indices, train_indices = [], []
for _, idx in skf.split(torch.zeros(len(dataset)), dataset.data.y[dataset.indices()]):
test_indices.append(torch.from_numpy(idx).long())
val_indices = [test_indices[i - 1] for i in range(folds)]
for i in range(folds):
train_mask = torch.ones(len(dataset)).long()
train_mask[test_indices[i]] = 0
train_mask[val_indices[i]] = 0
train_indices.append(train_mask.nonzero().view(-1))
return train_indices, test_indices, val_indices
def get_seed(seed=234) -> int:
r"""Return random seed based on current time.
Args:
seed (int): base seed.
"""
t = int(time.time() * 1000.0)
seed = seed + ((t & 0xff000000) >> 24) + ((t & 0x00ff0000) >> 8) + ((t & 0x0000ff00) << 8) + ((t & 0x000000ff) << 24)
return seed