From ac20c1a4632995f5f31d37060f4a990e273b140b Mon Sep 17 00:00:00 2001 From: yang_starry_sky Date: Sat, 20 Aug 2022 17:13:49 +0800 Subject: [PATCH 01/25] Create ogb_node.py support ogbn dataset --- gammagl/datasets/ogb_node.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 gammagl/datasets/ogb_node.py diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py new file mode 100644 index 00000000..a0aeed70 --- /dev/null +++ b/gammagl/datasets/ogb_node.py @@ -0,0 +1,16 @@ +from ogb.nodeproppred import NodePropPredDataset +from gammagl.data import Graph +def ogbn_dataset(name, root = 'dataset', meta_dict = None): + ''' + - name (str): name of the dataset + - root (str): root directory to store the dataset folder + - meta_dict: dictionary that stores all the meta-information about data. Default is None, + but when something is passed, it uses its information. Useful for debugging for external contributers. + ''' + dataset=NodePropPredDataset(name=name,root=root,meta_dict=meta_dict) + dataset=dataset[0] + data = Graph(edge_index=dataset[0]['edge_index'], x=dataset[0]['node_feat'],y=dataset[1]) + data.num_nodes=dataset[0]['num_nodes'] + data.edge_attr = dataset[0]['edge_feat'] + data.tensor() + return data \ No newline at end of file From d4ef8220af874536d279677c481ba5595a0f8d2f Mon Sep 17 00:00:00 2001 From: yang_starry_sky Date: Thu, 25 Aug 2022 14:59:32 +0800 Subject: [PATCH 02/25] support ogb node dataset --- gammagl/datasets/master.csv | 16 + gammagl/datasets/ogb_node.py | 218 +++++++++++- gammagl/io/read_ogb_raw.py | 653 +++++++++++++++++++++++++++++++++++ gammagl/utils/ogb_url.py | 91 +++++ 4 files changed, 963 insertions(+), 15 deletions(-) create mode 100644 gammagl/datasets/master.csv create mode 100644 gammagl/io/read_ogb_raw.py create mode 100644 gammagl/utils/ogb_url.py diff --git a/gammagl/datasets/master.csv b/gammagl/datasets/master.csv new file mode 100644 index 00000000..026ded02 --- /dev/null +++ b/gammagl/datasets/master.csv @@ -0,0 +1,16 @@ +,ogbn-proteins,ogbn-products,ogbn-arxiv,ogbn-mag,ogbn-papers100M +num tasks,112,1,1,1,1 +num classes,2,47,40,349,172 +eval metric,rocauc,acc,acc,acc,acc +task type,binary classification,multiclass classification,multiclass classification,multiclass classification,multiclass classification +download_name,proteins,products,arxiv,mag,papers100M-bin +version,1,1,1,2,1 +url,http://snap.stanford.edu/ogb/data/nodeproppred/proteins.zip,http://snap.stanford.edu/ogb/data/nodeproppred/products.zip,http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip,http://snap.stanford.edu/ogb/data/nodeproppred/mag.zip,http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip +add_inverse_edge,True,True,False,False,False +has_node_attr,False,True,True,True,True +has_edge_attr,True,False,False,False,False +split,species,sales_ranking,time,time,time +additional node files,node_species,None,node_year,node_year,node_year +additional edge files,None,None,None,edge_reltype,None +is hetero,False,False,False,True,False +binary,False,False,False,False,True diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index a0aeed70..b3c3f290 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -1,16 +1,204 @@ -from ogb.nodeproppred import NodePropPredDataset from gammagl.data import Graph -def ogbn_dataset(name, root = 'dataset', meta_dict = None): - ''' - - name (str): name of the dataset - - root (str): root directory to store the dataset folder - - meta_dict: dictionary that stores all the meta-information about data. Default is None, - but when something is passed, it uses its information. Useful for debugging for external contributers. - ''' - dataset=NodePropPredDataset(name=name,root=root,meta_dict=meta_dict) - dataset=dataset[0] - data = Graph(edge_index=dataset[0]['edge_index'], x=dataset[0]['node_feat'],y=dataset[1]) - data.num_nodes=dataset[0]['num_nodes'] - data.edge_attr = dataset[0]['edge_feat'] - data.tensor() - return data \ No newline at end of file +import pandas as pd +import shutil, os +import os.path as osp +from gammagl.utils.ogb_url import decide_download, download_url, extract_zip +from gammagl.io.read_ogb_raw import read_csv_graph_raw, read_csv_heterograph_raw, \ + read_node_label_hetero, read_nodesplitidx_split_hetero, \ + read_binary_graph_raw, read_binary_heterograph_raw + +import torch +import numpy as np + +class OgbNodeDataset(object): + def __init__(self, name, root='dataset', meta_dict=None): + ''' + - name (str): name of the dataset + - root (str): root directory to store the dataset folder + - meta_dict: dictionary that stores all the meta-information about data. Default is None, + but when something is passed, it uses its information. Useful for debugging for external contributers. + ''' + + self.name = name ## original name, e.g., ogbn-proteins + + if meta_dict is None: + self.dir_name = '_'.join(name.split('-')) ## replace hyphen with underline, e.g., ogbn_proteins + self.original_root = root + self.root = osp.join(root, self.dir_name) + + master = pd.read_csv(os.path.join(os.path.dirname(__file__), 'master.csv'), index_col=0) + if not self.name in master: + error_mssg = 'Invalid dataset name {}.\n'.format(self.name) + error_mssg += 'Available datasets are as follows:\n' + error_mssg += '\n'.join(master.keys()) + raise ValueError(error_mssg) + self.meta_info = master[self.name] + + else: + self.dir_name = meta_dict['dir_path'] + self.original_root = '' + self.root = meta_dict['dir_path'] + self.meta_info = meta_dict + + # check version + # First check whether the dataset has been already downloaded or not. + # If so, check whether the dataset version is the newest or not. + # If the dataset is not the newest version, notify this to the user. + if osp.isdir(self.root) and ( + not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))): + print(self.name + ' has been updated.') + if input('Will you update the dataset now? (y/N)\n').lower() == 'y': + shutil.rmtree(self.root) + + self.download_name = self.meta_info['download_name'] ## name of downloaded file, e.g., tox21 + + self.num_tasks = int(self.meta_info['num tasks']) + self.task_type = self.meta_info['task type'] + self.eval_metric = self.meta_info['eval metric'] + self.num_classes = int(self.meta_info['num classes']) + self.is_hetero = self.meta_info['is hetero'] == 'True' + self.binary = self.meta_info['binary'] == 'True' + + super(OgbNodeDataset, self).__init__() + + self.pre_process() + + def pre_process(self): + processed_dir = osp.join(self.root, 'processed') + pre_processed_file_path = osp.join(processed_dir, 'data_processed') + + if osp.exists(pre_processed_file_path): + loaded_dict = torch.load(pre_processed_file_path) + self.graph, self.labels = loaded_dict['graph'], loaded_dict['labels'] + self.data = Graph(edge_index=self.graph['edge_index'], x=self.graph['node_feat'], y=self.labels) + self.data.num_nodes = self.graph['num_nodes'] + self.data.edge_attr = self.graph['edge_feat'] + self.data.tensor() + + else: + ### check download + if self.binary: + # npz format + has_necessary_file_simple = osp.exists(osp.join(self.root, 'raw', 'data.npz')) and (not self.is_hetero) + has_necessary_file_hetero = osp.exists( + osp.join(self.root, 'raw', 'edge_index_dict.npz')) and self.is_hetero + else: + # csv file + has_necessary_file_simple = osp.exists(osp.join(self.root, 'raw', 'edge.csv.gz')) and ( + not self.is_hetero) + has_necessary_file_hetero = osp.exists( + osp.join(self.root, 'raw', 'triplet-type-list.csv.gz')) and self.is_hetero + + has_necessary_file = has_necessary_file_simple or has_necessary_file_hetero + + if not has_necessary_file: + url = self.meta_info['url'] + if decide_download(url): + path = download_url(url, self.original_root) + extract_zip(path, self.original_root) + os.unlink(path) + # delete folder if there exists + try: + shutil.rmtree(self.root) + except: + pass + shutil.move(osp.join(self.original_root, self.download_name), self.root) + else: + print('Stop download.') + exit(-1) + + raw_dir = osp.join(self.root, 'raw') + + ### pre-process and save + add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True' + + if self.meta_info['additional node files'] == 'None': + additional_node_files = [] + else: + additional_node_files = self.meta_info['additional node files'].split(',') + + if self.meta_info['additional edge files'] == 'None': + additional_edge_files = [] + else: + additional_edge_files = self.meta_info['additional edge files'].split(',') + + if self.is_hetero: + if self.binary: + self.graph = read_binary_heterograph_raw(raw_dir, add_inverse_edge=add_inverse_edge)[ + 0] # only a single graph + + tmp = np.load(osp.join(raw_dir, 'node-label.npz')) + self.labels = {} + for key in list(tmp.keys()): + self.labels[key] = tmp[key] + del tmp + else: + self.graph = read_csv_heterograph_raw(raw_dir, add_inverse_edge=add_inverse_edge, + additional_node_files=additional_node_files, + additional_edge_files=additional_edge_files)[ + 0] # only a single graph + self.labels = read_node_label_hetero(raw_dir) + + else: + if self.binary: + self.graph = read_binary_graph_raw(raw_dir, add_inverse_edge=add_inverse_edge)[ + 0] # only a single graph + self.labels = np.load(osp.join(raw_dir, 'node-label.npz'))['node_label'] + else: + self.graph = read_csv_graph_raw(raw_dir, add_inverse_edge=add_inverse_edge, + additional_node_files=additional_node_files, + additional_edge_files=additional_edge_files)[ + 0] # only a single graph + self.labels = pd.read_csv(osp.join(raw_dir, 'node-label.csv.gz'), compression='gzip', + header=None).values + + print('Saving...') + self.data = Graph(edge_index=self.graph['edge_index'], x=self.graph['node_feat'], y=self.labels) + self.data.num_nodes = self.graph['num_nodes'] + self.data.edge_attr = self.graph['edge_feat'] + self.data.tensor() + torch.save(self.data, pre_processed_file_path, pickle_protocol=4) + + def get_idx_split(self, split_type=None): + if split_type is None: + split_type = self.meta_info['split'] + + path = osp.join(self.root, 'split', split_type) + + # short-cut if split_dict.pt exists + if os.path.isfile(os.path.join(path, 'split_dict.pt')): + return torch.load(os.path.join(path, 'split_dict.pt')) + + if self.is_hetero: + train_idx_dict, valid_idx_dict, test_idx_dict = read_nodesplitidx_split_hetero(path) + for nodetype in train_idx_dict.keys(): + train_idx_dict[nodetype] = train_idx_dict[nodetype] + valid_idx_dict[nodetype] = valid_idx_dict[nodetype] + test_idx_dict[nodetype] = test_idx_dict[nodetype] + + return {'train': train_idx_dict, 'valid': valid_idx_dict, 'test': test_idx_dict} + + else: + train_idx = pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header=None).values.T[0] + valid_idx = pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header=None).values.T[0] + test_idx = pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header=None).values.T[0] + + return {'train': train_idx, 'valid': valid_idx, 'test': test_idx} + + def __getitem__(self, idx): + assert idx == 0, 'This dataset has only one graph' + return self.data + + def __len__(self): + return 1 + + def __repr__(self): # pragma: no cover + return '{}({})'.format(self.__class__.__name__, len(self)) + + +if __name__ == '__main__': + dataset = OgbNodeDataset(name='ogbn-mag') + print(dataset.num_classes) + split_index = dataset.get_idx_split() + print(dataset[0]) + print(split_index) \ No newline at end of file diff --git a/gammagl/io/read_ogb_raw.py b/gammagl/io/read_ogb_raw.py new file mode 100644 index 00000000..b76e099b --- /dev/null +++ b/gammagl/io/read_ogb_raw.py @@ -0,0 +1,653 @@ +import pandas as pd +import os.path as osp +import os +import numpy as np +from gammagl.utils.ogb_url import decide_download, download_url, extract_zip +from tqdm import tqdm + +### reading raw files from a directory. +### for homogeneous graph +def read_csv_graph_raw(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = []): + ''' + raw_dir: path to the raw directory + add_inverse_edge (bool): whether to add inverse edge or not + + return: graph_list, which is a list of graphs. + Each graph is a dictionary, containing edge_index, edge_feat, node_feat, and num_nodes + edge_feat and node_feat are optional: if a graph does not contain it, we will have None. + + additional_node_files and additional_edge_files must be in the raw directory. + - The name should be {additional_node_file, additional_edge_file}.csv.gz + - The length should be num_nodes or num_edges + + additional_node_files must start from 'node_' + additional_edge_files must start from 'edge_' + + + ''' + + print('Loading necessary files...') + print('This might take a while.') + # loading necessary files + try: + edge = pd.read_csv(osp.join(raw_dir, 'edge.csv.gz'), compression='gzip', header = None).values.T.astype(np.int64) # (2, num_edge) numpy array + num_node_list = pd.read_csv(osp.join(raw_dir, 'num-node-list.csv.gz'), compression='gzip', header = None).astype(np.int64)[0].tolist() # (num_graph, ) python list + num_edge_list = pd.read_csv(osp.join(raw_dir, 'num-edge-list.csv.gz'), compression='gzip', header = None).astype(np.int64)[0].tolist() # (num_edge, ) python list + + except FileNotFoundError: + raise RuntimeError('No necessary file') + + try: + node_feat = pd.read_csv(osp.join(raw_dir, 'node-feat.csv.gz'), compression='gzip', header = None).values + if 'int' in str(node_feat.dtype): + node_feat = node_feat.astype(np.int64) + else: + # float + node_feat = node_feat.astype(np.float32) + except FileNotFoundError: + node_feat = None + + try: + edge_feat = pd.read_csv(osp.join(raw_dir, 'edge-feat.csv.gz'), compression='gzip', header = None).values + if 'int' in str(edge_feat.dtype): + edge_feat = edge_feat.astype(np.int64) + else: + #float + edge_feat = edge_feat.astype(np.float32) + + except FileNotFoundError: + edge_feat = None + + + additional_node_info = {} + for additional_file in additional_node_files: + assert(additional_file[:5] == 'node_') + + # hack for ogbn-proteins + if additional_file == 'node_species' and osp.exists(osp.join(raw_dir, 'species.csv.gz')): + os.rename(osp.join(raw_dir, 'species.csv.gz'), osp.join(raw_dir, 'node_species.csv.gz')) + + temp = pd.read_csv(osp.join(raw_dir, additional_file + '.csv.gz'), compression='gzip', header = None).values + + if 'int' in str(temp.dtype): + additional_node_info[additional_file] = temp.astype(np.int64) + else: + # float + additional_node_info[additional_file] = temp.astype(np.float32) + + additional_edge_info = {} + for additional_file in additional_edge_files: + assert(additional_file[:5] == 'edge_') + temp = pd.read_csv(osp.join(raw_dir, additional_file + '.csv.gz'), compression='gzip', header = None).values + + if 'int' in str(temp.dtype): + additional_edge_info[additional_file] = temp.astype(np.int64) + else: + # float + additional_edge_info[additional_file] = temp.astype(np.float32) + + + graph_list = [] + num_node_accum = 0 + num_edge_accum = 0 + + print('Processing graphs...') + for num_node, num_edge in tqdm(zip(num_node_list, num_edge_list), total=len(num_node_list)): + + graph = dict() + + ### handling edge + if add_inverse_edge: + ### duplicate edge + duplicated_edge = np.repeat(edge[:, num_edge_accum:num_edge_accum+num_edge], 2, axis = 1) + duplicated_edge[0, 1::2] = duplicated_edge[1,0::2] + duplicated_edge[1, 1::2] = duplicated_edge[0,0::2] + + graph['edge_index'] = duplicated_edge + + if edge_feat is not None: + graph['edge_feat'] = np.repeat(edge_feat[num_edge_accum:num_edge_accum+num_edge], 2, axis = 0) + else: + graph['edge_feat'] = None + + for key, value in additional_edge_info.items(): + graph[key] = np.repeat(value[num_edge_accum:num_edge_accum+num_edge], 2, axis = 0) + + else: + graph['edge_index'] = edge[:, num_edge_accum:num_edge_accum+num_edge] + + if edge_feat is not None: + graph['edge_feat'] = edge_feat[num_edge_accum:num_edge_accum+num_edge] + else: + graph['edge_feat'] = None + + for key, value in additional_edge_info.items(): + graph[key] = value[num_edge_accum:num_edge_accum+num_edge] + + num_edge_accum += num_edge + + ### handling node + if node_feat is not None: + graph['node_feat'] = node_feat[num_node_accum:num_node_accum+num_node] + else: + graph['node_feat'] = None + + for key, value in additional_node_info.items(): + graph[key] = value[num_node_accum:num_node_accum+num_node] + + + graph['num_nodes'] = num_node + num_node_accum += num_node + + graph_list.append(graph) + + return graph_list + + +### reading raw files from a directory. +### npz ver +### for homogeneous graph +def read_binary_graph_raw(raw_dir, add_inverse_edge = False): + ''' + raw_dir: path to the raw directory + add_inverse_edge (bool): whether to add inverse edge or not + + return: graph_list, which is a list of graphs. + Each graph is a dictionary, containing edge_index, edge_feat, node_feat, and num_nodes + edge_feat and node_feat are optional: if a graph does not contain it, we will have None. + + raw_dir must contain data.npz + - edge_index + - num_nodes_list + - num_edges_list + - node_** (optional, node_feat is the default node features) + - edge_** (optional, edge_feat is the default edge features) + ''' + + if add_inverse_edge: + raise RuntimeError('add_inverse_edge is depreciated in read_binary') + + print('Loading necessary files...') + print('This might take a while.') + data_dict = np.load(osp.join(raw_dir, 'data.npz')) + + edge_index = data_dict['edge_index'] + num_nodes_list = data_dict['num_nodes_list'] + num_edges_list = data_dict['num_edges_list'] + + # storing node and edge features + node_dict = {} + edge_dict = {} + + for key in list(data_dict.keys()): + if key == 'edge_index' or key == 'num_nodes_list' or key == 'num_edges_list': + continue + + if key[:5] == 'node_': + node_dict[key] = data_dict[key] + elif key[:5] == 'edge_': + edge_dict[key] = data_dict[key] + else: + raise RuntimeError(f"Keys in graph object should start from either \'node_\' or \'edge_\', but found \'{key}\'.") + + graph_list = [] + num_nodes_accum = 0 + num_edges_accum = 0 + + print('Processing graphs...') + for num_nodes, num_edges in tqdm(zip(num_nodes_list, num_edges_list), total=len(num_nodes_list)): + + graph = dict() + + graph['edge_index'] = edge_index[:, num_edges_accum:num_edges_accum+num_edges] + + for key, feat in edge_dict.items(): + graph[key] = feat[num_edges_accum:num_edges_accum+num_edges] + + if 'edge_feat' not in graph: + graph['edge_feat'] = None + + for key, feat in node_dict.items(): + graph[key] = feat[num_nodes_accum:num_nodes_accum+num_nodes] + + if 'node_feat' not in graph: + graph['node_feat'] = None + + graph['num_nodes'] = num_nodes + + num_edges_accum += num_edges + num_nodes_accum += num_nodes + + graph_list.append(graph) + + return graph_list + + +### reading raw files from a directory. +### for heterogeneous graph +def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = []): + ''' + raw_dir: path to the raw directory + add_inverse_edge (bool): whether to add inverse edge or not + + return: graph_list, which is a list of heterogeneous graphs. + Each graph is a dictionary, containing the following keys: + - edge_index_dict + edge_index_dict[(head, rel, tail)] = edge_index for (head, rel, tail) + + - edge_feat_dict + edge_feat_dict[(head, rel, tail)] = edge_feat for (head, rel, tail) + + - node_feat_dict + node_feat_dict[nodetype] = node_feat for nodetype + + - num_nodes_dict + num_nodes_dict[nodetype] = num_nodes for nodetype + + * edge_feat_dict and node_feat_dict are optional: if a graph does not contain it, we will simply have None. + + We can also have additional node/edge features. For example, + - edge_reltype_dict + edge_reltype_dict[(head, rel, tail)] = edge_reltype for (head, rel, tail) + + - node_year_dict + node_year_dict[nodetype] = node_year + + ''' + + print('Loading necessary files...') + print('This might take a while.') + + # loading necessary files + try: + num_node_df = pd.read_csv(osp.join(raw_dir, 'num-node-dict.csv.gz'), compression='gzip') + num_node_dict = {nodetype: num_node_df[nodetype].astype(np.int64).tolist() for nodetype in num_node_df.keys()} + nodetype_list = sorted(list(num_node_dict.keys())) + + ## read edge_dict, num_edge_dict + triplet_df = pd.read_csv(osp.join(raw_dir, 'triplet-type-list.csv.gz'), compression='gzip', header = None) + triplet_list = sorted([(head, relation, tail) for head, relation, tail in zip(triplet_df[0].tolist(), triplet_df[1].tolist(), triplet_df[2].tolist())]) + + edge_dict = {} + num_edge_dict = {} + + for triplet in triplet_list: + subdir = osp.join(raw_dir, 'relations', '___'.join(triplet)) + + edge_dict[triplet] = pd.read_csv(osp.join(subdir, 'edge.csv.gz'), compression='gzip', header = None).values.T.astype(np.int64) + num_edge_dict[triplet] = pd.read_csv(osp.join(subdir, 'num-edge-list.csv.gz'), compression='gzip', header = None).astype(np.int64)[0].tolist() + + # check the number of graphs coincide + assert(len(num_node_dict[nodetype_list[0]]) == len(num_edge_dict[triplet_list[0]])) + + num_graphs = len(num_node_dict[nodetype_list[0]]) + + except FileNotFoundError: + raise RuntimeError('No necessary file') + + node_feat_dict = {} + for nodetype in nodetype_list: + subdir = osp.join(raw_dir, 'node-feat', nodetype) + + try: + node_feat = pd.read_csv(osp.join(subdir, 'node-feat.csv.gz'), compression='gzip', header = None).values + if 'int' in str(node_feat.dtype): + node_feat = node_feat.astype(np.int64) + else: + # float + node_feat = node_feat.astype(np.float32) + + node_feat_dict[nodetype] = node_feat + except FileNotFoundError: + pass + + edge_feat_dict = {} + for triplet in triplet_list: + subdir = osp.join(raw_dir, 'relations', '___'.join(triplet)) + + try: + edge_feat = pd.read_csv(osp.join(subdir, 'edge-feat.csv.gz'), compression='gzip', header = None).values + if 'int' in str(edge_feat.dtype): + edge_feat = edge_feat.astype(np.int64) + else: + #float + edge_feat = edge_feat.astype(np.float32) + + edge_feat_dict[triplet] = edge_feat + + except FileNotFoundError: + pass + + + additional_node_info = {} + # e.g., additional_node_info['node_year'] = node_feature_dict for node_year + for additional_file in additional_node_files: + additional_feat_dict = {} + assert(additional_file[:5] == 'node_') + + for nodetype in nodetype_list: + subdir = osp.join(raw_dir, 'node-feat', nodetype) + + try: + node_feat = pd.read_csv(osp.join(subdir, additional_file + '.csv.gz'), compression='gzip', header = None).values + if 'int' in str(node_feat.dtype): + node_feat = node_feat.astype(np.int64) + else: + # float + node_feat = node_feat.astype(np.float32) + + assert(len(node_feat) == sum(num_node_dict[nodetype])) + + additional_feat_dict[nodetype] = node_feat + + except FileNotFoundError: + pass + + additional_node_info[additional_file] = additional_feat_dict + + additional_edge_info = {} + # e.g., additional_edge_info['edge_reltype'] = edge_feat_dict for edge_reltype + for additional_file in additional_edge_files: + assert(additional_file[:5] == 'edge_') + additional_feat_dict = {} + for triplet in triplet_list: + subdir = osp.join(raw_dir, 'relations', '___'.join(triplet)) + + try: + edge_feat = pd.read_csv(osp.join(subdir, additional_file + '.csv.gz'), compression='gzip', header = None).values + if 'int' in str(edge_feat.dtype): + edge_feat = edge_feat.astype(np.int64) + else: + # float + edge_feat = edge_feat.astype(np.float32) + + assert(len(edge_feat) == sum(num_edge_dict[triplet])) + + additional_feat_dict[triplet] = edge_feat + + except FileNotFoundError: + pass + + additional_edge_info[additional_file] = additional_feat_dict + + graph_list = [] + num_node_accum_dict = {nodetype: 0 for nodetype in nodetype_list} + num_edge_accum_dict = {triplet: 0 for triplet in triplet_list} + + print('Processing graphs...') + for i in tqdm(range(num_graphs)): + + graph = dict() + + ### set up default atribute + graph['edge_index_dict'] = {} + graph['edge_feat_dict'] = {} + graph['node_feat_dict'] = {} + graph['num_nodes_dict'] = {} + + ### set up additional node/edge attributes + for key in additional_node_info.keys(): + graph[key] = {} + + for key in additional_edge_info.keys(): + graph[key] = {} + + ### handling edge + for triplet in triplet_list: + edge = edge_dict[triplet] + num_edge = num_edge_dict[triplet][i] + num_edge_accum = num_edge_accum_dict[triplet] + + if add_inverse_edge: + ### add edge_index + # duplicate edge + duplicated_edge = np.repeat(edge[:, num_edge_accum:num_edge_accum + num_edge], 2, axis = 1) + duplicated_edge[0, 1::2] = duplicated_edge[1,0::2] + duplicated_edge[1, 1::2] = duplicated_edge[0,0::2] + graph['edge_index_dict'][triplet] = duplicated_edge + + ### add default edge feature + if len(edge_feat_dict) > 0: + # if edge_feat exists for some triplet + if triplet in edge_feat_dict: + graph['edge_feat_dict'][triplet] = np.repeat(edge_feat_dict[triplet][num_edge:num_edge + num_edge], 2, axis = 0) + + else: + # if edge_feat is not given for any triplet + graph['edge_feat_dict'] = None + + ### add additional edge feature + for key, value in additional_edge_info.items(): + if triplet in value: + graph[key][triplet] = np.repeat(value[triplet][num_edge_accum : num_edge_accum + num_edge], 2, axis = 0) + + else: + ### add edge_index + graph['edge_index_dict'][triplet] = edge[:, num_edge_accum:num_edge_accum+num_edge] + + ### add default edge feature + if len(edge_feat_dict) > 0: + # if edge_feat exists for some triplet + if triplet in edge_feat_dict: + graph['edge_feat_dict'][triplet] = edge_feat_dict[triplet][num_edge:num_edge + num_edge] + + else: + # if edge_feat is not given for any triplet + graph['edge_feat_dict'] = None + + ### add additional edge feature + for key, value in additional_edge_info.items(): + if triplet in value: + graph[key][triplet] = value[triplet][num_edge_accum : num_edge_accum + num_edge] + + num_edge_accum_dict[triplet] += num_edge + + ### handling node + for nodetype in nodetype_list: + num_node = num_node_dict[nodetype][i] + num_node_accum = num_node_accum_dict[nodetype] + + ### add default node feature + if len(node_feat_dict) > 0: + # if node_feat exists for some node type + if nodetype in node_feat_dict: + graph['node_feat_dict'][nodetype] = node_feat_dict[nodetype][num_node_accum:num_node_accum + num_node] + + else: + graph['node_feat_dict'] = None + + ### add additional node feature + for key, value in additional_node_info.items(): + if nodetype in value: + graph[key][nodetype] = value[nodetype][num_node_accum : num_node_accum + num_node] + + graph['num_nodes_dict'][nodetype] = num_node + num_node_accum_dict[nodetype] += num_node + + graph_list.append(graph) + + return graph_list + + +def read_binary_heterograph_raw(raw_dir, add_inverse_edge = False): + ''' + raw_dir: path to the raw directory + add_inverse_edge (bool): whether to add inverse edge or not + + return: graph_list, which is a list of heterogeneous graphs. + Each graph is a dictionary, containing the following keys: + - edge_index_dict + edge_index_dict[(head, rel, tail)] = edge_index for (head, rel, tail) + + - edge_feat_dict + edge_feat_dict[(head, rel, tail)] = edge_feat for (head, rel, tail) + + - node_feat_dict + node_feat_dict[nodetype] = node_feat for nodetype + + - num_nodes_dict + num_nodes_dict[nodetype] = num_nodes for nodetype + + * edge_feat_dict and node_feat_dict are optional: if a graph does not contain it, we will simply have None. + + We can also have additional node/edge features. For example, + - edge_** + - node_** + + ''' + + if add_inverse_edge: + raise RuntimeError('add_inverse_edge is depreciated in read_binary') + + print('Loading necessary files...') + print('This might take a while.') + + # loading necessary files + try: + num_nodes_dict = read_npz_dict(osp.join(raw_dir, 'num_nodes_dict.npz')) + tmp = read_npz_dict(osp.join(raw_dir, 'num_edges_dict.npz')) + num_edges_dict = {tuple(key.split('___')): tmp[key] for key in tmp.keys()} + del tmp + tmp = read_npz_dict(osp.join(raw_dir, 'edge_index_dict.npz')) + edge_index_dict = {tuple(key.split('___')): tmp[key] for key in tmp.keys()} + del tmp + + ent_type_list = sorted(list(num_nodes_dict.keys())) + triplet_type_list = sorted(list(num_edges_dict.keys())) + + num_graphs = len(num_nodes_dict[ent_type_list[0]]) + + except FileNotFoundError: + raise RuntimeError('No necessary file') + + # storing node and edge features + # mapping from the name of the features to feat_dict + node_feat_dict_dict = {} + edge_feat_dict_dict = {} + + for filename in os.listdir(raw_dir): + if '.npz' not in filename: + continue + if filename in ['num_nodes_dict.npz', 'num_edges_dict.npz', 'edge_index_dict.npz']: + continue + + # do not read target label information here + if '-label.npz' in filename: + continue + + feat_name = filename.split('.')[0] + + if 'node_' in feat_name: + feat_dict = read_npz_dict(osp.join(raw_dir, filename)) + node_feat_dict_dict[feat_name] = feat_dict + elif 'edge_' in feat_name: + tmp = read_npz_dict(osp.join(raw_dir, filename)) + feat_dict = {tuple(key.split('___')): tmp[key] for key in tmp.keys()} + del tmp + edge_feat_dict_dict[feat_name] = feat_dict + else: + raise RuntimeError(f"Keys in graph object should start from either \'node_\' or \'edge_\', but found \'{feat_name}\'.") + + graph_list = [] + num_nodes_accum_dict = {ent_type: 0 for ent_type in ent_type_list} + num_edges_accum_dict = {triplet: 0 for triplet in triplet_type_list} + + print('Processing graphs...') + for i in tqdm(range(num_graphs)): + + graph = dict() + + ### set up default atribute + graph['edge_index_dict'] = {} + graph['num_nodes_dict'] = {} + + for feat_name in node_feat_dict_dict.keys(): + graph[feat_name] = {} + + for feat_name in edge_feat_dict_dict.keys(): + graph[feat_name] = {} + + if not 'edge_feat_dict' in graph: + graph['edge_feat_dict'] = None + + if not 'node_feat_dict' in graph: + graph['node_feat_dict'] = None + + ### handling edge + for triplet in triplet_type_list: + edge_index = edge_index_dict[triplet] + num_edges = num_edges_dict[triplet][i] + num_edges_accum = num_edges_accum_dict[triplet] + + ### add edge_index + graph['edge_index_dict'][triplet] = edge_index[:, num_edges_accum:num_edges_accum+num_edges] + + ### add edge feature + for feat_name in edge_feat_dict_dict.keys(): + if triplet in edge_feat_dict_dict[feat_name]: + feat = edge_feat_dict_dict[feat_name][triplet] + graph[feat_name][triplet] = feat[num_edges_accum : num_edges_accum + num_edges] + + num_edges_accum_dict[triplet] += num_edges + + ### handling node + for ent_type in ent_type_list: + num_nodes = num_nodes_dict[ent_type][i] + num_nodes_accum = num_nodes_accum_dict[ent_type] + + ### add node feature + for feat_name in node_feat_dict_dict.keys(): + if ent_type in node_feat_dict_dict[feat_name]: + feat = node_feat_dict_dict[feat_name][ent_type] + graph[feat_name][ent_type] = feat[num_nodes_accum : num_nodes_accum + num_nodes] + + graph['num_nodes_dict'][ent_type] = num_nodes + num_nodes_accum_dict[ent_type] += num_nodes + + graph_list.append(graph) + + return graph_list + +def read_npz_dict(path): + tmp = np.load(path) + dict = {} + for key in tmp.keys(): + dict[key] = tmp[key] + del tmp + return dict + +def read_node_label_hetero(raw_dir): + df = pd.read_csv(osp.join(raw_dir, 'nodetype-has-label.csv.gz')) + label_dict = {} + for nodetype in df.keys(): + has_label = df[nodetype].values[0] + if has_label: + label_dict[nodetype] = pd.read_csv(osp.join(raw_dir, 'node-label', nodetype, 'node-label.csv.gz'), compression='gzip', header = None).values + + if len(label_dict) == 0: + raise RuntimeError('No node label file found.') + + return label_dict + + +def read_nodesplitidx_split_hetero(split_dir): + df = pd.read_csv(osp.join(split_dir, 'nodetype-has-split.csv.gz')) + train_dict = {} + valid_dict = {} + test_dict = {} + for nodetype in df.keys(): + has_label = df[nodetype].values[0] + if has_label: + train_dict[nodetype] = pd.read_csv(osp.join(split_dir, nodetype, 'train.csv.gz'), compression='gzip', header = None).values.T[0] + valid_dict[nodetype] = pd.read_csv(osp.join(split_dir, nodetype, 'valid.csv.gz'), compression='gzip', header = None).values.T[0] + test_dict[nodetype] = pd.read_csv(osp.join(split_dir, nodetype, 'test.csv.gz'), compression='gzip', header = None).values.T[0] + + if len(train_dict) == 0: + raise RuntimeError('No split file found.') + + return train_dict, valid_dict, test_dict + +if __name__ == '__main__': + pass + + diff --git a/gammagl/utils/ogb_url.py b/gammagl/utils/ogb_url.py new file mode 100644 index 00000000..e16b01df --- /dev/null +++ b/gammagl/utils/ogb_url.py @@ -0,0 +1,91 @@ +import urllib.request as ur +import zipfile +import os +import os.path as osp +from six.moves import urllib +import errno +from tqdm import tqdm + +GBFACTOR = float(1 << 30) + +def decide_download(url): + d = ur.urlopen(url) + size = int(d.info()["Content-Length"])/GBFACTOR + + ### confirm if larger than 1GB + if size > 1: + return input("This will download %.2fGB. Will you proceed? (y/N)\n" % (size)).lower() == "y" + else: + return True + +def makedirs(path): + try: + os.makedirs(osp.expanduser(osp.normpath(path))) + except OSError as e: + if e.errno != errno.EEXIST and osp.isdir(path): + raise e + +def download_url(url, folder, log=True): + r"""Downloads the content of an URL to a specific folder. + Args: + url (string): The url. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + + filename = url.rpartition('/')[2] + path = osp.join(folder, filename) + + if osp.exists(path) and osp.getsize(path) > 0: # pragma: no cover + if log: + print('Using exist file', filename) + return path + + if log: + print('Downloading', url) + + makedirs(folder) + data = ur.urlopen(url) + + size = int(data.info()["Content-Length"]) + + chunk_size = 1024*1024 + num_iter = int(size/chunk_size) + 2 + + downloaded_size = 0 + + try: + with open(path, 'wb') as f: + pbar = tqdm(range(num_iter)) + for i in pbar: + chunk = data.read(chunk_size) + downloaded_size += len(chunk) + pbar.set_description("Downloaded {:.2f} GB".format(float(downloaded_size)/GBFACTOR)) + f.write(chunk) + except: + if os.path.exists(path): + os.remove(path) + raise RuntimeError('Stopped downloading due to interruption.') + + + return path + +def maybe_log(path, log=True): + if log: + print('Extracting', path) + +def extract_zip(path, folder, log=True): + r"""Extracts a zip archive to a specific folder. + Args: + path (string): The path to the tar archive. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + maybe_log(path, log) + with zipfile.ZipFile(path, 'r') as f: + f.extractall(folder) + +if __name__ == "__main__": + pass \ No newline at end of file From 84a9fa2084a73b0bd98281789adce47263e5313e Mon Sep 17 00:00:00 2001 From: yang_starry_sky Date: Tue, 30 Aug 2022 16:59:33 +0800 Subject: [PATCH 03/25] extend InMemoryDataset --- gammagl/datasets/ogb_node.py | 272 ++++++++++++++++++----------------- gammagl/io/read_ogb_pyg.py | 111 ++++++++++++++ 2 files changed, 255 insertions(+), 128 deletions(-) create mode 100644 gammagl/io/read_ogb_pyg.py diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index b3c3f290..0f102a5d 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -1,28 +1,36 @@ -from gammagl.data import Graph import pandas as pd import shutil, os import os.path as osp -from gammagl.utils.ogb_url import decide_download, download_url, extract_zip -from gammagl.io.read_ogb_raw import read_csv_graph_raw, read_csv_heterograph_raw, \ - read_node_label_hetero, read_nodesplitidx_split_hetero, \ - read_binary_graph_raw, read_binary_heterograph_raw - import torch import numpy as np +# from gammagl.data import Graph +from gammagl.data import InMemoryDataset +from gammagl.utils.url import decide_download, download_url, extract_zip +from gammagl.io.read_graph_pyg import read_graph_pyg, read_heterograph_pyg +from gammagl.io.read_graph_raw import read_node_label_hetero, read_nodesplitidx_split_hetero + -class OgbNodeDataset(object): - def __init__(self, name, root='dataset', meta_dict=None): +class OgbNodeDataset(InMemoryDataset): + def __init__(self, name, root='dataset', transform=None, pre_transform=None, meta_dict=None): ''' - name (str): name of the dataset - root (str): root directory to store the dataset folder - - meta_dict: dictionary that stores all the meta-information about data. Default is None, + - transform, pre_transform (optional): transform/pre-transform graph objects + + - meta_dict: dictionary that stores all the meta-information about data. Default is None, but when something is passed, it uses its information. Useful for debugging for external contributers. ''' self.name = name ## original name, e.g., ogbn-proteins if meta_dict is None: - self.dir_name = '_'.join(name.split('-')) ## replace hyphen with underline, e.g., ogbn_proteins + self.dir_name = '_'.join(name.split('-')) + + # check if previously-downloaded folder exists. + # If so, use that one. + if osp.exists(osp.join(root, self.dir_name + '_pyg')): + self.dir_name = self.dir_name + '_pyg' + self.original_root = root self.root = osp.join(root, self.dir_name) @@ -43,9 +51,9 @@ def __init__(self, name, root='dataset', meta_dict=None): # check version # First check whether the dataset has been already downloaded or not. # If so, check whether the dataset version is the newest or not. - # If the dataset is not the newest version, notify this to the user. + # If the dataset is not the newest version, notify this to the user. if osp.isdir(self.root) and ( - not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))): + not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))): print(self.name + ' has been updated.') if input('Will you update the dataset now? (y/N)\n').lower() == 'y': shutil.rmtree(self.root) @@ -55,109 +63,12 @@ def __init__(self, name, root='dataset', meta_dict=None): self.num_tasks = int(self.meta_info['num tasks']) self.task_type = self.meta_info['task type'] self.eval_metric = self.meta_info['eval metric'] - self.num_classes = int(self.meta_info['num classes']) + self.__num_classes__ = int(self.meta_info['num classes']) self.is_hetero = self.meta_info['is hetero'] == 'True' self.binary = self.meta_info['binary'] == 'True' - super(OgbNodeDataset, self).__init__() - - self.pre_process() - - def pre_process(self): - processed_dir = osp.join(self.root, 'processed') - pre_processed_file_path = osp.join(processed_dir, 'data_processed') - - if osp.exists(pre_processed_file_path): - loaded_dict = torch.load(pre_processed_file_path) - self.graph, self.labels = loaded_dict['graph'], loaded_dict['labels'] - self.data = Graph(edge_index=self.graph['edge_index'], x=self.graph['node_feat'], y=self.labels) - self.data.num_nodes = self.graph['num_nodes'] - self.data.edge_attr = self.graph['edge_feat'] - self.data.tensor() - - else: - ### check download - if self.binary: - # npz format - has_necessary_file_simple = osp.exists(osp.join(self.root, 'raw', 'data.npz')) and (not self.is_hetero) - has_necessary_file_hetero = osp.exists( - osp.join(self.root, 'raw', 'edge_index_dict.npz')) and self.is_hetero - else: - # csv file - has_necessary_file_simple = osp.exists(osp.join(self.root, 'raw', 'edge.csv.gz')) and ( - not self.is_hetero) - has_necessary_file_hetero = osp.exists( - osp.join(self.root, 'raw', 'triplet-type-list.csv.gz')) and self.is_hetero - - has_necessary_file = has_necessary_file_simple or has_necessary_file_hetero - - if not has_necessary_file: - url = self.meta_info['url'] - if decide_download(url): - path = download_url(url, self.original_root) - extract_zip(path, self.original_root) - os.unlink(path) - # delete folder if there exists - try: - shutil.rmtree(self.root) - except: - pass - shutil.move(osp.join(self.original_root, self.download_name), self.root) - else: - print('Stop download.') - exit(-1) - - raw_dir = osp.join(self.root, 'raw') - - ### pre-process and save - add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True' - - if self.meta_info['additional node files'] == 'None': - additional_node_files = [] - else: - additional_node_files = self.meta_info['additional node files'].split(',') - - if self.meta_info['additional edge files'] == 'None': - additional_edge_files = [] - else: - additional_edge_files = self.meta_info['additional edge files'].split(',') - - if self.is_hetero: - if self.binary: - self.graph = read_binary_heterograph_raw(raw_dir, add_inverse_edge=add_inverse_edge)[ - 0] # only a single graph - - tmp = np.load(osp.join(raw_dir, 'node-label.npz')) - self.labels = {} - for key in list(tmp.keys()): - self.labels[key] = tmp[key] - del tmp - else: - self.graph = read_csv_heterograph_raw(raw_dir, add_inverse_edge=add_inverse_edge, - additional_node_files=additional_node_files, - additional_edge_files=additional_edge_files)[ - 0] # only a single graph - self.labels = read_node_label_hetero(raw_dir) - - else: - if self.binary: - self.graph = read_binary_graph_raw(raw_dir, add_inverse_edge=add_inverse_edge)[ - 0] # only a single graph - self.labels = np.load(osp.join(raw_dir, 'node-label.npz'))['node_label'] - else: - self.graph = read_csv_graph_raw(raw_dir, add_inverse_edge=add_inverse_edge, - additional_node_files=additional_node_files, - additional_edge_files=additional_edge_files)[ - 0] # only a single graph - self.labels = pd.read_csv(osp.join(raw_dir, 'node-label.csv.gz'), compression='gzip', - header=None).values - - print('Saving...') - self.data = Graph(edge_index=self.graph['edge_index'], x=self.graph['node_feat'], y=self.labels) - self.data.num_nodes = self.graph['num_nodes'] - self.data.edge_attr = self.graph['edge_feat'] - self.data.tensor() - torch.save(self.data, pre_processed_file_path, pickle_protocol=4) + super(OgbNodeDataset, self).__init__(self.root, transform, pre_transform) + self.data, self.slices = torch.load(self.processed_paths[0]) def get_idx_split(self, split_type=None): if split_type is None: @@ -172,33 +83,138 @@ def get_idx_split(self, split_type=None): if self.is_hetero: train_idx_dict, valid_idx_dict, test_idx_dict = read_nodesplitidx_split_hetero(path) for nodetype in train_idx_dict.keys(): - train_idx_dict[nodetype] = train_idx_dict[nodetype] - valid_idx_dict[nodetype] = valid_idx_dict[nodetype] - test_idx_dict[nodetype] = test_idx_dict[nodetype] + train_idx_dict[nodetype] = torch.from_numpy(train_idx_dict[nodetype]).to(torch.long) + valid_idx_dict[nodetype] = torch.from_numpy(valid_idx_dict[nodetype]).to(torch.long) + test_idx_dict[nodetype] = torch.from_numpy(test_idx_dict[nodetype]).to(torch.long) return {'train': train_idx_dict, 'valid': valid_idx_dict, 'test': test_idx_dict} else: - train_idx = pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header=None).values.T[0] - valid_idx = pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header=None).values.T[0] - test_idx = pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header=None).values.T[0] + train_idx = torch.from_numpy( + pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header=None).values.T[0]).to(torch.long) + valid_idx = torch.from_numpy( + pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header=None).values.T[0]).to(torch.long) + test_idx = torch.from_numpy( + pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header=None).values.T[0]).to(torch.long) return {'train': train_idx, 'valid': valid_idx, 'test': test_idx} + @property + def num_classes(self): + return self.__num_classes__ + + @property + def raw_file_names(self): + if self.binary: + if self.is_hetero: + return ['edge_index_dict.npz'] + else: + return ['data.npz'] + else: + if self.is_hetero: + return ['num-node-dict.csv.gz', 'triplet-type-list.csv.gz'] + else: + file_names = ['edge'] + if self.meta_info['has_node_attr'] == 'True': + file_names.append('node-feat') + if self.meta_info['has_edge_attr'] == 'True': + file_names.append('edge-feat') + return [file_name + '.csv.gz' for file_name in file_names] + + @property + def processed_file_names(self): + return osp.join('geometric_data_processed.pt') + + def download(self): + url = self.meta_info['url'] + if decide_download(url): + path = download_url(url, self.original_root) + extract_zip(path, self.original_root) + os.unlink(path) + shutil.rmtree(self.root) + shutil.move(osp.join(self.original_root, self.download_name), self.root) + else: + print('Stop downloading.') + shutil.rmtree(self.root) + exit(-1) + + def process(self): + add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True' + + if self.meta_info['additional node files'] == 'None': + additional_node_files = [] + else: + additional_node_files = self.meta_info['additional node files'].split(',') + + if self.meta_info['additional edge files'] == 'None': + additional_edge_files = [] + else: + additional_edge_files = self.meta_info['additional edge files'].split(',') + + if self.is_hetero: + data = read_heterograph_pyg(self.raw_dir, add_inverse_edge=add_inverse_edge, + additional_node_files=additional_node_files, + additional_edge_files=additional_edge_files, binary=self.binary)[0] + + if self.binary: + tmp = np.load(osp.join(self.raw_dir, 'node-label.npz')) + node_label_dict = {} + for key in list(tmp.keys()): + node_label_dict[key] = tmp[key] + del tmp + else: + node_label_dict = read_node_label_hetero(self.raw_dir) + + data.y_dict = {} + if 'classification' in self.task_type: + for nodetype, node_label in node_label_dict.items(): + # detect if there is any nan + if np.isnan(node_label).any(): + data.y_dict[nodetype] = torch.from_numpy(node_label).to(torch.float32) + else: + data.y_dict[nodetype] = torch.from_numpy(node_label).to(torch.long) + else: + for nodetype, node_label in node_label_dict.items(): + data.y_dict[nodetype] = torch.from_numpy(node_label).to(torch.float32) + + else: + data = \ + read_graph_pyg(self.raw_dir, add_inverse_edge=add_inverse_edge, additional_node_files=additional_node_files, + additional_edge_files=additional_edge_files, binary=self.binary)[0] + ### adding prediction target + if self.binary: + node_label = np.load(osp.join(self.raw_dir, 'node-label.npz'))['node_label'] + else: + node_label = pd.read_csv(osp.join(self.raw_dir, 'node-label.csv.gz'), compression='gzip', + header=None).values + ''' + if 'classification' in self.task_type: + # detect if there is any nan + if np.isnan(node_label).any(): + data.y = torch.from_numpy(node_label).to(torch.float32) + else: + data.y = torch.from_numpy(node_label).to(torch.long) + + else: + data.y = torch.from_numpy(node_label).to(torch.float32) + ''' + data.y = node_label + data = data if self.pre_transform is None else self.pre_transform(data) + self.data = data + print('Saving...') + torch.save(self.collate([data]), self.processed_paths[0]) + def __getitem__(self, idx): assert idx == 0, 'This dataset has only one graph' return self.data - def __len__(self): - return 1 - - def __repr__(self): # pragma: no cover - return '{}({})'.format(self.__class__.__name__, len(self)) + def __repr__(self): + return '{}()'.format(self.__class__.__name__) if __name__ == '__main__': - dataset = OgbNodeDataset(name='ogbn-mag') - print(dataset.num_classes) - split_index = dataset.get_idx_split() - print(dataset[0]) - print(split_index) \ No newline at end of file + pyg_dataset = PygNodePropPredDataset(name='ogbn-mag') + print(pyg_dataset[0]) + split_index = pyg_dataset.get_idx_split() + # print(split_index) + diff --git a/gammagl/io/read_ogb_pyg.py b/gammagl/io/read_ogb_pyg.py new file mode 100644 index 00000000..052fb9fd --- /dev/null +++ b/gammagl/io/read_ogb_pyg.py @@ -0,0 +1,111 @@ +import pandas as pd +import torch +from torch_geometric.data import Data +import os.path as osp +import numpy as np +from gammagl.data import Graph +from gammagl.io.read_ogb_raw import read_csv_graph_raw, read_csv_heterograph_raw, read_binary_graph_raw, read_binary_heterograph_raw +from tqdm import tqdm + +def read_graph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False): + + if binary: + # npz + graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge) + else: + # csv + graph_list = read_csv_graph_raw(raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files) + + pyg_graph_list = [] + + print('Converting graphs into PyG objects...') + + for graph in tqdm(graph_list): + g = Graph() + g.num_nodes = graph['num_nodes'] + g.edge_index = graph['edge_index'] + + del graph['num_nodes'] + del graph['edge_index'] + + if graph['edge_feat'] is not None: + g.edge_attr = graph['edge_feat'] + del graph['edge_feat'] + + if graph['node_feat'] is not None: + g.x = graph['node_feat'] + del graph['node_feat'] + + for key in additional_node_files: + g[key] = graph[key] + del graph[key] + + for key in additional_edge_files: + g[key] = graph[key] + del graph[key] + + pyg_graph_list.append(g) + + return pyg_graph_list + + +def read_heterograph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False): + + if binary: + # npz + graph_list = read_binary_heterograph_raw(raw_dir, add_inverse_edge) + else: + # csv + graph_list = read_csv_heterograph_raw(raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files) + + pyg_graph_list = [] + + print('Converting graphs into PyG objects...') + + for graph in tqdm(graph_list): + g = Data() + + g.__num_nodes__ = graph['num_nodes_dict'] + g.num_nodes_dict = graph['num_nodes_dict'] + + # add edge connectivity + g.edge_index_dict = {} + for triplet, edge_index in graph['edge_index_dict'].items(): + g.edge_index_dict[triplet] = torch.from_numpy(edge_index) + + del graph['edge_index_dict'] + + if graph['edge_feat_dict'] is not None: + g.edge_attr_dict = {} + for triplet in graph['edge_feat_dict'].keys(): + g.edge_attr_dict[triplet] = torch.from_numpy(graph['edge_feat_dict'][triplet]) + + del graph['edge_feat_dict'] + + if graph['node_feat_dict'] is not None: + g.x_dict = {} + for nodetype in graph['node_feat_dict'].keys(): + g.x_dict[nodetype] = torch.from_numpy(graph['node_feat_dict'][nodetype]) + + del graph['node_feat_dict'] + + for key in additional_node_files: + g[key] = {} + for nodetype in graph[key].keys(): + g[key][nodetype] = torch.from_numpy(graph[key][nodetype]) + + del graph[key] + + for key in additional_edge_files: + g[key] = {} + for triplet in graph[key].keys(): + g[key][triplet] = torch.from_numpy(graph[key][triplet]) + + del graph[key] + + pyg_graph_list.append(g) + + return pyg_graph_list + +if __name__ == '__main__': + pass From ae6e3956bd42fb22517d4645a062947086fa02da Mon Sep 17 00:00:00 2001 From: yang_starry_sky Date: Tue, 30 Aug 2022 18:11:32 +0800 Subject: [PATCH 04/25] do not rely on torch --- gammagl/datasets/ogb_node.py | 73 +------ gammagl/io/{read_ogb_raw.py => read_ogb.py} | 226 +++++++++++++------- gammagl/io/read_ogb_pyg.py | 111 ---------- 3 files changed, 155 insertions(+), 255 deletions(-) rename gammagl/io/{read_ogb_raw.py => read_ogb.py} (76%) delete mode 100644 gammagl/io/read_ogb_pyg.py diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index 0f102a5d..bb463c77 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -1,23 +1,21 @@ import pandas as pd import shutil, os import os.path as osp -import torch import numpy as np # from gammagl.data import Graph from gammagl.data import InMemoryDataset -from gammagl.utils.url import decide_download, download_url, extract_zip -from gammagl.io.read_graph_pyg import read_graph_pyg, read_heterograph_pyg -from gammagl.io.read_graph_raw import read_node_label_hetero, read_nodesplitidx_split_hetero +from gammgl.utils.ogb_url import decide_download, download_url, extract_zip +from read_ogb import read_node_label_hetero, read_nodesplitidx_split_hetero,read_graph, read_heterograph -class OgbNodeDataset(InMemoryDataset): +class PygNodePropPredDataset(InMemoryDataset): def __init__(self, name, root='dataset', transform=None, pre_transform=None, meta_dict=None): ''' - name (str): name of the dataset - root (str): root directory to store the dataset folder - transform, pre_transform (optional): transform/pre-transform graph objects - - meta_dict: dictionary that stores all the meta-information about data. Default is None, + - meta_dict: dictionary that stores all the meta-information about data. Default is None, but when something is passed, it uses its information. Useful for debugging for external contributers. ''' @@ -51,7 +49,7 @@ def __init__(self, name, root='dataset', transform=None, pre_transform=None, met # check version # First check whether the dataset has been already downloaded or not. # If so, check whether the dataset version is the newest or not. - # If the dataset is not the newest version, notify this to the user. + # If the dataset is not the newest version, notify this to the user. if osp.isdir(self.root) and ( not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))): print(self.name + ' has been updated.') @@ -67,37 +65,8 @@ def __init__(self, name, root='dataset', transform=None, pre_transform=None, met self.is_hetero = self.meta_info['is hetero'] == 'True' self.binary = self.meta_info['binary'] == 'True' - super(OgbNodeDataset, self).__init__(self.root, transform, pre_transform) - self.data, self.slices = torch.load(self.processed_paths[0]) - - def get_idx_split(self, split_type=None): - if split_type is None: - split_type = self.meta_info['split'] - - path = osp.join(self.root, 'split', split_type) - - # short-cut if split_dict.pt exists - if os.path.isfile(os.path.join(path, 'split_dict.pt')): - return torch.load(os.path.join(path, 'split_dict.pt')) - - if self.is_hetero: - train_idx_dict, valid_idx_dict, test_idx_dict = read_nodesplitidx_split_hetero(path) - for nodetype in train_idx_dict.keys(): - train_idx_dict[nodetype] = torch.from_numpy(train_idx_dict[nodetype]).to(torch.long) - valid_idx_dict[nodetype] = torch.from_numpy(valid_idx_dict[nodetype]).to(torch.long) - test_idx_dict[nodetype] = torch.from_numpy(test_idx_dict[nodetype]).to(torch.long) - - return {'train': train_idx_dict, 'valid': valid_idx_dict, 'test': test_idx_dict} - - else: - train_idx = torch.from_numpy( - pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header=None).values.T[0]).to(torch.long) - valid_idx = torch.from_numpy( - pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header=None).values.T[0]).to(torch.long) - test_idx = torch.from_numpy( - pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header=None).values.T[0]).to(torch.long) - - return {'train': train_idx, 'valid': valid_idx, 'test': test_idx} + super(PygNodePropPredDataset, self).__init__(self.root, transform, pre_transform) + self.data, self.slices = self.load_data(self.processed_paths[0]) @property def num_classes(self): @@ -152,31 +121,7 @@ def process(self): additional_edge_files = self.meta_info['additional edge files'].split(',') if self.is_hetero: - data = read_heterograph_pyg(self.raw_dir, add_inverse_edge=add_inverse_edge, - additional_node_files=additional_node_files, - additional_edge_files=additional_edge_files, binary=self.binary)[0] - - if self.binary: - tmp = np.load(osp.join(self.raw_dir, 'node-label.npz')) - node_label_dict = {} - for key in list(tmp.keys()): - node_label_dict[key] = tmp[key] - del tmp - else: - node_label_dict = read_node_label_hetero(self.raw_dir) - - data.y_dict = {} - if 'classification' in self.task_type: - for nodetype, node_label in node_label_dict.items(): - # detect if there is any nan - if np.isnan(node_label).any(): - data.y_dict[nodetype] = torch.from_numpy(node_label).to(torch.float32) - else: - data.y_dict[nodetype] = torch.from_numpy(node_label).to(torch.long) - else: - for nodetype, node_label in node_label_dict.items(): - data.y_dict[nodetype] = torch.from_numpy(node_label).to(torch.float32) - + pass else: data = \ read_graph_pyg(self.raw_dir, add_inverse_edge=add_inverse_edge, additional_node_files=additional_node_files, @@ -202,7 +147,7 @@ def process(self): data = data if self.pre_transform is None else self.pre_transform(data) self.data = data print('Saving...') - torch.save(self.collate([data]), self.processed_paths[0]) + self.save_data(self.collate([data]), self.processed_paths[0]) def __getitem__(self, idx): assert idx == 0, 'This dataset has only one graph' diff --git a/gammagl/io/read_ogb_raw.py b/gammagl/io/read_ogb.py similarity index 76% rename from gammagl/io/read_ogb_raw.py rename to gammagl/io/read_ogb.py index b76e099b..106cfd49 100644 --- a/gammagl/io/read_ogb_raw.py +++ b/gammagl/io/read_ogb.py @@ -4,10 +4,57 @@ import numpy as np from gammagl.utils.ogb_url import decide_download, download_url, extract_zip from tqdm import tqdm +from gammagl.data import Graph + +def read_graph(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[], binary=False): + if binary: + # npz + graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge) + else: + # csv + graph_list = read_csv_graph_raw(raw_dir, add_inverse_edge, additional_node_files=additional_node_files, + additional_edge_files=additional_edge_files) + + pyg_graph_list = [] + + print('Converting graphs into PyG objects...') + + for graph in tqdm(graph_list): + g = Graph() + g.num_nodes = graph['num_nodes'] + g.edge_index = graph['edge_index'] + + del graph['num_nodes'] + del graph['edge_index'] + + if graph['edge_feat'] is not None: + g.edge_attr = graph['edge_feat'] + del graph['edge_feat'] + + if graph['node_feat'] is not None: + g.x = graph['node_feat'] + del graph['node_feat'] + + for key in additional_node_files: + g[key] = graph[key] + del graph[key] + + for key in additional_edge_files: + g[key] = graph[key] + del graph[key] + + pyg_graph_list.append(g) + + return pyg_graph_list + + +def read_heterograph(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[], binary=False): + pass + ### reading raw files from a directory. ### for homogeneous graph -def read_csv_graph_raw(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = []): +def read_csv_graph_raw(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[]): ''' raw_dir: path to the raw directory add_inverse_edge (bool): whether to add inverse edge or not @@ -23,22 +70,27 @@ def read_csv_graph_raw(raw_dir, add_inverse_edge = False, additional_node_files additional_node_files must start from 'node_' additional_edge_files must start from 'edge_' - + ''' print('Loading necessary files...') print('This might take a while.') # loading necessary files try: - edge = pd.read_csv(osp.join(raw_dir, 'edge.csv.gz'), compression='gzip', header = None).values.T.astype(np.int64) # (2, num_edge) numpy array - num_node_list = pd.read_csv(osp.join(raw_dir, 'num-node-list.csv.gz'), compression='gzip', header = None).astype(np.int64)[0].tolist() # (num_graph, ) python list - num_edge_list = pd.read_csv(osp.join(raw_dir, 'num-edge-list.csv.gz'), compression='gzip', header = None).astype(np.int64)[0].tolist() # (num_edge, ) python list + edge = pd.read_csv(osp.join(raw_dir, 'edge.csv.gz'), compression='gzip', header=None).values.T.astype( + np.int64) # (2, num_edge) numpy array + num_node_list = \ + pd.read_csv(osp.join(raw_dir, 'num-node-list.csv.gz'), compression='gzip', header=None).astype(np.int64)[ + 0].tolist() # (num_graph, ) python list + num_edge_list = \ + pd.read_csv(osp.join(raw_dir, 'num-edge-list.csv.gz'), compression='gzip', header=None).astype(np.int64)[ + 0].tolist() # (num_edge, ) python list except FileNotFoundError: raise RuntimeError('No necessary file') try: - node_feat = pd.read_csv(osp.join(raw_dir, 'node-feat.csv.gz'), compression='gzip', header = None).values + node_feat = pd.read_csv(osp.join(raw_dir, 'node-feat.csv.gz'), compression='gzip', header=None).values if 'int' in str(node_feat.dtype): node_feat = node_feat.astype(np.int64) else: @@ -48,26 +100,25 @@ def read_csv_graph_raw(raw_dir, add_inverse_edge = False, additional_node_files node_feat = None try: - edge_feat = pd.read_csv(osp.join(raw_dir, 'edge-feat.csv.gz'), compression='gzip', header = None).values + edge_feat = pd.read_csv(osp.join(raw_dir, 'edge-feat.csv.gz'), compression='gzip', header=None).values if 'int' in str(edge_feat.dtype): edge_feat = edge_feat.astype(np.int64) else: - #float + # float edge_feat = edge_feat.astype(np.float32) except FileNotFoundError: edge_feat = None - - additional_node_info = {} + additional_node_info = {} for additional_file in additional_node_files: - assert(additional_file[:5] == 'node_') + assert (additional_file[:5] == 'node_') # hack for ogbn-proteins if additional_file == 'node_species' and osp.exists(osp.join(raw_dir, 'species.csv.gz')): os.rename(osp.join(raw_dir, 'species.csv.gz'), osp.join(raw_dir, 'node_species.csv.gz')) - temp = pd.read_csv(osp.join(raw_dir, additional_file + '.csv.gz'), compression='gzip', header = None).values + temp = pd.read_csv(osp.join(raw_dir, additional_file + '.csv.gz'), compression='gzip', header=None).values if 'int' in str(temp.dtype): additional_node_info[additional_file] = temp.astype(np.int64) @@ -75,10 +126,10 @@ def read_csv_graph_raw(raw_dir, add_inverse_edge = False, additional_node_files # float additional_node_info[additional_file] = temp.astype(np.float32) - additional_edge_info = {} + additional_edge_info = {} for additional_file in additional_edge_files: - assert(additional_file[:5] == 'edge_') - temp = pd.read_csv(osp.join(raw_dir, additional_file + '.csv.gz'), compression='gzip', header = None).values + assert (additional_file[:5] == 'edge_') + temp = pd.read_csv(osp.join(raw_dir, additional_file + '.csv.gz'), compression='gzip', header=None).values if 'int' in str(temp.dtype): additional_edge_info[additional_file] = temp.astype(np.int64) @@ -86,7 +137,6 @@ def read_csv_graph_raw(raw_dir, add_inverse_edge = False, additional_node_files # float additional_edge_info[additional_file] = temp.astype(np.float32) - graph_list = [] num_node_accum = 0 num_edge_accum = 0 @@ -99,42 +149,41 @@ def read_csv_graph_raw(raw_dir, add_inverse_edge = False, additional_node_files ### handling edge if add_inverse_edge: ### duplicate edge - duplicated_edge = np.repeat(edge[:, num_edge_accum:num_edge_accum+num_edge], 2, axis = 1) - duplicated_edge[0, 1::2] = duplicated_edge[1,0::2] - duplicated_edge[1, 1::2] = duplicated_edge[0,0::2] + duplicated_edge = np.repeat(edge[:, num_edge_accum:num_edge_accum + num_edge], 2, axis=1) + duplicated_edge[0, 1::2] = duplicated_edge[1, 0::2] + duplicated_edge[1, 1::2] = duplicated_edge[0, 0::2] graph['edge_index'] = duplicated_edge if edge_feat is not None: - graph['edge_feat'] = np.repeat(edge_feat[num_edge_accum:num_edge_accum+num_edge], 2, axis = 0) + graph['edge_feat'] = np.repeat(edge_feat[num_edge_accum:num_edge_accum + num_edge], 2, axis=0) else: graph['edge_feat'] = None for key, value in additional_edge_info.items(): - graph[key] = np.repeat(value[num_edge_accum:num_edge_accum+num_edge], 2, axis = 0) + graph[key] = np.repeat(value[num_edge_accum:num_edge_accum + num_edge], 2, axis=0) else: - graph['edge_index'] = edge[:, num_edge_accum:num_edge_accum+num_edge] + graph['edge_index'] = edge[:, num_edge_accum:num_edge_accum + num_edge] if edge_feat is not None: - graph['edge_feat'] = edge_feat[num_edge_accum:num_edge_accum+num_edge] + graph['edge_feat'] = edge_feat[num_edge_accum:num_edge_accum + num_edge] else: graph['edge_feat'] = None for key, value in additional_edge_info.items(): - graph[key] = value[num_edge_accum:num_edge_accum+num_edge] + graph[key] = value[num_edge_accum:num_edge_accum + num_edge] num_edge_accum += num_edge ### handling node if node_feat is not None: - graph['node_feat'] = node_feat[num_node_accum:num_node_accum+num_node] + graph['node_feat'] = node_feat[num_node_accum:num_node_accum + num_node] else: graph['node_feat'] = None for key, value in additional_node_info.items(): - graph[key] = value[num_node_accum:num_node_accum+num_node] - + graph[key] = value[num_node_accum:num_node_accum + num_node] graph['num_nodes'] = num_node num_node_accum += num_node @@ -147,7 +196,7 @@ def read_csv_graph_raw(raw_dir, add_inverse_edge = False, additional_node_files ### reading raw files from a directory. ### npz ver ### for homogeneous graph -def read_binary_graph_raw(raw_dir, add_inverse_edge = False): +def read_binary_graph_raw(raw_dir, add_inverse_edge=False): ''' raw_dir: path to the raw directory add_inverse_edge (bool): whether to add inverse edge or not @@ -188,7 +237,8 @@ def read_binary_graph_raw(raw_dir, add_inverse_edge = False): elif key[:5] == 'edge_': edge_dict[key] = data_dict[key] else: - raise RuntimeError(f"Keys in graph object should start from either \'node_\' or \'edge_\', but found \'{key}\'.") + raise RuntimeError( + f"Keys in graph object should start from either \'node_\' or \'edge_\', but found \'{key}\'.") graph_list = [] num_nodes_accum = 0 @@ -199,16 +249,16 @@ def read_binary_graph_raw(raw_dir, add_inverse_edge = False): graph = dict() - graph['edge_index'] = edge_index[:, num_edges_accum:num_edges_accum+num_edges] + graph['edge_index'] = edge_index[:, num_edges_accum:num_edges_accum + num_edges] for key, feat in edge_dict.items(): - graph[key] = feat[num_edges_accum:num_edges_accum+num_edges] + graph[key] = feat[num_edges_accum:num_edges_accum + num_edges] if 'edge_feat' not in graph: - graph['edge_feat'] = None + graph['edge_feat'] = None for key, feat in node_dict.items(): - graph[key] = feat[num_nodes_accum:num_nodes_accum+num_nodes] + graph[key] = feat[num_nodes_accum:num_nodes_accum + num_nodes] if 'node_feat' not in graph: graph['node_feat'] = None @@ -225,7 +275,7 @@ def read_binary_graph_raw(raw_dir, add_inverse_edge = False): ### reading raw files from a directory. ### for heterogeneous graph -def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = []): +def read_csv_heterograph_raw(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[]): ''' raw_dir: path to the raw directory add_inverse_edge (bool): whether to add inverse edge or not @@ -240,10 +290,10 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ - node_feat_dict node_feat_dict[nodetype] = node_feat for nodetype - + - num_nodes_dict num_nodes_dict[nodetype] = num_nodes for nodetype - + * edge_feat_dict and node_feat_dict are optional: if a graph does not contain it, we will simply have None. We can also have additional node/edge features. For example, @@ -252,7 +302,7 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ - node_year_dict node_year_dict[nodetype] = node_year - + ''' print('Loading necessary files...') @@ -265,8 +315,9 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ nodetype_list = sorted(list(num_node_dict.keys())) ## read edge_dict, num_edge_dict - triplet_df = pd.read_csv(osp.join(raw_dir, 'triplet-type-list.csv.gz'), compression='gzip', header = None) - triplet_list = sorted([(head, relation, tail) for head, relation, tail in zip(triplet_df[0].tolist(), triplet_df[1].tolist(), triplet_df[2].tolist())]) + triplet_df = pd.read_csv(osp.join(raw_dir, 'triplet-type-list.csv.gz'), compression='gzip', header=None) + triplet_list = sorted([(head, relation, tail) for head, relation, tail in + zip(triplet_df[0].tolist(), triplet_df[1].tolist(), triplet_df[2].tolist())]) edge_dict = {} num_edge_dict = {} @@ -274,11 +325,14 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ for triplet in triplet_list: subdir = osp.join(raw_dir, 'relations', '___'.join(triplet)) - edge_dict[triplet] = pd.read_csv(osp.join(subdir, 'edge.csv.gz'), compression='gzip', header = None).values.T.astype(np.int64) - num_edge_dict[triplet] = pd.read_csv(osp.join(subdir, 'num-edge-list.csv.gz'), compression='gzip', header = None).astype(np.int64)[0].tolist() + edge_dict[triplet] = pd.read_csv(osp.join(subdir, 'edge.csv.gz'), compression='gzip', + header=None).values.T.astype(np.int64) + num_edge_dict[triplet] = \ + pd.read_csv(osp.join(subdir, 'num-edge-list.csv.gz'), compression='gzip', header=None).astype(np.int64)[ + 0].tolist() # check the number of graphs coincide - assert(len(num_node_dict[nodetype_list[0]]) == len(num_edge_dict[triplet_list[0]])) + assert (len(num_node_dict[nodetype_list[0]]) == len(num_edge_dict[triplet_list[0]])) num_graphs = len(num_node_dict[nodetype_list[0]]) @@ -288,9 +342,9 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ node_feat_dict = {} for nodetype in nodetype_list: subdir = osp.join(raw_dir, 'node-feat', nodetype) - + try: - node_feat = pd.read_csv(osp.join(subdir, 'node-feat.csv.gz'), compression='gzip', header = None).values + node_feat = pd.read_csv(osp.join(subdir, 'node-feat.csv.gz'), compression='gzip', header=None).values if 'int' in str(node_feat.dtype): node_feat = node_feat.astype(np.int64) else: @@ -306,11 +360,11 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ subdir = osp.join(raw_dir, 'relations', '___'.join(triplet)) try: - edge_feat = pd.read_csv(osp.join(subdir, 'edge-feat.csv.gz'), compression='gzip', header = None).values + edge_feat = pd.read_csv(osp.join(subdir, 'edge-feat.csv.gz'), compression='gzip', header=None).values if 'int' in str(edge_feat.dtype): edge_feat = edge_feat.astype(np.int64) else: - #float + # float edge_feat = edge_feat.astype(np.float32) edge_feat_dict[triplet] = edge_feat @@ -318,25 +372,25 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ except FileNotFoundError: pass - additional_node_info = {} # e.g., additional_node_info['node_year'] = node_feature_dict for node_year for additional_file in additional_node_files: additional_feat_dict = {} - assert(additional_file[:5] == 'node_') + assert (additional_file[:5] == 'node_') for nodetype in nodetype_list: subdir = osp.join(raw_dir, 'node-feat', nodetype) try: - node_feat = pd.read_csv(osp.join(subdir, additional_file + '.csv.gz'), compression='gzip', header = None).values + node_feat = pd.read_csv(osp.join(subdir, additional_file + '.csv.gz'), compression='gzip', + header=None).values if 'int' in str(node_feat.dtype): node_feat = node_feat.astype(np.int64) else: # float node_feat = node_feat.astype(np.float32) - assert(len(node_feat) == sum(num_node_dict[nodetype])) + assert (len(node_feat) == sum(num_node_dict[nodetype])) additional_feat_dict[nodetype] = node_feat @@ -348,20 +402,21 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ additional_edge_info = {} # e.g., additional_edge_info['edge_reltype'] = edge_feat_dict for edge_reltype for additional_file in additional_edge_files: - assert(additional_file[:5] == 'edge_') + assert (additional_file[:5] == 'edge_') additional_feat_dict = {} for triplet in triplet_list: subdir = osp.join(raw_dir, 'relations', '___'.join(triplet)) - + try: - edge_feat = pd.read_csv(osp.join(subdir, additional_file + '.csv.gz'), compression='gzip', header = None).values + edge_feat = pd.read_csv(osp.join(subdir, additional_file + '.csv.gz'), compression='gzip', + header=None).values if 'int' in str(edge_feat.dtype): edge_feat = edge_feat.astype(np.int64) else: # float edge_feat = edge_feat.astype(np.float32) - assert(len(edge_feat) == sum(num_edge_dict[triplet])) + assert (len(edge_feat) == sum(num_edge_dict[triplet])) additional_feat_dict[triplet] = edge_feat @@ -401,16 +456,17 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ if add_inverse_edge: ### add edge_index # duplicate edge - duplicated_edge = np.repeat(edge[:, num_edge_accum:num_edge_accum + num_edge], 2, axis = 1) - duplicated_edge[0, 1::2] = duplicated_edge[1,0::2] - duplicated_edge[1, 1::2] = duplicated_edge[0,0::2] + duplicated_edge = np.repeat(edge[:, num_edge_accum:num_edge_accum + num_edge], 2, axis=1) + duplicated_edge[0, 1::2] = duplicated_edge[1, 0::2] + duplicated_edge[1, 1::2] = duplicated_edge[0, 0::2] graph['edge_index_dict'][triplet] = duplicated_edge ### add default edge feature if len(edge_feat_dict) > 0: # if edge_feat exists for some triplet if triplet in edge_feat_dict: - graph['edge_feat_dict'][triplet] = np.repeat(edge_feat_dict[triplet][num_edge:num_edge + num_edge], 2, axis = 0) + graph['edge_feat_dict'][triplet] = np.repeat( + edge_feat_dict[triplet][num_edge:num_edge + num_edge], 2, axis=0) else: # if edge_feat is not given for any triplet @@ -419,11 +475,12 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ ### add additional edge feature for key, value in additional_edge_info.items(): if triplet in value: - graph[key][triplet] = np.repeat(value[triplet][num_edge_accum : num_edge_accum + num_edge], 2, axis = 0) + graph[key][triplet] = np.repeat(value[triplet][num_edge_accum: num_edge_accum + num_edge], 2, + axis=0) else: ### add edge_index - graph['edge_index_dict'][triplet] = edge[:, num_edge_accum:num_edge_accum+num_edge] + graph['edge_index_dict'][triplet] = edge[:, num_edge_accum:num_edge_accum + num_edge] ### add default edge feature if len(edge_feat_dict) > 0: @@ -438,7 +495,7 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ ### add additional edge feature for key, value in additional_edge_info.items(): if triplet in value: - graph[key][triplet] = value[triplet][num_edge_accum : num_edge_accum + num_edge] + graph[key][triplet] = value[triplet][num_edge_accum: num_edge_accum + num_edge] num_edge_accum_dict[triplet] += num_edge @@ -451,15 +508,16 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ if len(node_feat_dict) > 0: # if node_feat exists for some node type if nodetype in node_feat_dict: - graph['node_feat_dict'][nodetype] = node_feat_dict[nodetype][num_node_accum:num_node_accum + num_node] - + graph['node_feat_dict'][nodetype] = node_feat_dict[nodetype][ + num_node_accum:num_node_accum + num_node] + else: - graph['node_feat_dict'] = None + graph['node_feat_dict'] = None - ### add additional node feature + ### add additional node feature for key, value in additional_node_info.items(): if nodetype in value: - graph[key][nodetype] = value[nodetype][num_node_accum : num_node_accum + num_node] + graph[key][nodetype] = value[nodetype][num_node_accum: num_node_accum + num_node] graph['num_nodes_dict'][nodetype] = num_node num_node_accum_dict[nodetype] += num_node @@ -469,7 +527,7 @@ def read_csv_heterograph_raw(raw_dir, add_inverse_edge = False, additional_node_ return graph_list -def read_binary_heterograph_raw(raw_dir, add_inverse_edge = False): +def read_binary_heterograph_raw(raw_dir, add_inverse_edge=False): ''' raw_dir: path to the raw directory add_inverse_edge (bool): whether to add inverse edge or not @@ -484,16 +542,16 @@ def read_binary_heterograph_raw(raw_dir, add_inverse_edge = False): - node_feat_dict node_feat_dict[nodetype] = node_feat for nodetype - + - num_nodes_dict num_nodes_dict[nodetype] = num_nodes for nodetype - + * edge_feat_dict and node_feat_dict are optional: if a graph does not contain it, we will simply have None. We can also have additional node/edge features. For example, - edge_** - node_** - + ''' if add_inverse_edge: @@ -511,7 +569,7 @@ def read_binary_heterograph_raw(raw_dir, add_inverse_edge = False): tmp = read_npz_dict(osp.join(raw_dir, 'edge_index_dict.npz')) edge_index_dict = {tuple(key.split('___')): tmp[key] for key in tmp.keys()} del tmp - + ent_type_list = sorted(list(num_nodes_dict.keys())) triplet_type_list = sorted(list(num_edges_dict.keys())) @@ -546,7 +604,8 @@ def read_binary_heterograph_raw(raw_dir, add_inverse_edge = False): del tmp edge_feat_dict_dict[feat_name] = feat_dict else: - raise RuntimeError(f"Keys in graph object should start from either \'node_\' or \'edge_\', but found \'{feat_name}\'.") + raise RuntimeError( + f"Keys in graph object should start from either \'node_\' or \'edge_\', but found \'{feat_name}\'.") graph_list = [] num_nodes_accum_dict = {ent_type: 0 for ent_type in ent_type_list} @@ -580,13 +639,13 @@ def read_binary_heterograph_raw(raw_dir, add_inverse_edge = False): num_edges_accum = num_edges_accum_dict[triplet] ### add edge_index - graph['edge_index_dict'][triplet] = edge_index[:, num_edges_accum:num_edges_accum+num_edges] + graph['edge_index_dict'][triplet] = edge_index[:, num_edges_accum:num_edges_accum + num_edges] ### add edge feature for feat_name in edge_feat_dict_dict.keys(): - if triplet in edge_feat_dict_dict[feat_name]: + if triplet in edge_feat_dict_dict[feat_name]: feat = edge_feat_dict_dict[feat_name][triplet] - graph[feat_name][triplet] = feat[num_edges_accum : num_edges_accum + num_edges] + graph[feat_name][triplet] = feat[num_edges_accum: num_edges_accum + num_edges] num_edges_accum_dict[triplet] += num_edges @@ -599,7 +658,7 @@ def read_binary_heterograph_raw(raw_dir, add_inverse_edge = False): for feat_name in node_feat_dict_dict.keys(): if ent_type in node_feat_dict_dict[feat_name]: feat = node_feat_dict_dict[feat_name][ent_type] - graph[feat_name][ent_type] = feat[num_nodes_accum : num_nodes_accum + num_nodes] + graph[feat_name][ent_type] = feat[num_nodes_accum: num_nodes_accum + num_nodes] graph['num_nodes_dict'][ent_type] = num_nodes num_nodes_accum_dict[ent_type] += num_nodes @@ -608,6 +667,7 @@ def read_binary_heterograph_raw(raw_dir, add_inverse_edge = False): return graph_list + def read_npz_dict(path): tmp = np.load(path) dict = {} @@ -616,13 +676,15 @@ def read_npz_dict(path): del tmp return dict + def read_node_label_hetero(raw_dir): df = pd.read_csv(osp.join(raw_dir, 'nodetype-has-label.csv.gz')) label_dict = {} for nodetype in df.keys(): has_label = df[nodetype].values[0] if has_label: - label_dict[nodetype] = pd.read_csv(osp.join(raw_dir, 'node-label', nodetype, 'node-label.csv.gz'), compression='gzip', header = None).values + label_dict[nodetype] = pd.read_csv(osp.join(raw_dir, 'node-label', nodetype, 'node-label.csv.gz'), + compression='gzip', header=None).values if len(label_dict) == 0: raise RuntimeError('No node label file found.') @@ -638,15 +700,19 @@ def read_nodesplitidx_split_hetero(split_dir): for nodetype in df.keys(): has_label = df[nodetype].values[0] if has_label: - train_dict[nodetype] = pd.read_csv(osp.join(split_dir, nodetype, 'train.csv.gz'), compression='gzip', header = None).values.T[0] - valid_dict[nodetype] = pd.read_csv(osp.join(split_dir, nodetype, 'valid.csv.gz'), compression='gzip', header = None).values.T[0] - test_dict[nodetype] = pd.read_csv(osp.join(split_dir, nodetype, 'test.csv.gz'), compression='gzip', header = None).values.T[0] + train_dict[nodetype] = \ + pd.read_csv(osp.join(split_dir, nodetype, 'train.csv.gz'), compression='gzip', header=None).values.T[0] + valid_dict[nodetype] = \ + pd.read_csv(osp.join(split_dir, nodetype, 'valid.csv.gz'), compression='gzip', header=None).values.T[0] + test_dict[nodetype] = \ + pd.read_csv(osp.join(split_dir, nodetype, 'test.csv.gz'), compression='gzip', header=None).values.T[0] if len(train_dict) == 0: raise RuntimeError('No split file found.') return train_dict, valid_dict, test_dict + if __name__ == '__main__': pass diff --git a/gammagl/io/read_ogb_pyg.py b/gammagl/io/read_ogb_pyg.py deleted file mode 100644 index 052fb9fd..00000000 --- a/gammagl/io/read_ogb_pyg.py +++ /dev/null @@ -1,111 +0,0 @@ -import pandas as pd -import torch -from torch_geometric.data import Data -import os.path as osp -import numpy as np -from gammagl.data import Graph -from gammagl.io.read_ogb_raw import read_csv_graph_raw, read_csv_heterograph_raw, read_binary_graph_raw, read_binary_heterograph_raw -from tqdm import tqdm - -def read_graph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False): - - if binary: - # npz - graph_list = read_binary_graph_raw(raw_dir, add_inverse_edge) - else: - # csv - graph_list = read_csv_graph_raw(raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files) - - pyg_graph_list = [] - - print('Converting graphs into PyG objects...') - - for graph in tqdm(graph_list): - g = Graph() - g.num_nodes = graph['num_nodes'] - g.edge_index = graph['edge_index'] - - del graph['num_nodes'] - del graph['edge_index'] - - if graph['edge_feat'] is not None: - g.edge_attr = graph['edge_feat'] - del graph['edge_feat'] - - if graph['node_feat'] is not None: - g.x = graph['node_feat'] - del graph['node_feat'] - - for key in additional_node_files: - g[key] = graph[key] - del graph[key] - - for key in additional_edge_files: - g[key] = graph[key] - del graph[key] - - pyg_graph_list.append(g) - - return pyg_graph_list - - -def read_heterograph_pyg(raw_dir, add_inverse_edge = False, additional_node_files = [], additional_edge_files = [], binary = False): - - if binary: - # npz - graph_list = read_binary_heterograph_raw(raw_dir, add_inverse_edge) - else: - # csv - graph_list = read_csv_heterograph_raw(raw_dir, add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files) - - pyg_graph_list = [] - - print('Converting graphs into PyG objects...') - - for graph in tqdm(graph_list): - g = Data() - - g.__num_nodes__ = graph['num_nodes_dict'] - g.num_nodes_dict = graph['num_nodes_dict'] - - # add edge connectivity - g.edge_index_dict = {} - for triplet, edge_index in graph['edge_index_dict'].items(): - g.edge_index_dict[triplet] = torch.from_numpy(edge_index) - - del graph['edge_index_dict'] - - if graph['edge_feat_dict'] is not None: - g.edge_attr_dict = {} - for triplet in graph['edge_feat_dict'].keys(): - g.edge_attr_dict[triplet] = torch.from_numpy(graph['edge_feat_dict'][triplet]) - - del graph['edge_feat_dict'] - - if graph['node_feat_dict'] is not None: - g.x_dict = {} - for nodetype in graph['node_feat_dict'].keys(): - g.x_dict[nodetype] = torch.from_numpy(graph['node_feat_dict'][nodetype]) - - del graph['node_feat_dict'] - - for key in additional_node_files: - g[key] = {} - for nodetype in graph[key].keys(): - g[key][nodetype] = torch.from_numpy(graph[key][nodetype]) - - del graph[key] - - for key in additional_edge_files: - g[key] = {} - for triplet in graph[key].keys(): - g[key][triplet] = torch.from_numpy(graph[key][triplet]) - - del graph[key] - - pyg_graph_list.append(g) - - return pyg_graph_list - -if __name__ == '__main__': - pass From 21f69790f6c0b40c6800af30d479241b8e0e884d Mon Sep 17 00:00:00 2001 From: yang_starry_sky Date: Tue, 30 Aug 2022 18:14:36 +0800 Subject: [PATCH 05/25] update --- gammagl/datasets/ogb_node.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index bb463c77..69854d7d 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -8,7 +8,7 @@ from read_ogb import read_node_label_hetero, read_nodesplitidx_split_hetero,read_graph, read_heterograph -class PygNodePropPredDataset(InMemoryDataset): +class OgbNodeDataset(InMemoryDataset): def __init__(self, name, root='dataset', transform=None, pre_transform=None, meta_dict=None): ''' - name (str): name of the dataset @@ -65,7 +65,7 @@ def __init__(self, name, root='dataset', transform=None, pre_transform=None, met self.is_hetero = self.meta_info['is hetero'] == 'True' self.binary = self.meta_info['binary'] == 'True' - super(PygNodePropPredDataset, self).__init__(self.root, transform, pre_transform) + super(OgbNodeDataset, self).__init__(self.root, transform, pre_transform) self.data, self.slices = self.load_data(self.processed_paths[0]) @property @@ -158,8 +158,6 @@ def __repr__(self): if __name__ == '__main__': - pyg_dataset = PygNodePropPredDataset(name='ogbn-mag') - print(pyg_dataset[0]) - split_index = pyg_dataset.get_idx_split() - # print(split_index) + data = OgbNodeDataset(name='ogbn-arxiv') + print(data[0]) From 09321cfd7f6a05c77907910c8e070d29f6a22c7e Mon Sep 17 00:00:00 2001 From: yang_starry_sky Date: Tue, 30 Aug 2022 18:16:58 +0800 Subject: [PATCH 06/25] update --- gammagl/datasets/ogb_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index 69854d7d..6ae80bbd 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -124,7 +124,7 @@ def process(self): pass else: data = \ - read_graph_pyg(self.raw_dir, add_inverse_edge=add_inverse_edge, additional_node_files=additional_node_files, + read_graph(self.raw_dir, add_inverse_edge=add_inverse_edge, additional_node_files=additional_node_files, additional_edge_files=additional_edge_files, binary=self.binary)[0] ### adding prediction target if self.binary: From bd694bccb57509d244030cfaa41f256d358f8d21 Mon Sep 17 00:00:00 2001 From: yang_starry_sky Date: Thu, 1 Sep 2022 15:29:17 +0800 Subject: [PATCH 07/25] support heterograph --- gammagl/datasets/ogb_node.py | 28 ++++++++++++++++- gammagl/io/read_ogb.py | 60 ++++++++++++++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index 6ae80bbd..ae764852 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -121,7 +121,33 @@ def process(self): additional_edge_files = self.meta_info['additional edge files'].split(',') if self.is_hetero: - pass + data = read_heterograph(self.raw_dir, add_inverse_edge=add_inverse_edge, + additional_node_files=additional_node_files, + additional_edge_files=additional_edge_files, binary=self.binary)[0] + + if self.binary: + tmp = np.load(osp.join(self.raw_dir, 'node-label.npz')) + node_label_dict = {} + for key in list(tmp.keys()): + node_label_dict[key] = tmp[key] + del tmp + else: + node_label_dict = read_node_label_hetero(self.raw_dir) + + data.y_dict = {} + if 'classification' in self.task_type: + for nodetype, node_label in node_label_dict.items(): + # detect if there is any nan + ''' + if np.isnan(node_label).any(): + data.y_dict[nodetype] = torch.from_numpy(node_label).to(torch.float32) + else: + data.y_dict[nodetype] = torch.from_numpy(node_label).to(torch.long) + ''' + data.y_dict[nodetype] = node_label + else: + for nodetype, node_label in node_label_dict.items(): + data.y_dict[nodetype] = node_label else: data = \ read_graph(self.raw_dir, add_inverse_edge=add_inverse_edge, additional_node_files=additional_node_files, diff --git a/gammagl/io/read_ogb.py b/gammagl/io/read_ogb.py index 106cfd49..71b6f566 100644 --- a/gammagl/io/read_ogb.py +++ b/gammagl/io/read_ogb.py @@ -48,8 +48,64 @@ def read_graph(raw_dir, add_inverse_edge=False, additional_node_files=[], additi return pyg_graph_list -def read_heterograph(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[], binary=False): - pass +def read_heterograph(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[], + binary=False): + if binary: + # npz + graph_list = read_binary_heterograph_raw(raw_dir, add_inverse_edge) + else: + # csv + graph_list = read_csv_heterograph_raw(raw_dir, add_inverse_edge, additional_node_files=additional_node_files, + additional_edge_files=additional_edge_files) + + pyg_graph_list = [] + + print('Converting graphs into PyG objects...') + + for graph in tqdm(graph_list): + g = HeteroGraph() + + g.__num_nodes__ = graph['num_nodes_dict'] + g.num_nodes_dict = graph['num_nodes_dict'] + + # add edge connectivity + g.edge_index_dict = {} + for triplet, edge_index in graph['edge_index_dict'].items(): + g.edge_index_dict[triplet] = edge_index + + del graph['edge_index_dict'] + + if graph['edge_feat_dict'] is not None: + g.edge_attr_dict = {} + for triplet in graph['edge_feat_dict'].keys(): + g.edge_attr_dict[triplet] = graph['edge_feat_dict'][triplet] + + del graph['edge_feat_dict'] + + if graph['node_feat_dict'] is not None: + g.x_dict = {} + for nodetype in graph['node_feat_dict'].keys(): + g.x_dict[nodetype] = graph['node_feat_dict'][nodetype] + + del graph['node_feat_dict'] + + for key in additional_node_files: + g[key] = {} + for nodetype in graph[key].keys(): + g[key][nodetype] = graph[key][nodetype] + + del graph[key] + + for key in additional_edge_files: + g[key] = {} + for triplet in graph[key].keys(): + g[key][triplet] = graph[key][triplet] + + del graph[key] + + pyg_graph_list.append(g) + + return pyg_graph_list ### reading raw files from a directory. From ef7e5d068ee111c2e141edce716deb938b20184c Mon Sep 17 00:00:00 2001 From: yang_starry_sky Date: Mon, 5 Sep 2022 22:34:34 +0800 Subject: [PATCH 08/25] support ogb graph dataset --- gammagl/datasets/OgbGraphData.csv | 16 ++ .../datasets/{master.csv => OgbNodeData.csv} | 0 gammagl/datasets/ogb_graph.py | 161 ++++++++++++++++++ gammagl/datasets/ogb_node.py | 5 +- 4 files changed, 179 insertions(+), 3 deletions(-) create mode 100644 gammagl/datasets/OgbGraphData.csv rename gammagl/datasets/{master.csv => OgbNodeData.csv} (100%) create mode 100644 gammagl/datasets/ogb_graph.py diff --git a/gammagl/datasets/OgbGraphData.csv b/gammagl/datasets/OgbGraphData.csv new file mode 100644 index 00000000..467d0626 --- /dev/null +++ b/gammagl/datasets/OgbGraphData.csv @@ -0,0 +1,16 @@ +,ogbg-molbace,ogbg-molbbbp,ogbg-molclintox,ogbg-molmuv,ogbg-molpcba,ogbg-molsider,ogbg-moltox21,ogbg-moltoxcast,ogbg-molhiv,ogbg-molesol,ogbg-molfreesolv,ogbg-mollipo,ogbg-molchembl,ogbg-ppa,ogbg-code2 +num tasks,1,1,2,17,128,27,12,617,1,1,1,1,1310,1,1 +eval metric,rocauc,rocauc,rocauc,ap,ap,rocauc,rocauc,rocauc,rocauc,rmse,rmse,rmse,rocauc,acc,F1 +download_name,bace,bbbp,clintox,muv,pcba,sider,tox21,toxcast,hiv,esol,freesolv,lipophilicity,chembl,ogbg_ppi_medium,code2 +version,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 +url,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/bace.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/bbbp.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/clintox.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/muv.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/pcba.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/sider.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/tox21.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/toxcast.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/esol.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/freesolv.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/lipophilicity.zip,http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/chembl.zip,http://snap.stanford.edu/ogb/data/graphproppred/ogbg_ppi_medium.zip,http://snap.stanford.edu/ogb/data/graphproppred/code2.zip +add_inverse_edge,True,True,True,True,True,True,True,True,True,True,True,True,True,True,False +data type,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,mol,, +has_node_attr,True,True,True,True,True,True,True,True,True,True,True,True,True,False,True +has_edge_attr,True,True,True,True,True,True,True,True,True,True,True,True,True,True,False +task type,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,binary classification,regression,regression,regression,binary classification,multiclass classification,subtoken prediction +num classes,2,2,2,2,2,2,2,2,2,-1,-1,-1,2,37,-1 +split,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,scaffold,species,project +additional node files,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"node_is_attributed,node_dfs_order,node_depth" +additional edge files,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None +binary,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False diff --git a/gammagl/datasets/master.csv b/gammagl/datasets/OgbNodeData.csv similarity index 100% rename from gammagl/datasets/master.csv rename to gammagl/datasets/OgbNodeData.csv diff --git a/gammagl/datasets/ogb_graph.py b/gammagl/datasets/ogb_graph.py new file mode 100644 index 00000000..0ea09c07 --- /dev/null +++ b/gammagl/datasets/ogb_graph.py @@ -0,0 +1,161 @@ +import pandas as pd +import shutil, os +import os.path as osp +import numpy as np +from gammagl.data import InMemoryDataset +from gammgl.utils.url import decide_download, download_url, extract_zip +from gammagl.io.read_ogb import read_graph + + +class OgbGraphDataset(InMemoryDataset): + def __init__(self, name, root = 'dataset', transform=None, pre_transform = None, meta_dict = None): + ''' + - name (str): name of the dataset + - root (str): root directory to store the dataset folder + - transform, pre_transform (optional): transform/pre-transform graph objects + + - meta_dict: dictionary that stores all the meta-information about data. Default is None, + but when something is passed, it uses its information. Useful for debugging for external contributers. + ''' + + self.name = name ## original name, e.g., ogbg-molhiv + + if meta_dict is None: + self.dir_name = '_'.join(name.split('-')) + + # check if previously-downloaded folder exists. + # If so, use that one. + if osp.exists(osp.join(root, self.dir_name + '_pyg')): + self.dir_name = self.dir_name + '_pyg' + + self.original_root = root + self.root = osp.join(root, self.dir_name) + + master = pd.read_csv(os.path.join(os.path.dirname(__file__), 'OgbGraphData.csv'), index_col = 0) + if not self.name in master: + error_mssg = 'Invalid dataset name {}.\n'.format(self.name) + error_mssg += 'Available datasets are as follows:\n' + error_mssg += '\n'.join(master.keys()) + raise ValueError(error_mssg) + self.meta_info = master[self.name] + + else: + self.dir_name = meta_dict['dir_path'] + self.original_root = '' + self.root = meta_dict['dir_path'] + self.meta_info = meta_dict + + # check version + # First check whether the dataset has been already downloaded or not. + # If so, check whether the dataset version is the newest or not. + # If the dataset is not the newest version, notify this to the user. + if osp.isdir(self.root) and (not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))): + print(self.name + ' has been updated.') + if input('Will you update the dataset now? (y/N)\n').lower() == 'y': + shutil.rmtree(self.root) + + self.download_name = self.meta_info['download_name'] ## name of downloaded file, e.g., tox21 + + self.num_tasks = int(self.meta_info['num tasks']) + self.eval_metric = self.meta_info['eval metric'] + self.task_type = self.meta_info['task type'] + self.__num_classes__ = int(self.meta_info['num classes']) + self.binary = self.meta_info['binary'] == 'True' + + super(PygGraphPropPredDataset, self).__init__(self.root, transform, pre_transform) + + self.data, self.slices = self.load_data(self.processed_paths[0]) + + def get_idx_split(self, split_type = None): + if split_type is None: + split_type = self.meta_info['split'] + + path = osp.join(self.root, 'split', split_type) + + # short-cut if split_dict.pt exists + if os.path.isfile(os.path.join(path, 'split_dict.pt')): + return self.load_data(os.path.join(path, 'split_dict.pt')) + + train_idx = pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header = None).values.T[0] + valid_idx = pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header = None).values.T[0] + test_idx = pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header = None).values.T[0] + + return {'train': train_idx, 'valid': valid_idx, 'test': test_idx} + + @property + def num_classes(self): + return self.__num_classes__ + + @property + def raw_file_names(self): + if self.binary: + return ['data.npz'] + else: + file_names = ['edge'] + if self.meta_info['has_node_attr'] == 'True': + file_names.append('node-feat') + if self.meta_info['has_edge_attr'] == 'True': + file_names.append('edge-feat') + return [file_name + '.csv.gz' for file_name in file_names] + + @property + def processed_file_names(self): + return 'geometric_data_processed.pt' + + def download(self): + url = self.meta_info['url'] + if decide_download(url): + path = download_url(url, self.original_root) + extract_zip(path, self.original_root) + os.unlink(path) + shutil.rmtree(self.root) + shutil.move(osp.join(self.original_root, self.download_name), self.root) + + else: + print('Stop downloading.') + shutil.rmtree(self.root) + exit(-1) + + def process(self): + ### read pyg graph list + add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True' + + if self.meta_info['additional node files'] == 'None': + additional_node_files = [] + else: + additional_node_files = self.meta_info['additional node files'].split(',') + + if self.meta_info['additional edge files'] == 'None': + additional_edge_files = [] + else: + additional_edge_files = self.meta_info['additional edge files'].split(',') + + data_list = read_graph(self.raw_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files, binary=self.binary) + + if self.task_type == 'subtoken prediction': + graph_label_notparsed = pd.read_csv(osp.join(self.raw_dir, 'graph-label.csv.gz'), compression='gzip', header = None).values + graph_label = [str(graph_label_notparsed[i][0]).split(' ') for i in range(len(graph_label_notparsed))] + + for i, g in enumerate(data_list): + g.y = graph_label[i] + + else: + if self.binary: + graph_label = np.load(osp.join(self.raw_dir, 'graph-label.npz'))['graph_label'] + else: + graph_label = pd.read_csv(osp.join(self.raw_dir, 'graph-label.csv.gz'), compression='gzip', header = None).values + + has_nan = np.isnan(graph_label).any() + + for i, g in enumerate(data_list): + g.y = graph_label[i] + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + data, slices = self.collate(data_list) + + print('Saving...') + self.save_data((data, slices), self.processed_paths[0]) + + diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index ae764852..2bbec61c 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -2,10 +2,9 @@ import shutil, os import os.path as osp import numpy as np -# from gammagl.data import Graph from gammagl.data import InMemoryDataset from gammgl.utils.ogb_url import decide_download, download_url, extract_zip -from read_ogb import read_node_label_hetero, read_nodesplitidx_split_hetero,read_graph, read_heterograph +from gammagl.io.read_ogb import read_node_label_hetero, read_nodesplitidx_split_hetero,read_graph, read_heterograph class OgbNodeDataset(InMemoryDataset): @@ -32,7 +31,7 @@ def __init__(self, name, root='dataset', transform=None, pre_transform=None, met self.original_root = root self.root = osp.join(root, self.dir_name) - master = pd.read_csv(os.path.join(os.path.dirname(__file__), 'master.csv'), index_col=0) + master = pd.read_csv(os.path.join(os.path.dirname(__file__), 'OgbNodeData.csv'), index_col=0) if not self.name in master: error_mssg = 'Invalid dataset name {}.\n'.format(self.name) error_mssg += 'Available datasets are as follows:\n' From ec483a72aaf9cff8aa66e5874ea51f18e8697748 Mon Sep 17 00:00:00 2001 From: yang_starry_sky Date: Fri, 9 Sep 2022 11:23:40 +0800 Subject: [PATCH 09/25] support ogb link dataset --- gammagl/datasets/OgbLinkData.csv | 14 +++ gammagl/datasets/ogb_graph.py | 2 +- gammagl/datasets/ogb_link.py | 142 +++++++++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 gammagl/datasets/OgbLinkData.csv create mode 100644 gammagl/datasets/ogb_link.py diff --git a/gammagl/datasets/OgbLinkData.csv b/gammagl/datasets/OgbLinkData.csv new file mode 100644 index 00000000..02597183 --- /dev/null +++ b/gammagl/datasets/OgbLinkData.csv @@ -0,0 +1,14 @@ +,ogbl-ppa,ogbl-collab,ogbl-citation2,ogbl-wikikg2,ogbl-ddi,ogbl-biokg,ogbl-vessel +eval metric,hits@100,hits@50,mrr,mrr,hits@20,mrr,rocauc +task type,link prediction,link prediction,link prediction,KG completion,link prediction,KG completion,link prediction +download_name,ppassoc,collab,citation-v2,wikikg-v2,ddi,biokg,vessel +version,1,1,1,1,1,1,1 +url,http://snap.stanford.edu/ogb/data/linkproppred/ppassoc.zip,http://snap.stanford.edu/ogb/data/linkproppred/collab.zip,http://snap.stanford.edu/ogb/data/linkproppred/citation-v2.zip,http://snap.stanford.edu/ogb/data/linkproppred/wikikg-v2.zip,http://snap.stanford.edu/ogb/data/linkproppred/ddi.zip,http://snap.stanford.edu/ogb/data/linkproppred/biokg.zip,http://snap.stanford.edu/ogb/data/linkproppred/vessel.zip +add_inverse_edge,True,True,False,False,True,False,False +has_node_attr,True,True,True,False,False,False,True +has_edge_attr,False,False,False,False,False,False,True +split,throughput,time,time,time,target,random,spatial +additional node files,None,None,node_year,None,None,None,None +additional edge files,None,"edge_weight,edge_year",None,edge_reltype,None,edge_reltype,None +is hetero,False,False,False,False,False,True,False +binary,False,False,False,False,False,False,True diff --git a/gammagl/datasets/ogb_graph.py b/gammagl/datasets/ogb_graph.py index 0ea09c07..3d20f183 100644 --- a/gammagl/datasets/ogb_graph.py +++ b/gammagl/datasets/ogb_graph.py @@ -3,7 +3,7 @@ import os.path as osp import numpy as np from gammagl.data import InMemoryDataset -from gammgl.utils.url import decide_download, download_url, extract_zip +from gammgl.utils.ogb_url import decide_download, download_url, extract_zip from gammagl.io.read_ogb import read_graph diff --git a/gammagl/datasets/ogb_link.py b/gammagl/datasets/ogb_link.py new file mode 100644 index 00000000..cb4f083c --- /dev/null +++ b/gammagl/datasets/ogb_link.py @@ -0,0 +1,142 @@ +import pandas as pd +import shutil, os +import os.path as osp +import numpy as np +from gammagl.data import InMemoryDataset +from gammgl.utils.ogb_url import decide_download, download_url, extract_zip +from gammagl.io.read_ogb import read_graph, read_heterograph + +class OgbLinkDataset(InMemoryDataset): + def __init__(self, name, root = 'dataset', transform=None, pre_transform=None, meta_dict = None): + ''' + - name (str): name of the dataset + - root (str): root directory to store the dataset folder + + - meta_dict: dictionary that stores all the meta-information about data. Default is None, + but when something is passed, it uses its information. Useful for debugging for external contributers. + ''' + + self.name = name ## original name, e.g., ogbl-ppa + + if meta_dict is None: + self.dir_name = '_'.join(name.split('-')) + + # check if previously-downloaded folder exists. + # If so, use that one. + if osp.exists(osp.join(root, self.dir_name + '_pyg')): + self.dir_name = self.dir_name + '_pyg' + + self.original_root = root + self.root = osp.join(root, self.dir_name) + + master = pd.read_csv(os.path.join(os.path.dirname(__file__), 'OgbLinkData.csv'), index_col = 0) + if not self.name in master: + error_mssg = 'Invalid dataset name {}.\n'.format(self.name) + error_mssg += 'Available datasets are as follows:\n' + error_mssg += '\n'.join(master.keys()) + raise ValueError(error_mssg) + self.meta_info = master[self.name] + + else: + self.dir_name = meta_dict['dir_path'] + self.original_root = '' + self.root = meta_dict['dir_path'] + self.meta_info = meta_dict + + # check version + # First check whether the dataset has been already downloaded or not. + # If so, check whether the dataset version is the newest or not. + # If the dataset is not the newest version, notify this to the user. + if osp.isdir(self.root) and (not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))): + print(self.name + ' has been updated.') + if input('Will you update the dataset now? (y/N)\n').lower() == 'y': + shutil.rmtree(self.root) + + self.download_name = self.meta_info['download_name'] ## name of downloaded file, e.g., ppassoc + + self.task_type = self.meta_info['task type'] + self.eval_metric = self.meta_info['eval metric'] + self.is_hetero = self.meta_info['is hetero'] == 'True' + self.binary = self.meta_info['binary'] == 'True' + + super(OgbLinkDataset, self).__init__(self.root, transform, pre_transform) + self.data, self.slices = self.load_data(self.processed_paths[0]) + + def get_edge_split(self, split_type = None): + if split_type is None: + split_type = self.meta_info['split'] + + path = osp.join(self.root, 'split', split_type) + + # short-cut if split_dict.pt exists + if os.path.isfile(os.path.join(path, 'split_dict.pt')): + return self.load_data(os.path.join(path, 'split_dict.pt')) + + train = self.load_data(osp.join(path, 'train.pt')) + valid = self.load_data(osp.join(path, 'valid.pt')) + test = self.load_data(osp.join(path, 'test.pt')) + + return {'train': train, 'valid': valid, 'test': test} + + @property + def raw_file_names(self): + if self.binary: + if self.is_hetero: + return ['edge_index_dict.npz'] + else: + return ['data.npz'] + else: + if self.is_hetero: + return ['num-node-dict.csv.gz', 'triplet-type-list.csv.gz'] + else: + file_names = ['edge'] + if self.meta_info['has_node_attr'] == 'True': + file_names.append('node-feat') + if self.meta_info['has_edge_attr'] == 'True': + file_names.append('edge-feat') + return [file_name + '.csv.gz' for file_name in file_names] + + @property + def processed_file_names(self): + return osp.join('geometric_data_processed.pt') + + def download(self): + url = self.meta_info['url'] + if decide_download(url): + path = download_url(url, self.original_root) + extract_zip(path, self.original_root) + os.unlink(path) + shutil.rmtree(self.root) + shutil.move(osp.join(self.original_root, self.download_name), self.root) + else: + print('Stop downloading.') + shutil.rmtree(self.root) + exit(-1) + + def process(self): + add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True' + + if self.meta_info['additional node files'] == 'None': + additional_node_files = [] + else: + additional_node_files = self.meta_info['additional node files'].split(',') + + if self.meta_info['additional edge files'] == 'None': + additional_edge_files = [] + else: + additional_edge_files = self.meta_info['additional edge files'].split(',') + + if self.is_hetero: + data = read_heterograph(self.raw_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files, binary=self.binary)[0] + else: + data = read_graph(self.raw_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files, binary=self.binary)[0] + + data = data if self.pre_transform is None else self.pre_transform(data) + + print('Saving...') + self.save_data(self.collate([data]), self.processed_paths[0]) + + def __repr__(self): + return '{}()'.format(self.__class__.__name__) + + From f987fa4ebefd64711d8e1c8d6eefe52bcbd1bd76 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Thu, 15 Sep 2022 19:27:19 +0800 Subject: [PATCH 10/25] add test file --- gammagl/datasets/ogb_graph.py | 3 +-- tests/datasets/test_ogb_graph.py | 8 ++++++++ tests/datasets/test_ogb_link.py | 7 +++++++ tests/datasets/test_ogb_node.py | 7 +++++++ 4 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 tests/datasets/test_ogb_graph.py create mode 100644 tests/datasets/test_ogb_link.py create mode 100644 tests/datasets/test_ogb_node.py diff --git a/gammagl/datasets/ogb_graph.py b/gammagl/datasets/ogb_graph.py index 3d20f183..16202f65 100644 --- a/gammagl/datasets/ogb_graph.py +++ b/gammagl/datasets/ogb_graph.py @@ -62,7 +62,7 @@ def __init__(self, name, root = 'dataset', transform=None, pre_transform = None, self.__num_classes__ = int(self.meta_info['num classes']) self.binary = self.meta_info['binary'] == 'True' - super(PygGraphPropPredDataset, self).__init__(self.root, transform, pre_transform) + super(OgbGraphDataset, self).__init__(self.root, transform, pre_transform) self.data, self.slices = self.load_data(self.processed_paths[0]) @@ -117,7 +117,6 @@ def download(self): exit(-1) def process(self): - ### read pyg graph list add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True' if self.meta_info['additional node files'] == 'None': diff --git a/tests/datasets/test_ogb_graph.py b/tests/datasets/test_ogb_graph.py new file mode 100644 index 00000000..0afd592f --- /dev/null +++ b/tests/datasets/test_ogb_graph.py @@ -0,0 +1,8 @@ +from gammagl.datasets.ogb_graph import OgbGraphDataset + +def test(): + data=OgbGraphDataset('ogbg-molhiv') + print(data) + print(data[0]) + +test() \ No newline at end of file diff --git a/tests/datasets/test_ogb_link.py b/tests/datasets/test_ogb_link.py new file mode 100644 index 00000000..b76c317b --- /dev/null +++ b/tests/datasets/test_ogb_link.py @@ -0,0 +1,7 @@ +from gammagl.datasets.ogb_link import OgbLinkDataset + +def test(): + data=OgbLinkDataset('ogbl-ppa') + print(data[0]) + +test() \ No newline at end of file diff --git a/tests/datasets/test_ogb_node.py b/tests/datasets/test_ogb_node.py new file mode 100644 index 00000000..a0544417 --- /dev/null +++ b/tests/datasets/test_ogb_node.py @@ -0,0 +1,7 @@ +from gammagl.datasets.ogb_node import OgbNodeDataset + +def test(): + data=OgbNodeDataset('ogbn-arxiv') + print(data[0]) + +test() \ No newline at end of file From 974f0688f6e2ba25c9342203a61bc597b171b179 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Thu, 15 Sep 2022 19:48:17 +0800 Subject: [PATCH 11/25] del pyg --- gammagl/io/read_ogb.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/gammagl/io/read_ogb.py b/gammagl/io/read_ogb.py index 71b6f566..8ef686af 100644 --- a/gammagl/io/read_ogb.py +++ b/gammagl/io/read_ogb.py @@ -15,9 +15,7 @@ def read_graph(raw_dir, add_inverse_edge=False, additional_node_files=[], additi graph_list = read_csv_graph_raw(raw_dir, add_inverse_edge, additional_node_files=additional_node_files, additional_edge_files=additional_edge_files) - pyg_graph_list = [] - - print('Converting graphs into PyG objects...') + graph_list = [] for graph in tqdm(graph_list): g = Graph() @@ -43,9 +41,9 @@ def read_graph(raw_dir, add_inverse_edge=False, additional_node_files=[], additi g[key] = graph[key] del graph[key] - pyg_graph_list.append(g) + graph_list.append(g) - return pyg_graph_list + return graph_list def read_heterograph(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[], @@ -58,9 +56,7 @@ def read_heterograph(raw_dir, add_inverse_edge=False, additional_node_files=[], graph_list = read_csv_heterograph_raw(raw_dir, add_inverse_edge, additional_node_files=additional_node_files, additional_edge_files=additional_edge_files) - pyg_graph_list = [] - - print('Converting graphs into PyG objects...') + graph_list = [] for graph in tqdm(graph_list): g = HeteroGraph() @@ -103,9 +99,9 @@ def read_heterograph(raw_dir, add_inverse_edge=False, additional_node_files=[], del graph[key] - pyg_graph_list.append(g) + graph_list.append(g) - return pyg_graph_list + return graph_list ### reading raw files from a directory. From dafdd1b4d2ec16f80023f17e559b752e12ed9aaf Mon Sep 17 00:00:00 2001 From: starry_sky Date: Thu, 15 Sep 2022 21:44:05 +0800 Subject: [PATCH 12/25] Update ogb_node.py --- gammagl/datasets/ogb_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index 2bbec61c..8870692b 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -25,8 +25,8 @@ def __init__(self, name, root='dataset', transform=None, pre_transform=None, met # check if previously-downloaded folder exists. # If so, use that one. - if osp.exists(osp.join(root, self.dir_name + '_pyg')): - self.dir_name = self.dir_name + '_pyg' + if osp.exists(osp.join(root, self.dir_name + '_gammagl')): + self.dir_name = self.dir_name + '_gammagl' self.original_root = root self.root = osp.join(root, self.dir_name) From f1aa221f77e46a32d186addb99b73ba527c241a7 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Fri, 16 Sep 2022 09:53:13 +0800 Subject: [PATCH 13/25] change pyg to gammagl --- gammagl/datasets/ogb_graph.py | 4 ++-- gammagl/datasets/ogb_link.py | 4 ++-- gammagl/datasets/ogb_node.py | 18 ------------------ 3 files changed, 4 insertions(+), 22 deletions(-) diff --git a/gammagl/datasets/ogb_graph.py b/gammagl/datasets/ogb_graph.py index 16202f65..86875077 100644 --- a/gammagl/datasets/ogb_graph.py +++ b/gammagl/datasets/ogb_graph.py @@ -25,8 +25,8 @@ def __init__(self, name, root = 'dataset', transform=None, pre_transform = None, # check if previously-downloaded folder exists. # If so, use that one. - if osp.exists(osp.join(root, self.dir_name + '_pyg')): - self.dir_name = self.dir_name + '_pyg' + if osp.exists(osp.join(root, self.dir_name + '_gammagl')): + self.dir_name = self.dir_name + '_gammagl' self.original_root = root self.root = osp.join(root, self.dir_name) diff --git a/gammagl/datasets/ogb_link.py b/gammagl/datasets/ogb_link.py index cb4f083c..1693428b 100644 --- a/gammagl/datasets/ogb_link.py +++ b/gammagl/datasets/ogb_link.py @@ -23,8 +23,8 @@ def __init__(self, name, root = 'dataset', transform=None, pre_transform=None, m # check if previously-downloaded folder exists. # If so, use that one. - if osp.exists(osp.join(root, self.dir_name + '_pyg')): - self.dir_name = self.dir_name + '_pyg' + if osp.exists(osp.join(root, self.dir_name + '_gammagl')): + self.dir_name = self.dir_name + '_gammagl' self.original_root = root self.root = osp.join(root, self.dir_name) diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index 8870692b..843ab0f0 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -136,13 +136,6 @@ def process(self): data.y_dict = {} if 'classification' in self.task_type: for nodetype, node_label in node_label_dict.items(): - # detect if there is any nan - ''' - if np.isnan(node_label).any(): - data.y_dict[nodetype] = torch.from_numpy(node_label).to(torch.float32) - else: - data.y_dict[nodetype] = torch.from_numpy(node_label).to(torch.long) - ''' data.y_dict[nodetype] = node_label else: for nodetype, node_label in node_label_dict.items(): @@ -157,17 +150,6 @@ def process(self): else: node_label = pd.read_csv(osp.join(self.raw_dir, 'node-label.csv.gz'), compression='gzip', header=None).values - ''' - if 'classification' in self.task_type: - # detect if there is any nan - if np.isnan(node_label).any(): - data.y = torch.from_numpy(node_label).to(torch.float32) - else: - data.y = torch.from_numpy(node_label).to(torch.long) - - else: - data.y = torch.from_numpy(node_label).to(torch.float32) - ''' data.y = node_label data = data if self.pre_transform is None else self.pre_transform(data) self.data = data From c0ddf813a4dc2db5a969f38f2d74baa64af45204 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Sat, 17 Sep 2022 13:13:38 +0800 Subject: [PATCH 14/25] do not use ogb_url --- gammagl/datasets/ogb_graph.py | 18 +++---- gammagl/datasets/ogb_link.py | 19 +++----- gammagl/datasets/ogb_node.py | 21 ++++---- gammagl/io/read_ogb.py | 5 +- gammagl/utils/ogb_url.py | 91 ----------------------------------- 5 files changed, 27 insertions(+), 127 deletions(-) delete mode 100644 gammagl/utils/ogb_url.py diff --git a/gammagl/datasets/ogb_graph.py b/gammagl/datasets/ogb_graph.py index 86875077..ea6f8f2a 100644 --- a/gammagl/datasets/ogb_graph.py +++ b/gammagl/datasets/ogb_graph.py @@ -3,7 +3,8 @@ import os.path as osp import numpy as np from gammagl.data import InMemoryDataset -from gammgl.utils.ogb_url import decide_download, download_url, extract_zip +from gammagl.data.download import download_url +from gammagl.data.extract import extract_zip from gammagl.io.read_ogb import read_graph @@ -104,17 +105,12 @@ def processed_file_names(self): def download(self): url = self.meta_info['url'] - if decide_download(url): - path = download_url(url, self.original_root) - extract_zip(path, self.original_root) - os.unlink(path) - shutil.rmtree(self.root) - shutil.move(osp.join(self.original_root, self.download_name), self.root) + path = download_url(url, self.original_root) + extract_zip(path, self.original_root) + os.unlink(path) + shutil.rmtree(self.root) + shutil.move(osp.join(self.original_root, self.download_name), self.root) - else: - print('Stop downloading.') - shutil.rmtree(self.root) - exit(-1) def process(self): add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True' diff --git a/gammagl/datasets/ogb_link.py b/gammagl/datasets/ogb_link.py index 1693428b..60226a56 100644 --- a/gammagl/datasets/ogb_link.py +++ b/gammagl/datasets/ogb_link.py @@ -3,7 +3,8 @@ import os.path as osp import numpy as np from gammagl.data import InMemoryDataset -from gammgl.utils.ogb_url import decide_download, download_url, extract_zip +from gammagl.data.download import download_url +from gammagl.data.extract import extract_zip from gammagl.io.read_ogb import read_graph, read_heterograph class OgbLinkDataset(InMemoryDataset): @@ -102,16 +103,12 @@ def processed_file_names(self): def download(self): url = self.meta_info['url'] - if decide_download(url): - path = download_url(url, self.original_root) - extract_zip(path, self.original_root) - os.unlink(path) - shutil.rmtree(self.root) - shutil.move(osp.join(self.original_root, self.download_name), self.root) - else: - print('Stop downloading.') - shutil.rmtree(self.root) - exit(-1) + path = download_url(url, self.original_root) + extract_zip(path, self.original_root) + os.unlink(path) + shutil.rmtree(self.root) + shutil.move(osp.join(self.original_root, self.download_name), self.root) + def process(self): add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True' diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index 843ab0f0..174339ad 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -3,8 +3,9 @@ import os.path as osp import numpy as np from gammagl.data import InMemoryDataset -from gammgl.utils.ogb_url import decide_download, download_url, extract_zip -from gammagl.io.read_ogb import read_node_label_hetero, read_nodesplitidx_split_hetero,read_graph, read_heterograph +from gammagl.data.download import download_url +from gammagl.data.extract import extract_zip +from gammagl.io.read_ogb import read_node_label_hetero, read_graph, read_heterograph class OgbNodeDataset(InMemoryDataset): @@ -95,16 +96,12 @@ def processed_file_names(self): def download(self): url = self.meta_info['url'] - if decide_download(url): - path = download_url(url, self.original_root) - extract_zip(path, self.original_root) - os.unlink(path) - shutil.rmtree(self.root) - shutil.move(osp.join(self.original_root, self.download_name), self.root) - else: - print('Stop downloading.') - shutil.rmtree(self.root) - exit(-1) + path = download_url(url, self.original_root) + extract_zip(path, self.original_root) + os.unlink(path) + shutil.rmtree(self.root) + shutil.move(osp.join(self.original_root, self.download_name), self.root) + def process(self): add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True' diff --git a/gammagl/io/read_ogb.py b/gammagl/io/read_ogb.py index 8ef686af..6cdaad23 100644 --- a/gammagl/io/read_ogb.py +++ b/gammagl/io/read_ogb.py @@ -2,9 +2,10 @@ import os.path as osp import os import numpy as np -from gammagl.utils.ogb_url import decide_download, download_url, extract_zip +from gammagl.data.download import download_url +from gammagl.data.extract import extract_zip from tqdm import tqdm -from gammagl.data import Graph +from gammagl.data import Graph,HeteroGraph def read_graph(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[], binary=False): if binary: diff --git a/gammagl/utils/ogb_url.py b/gammagl/utils/ogb_url.py deleted file mode 100644 index e16b01df..00000000 --- a/gammagl/utils/ogb_url.py +++ /dev/null @@ -1,91 +0,0 @@ -import urllib.request as ur -import zipfile -import os -import os.path as osp -from six.moves import urllib -import errno -from tqdm import tqdm - -GBFACTOR = float(1 << 30) - -def decide_download(url): - d = ur.urlopen(url) - size = int(d.info()["Content-Length"])/GBFACTOR - - ### confirm if larger than 1GB - if size > 1: - return input("This will download %.2fGB. Will you proceed? (y/N)\n" % (size)).lower() == "y" - else: - return True - -def makedirs(path): - try: - os.makedirs(osp.expanduser(osp.normpath(path))) - except OSError as e: - if e.errno != errno.EEXIST and osp.isdir(path): - raise e - -def download_url(url, folder, log=True): - r"""Downloads the content of an URL to a specific folder. - Args: - url (string): The url. - folder (string): The folder. - log (bool, optional): If :obj:`False`, will not print anything to the - console. (default: :obj:`True`) - """ - - filename = url.rpartition('/')[2] - path = osp.join(folder, filename) - - if osp.exists(path) and osp.getsize(path) > 0: # pragma: no cover - if log: - print('Using exist file', filename) - return path - - if log: - print('Downloading', url) - - makedirs(folder) - data = ur.urlopen(url) - - size = int(data.info()["Content-Length"]) - - chunk_size = 1024*1024 - num_iter = int(size/chunk_size) + 2 - - downloaded_size = 0 - - try: - with open(path, 'wb') as f: - pbar = tqdm(range(num_iter)) - for i in pbar: - chunk = data.read(chunk_size) - downloaded_size += len(chunk) - pbar.set_description("Downloaded {:.2f} GB".format(float(downloaded_size)/GBFACTOR)) - f.write(chunk) - except: - if os.path.exists(path): - os.remove(path) - raise RuntimeError('Stopped downloading due to interruption.') - - - return path - -def maybe_log(path, log=True): - if log: - print('Extracting', path) - -def extract_zip(path, folder, log=True): - r"""Extracts a zip archive to a specific folder. - Args: - path (string): The path to the tar archive. - folder (string): The folder. - log (bool, optional): If :obj:`False`, will not print anything to the - console. (default: :obj:`True`) - """ - maybe_log(path, log) - with zipfile.ZipFile(path, 'r') as f: - f.extractall(folder) - -if __name__ == "__main__": - pass \ No newline at end of file From 6902fe7761e2963e51afd7883802098712af7264 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Tue, 20 Sep 2022 10:47:59 +0800 Subject: [PATCH 15/25] Update read_ogb.py --- gammagl/io/read_ogb.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gammagl/io/read_ogb.py b/gammagl/io/read_ogb.py index 6cdaad23..c7b7bc37 100644 --- a/gammagl/io/read_ogb.py +++ b/gammagl/io/read_ogb.py @@ -16,7 +16,7 @@ def read_graph(raw_dir, add_inverse_edge=False, additional_node_files=[], additi graph_list = read_csv_graph_raw(raw_dir, add_inverse_edge, additional_node_files=additional_node_files, additional_edge_files=additional_edge_files) - graph_list = [] + result_list = [] for graph in tqdm(graph_list): g = Graph() @@ -42,9 +42,9 @@ def read_graph(raw_dir, add_inverse_edge=False, additional_node_files=[], additi g[key] = graph[key] del graph[key] - graph_list.append(g) + result_list.append(g) - return graph_list + return result_list def read_heterograph(raw_dir, add_inverse_edge=False, additional_node_files=[], additional_edge_files=[], From bc801127b057e27e470833f083952fc98916ecba Mon Sep 17 00:00:00 2001 From: starry_sky Date: Wed, 21 Sep 2022 14:37:27 +0800 Subject: [PATCH 16/25] Update ogb_node.py --- gammagl/datasets/ogb_node.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index 174339ad..ab429fa2 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -67,7 +67,27 @@ def __init__(self, name, root='dataset', transform=None, pre_transform=None, met super(OgbNodeDataset, self).__init__(self.root, transform, pre_transform) self.data, self.slices = self.load_data(self.processed_paths[0]) + def get_idx_split(self, split_type = None): + if split_type is None: + split_type = self.meta_info['split'] + path = osp.join(self.root, 'split', split_type) + + if self.is_hetero: + train_idx_dict, valid_idx_dict, test_idx_dict = read_nodesplitidx_split_hetero(path) + for nodetype in train_idx_dict.keys(): + train_idx_dict[nodetype] = train_idx_dict[nodetype] + valid_idx_dict[nodetype] = valid_idx_dict[nodetype] + test_idx_dict[nodetype] = test_idx_dict[nodetype] + + return {'train': train_idx_dict, 'valid': valid_idx_dict, 'test': test_idx_dict} + + else: + train_idx = pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header = None).values.T[0] + valid_idx = pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header = None).values.T[0] + test_idx = pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header = None).values.T[0] + + return {'train': train_idx, 'valid': valid_idx, 'test': test_idx} @property def num_classes(self): return self.__num_classes__ From 23d842e7525887b775693a92c1716548538e6061 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Wed, 21 Sep 2022 14:53:47 +0800 Subject: [PATCH 17/25] update ogb_node --- gammagl/datasets/ogb_node.py | 2 +- tests/datasets/{test_ogb_graph.py => test_ogbgraphdataset.py} | 0 tests/datasets/{test_ogb_link.py => test_ogblinkdataset.py} | 0 tests/datasets/{test_ogb_node.py => test_ogbnodedataset.py} | 0 4 files changed, 1 insertion(+), 1 deletion(-) rename tests/datasets/{test_ogb_graph.py => test_ogbgraphdataset.py} (100%) rename tests/datasets/{test_ogb_link.py => test_ogblinkdataset.py} (100%) rename tests/datasets/{test_ogb_node.py => test_ogbnodedataset.py} (100%) diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index ab429fa2..7a868860 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -5,7 +5,7 @@ from gammagl.data import InMemoryDataset from gammagl.data.download import download_url from gammagl.data.extract import extract_zip -from gammagl.io.read_ogb import read_node_label_hetero, read_graph, read_heterograph +from gammagl.io.read_ogb import read_node_label_hetero, read_graph, read_heterograph, read_nodesplitidx_split_hetero class OgbNodeDataset(InMemoryDataset): diff --git a/tests/datasets/test_ogb_graph.py b/tests/datasets/test_ogbgraphdataset.py similarity index 100% rename from tests/datasets/test_ogb_graph.py rename to tests/datasets/test_ogbgraphdataset.py diff --git a/tests/datasets/test_ogb_link.py b/tests/datasets/test_ogblinkdataset.py similarity index 100% rename from tests/datasets/test_ogb_link.py rename to tests/datasets/test_ogblinkdataset.py diff --git a/tests/datasets/test_ogb_node.py b/tests/datasets/test_ogbnodedataset.py similarity index 100% rename from tests/datasets/test_ogb_node.py rename to tests/datasets/test_ogbnodedataset.py From 646dad4e892d27d3d505f7533b4742ae9410a469 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Wed, 21 Sep 2022 16:49:49 +0800 Subject: [PATCH 18/25] Update ogb_node.py --- gammagl/datasets/ogb_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gammagl/datasets/ogb_node.py b/gammagl/datasets/ogb_node.py index 7a868860..f82f357f 100644 --- a/gammagl/datasets/ogb_node.py +++ b/gammagl/datasets/ogb_node.py @@ -5,7 +5,7 @@ from gammagl.data import InMemoryDataset from gammagl.data.download import download_url from gammagl.data.extract import extract_zip -from gammagl.io.read_ogb import read_node_label_hetero, read_graph, read_heterograph, read_nodesplitidx_split_hetero +from gammagl.io.read_ogb import read_node_label_hetero, read_graph, read_heterograph,read_nodesplitidx_split_hetero class OgbNodeDataset(InMemoryDataset): From 1c48b273d1c5d22e115d693b262fb6b8fd16efa4 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Thu, 27 Oct 2022 10:59:47 +0800 Subject: [PATCH 19/25] upload unimp --- examples/unimp/readme.md | 32 ++++++ examples/unimp/unimp_trainer.py | 168 ++++++++++++++++++++++++++++++++ 2 files changed, 200 insertions(+) create mode 100644 examples/unimp/readme.md create mode 100644 examples/unimp/unimp_trainer.py diff --git a/examples/unimp/readme.md b/examples/unimp/readme.md new file mode 100644 index 00000000..15cf3c0d --- /dev/null +++ b/examples/unimp/readme.md @@ -0,0 +1,32 @@ +# Graph Convolutional Networks (GCN) + +- Paper link: [https://arxiv.org/abs/2009.03509](https://arxiv.org/abs/2009.03509) + +# Dataset Statics + +| Dataset | # Nodes | # Edges | # Classes | +|----------|---------|---------|-----------| +| Cora | 2,708 | 10,556 | 7 | +| Citeseer | 3,327 | 9,228 | 6 | +| Pubmed | 19,717 | 88,651 | 3 | + +Refer to [Planetoid](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Planetoid). + +Results +------- + +```bash +# available dataset: "cora", "citeseer", "pubmed" +TL_BACKEND="tensorflow" python unimp_trainer.py --dataset cora +TL_BACKEND="tensorflow" python unimp_trainer.py --dataset citeseer +TL_BACKEND="tensorflow" python unimp_trainer.py --dataset pubmed +TL_BACKEND="torch" python unimp_trainer.py --dataset cora +TL_BACKEND="torch" python unimp_trainer.py --dataset citeseer +TL_BACKEND="torch" python unimp_trainer.py --dataset pubmed +``` + +| Dataset | Our(tf) | Our(torch) | +|----------|------------|------------| +| cora | 83.10±1.12 | 82.30±0.67 | +| citeseer | 79.90±0.68 | 78.53±0.18 | +| pubmed | 74.10±1.08 | 73.63±0.12 | diff --git a/examples/unimp/unimp_trainer.py b/examples/unimp/unimp_trainer.py new file mode 100644 index 00000000..9110985d --- /dev/null +++ b/examples/unimp/unimp_trainer.py @@ -0,0 +1,168 @@ +import os +os.environ['CUDA_VISIBLE_DEVICES']='0' +import random +import argparse +import tensorlayerx as tlx +import tensorlayerx.nn as nn +from gammagl.utils import segment_softmax +from gammagl.datasets import Planetoid +from gammagl.layers.conv import MessagePassing +from gammagl.utils import add_self_loops, mask_to_index +from tensorlayerx.model import TrainOneStep, WithLoss + +class CrossEntropyLoss(WithLoss): + def __init__(self, model, loss_func): + super(CrossEntropyLoss, self).__init__(model,loss_func) + + def forward(self, data, label): + out = self.backbone_network(data['x'], data['edge_index']) + out = tlx.gather(out, data['val_idx']) + label = tlx.reshape(tlx.gather(label, data['val_idx']),shape=(-1,)) + #print(out[0]) + #print(label[0]) + loss = self._loss_fn(out, label) + return loss + + +class MultiHead(MessagePassing): + def __init__(self, in_features, out_features, n_heads,num_nodes): + super().__init__() + self.heads=n_heads + self.num_nodes=num_nodes + self.out_channels=out_features + self.linear = tlx.layers.Linear(out_features=out_features* n_heads, + in_features=in_features) + + init = tlx.initializers.RandomNormal() + self.att_src = init(shape=(1, n_heads, out_features), dtype=tlx.float32) + self.att_dst = init(shape=(1, n_heads, out_features), dtype=tlx.float32) + + self.leaky_relu = tlx.layers.LeakyReLU(0.2) + self.dropout = tlx.layers.Dropout() + + def message(self, x, edge_index): + node_src = edge_index[0, :] + node_dst = edge_index[1, :] + weight_src = tlx.gather(tlx.reduce_sum(x * self.att_src, -1), node_src) + weight_dst = tlx.gather(tlx.reduce_sum(x * self.att_dst, -1), node_dst) + weight = self.leaky_relu(weight_src + weight_dst) + + alpha = self.dropout(segment_softmax(weight, node_dst, self.num_nodes)) + x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1) + return x + + + def forward(self, x, edge_index): + x = tlx.reshape(self.linear(x), shape=(-1,self.heads, self.out_channels)) + x = self.propagate(x, edge_index, num_nodes=self.num_nodes) + x=tlx.ops.reduce_mean(x,axis=1) + + return x + + +class Unimp(tlx.nn.Module): + def __init__(self,dataset): + super(Unimp, self).__init__() + + out_layer1=int(dataset.num_node_features/2) + self.layer1=MultiHead(dataset.num_node_features+1, out_layer1, 4,dataset[0].num_nodes) + self.norm1=nn.LayerNorm(out_layer1) + self.relu1=nn.ReLU() + + self.layer2=MultiHead(out_layer1, dataset.num_classes, 4,dataset[0].num_nodes) + self.norm2=nn.LayerNorm(dataset.num_classes) + self.relu2=nn.ReLU() + def forward(self, x, edge_index): + out1 = self.layer1(x, edge_index) + out2=self.norm1(out1) + out3=self.relu1(out2) + out4=self.layer2(out3,edge_index) + out5 = self.norm2(out4) + out6 = self.relu2(out5) + return out6 + +def calculate_acc(logits, y, metrics): + metrics.update(logits, y) + rst = metrics.result() + metrics.reset() + return rst +def get_label_mask(label,node,dtype): + mask=[1 for i in range(node['train_node1'])]+[0 for i in range(node['train_node2'])] + random.shuffle(mask) + label_mask=[] + for i in range(node['train_node']): + if mask[i]==0: + label_mask.append([-1]) + else: + label_mask.append([(int)(label[i])]) + label_mask+=[[0] for i in range(node['num_node']-node['train_node'])] + return tlx.ops.convert_to_tensor(label_mask,dtype=dtype) + +def merge_feature_label(label,feature): + return tlx.ops.concat([label,feature],axis=1) +def main(args): + dataset = Planetoid(root='./',name=args.dataset) + graph=dataset[0] + feature=graph.x + edge_index=graph.edge_index + label=graph.y + train_node=int(graph.num_nodes * 0.3) + train_node1=int(graph.num_nodes * 0.1) + node = { + 'train_node': train_node, + 'train_node1': train_node1, + 'train_node2': train_node-train_node1, + 'num_node': graph.num_nodes + } + val_mask = tlx.ops.concat( + [tlx.ops.zeros((train_node, 1),dtype=tlx.int32), + tlx.ops.ones((train_node-train_node1, 1),dtype=tlx.int32)],axis=0) + test_mask=graph.test_mask + model=Unimp(dataset) + loss = tlx.losses.softmax_cross_entropy_with_logits + optimizer = tlx.optimizers.Adam(lr=0.01, weight_decay=5e-4) + train_weights = model.trainable_weights + loss_func = CrossEntropyLoss(model, loss) + train_one_step = TrainOneStep(loss_func, optimizer, train_weights) + val_idx = mask_to_index(val_mask) + test_idx = mask_to_index(test_mask) + metrics = tlx.metrics.Accuracy() + data = { + "x": feature, + "y": label, + "edge_index": edge_index, + "val_idx":val_idx, + "test_idx": test_idx, + "num_nodes": graph.num_nodes, + } + + epochs=args.epochs + best_val_acc=0 + for epoch in range(epochs): + model.set_train() + label_mask=get_label_mask(label,node,feature[0].dtype) + data['x']=merge_feature_label(label_mask,feature) + train_loss = train_one_step(data, graph.y) + + model.set_eval() + logits = model(data['x'], data['edge_index']) + test_logits = tlx.gather(logits, data['test_idx']) + test_y = tlx.gather(data['y'], data['test_idx']) + test_acc = calculate_acc(test_logits, test_y, metrics) + + print("Epoch [{:0>3d}] ".format(epoch + 1) + + " train loss: {:.4f}".format(train_loss.item()) + + " val acc: {:.4f}".format(test_acc)) + + # save best model on evaluation set + if test_acc > best_val_acc: + best_val_acc = test_acc + model.save_weights('./'+ 'unimp' + ".npz", format='npz_dict') + print("The Best ACC : {:.4f}".format(best_val_acc)) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--epochs", type=int, default=200, help="number of epoch") + parser.add_argument('--dataset', type=str, default='cora', help='dataset') + args = parser.parse_args() + main(args) \ No newline at end of file From de1095e4be973c55cbe7c11bf3088270b2b1fd29 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Thu, 27 Oct 2022 11:06:12 +0800 Subject: [PATCH 20/25] change function name --- tests/datasets/test_ogbgraphdataset.py | 4 ++-- tests/datasets/test_ogblinkdataset.py | 4 ++-- tests/datasets/test_ogbnodedataset.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/datasets/test_ogbgraphdataset.py b/tests/datasets/test_ogbgraphdataset.py index 0afd592f..43936a25 100644 --- a/tests/datasets/test_ogbgraphdataset.py +++ b/tests/datasets/test_ogbgraphdataset.py @@ -1,8 +1,8 @@ from gammagl.datasets.ogb_graph import OgbGraphDataset -def test(): +def test_ogbgraphdataset(): data=OgbGraphDataset('ogbg-molhiv') print(data) print(data[0]) -test() \ No newline at end of file +test_ogbgraphdataset() \ No newline at end of file diff --git a/tests/datasets/test_ogblinkdataset.py b/tests/datasets/test_ogblinkdataset.py index b76c317b..2f22b3ff 100644 --- a/tests/datasets/test_ogblinkdataset.py +++ b/tests/datasets/test_ogblinkdataset.py @@ -1,7 +1,7 @@ from gammagl.datasets.ogb_link import OgbLinkDataset -def test(): +def test_ogblinkdataset(): data=OgbLinkDataset('ogbl-ppa') print(data[0]) -test() \ No newline at end of file +test_ogblinkdataset() \ No newline at end of file diff --git a/tests/datasets/test_ogbnodedataset.py b/tests/datasets/test_ogbnodedataset.py index a0544417..8be9198b 100644 --- a/tests/datasets/test_ogbnodedataset.py +++ b/tests/datasets/test_ogbnodedataset.py @@ -1,7 +1,7 @@ from gammagl.datasets.ogb_node import OgbNodeDataset -def test(): +def test_ogbnodedataset(): data=OgbNodeDataset('ogbn-arxiv') print(data[0]) -test() \ No newline at end of file +test_ogbnodedataset() \ No newline at end of file From 62c8e26eff8cfe3225166a2e78c09f966a4eed6c Mon Sep 17 00:00:00 2001 From: starry_sky Date: Fri, 28 Oct 2022 11:32:48 +0800 Subject: [PATCH 21/25] slipt unimp_trainer to three file --- examples/unimp/unimp_trainer.py | 57 +++---------------------------- gammagl/layers/conv/multi_head.py | 30 ++++++++++++++++ gammagl/models/unimp.py | 24 +++++++++++++ 3 files changed, 58 insertions(+), 53 deletions(-) create mode 100644 gammagl/layers/conv/multi_head.py create mode 100644 gammagl/models/unimp.py diff --git a/examples/unimp/unimp_trainer.py b/examples/unimp/unimp_trainer.py index 9110985d..f9ef9c5a 100644 --- a/examples/unimp/unimp_trainer.py +++ b/examples/unimp/unimp_trainer.py @@ -3,11 +3,9 @@ import random import argparse import tensorlayerx as tlx -import tensorlayerx.nn as nn -from gammagl.utils import segment_softmax +from gammagl.models.unimp import Unimp from gammagl.datasets import Planetoid -from gammagl.layers.conv import MessagePassing -from gammagl.utils import add_self_loops, mask_to_index +from gammagl.utils import mask_to_index from tensorlayerx.model import TrainOneStep, WithLoss class CrossEntropyLoss(WithLoss): @@ -24,34 +22,6 @@ def forward(self, data, label): return loss -class MultiHead(MessagePassing): - def __init__(self, in_features, out_features, n_heads,num_nodes): - super().__init__() - self.heads=n_heads - self.num_nodes=num_nodes - self.out_channels=out_features - self.linear = tlx.layers.Linear(out_features=out_features* n_heads, - in_features=in_features) - - init = tlx.initializers.RandomNormal() - self.att_src = init(shape=(1, n_heads, out_features), dtype=tlx.float32) - self.att_dst = init(shape=(1, n_heads, out_features), dtype=tlx.float32) - - self.leaky_relu = tlx.layers.LeakyReLU(0.2) - self.dropout = tlx.layers.Dropout() - - def message(self, x, edge_index): - node_src = edge_index[0, :] - node_dst = edge_index[1, :] - weight_src = tlx.gather(tlx.reduce_sum(x * self.att_src, -1), node_src) - weight_dst = tlx.gather(tlx.reduce_sum(x * self.att_dst, -1), node_dst) - weight = self.leaky_relu(weight_src + weight_dst) - - alpha = self.dropout(segment_softmax(weight, node_dst, self.num_nodes)) - x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1) - return x - - def forward(self, x, edge_index): x = tlx.reshape(self.linear(x), shape=(-1,self.heads, self.out_channels)) x = self.propagate(x, edge_index, num_nodes=self.num_nodes) @@ -60,32 +30,12 @@ def forward(self, x, edge_index): return x -class Unimp(tlx.nn.Module): - def __init__(self,dataset): - super(Unimp, self).__init__() - - out_layer1=int(dataset.num_node_features/2) - self.layer1=MultiHead(dataset.num_node_features+1, out_layer1, 4,dataset[0].num_nodes) - self.norm1=nn.LayerNorm(out_layer1) - self.relu1=nn.ReLU() - - self.layer2=MultiHead(out_layer1, dataset.num_classes, 4,dataset[0].num_nodes) - self.norm2=nn.LayerNorm(dataset.num_classes) - self.relu2=nn.ReLU() - def forward(self, x, edge_index): - out1 = self.layer1(x, edge_index) - out2=self.norm1(out1) - out3=self.relu1(out2) - out4=self.layer2(out3,edge_index) - out5 = self.norm2(out4) - out6 = self.relu2(out5) - return out6 - def calculate_acc(logits, y, metrics): metrics.update(logits, y) rst = metrics.result() metrics.reset() return rst + def get_label_mask(label,node,dtype): mask=[1 for i in range(node['train_node1'])]+[0 for i in range(node['train_node2'])] random.shuffle(mask) @@ -100,6 +50,7 @@ def get_label_mask(label,node,dtype): def merge_feature_label(label,feature): return tlx.ops.concat([label,feature],axis=1) + def main(args): dataset = Planetoid(root='./',name=args.dataset) graph=dataset[0] diff --git a/gammagl/layers/conv/multi_head.py b/gammagl/layers/conv/multi_head.py new file mode 100644 index 00000000..6910e7b0 --- /dev/null +++ b/gammagl/layers/conv/multi_head.py @@ -0,0 +1,30 @@ +import tensorlayerx as tlx +from gammagl.layers.conv import MessagePassing +from gammagl.utils import segment_softmax + +class MultiHead(MessagePassing): + def __init__(self, in_features, out_features, n_heads,num_nodes): + super().__init__() + self.heads=n_heads + self.num_nodes=num_nodes + self.out_channels=out_features + self.linear = tlx.layers.Linear(out_features=out_features* n_heads, + in_features=in_features) + + init = tlx.initializers.RandomNormal() + self.att_src = init(shape=(1, n_heads, out_features), dtype=tlx.float32) + self.att_dst = init(shape=(1, n_heads, out_features), dtype=tlx.float32) + + self.leaky_relu = tlx.layers.LeakyReLU(0.2) + self.dropout = tlx.layers.Dropout() + + def message(self, x, edge_index): + node_src = edge_index[0, :] + node_dst = edge_index[1, :] + weight_src = tlx.gather(tlx.reduce_sum(x * self.att_src, -1), node_src) + weight_dst = tlx.gather(tlx.reduce_sum(x * self.att_dst, -1), node_dst) + weight = self.leaky_relu(weight_src + weight_dst) + + alpha = self.dropout(segment_softmax(weight, node_dst, self.num_nodes)) + x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1) + return x \ No newline at end of file diff --git a/gammagl/models/unimp.py b/gammagl/models/unimp.py new file mode 100644 index 00000000..36e60566 --- /dev/null +++ b/gammagl/models/unimp.py @@ -0,0 +1,24 @@ +import tensorlayerx as tlx +import tlx.nn as nn +from gammagl.layers import MultiHead + +class Unimp(tlx.nn.Module): + def __init__(self,dataset): + super(Unimp, self).__init__() + + out_layer1=int(dataset.num_node_features/2) + self.layer1=MultiHead(dataset.num_node_features+1, out_layer1, 4,dataset[0].num_nodes) + self.norm1=nn.LayerNorm(out_layer1) + self.relu1=nn.ReLU() + + self.layer2=MultiHead(out_layer1, dataset.num_classes, 4,dataset[0].num_nodes) + self.norm2=nn.LayerNorm(dataset.num_classes) + self.relu2=nn.ReLU() + def forward(self, x, edge_index): + out1 = self.layer1(x, edge_index) + out2=self.norm1(out1) + out3=self.relu1(out2) + out4=self.layer2(out3,edge_index) + out5 = self.norm2(out4) + out6 = self.relu2(out5) + return out6 \ No newline at end of file From 64fbefc306561c28e4b5aa2753c27b62b898e090 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Fri, 28 Oct 2022 13:48:45 +0800 Subject: [PATCH 22/25] fix bug --- examples/unimp/unimp_trainer.py | 8 -------- gammagl/layers/conv/multi_head.py | 7 +++++++ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/unimp/unimp_trainer.py b/examples/unimp/unimp_trainer.py index f9ef9c5a..d8b071b4 100644 --- a/examples/unimp/unimp_trainer.py +++ b/examples/unimp/unimp_trainer.py @@ -22,14 +22,6 @@ def forward(self, data, label): return loss - def forward(self, x, edge_index): - x = tlx.reshape(self.linear(x), shape=(-1,self.heads, self.out_channels)) - x = self.propagate(x, edge_index, num_nodes=self.num_nodes) - x=tlx.ops.reduce_mean(x,axis=1) - - return x - - def calculate_acc(logits, y, metrics): metrics.update(logits, y) rst = metrics.result() diff --git a/gammagl/layers/conv/multi_head.py b/gammagl/layers/conv/multi_head.py index 6910e7b0..29153762 100644 --- a/gammagl/layers/conv/multi_head.py +++ b/gammagl/layers/conv/multi_head.py @@ -27,4 +27,11 @@ def message(self, x, edge_index): alpha = self.dropout(segment_softmax(weight, node_dst, self.num_nodes)) x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1) + return x + + def forward(self, x, edge_index): + x = tlx.reshape(self.linear(x), shape=(-1,self.heads, self.out_channels)) + x = self.propagate(x, edge_index, num_nodes=self.num_nodes) + x=tlx.ops.reduce_mean(x,axis=1) + return x \ No newline at end of file From 9a12bf4e2001384f64e72a7f5f1cbd6d2c38c134 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Fri, 28 Oct 2022 16:56:59 +0800 Subject: [PATCH 23/25] add doc --- gammagl/layers/conv/multi_head.py | 41 +++++++++++++++++++++++++++++++ gammagl/models/unimp.py | 23 +++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/gammagl/layers/conv/multi_head.py b/gammagl/layers/conv/multi_head.py index 29153762..57d3f51b 100644 --- a/gammagl/layers/conv/multi_head.py +++ b/gammagl/layers/conv/multi_head.py @@ -3,6 +3,47 @@ from gammagl.utils import segment_softmax class MultiHead(MessagePassing): + + r"""A module for attention mechanisms which runs through an attention mechanism several times in parallel. + + The independent attention outputs are then concatenated and linearly transformed into the expected dimension. + + Intuitively, multiple attention heads allows for attending to parts of the sequence differently (e.g. longer-term dependencies versus shorter-term dependencies). + + .. math:: + \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, + + where the attention coefficients :math:`\alpha_{i,j}` are computed as + + .. math:: + \alpha_{i,j} = + \frac{ + \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} + [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] + \right)\right)} + {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} + \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} + [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] + \right)\right)}. + + Parameters + ---------- + in_features: int + Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_features: int + Size of each output sample. + n_heads: int + Number of multi-head-attentions. + (default: :obj:`1`) + num_nodes: int + Number of nodes + + """ + def __init__(self, in_features, out_features, n_heads,num_nodes): super().__init__() self.heads=n_heads diff --git a/gammagl/models/unimp.py b/gammagl/models/unimp.py index 36e60566..af9ab867 100644 --- a/gammagl/models/unimp.py +++ b/gammagl/models/unimp.py @@ -3,6 +3,29 @@ from gammagl.layers import MultiHead class Unimp(tlx.nn.Module): + + r"""The graph attentional operator from the `"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" + `_ paper + + Parameters + ---------- + dataset: + num_node_features: int + Input feature dimension + num_nodes: int + Number of nodes + x: [num_nodes, num_node_features] + Feature of node + edge_index: [2, num_edges] + Graph connectivity in COO format + edge_attr: [num_edges, num_edge_features] + Edge feature matrix + y: [1. *] + Target to train against (may have arbitrary shape) + pos: [num_nodes, num_dimensions] + Node position matrix + """ + def __init__(self,dataset): super(Unimp, self).__init__() From a8a2bc90c2ca03f8d43e88be6f4cb470d7bcdc82 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Sat, 19 Nov 2022 17:12:58 +0800 Subject: [PATCH 24/25] change multi-head --- gammagl/layers/conv/multi_head.py | 116 +++++++++++++++++------------- 1 file changed, 67 insertions(+), 49 deletions(-) diff --git a/gammagl/layers/conv/multi_head.py b/gammagl/layers/conv/multi_head.py index 57d3f51b..69cafc59 100644 --- a/gammagl/layers/conv/multi_head.py +++ b/gammagl/layers/conv/multi_head.py @@ -1,47 +1,55 @@ import tensorlayerx as tlx from gammagl.layers.conv import MessagePassing from gammagl.utils import segment_softmax - +import math class MultiHead(MessagePassing): + r"""The graph transformer operator from the `"Masked Label Prediction: + Unified Message Passing Model for Semi-Supervised Classification" + `_ paper - r"""A module for attention mechanisms which runs through an attention mechanism several times in parallel. + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j}, - The independent attention outputs are then concatenated and linearly transformed into the expected dimension. - - Intuitively, multiple attention heads allows for attending to parts of the sequence differently (e.g. longer-term dependencies versus shorter-term dependencies). - - .. math:: - \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + - \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, + where the attention coefficients :math:`\alpha_{i,j}` are computed via + multi-head dot product attention: + + .. math:: + \alpha_{i,j} = \textrm{softmax} \left( + \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} + {\sqrt{d}} \right) - where the attention coefficients :math:`\alpha_{i,j}` are computed as + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + .. math:: + \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} + \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i} - .. math:: - \alpha_{i,j} = - \frac{ - \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} - [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] - \right)\right)} - {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} - \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} - [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] - \right)\right)}. + with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} + [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 + \mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`) - Parameters - ---------- - in_features: int - Size of each input sample, or :obj:`-1` to - derive the size from the first input(s) to the forward method. - A tuple corresponds to the sizes of source and target - dimensionalities. - out_features: int - Size of each output sample. - n_heads: int - Number of multi-head-attentions. - (default: :obj:`1`) - num_nodes: int - Number of nodes - + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( + \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} + \right), + + where the attention coefficients :math:`\alpha_{i,j}` are now + computed via: + + .. math:: + \alpha_{i,j} = \textrm{softmax} \left( + \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} + (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} + {\sqrt{d}} \right) """ def __init__(self, in_features, out_features, n_heads,num_nodes): @@ -52,27 +60,37 @@ def __init__(self, in_features, out_features, n_heads,num_nodes): self.linear = tlx.layers.Linear(out_features=out_features* n_heads, in_features=in_features) + self.lin_key = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features) + self.lin_query = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features) + self.lin_value = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features) + self.lin_skip = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features) init = tlx.initializers.RandomNormal() self.att_src = init(shape=(1, n_heads, out_features), dtype=tlx.float32) self.att_dst = init(shape=(1, n_heads, out_features), dtype=tlx.float32) self.leaky_relu = tlx.layers.LeakyReLU(0.2) self.dropout = tlx.layers.Dropout() + self.reset_parameters() - def message(self, x, edge_index): - node_src = edge_index[0, :] - node_dst = edge_index[1, :] - weight_src = tlx.gather(tlx.reduce_sum(x * self.att_src, -1), node_src) - weight_dst = tlx.gather(tlx.reduce_sum(x * self.att_dst, -1), node_dst) - weight = self.leaky_relu(weight_src + weight_dst) + def reset_parameters(self): + self.lin_key.reset_parameters() + self.lin_query.reset_parameters() + self.lin_value.reset_parameters() + self.lin_skip.reset_parameters() - alpha = self.dropout(segment_softmax(weight, node_dst, self.num_nodes)) - x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1) - return x + def message(self, query, key, value): + alpha = (query * key).sum(dim=-1) / math.sqrt(self.out_channels) + alpha = segment_softmax(alpha) + alpha = tlx.layers.Dropout(alpha) + out = value + out = out * alpha.view(-1, self.heads, 1) + return out def forward(self, x, edge_index): - x = tlx.reshape(self.linear(x), shape=(-1,self.heads, self.out_channels)) - x = self.propagate(x, edge_index, num_nodes=self.num_nodes) - x=tlx.ops.reduce_mean(x,axis=1) - - return x \ No newline at end of file + H, C = self.heads, self.out_channels + query = self.lin_query(x[1]).view(-1, H, C) + key = self.lin_key(x[0]).view(-1, H, C) + value = self.lin_value(x[0]).view(-1, H, C) + out = self.propagate(edge_index, query=query, key=key, value=value) + out = out.view(-1, self.heads * self.out_channels) + return out \ No newline at end of file From 9b2a3d5fc6b10db1bef635270cb1c19ac1a87454 Mon Sep 17 00:00:00 2001 From: starry_sky Date: Mon, 21 Nov 2022 17:00:20 +0800 Subject: [PATCH 25/25] Update multi_head.py --- gammagl/layers/conv/multi_head.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/gammagl/layers/conv/multi_head.py b/gammagl/layers/conv/multi_head.py index 69cafc59..decd185a 100644 --- a/gammagl/layers/conv/multi_head.py +++ b/gammagl/layers/conv/multi_head.py @@ -31,7 +31,7 @@ class MultiHead(MessagePassing): \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i} - + beta with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 \mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`) @@ -52,24 +52,21 @@ class MultiHead(MessagePassing): {\sqrt{d}} \right) """ - def __init__(self, in_features, out_features, n_heads,num_nodes): + def __init__(self, in_features, out_features, n_heads,num_nodes,beta=True): super().__init__() + self.beta=beta self.heads=n_heads self.num_nodes=num_nodes self.out_channels=out_features self.linear = tlx.layers.Linear(out_features=out_features* n_heads, in_features=in_features) - self.lin_key = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features) - self.lin_query = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features) - self.lin_value = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features) - self.lin_skip = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features) - init = tlx.initializers.RandomNormal() - self.att_src = init(shape=(1, n_heads, out_features), dtype=tlx.float32) - self.att_dst = init(shape=(1, n_heads, out_features), dtype=tlx.float32) - - self.leaky_relu = tlx.layers.LeakyReLU(0.2) - self.dropout = tlx.layers.Dropout() + self.lin_key = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features, bias=True) + self.lin_query = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features, bias=True) + self.lin_value = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features, bias=True) + self.lin_skip = tlx.layers.Linear(in_features=in_features, out_features=n_heads * out_features, bias=True) + if beta: + self.lin_beta = tlx.layers.Linear(3 * n_heads * out_features, 1, bias=False) self.reset_parameters() def reset_parameters(self): @@ -77,6 +74,8 @@ def reset_parameters(self): self.lin_query.reset_parameters() self.lin_value.reset_parameters() self.lin_skip.reset_parameters() + if self.beta: + self.lin_beta.reset_parameters() def message(self, query, key, value): alpha = (query * key).sum(dim=-1) / math.sqrt(self.out_channels) @@ -93,4 +92,9 @@ def forward(self, x, edge_index): value = self.lin_value(x[0]).view(-1, H, C) out = self.propagate(edge_index, query=query, key=key, value=value) out = out.view(-1, self.heads * self.out_channels) + if self.beta: + x_r = self.lin_skip(x[1]) + beta = self.lin_beta(tlx.ops.concat([out, x_r, out - x_r], aixs=-1)) + beta = beta.sigmoid() + out = beta * x_r + (1 - beta) * out return out \ No newline at end of file