diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..baf11ba Binary files /dev/null and b/.DS_Store differ diff --git a/README.md b/README.md index f00756e..54ce413 100644 --- a/README.md +++ b/README.md @@ -1 +1,31 @@ -# GraphCTA \ No newline at end of file +# GraphCTA +Collaborate to Adapt: Source-Free Graph Domain Adaptation via Bi-directional Adaptation (WWW 2024) + +![](https://github.com/cszhangzhen/GraphCTA/blob/main/fig/model.png) + +This is a PyTorch implementation of the GraphCTA algorithm, which tries to address the domain adaptation problem without accessing the labelled source graph. It performs model adaptation and graph adaptation collaboratively through a series of procedures: (1) conduct model adaptation based on node's neighborhood predictions in target graph considering both local and global information; (2) perform graph adaptation by updating graph structure and node attributes via neighborhood constrastive learning; and (3) the updated graph serves as an input to facilitate the subsequent iteration of model adaptation, thereby establishing a collaborative loop between model adaptation and graph adaptation. + + +## Requirements +* python3.8 +* pytorch==2.0.0 +* torch-scatter==2.1.1+pt20cu118 +* torch-sparse==0.6.17+pt20cu118 +* torch-cluster==1.6.1+pt20cu118 +* torch-geometric==2.3.1 +* numpy==1.24.3 +* scipy==1.10.1 +* tqdm==4.65.0 + +## Datasets +Datasets used in the paper are all publicly available datasets. You can find [Elliptic](https://www.kaggle.com/datasets/ellipticco/elliptic-data-set), [Twitch](https://github.com/benedekrozemberczki/datasets#twitch-social-networks) and [Citation](https://github.com/yuntaodu/ASN/tree/main/data) via the links. + +## Quick Start: +Just execuate the following command for source model pre-training: +``` +python train_source.py +``` +Then, execuate the following command for adaptation: +``` +python train_target.py +``` diff --git a/data/Citation.zip b/data/Citation.zip new file mode 100644 index 0000000..967ba4e Binary files /dev/null and b/data/Citation.zip differ diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000..073845f --- /dev/null +++ b/datasets.py @@ -0,0 +1,271 @@ +import os.path as osp +import torch +import numpy as np +from torch_geometric.data import InMemoryDataset, Data +from torch_geometric.io import read_txt_array +import torch.nn.functional as F + +import scipy +import pickle as pkl +import csv +import json + +import warnings +warnings.filterwarnings('ignore', category=DeprecationWarning) + + +class CitationDataset(InMemoryDataset): + def __init__(self, + root, + name, + transform=None, + pre_transform=None, + pre_filter=None): + self.name = name + self.root = root + super(CitationDataset, self).__init__(root, transform, pre_transform, pre_filter) + + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + return ["docs.txt", "edgelist.txt", "labels.txt"] + + @property + def processed_file_names(self): + return ['data.pt'] + + def download(self): + pass + + def process(self): + edge_path = osp.join(self.raw_dir, '{}_edgelist.txt'.format(self.name)) + edge_index = read_txt_array(edge_path, sep=',', dtype=torch.long).t() + + docs_path = osp.join(self.raw_dir, '{}_docs.txt'.format(self.name)) + f = open(docs_path, 'rb') + content_list = [] + for line in f.readlines(): + line = str(line, encoding="utf-8") + content_list.append(line.split(",")) + x = np.array(content_list, dtype=float) + x = torch.from_numpy(x).to(torch.float) + + label_path = osp.join(self.raw_dir, '{}_labels.txt'.format(self.name)) + f = open(label_path, 'rb') + content_list = [] + for line in f.readlines(): + line = str(line, encoding="utf-8") + line = line.replace("\r", "").replace("\n", "") + content_list.append(line) + y = np.array(content_list, dtype=int) + y = torch.from_numpy(y).to(torch.int64) + + data_list = [] + data = Data(edge_index=edge_index, x=x, y=y) + + random_node_indices = np.random.permutation(y.shape[0]) + training_size = int(len(random_node_indices) * 0.8) + val_size = int(len(random_node_indices) * 0.1) + train_node_indices = random_node_indices[:training_size] + val_node_indices = random_node_indices[training_size:training_size + val_size] + test_node_indices = random_node_indices[training_size + val_size:] + + train_masks = torch.zeros([y.shape[0]], dtype=torch.bool) + train_masks[train_node_indices] = 1 + val_masks = torch.zeros([y.shape[0]], dtype=torch.bool) + val_masks[val_node_indices] = 1 + test_masks = torch.zeros([y.shape[0]], dtype=torch.bool) + test_masks[test_node_indices] = 1 + + data.train_mask = train_masks + data.val_mask = val_masks + data.test_mask = test_masks + + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + data, slices = self.collate([data]) + + torch.save((data, slices), self.processed_paths[0]) + + +class EllipticDataset(InMemoryDataset): + def __init__(self, + root, + name, + transform=None, + pre_transform=None, + pre_filter=None): + self.name = name + self.root = root + super(EllipticDataset, self).__init__(root, transform, pre_transform, pre_filter) + + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + return [".pkl"] + + @property + def processed_file_names(self): + return ['data.pt'] + + def download(self): + pass + + def process(self): + path = osp.join(self.raw_dir, '{}.pkl'.format(self.name)) + result = pkl.load(open(path, 'rb')) + A, label, features = result + label = label + 1 + edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long) + features = np.array(features) + x = torch.from_numpy(features).to(torch.float) + y = torch.tensor(label).to(torch.int64) + + data_list = [] + data = Data(edge_index=edge_index, x=x, y=y) + + random_node_indices = np.random.permutation(y.shape[0]) + training_size = int(len(random_node_indices) * 0.8) + val_size = int(len(random_node_indices) * 0.1) + train_node_indices = random_node_indices[:training_size] + val_node_indices = random_node_indices[training_size:training_size + val_size] + test_node_indices = random_node_indices[training_size + val_size:] + + train_masks = torch.zeros([y.shape[0]], dtype=torch.bool) + train_masks[train_node_indices] = 1 + val_masks = torch.zeros([y.shape[0]], dtype=torch.bool) + val_masks[val_node_indices] = 1 + test_masks = torch.zeros([y.shape[0]], dtype=torch.bool) + test_masks[test_node_indices] = 1 + + data.train_mask = train_masks + data.val_mask = val_masks + data.test_mask = test_masks + + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + data, slices = self.collate([data]) + + torch.save((data, slices), self.processed_paths[0]) + + +class TwitchDataset(InMemoryDataset): + def __init__(self, + root, + name, + transform=None, + pre_transform=None, + pre_filter=None): + self.name = name + self.root = root + super(TwitchDataset, self).__init__(root, transform, pre_transform, pre_filter) + + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def raw_file_names(self): + return ["edges.csv, features.json, target.csv"] + + @property + def processed_file_names(self): + return ['data.pt'] + + def download(self): + pass + + def load_twitch(self, lang): + assert lang in ('DE', 'EN', 'FR'), 'Invalid dataset' + filepath = self.raw_dir + label = [] + node_ids = [] + src = [] + targ = [] + uniq_ids = set() + with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f: + reader = csv.reader(f) + next(reader) + for row in reader: + node_id = int(row[5]) + # handle FR case of non-unique rows + if node_id not in uniq_ids: + uniq_ids.add(node_id) + label.append(int(row[2]=="True")) + node_ids.append(int(row[5])) + + node_ids = np.array(node_ids, dtype=np.int32) + + with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f: + reader = csv.reader(f) + next(reader) + for row in reader: + src.append(int(row[0])) + targ.append(int(row[1])) + + with open(f"{filepath}/musae_{lang}_features.json", 'r') as f: + j = json.load(f) + + src = np.array(src) + targ = np.array(targ) + label = np.array(label) + + inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)} + reorder_node_ids = np.zeros_like(node_ids) + for i in range(label.shape[0]): + reorder_node_ids[i] = inv_node_ids[i] + + n = label.shape[0] + A = scipy.sparse.csr_matrix((np.ones(len(src)), (np.array(src), np.array(targ))), shape=(n,n)) + features = np.zeros((n,3170)) + for node, feats in j.items(): + if int(node) >= n: + continue + features[int(node), np.array(feats, dtype=int)] = 1 + new_label = label[reorder_node_ids] + label = new_label + + return A, label, features + + def process(self): + A, label, features = self.load_twitch(self.name) + edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long) + features = np.array(features) + x = torch.from_numpy(features).to(torch.float) + y = torch.from_numpy(label).to(torch.int64) + + data_list = [] + data = Data(edge_index=edge_index, x=x, y=y) + + random_node_indices = np.random.permutation(y.shape[0]) + training_size = int(len(random_node_indices) * 0.8) + val_size = int(len(random_node_indices) * 0.1) + train_node_indices = random_node_indices[:training_size] + val_node_indices = random_node_indices[training_size:training_size + val_size] + test_node_indices = random_node_indices[training_size + val_size:] + + train_masks = torch.zeros([y.shape[0]], dtype=torch.bool) + train_masks[train_node_indices] = 1 + val_masks = torch.zeros([y.shape[0]], dtype=torch.bool) + val_masks[val_node_indices] = 1 + test_masks = torch.zeros([y.shape[0]], dtype=torch.bool) + test_masks[test_node_indices] = 1 + + data.train_mask = train_masks + data.val_mask = val_masks + data.test_mask = test_masks + + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + data, slices = self.collate([data]) + + torch.save((data, slices), self.processed_paths[0]) \ No newline at end of file diff --git a/fig/model.png b/fig/model.png new file mode 100644 index 0000000..1727a25 Binary files /dev/null and b/fig/model.png differ diff --git a/layer.py b/layer.py new file mode 100644 index 0000000..3eb1e43 --- /dev/null +++ b/layer.py @@ -0,0 +1,536 @@ +from typing import Optional, Union, Tuple, Callable + +from torch import Tensor +from torch_sparse import SparseTensor, matmul, set_diag +import torch.nn.functional as F +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.typing import Adj, PairTensor, OptPairTensor, Size, OptTensor, NoneType, SparseTensor +from torch_scatter import scatter +from torch_geometric.nn import GCNConv +from torch.nn import Sequential, Linear +from torch_geometric.nn.dense.linear import Linear +import torch.nn as nn +import torch +from torch.nn import Parameter +from torch_geometric.utils import remove_self_loops, add_self_loops, softmax +from torch_geometric.nn.inits import glorot, zeros, reset + + +class NeighborPropagate(MessagePassing): + def __init__(self, aggr: str = 'mean', **kwargs,): + kwargs['aggr'] = aggr if aggr != 'lstm' else None + super().__init__(**kwargs) + + def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, + size: Size = None) -> Tensor: + """""" + if isinstance(x, Tensor): + x: OptPairTensor = (x, x) + + # propagate_type: (x: OptPairTensor) + out = self.propagate(edge_index, x=x, size=size) + + return out + + def message(self, x_j: Tensor) -> Tensor: + return x_j + + def aggregate(self, x: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None) -> Tensor: + return scatter(x, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr) + + +class GCN(nn.Module): + def __init__(self, args): + super(GCN, self).__init__() + self.args = args + self.num_features = args.num_features + self.nhid = args.nhid + self.num_classes = args.num_classes + self.dropout_ratio = args.dropout_ratio + self.num_layers = args.num_layers + + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + + self.convs.append(GCNConv(self.num_features, self.nhid)) + self.bns.append(nn.BatchNorm1d(self.nhid)) + + for _ in range(self.num_layers - 1): + self.convs.append(GCNConv(self.nhid, self.nhid)) + self.bns.append(nn.BatchNorm1d(self.nhid)) + + self.cls = torch.nn.Linear(self.nhid, self.num_classes) + + self.activation = F.relu + self.use_bn = args.use_bn + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None): + x = self.feat_bottleneck(x, edge_index, edge_weight) + x = self.feat_classifier(x) + + return x + + def feat_bottleneck(self, x, edge_index, edge_weight=None): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index, edge_weight) + if self.use_bn: + x = self.bns[i](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout_ratio, training=self.training) + return x + + def feat_classifier(self, x): + x = self.cls(x) + + return x + + +class SAGEConv(MessagePassing): + def __init__(self, in_channels: Union[int, Tuple[int, int]], + out_channels: int, normalize: bool = False, + bias: bool = True, **kwargs): # yapf: disable + kwargs.setdefault('aggr', 'mean') + super(SAGEConv, self).__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.normalize = normalize + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_l = Linear(in_channels[0], out_channels, bias=bias) + self.lin_r = Linear(in_channels[1], out_channels, bias=False) + + self.reset_parameters() + + def reset_parameters(self): + self.lin_l.reset_parameters() + self.lin_r.reset_parameters() + + def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, + size: Size = None) -> Tensor: + """""" + if 0: + if isinstance(x, Tensor): + x: OptPairTensor = (x, x) + # propagate_type: (x: OptPairTensor) + out = self.propagate(edge_index, x=x, size=size) + out = self.lin_l(out) + else: + if isinstance(x, Tensor): + x: OptPairTensor = (x, x) + out = self.lin_l(x[0]) + # propagate_type: (x: OptPairTensor) + out = self.propagate(edge_index, x=(out, out), size=size) + + x_r = x[1] + if x_r is not None: + out += self.lin_r(x_r) + + if self.normalize: + out = F.normalize(out, p=2., dim=-1) + + return out + + def message(self, x_j: Tensor) -> Tensor: + return x_j + + def message_and_aggregate(self, adj_t: SparseTensor, + x: OptPairTensor) -> Tensor: + # Deleted the following line to make propagation differentiable + # adj_t = adj_t.set_value(None, layout=None) + return matmul(adj_t, x[0], reduce=self.aggr) + + def __repr__(self): + return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, + self.out_channels) + + +class SAGE(nn.Module): + def __init__(self, args): + super(SAGE, self).__init__() + self.args = args + self.num_features = args.num_features + self.nhid = args.nhid + self.num_classes = args.num_classes + self.dropout_ratio = args.dropout_ratio + self.num_layers = args.num_layers + + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + + self.convs.append(SAGEConv(self.num_features, self.nhid)) + self.bns.append(nn.BatchNorm1d(self.nhid)) + + for _ in range(self.num_layers - 1): + self.convs.append(SAGEConv(self.nhid, self.nhid)) + self.bns.append(nn.BatchNorm1d(self.nhid)) + + self.cls = torch.nn.Linear(self.nhid, self.num_classes) + + self.activation = F.relu + self.use_bn = args.use_bn + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None): + x = self.feat_bottleneck(x, edge_index, edge_weight) + x = self.feat_classifier(x) + + return x + + def feat_bottleneck(self, x, edge_index, edge_weight=None): + if edge_weight is not None: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() + + for i, conv in enumerate(self.convs): + if edge_weight is not None: + x = conv(x, adj) + else: + x = conv(x, edge_index, edge_weight) + if self.use_bn: + x = self.bns[i](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout_ratio, training=self.training) + return x + + def feat_classifier(self, x): + x = self.cls(x) + + return x + + +class GATConv(MessagePassing): + _alpha: OptTensor + + def __init__(self, in_channels: Union[int, Tuple[int, int]], + out_channels: int, heads: int = 1, concat: bool = True, + negative_slope: float = 0.2, dropout: float = 0.0, + add_self_loops: bool = True, bias: bool = True, **kwargs): + kwargs.setdefault('aggr', 'add') + super(GATConv, self).__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.concat = concat + self.negative_slope = negative_slope + self.dropout = dropout + self.add_self_loops = add_self_loops + + # In case we are operating in bipartite graphs, we apply separate + # transformations 'lin_src' and 'lin_dst' to source and target nodes: + if isinstance(in_channels, int): + self.lin_src = Linear(in_channels, heads * out_channels, + bias=False, weight_initializer='glorot') + self.lin_dst = self.lin_src + else: + self.lin_src = Linear(in_channels[0], heads * out_channels, False, + weight_initializer='glorot') + self.lin_dst = Linear(in_channels[1], heads * out_channels, False, + weight_initializer='glorot') + + # The learnable parameters to compute attention coefficients: + self.att_src = Parameter(torch.Tensor(1, heads, out_channels)) + self.att_dst = Parameter(torch.Tensor(1, heads, out_channels)) + + if bias and concat: + self.bias = Parameter(torch.Tensor(heads * out_channels)) + elif bias and not concat: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + + self._alpha = None + + self.reset_parameters() + self.edge_weight = None + + def reset_parameters(self): + self.lin_src.reset_parameters() + self.lin_dst.reset_parameters() + glorot(self.att_src) + glorot(self.att_dst) + zeros(self.bias) + + def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, + size: Size = None, return_attention_weights=None, edge_weight=None): + # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa + # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa + # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa + # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa + r""" + Args: + return_attention_weights (bool, optional): If set to :obj:`True`, + will additionally return the tuple + :obj:`(edge_index, attention_weights)`, holding the computed + attention weights for each edge. (default: :obj:`None`) + """ + H, C = self.heads, self.out_channels + + # We first transform the input node features. If a tuple is passed, we + # transform source and target node features via separate weights: + if isinstance(x, Tensor): + assert x.dim() == 2, "Static graphs not supported in 'GATConv'" + x_src = x_dst = self.lin_src(x).view(-1, H, C) + else: # Tuple of source and target node features: + x_src, x_dst = x + assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" + x_src = self.lin_src(x_src).view(-1, H, C) + if x_dst is not None: + x_dst = self.lin_dst(x_dst).view(-1, H, C) + + x = (x_src, x_dst) + + # Next, we compute node-level attention coefficients, both for source + # and target nodes (if present): + alpha_src = (x_src * self.att_src).sum(dim=-1) + alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) + alpha = (alpha_src, alpha_dst) + + if self.add_self_loops: + if isinstance(edge_index, Tensor): + # We only want to add self-loops for nodes that appear both as + # source and target nodes: + num_nodes = x_src.size(0) + if x_dst is not None: + num_nodes = min(num_nodes, x_dst.size(0)) + num_nodes = min(size) if size is not None else num_nodes + # edge_index, _ = remove_self_loops(edge_index) + # edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) + edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) + edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=num_nodes) + self.edge_weight = edge_weight + # if edge_index.size(1) != self.edge_weight.shape[0]: + # self.edge_weight = None + + elif isinstance(edge_index, SparseTensor): + edge_index = set_diag(edge_index) + + # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) + out = self.propagate(edge_index, x=x, alpha=alpha, size=size) + + alpha = self._alpha + assert alpha is not None + self._alpha = None + + if self.concat: + out = out.view(-1, self.heads * self.out_channels) + else: + out = out.mean(dim=1) + + if self.bias is not None: + out += self.bias + + if isinstance(return_attention_weights, bool): + if isinstance(edge_index, Tensor): + return out, (edge_index, alpha) + elif isinstance(edge_index, SparseTensor): + return out, edge_index.set_value(alpha, layout='coo') + else: + return out + + def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, + index: Tensor, ptr: OptTensor, + size_i: Optional[int]) -> Tensor: + # Given egel-level attention coefficients for source and target nodes, + # we simply need to sum them up to "emulate" concatenation: + alpha = alpha_j if alpha_i is None else alpha_j + alpha_i + + alpha = F.leaky_relu(alpha, self.negative_slope) + alpha = softmax(alpha, index, ptr, size_i) + self._alpha = alpha # Save for later use. + alpha = F.dropout(alpha, p=self.dropout, training=self.training) + + if self.edge_weight is not None: + x_j = self.edge_weight.view(-1, 1, 1) * x_j + return x_j * alpha.unsqueeze(-1) + + def __repr__(self): + return '{}({}, {}, heads={})'.format(self.__class__.__name__, + self.in_channels, + self.out_channels, self.heads) + + +class GAT(nn.Module): + def __init__(self, args): + super(GAT, self).__init__() + self.args = args + self.num_features = args.num_features + self.nhid = args.nhid + self.num_classes = args.num_classes + self.dropout_ratio = args.dropout_ratio + self.num_layers = args.num_layers + + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + + self.convs.append(GATConv(self.num_features, self.nhid, heads=1, concat=False)) + self.bns.append(nn.BatchNorm1d(self.nhid)) + + for _ in range(self.num_layers - 1): + self.convs.append(GATConv(self.nhid, self.nhid, heads=1, concat=False)) + self.bns.append(nn.BatchNorm1d(self.nhid)) + + self.cls = torch.nn.Linear(self.nhid, self.num_classes) + + self.activation = F.relu + self.use_bn = args.use_bn + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None): + x = self.feat_bottleneck(x, edge_index, edge_weight) + x = self.feat_classifier(x) + + return x + + def _ensure_contiguousness(self, x, edge_idx, edge_weight): + if not x.is_sparse: + x = x.contiguous() + if hasattr(edge_idx, 'contiguous'): + edge_idx = edge_idx.contiguous() + if edge_weight is not None: + edge_weight = edge_weight.contiguous() + + return x, edge_idx, edge_weight + + def feat_bottleneck(self, x, edge_index, edge_weight=None): + x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight) + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index, edge_weight=edge_weight) + if self.use_bn: + x = self.bns[i](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout_ratio, training=self.training) + return x + + def feat_classifier(self, x): + x = self.cls(x) + + return x + + +class GINConv(MessagePassing): + def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, + **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + self.nn = nn + self.initial_eps = eps + if train_eps: + self.eps = torch.nn.Parameter(torch.Tensor([eps])) + else: + self.register_buffer('eps', torch.Tensor([eps])) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + reset(self.nn) + self.eps.data.fill_(self.initial_eps) + + def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, + size: Size = None) -> Tensor: + + if isinstance(x, Tensor): + x: OptPairTensor = (x, x) + + # propagate_type: (x: OptPairTensor) + out = self.propagate(edge_index, x=x, size=size) + + x_r = x[1] + if x_r is not None: + out = out + (1 + self.eps) * x_r + + return self.nn(out) + + def message(self, x_j: Tensor) -> Tensor: + return x_j + + def message_and_aggregate(self, adj_t: SparseTensor, + x: OptPairTensor) -> Tensor: + # if isinstance(adj_t, SparseTensor): + # adj_t = adj_t.set_value(None, layout=None) + return matmul(adj_t, x[0], reduce=self.aggr) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(nn={self.nn})' + + +class GIN(nn.Module): + def __init__(self, args): + super(GIN, self).__init__() + self.args = args + self.num_features = args.num_features + self.nhid = args.nhid + self.num_classes = args.num_classes + self.dropout_ratio = args.dropout_ratio + self.num_layers = args.num_layers + + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + + self.lin = torch.nn.Linear(self.num_features, self.nhid) + + self.convs.append(GINConv(Sequential(Linear(self.nhid, self.nhid)), train_eps=True)) + self.bns.append(nn.BatchNorm1d(self.nhid)) + + for _ in range(self.num_layers - 1): + self.convs.append(GINConv(Sequential(Linear(self.nhid, self.nhid)), train_eps=True)) + self.bns.append(nn.BatchNorm1d(self.nhid)) + + self.cls = torch.nn.Linear(self.nhid, self.num_classes) + + self.activation = F.relu + self.use_bn = args.use_bn + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None): + x = self.feat_bottleneck(x, edge_index, edge_weight) + x = self.feat_classifier(x) + + return x + + def feat_bottleneck(self, x, edge_index, edge_weight=None): + x = self.lin(x) + + if edge_weight is not None: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() + + for i, conv in enumerate(self.convs): + if edge_weight is not None: + x = conv(x, adj) + else: + x = conv(x, edge_index, edge_weight) + if self.use_bn: + x = self.bns[i](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout_ratio, training=self.training) + return x + + def feat_classifier(self, x): + x = self.cls(x) + + return x diff --git a/model.py b/model.py new file mode 100644 index 0000000..3514844 --- /dev/null +++ b/model.py @@ -0,0 +1,40 @@ +import torch +import torch.nn.functional as F +from layer import * + + +class Model(torch.nn.Module): + def __init__(self, args): + super(Model, self).__init__() + self.args = args + + if args.gnn == 'gcn': + self.gnn = GCN(args) + elif args.gnn == 'sage': + self.gnn = SAGE(args) + elif args.gnn == 'gat': + self.gnn = GAT(args) + elif args.gnn == 'gin': + self.gnn = GIN(args) + else: + assert args.gnn in ('gcn', 'sage', 'gat', 'gin'), 'Invalid gnn' + + self.reset_parameters() + + def reset_parameters(self): + self.gnn.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None): + x = self.feat_bottleneck(x, edge_index, edge_weight) + x = self.feat_classifier(x) + + return F.log_softmax(x, dim=1) + + def feat_bottleneck(self, x, edge_index, edge_weight=None): + x = self.gnn.feat_bottleneck(x, edge_index, edge_weight) + return x + + def feat_classifier(self, x): + x = self.gnn.feat_classifier(x) + + return x diff --git a/train_source.py b/train_source.py new file mode 100644 index 0000000..aaac6ed --- /dev/null +++ b/train_source.py @@ -0,0 +1,133 @@ +import argparse +import glob +import os +import time + +import torch +import torch.nn.functional as F +from model import * +from utils import * +from datasets import * +import numpy as np + +parser = argparse.ArgumentParser() + +parser.add_argument('--seed', type=int, default=200, help='random seed') +parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') +parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay') +parser.add_argument('--nhid', type=int, default=128, help='hidden size') +parser.add_argument('--dropout_ratio', type=float, default=0.1, help='dropout ratio') +parser.add_argument('--device', type=str, default='cuda:2', help='specify cuda devices') +parser.add_argument('--source', type=str, default='Citationv1', help='source domain data') +parser.add_argument('--target', type=str, default='DBLPv7', help='target domain data') +parser.add_argument('--epochs', type=int, default=1000, help='maximum number of epochs') +parser.add_argument('--patience', type=int, default=100, help='patience for early stopping') +parser.add_argument('--num_layers', type=int, default=2, help='number of gnn layers') +parser.add_argument('--gnn', type=str, default='gcn', help='different types of gnns') +parser.add_argument('--use_bn', type=bool, default=False, help='do not use batchnorm') + +args = parser.parse_args() + +if args.source in {'DBLPv7', 'ACMv9', 'Citationv1'}: + path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Citation', args.source) + source_dataset = CitationDataset(path, args.source) +elif args.source in {'S10', 'M10', 'E10'}: + path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Elliptic', args.source) + source_dataset = EllipticDataset(path, args.source) +elif args.source in {'DE', 'EN', 'FR'}: + path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Twitch', args.source) + source_dataset = TwitchDataset(path, args.source) + +if args.target in {'DBLPv7', 'ACMv9', 'Citationv1'}: + path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Citation', args.target) + target_dataset = CitationDataset(path, args.target) +elif args.target in {'S10', 'M10', 'E10'}: + path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Elliptic', args.target) + target_dataset = EllipticDataset(path, args.target) +elif args.target in {'DE', 'EN', 'FR'}: + path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Twitch', args.target) + target_dataset = TwitchDataset(path, args.target) + +target_data = target_dataset[0] +data = source_dataset[0] + +args.num_classes = len(np.unique(data.y.numpy())) +args.num_features = data.x.size(1) + +print(args) + +model = Model(args).to(args.device) +data = data.to(args.device) +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + +def train_source(): + min_loss = 1e10 + patience_cnt = 0 + val_loss_values = [] + best_epoch = 0 + + t = time.time() + for epoch in range(args.epochs): + model.train() + optimizer.zero_grad() + correct = 0 + output = model(data.x, data.edge_index) + train_loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask]) + train_loss.backward() + optimizer.step() + pred = output[data.train_mask].max(dim=1)[1] + correct = pred.eq(data.y[data.train_mask]).sum().item() + train_acc = correct * 1.0 / (data.train_mask).sum().item() + + val_acc, val_loss = compute_test(data.val_mask, model, data) + + print('Epoch: {:04d}'.format(epoch + 1), 'train_loss: {:.6f}'.format(train_loss), + 'train_acc: {:.6f}'.format(train_acc), 'loss_val: {:.6f}'.format(val_loss), + 'acc_val: {:.6f}'.format(val_acc), 'time: {:.6f}s'.format(time.time() - t)) + + val_loss_values.append(val_loss) + torch.save(model.state_dict(), '{}.pth'.format(epoch)) + + if val_loss_values[-1] < min_loss: + min_loss = val_loss_values[-1] + best_epoch = epoch + patience_cnt = 0 + else: + patience_cnt += 1 + + if patience_cnt == args.patience: + break + + files = glob.glob('*.pth') + for f in files: + epoch_nb = int(f.split('.')[0]) + if epoch_nb < best_epoch: + os.remove(f) + + files = glob.glob('*.pth') + for f in files: + epoch_nb = int(f.split('.')[0]) + if epoch_nb > best_epoch: + os.remove(f) + print('Optimization Finished! Total time elapsed: {:.6f}'.format(time.time() - t)) + + return best_epoch + + +if __name__ == '__main__': + if os.path.exists('model.pth'): + os.remove('model.pth') + # Model training + best_model = train_source() + # Restore best model for test set + model.load_state_dict(torch.load('{}.pth'.format(best_model))) + test_acc, test_loss = compute_test(data.test_mask, model, data) + print('Source {} test set results, loss = {:.6f}, accuracy = {:.6f}'.format(args.source, test_loss, test_acc)) + + target_data = target_data.to(args.device) + test_acc, test_loss = evaluate(target_data.x, target_data.edge_index, target_data.edge_weight, target_data.y, model) + print('Target {} test results, loss = {:.6f}, accuracy = {:.6f}'.format(args.target, test_loss, test_acc)) + + # Save model for target domain adaptation + torch.save(model.state_dict(), 'model.pth') + os.remove('{}.pth'.format(best_model)) diff --git a/train_target.py b/train_target.py new file mode 100644 index 0000000..6111ede --- /dev/null +++ b/train_target.py @@ -0,0 +1,312 @@ +import argparse +import glob +import os +import time + +import torch +import torch.nn.functional as F +from model import * +from utils import * +from layer import * +from datasets import * +import numpy as np +from torch_geometric.transforms import Constant +from torch.nn.parameter import Parameter +from torch_geometric.utils import dropout_adj +from tqdm import tqdm +import random + +from torch import Tensor + +parser = argparse.ArgumentParser() + +parser.add_argument('--seed', type=int, default=200, help='random seed') +parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') +parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay') +parser.add_argument('--nhid', type=int, default=128, help='hidden size') +parser.add_argument('--dropout_ratio', type=float, default=0.1, help='dropout ratio') +parser.add_argument('--device', type=str, default='cuda:2', help='specify cuda devices') +parser.add_argument('--source', type=str, default='Citationv1', help='source domain data') +parser.add_argument('--target', type=str, default='DBLPv7', help='target domain data') +parser.add_argument('--epochs', type=int, default=1000, help='maximum number of epochs') +parser.add_argument('--momentum', type=float, default=0.9, help='momentum') +parser.add_argument('--tau', type=float, default=0.2, help='tau') +parser.add_argument('--lamb', type=float, default=0.2, help='trade-off parameter lambda') +parser.add_argument('--num_layers', type=int, default=2, help='number of gnn layers') +parser.add_argument('--gnn', type=str, default='gcn', help='different types of gnns') +parser.add_argument('--use_bn', type=bool, default=False, help='do not use batchnorm') +parser.add_argument('--make_undirected', type=bool, default=True, help='directed graph or not') + +parser.add_argument('--ratio', type=float, default=0.2, help='structure perturbation budget') +parser.add_argument('--loop_adj', type=int, default=1, help='inner loop for adjacent update') +parser.add_argument('--loop_feat', type=int, default=2, help='inner loop for feature update') +parser.add_argument('--loop_model', type=int, default=3, help='inner loop for model update') +parser.add_argument('--debug', type=int, default=1, help='whether output intermediate results') +parser.add_argument("--K", type=int, default=5, help='number of k-nearest neighbors') + +args = parser.parse_args() + + +if args.target in {'DBLPv7', 'ACMv9', 'Citationv1'}: + path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Citation', args.target) + target_dataset = CitationDataset(path, args.target) +elif args.target in {'S10', 'M10', 'E10'}: + path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Elliptic', args.target) + target_dataset = EllipticDataset(path, args.target) +elif args.target in {'DE', 'EN', 'FR'}: + path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Twitch', args.target) + target_dataset = TwitchDataset(path, args.target) + +data = target_dataset[0] + +args.num_classes = len(np.unique(data.y.numpy())) +args.num_features = data.x.size(1) + +print(args) + +model = Model(args).to(args.device) +data = data.to(args.device) + +neighprop = NeighborPropagate() + +model.load_state_dict(torch.load('model.pth')) + +delta_feat = Parameter(torch.FloatTensor(data.x.size(0), data.x.size(1)).to(args.device)) +delta_feat.data.fill_(1e-7) +optimizer_feat = torch.optim.Adam([delta_feat], lr=0.0001, weight_decay=0.0001) + +modified_edge_index = data.edge_index.clone() +modified_edge_index = modified_edge_index[:, modified_edge_index[0] < modified_edge_index[1]] +row, col = modified_edge_index[0], modified_edge_index[1] +edge_index_id = (2 * data.x.size(0) - row - 1) * row // 2 + col - row - 1 +edge_index_id = edge_index_id.long() +modified_edge_index = linear_to_triu_idx(data.x.size(0), edge_index_id) +perturbed_edge_weight = torch.full_like(edge_index_id, 1e-7, dtype=torch.float32, requires_grad=True).to(args.device) + +optimizer_adj = torch.optim.Adam([perturbed_edge_weight], lr=0.0001, weight_decay=0.0001) + +n_perturbations = int(args.ratio * data.edge_index.shape[1] // 2) + +n = data.x.size(0) + +def train_target(target_data, perturbed_edge_weight): + optimizer_model = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + t = time.time() + edge_index = target_data.edge_index + edge_weight = torch.ones(edge_index.shape[1]).to(args.device) + feat = target_data.x + + mem_fea = torch.rand(target_data.x.size(0), args.nhid).to(args.device) + mem_cls = torch.ones(target_data.x.size(0), args.num_classes).to(args.device) / args.num_classes + + for it in tqdm(range(args.epochs//(args.loop_feat+args.loop_adj))): + for loop_model in range(args.loop_model): + for k,v in model.named_parameters(): + v.requires_grad = True + model.train() + feat = feat.detach() + edge_weight = edge_weight.detach() + + optimizer_model.zero_grad() + feat_output = model.feat_bottleneck(feat, edge_index, edge_weight) + cls_output = model.feat_classifier(feat_output) + + onehot = torch.nn.functional.one_hot(cls_output.argmax(1), num_classes=args.num_classes).float() + proto = (torch.mm(mem_fea.t(), onehot) / (onehot.sum(dim=0) + 1e-8)).t() + + prob = neighprop(mem_cls, edge_index) + weight, pred = torch.max(prob, dim=1) + cl, weight_ = instance_proto_alignment(feat_output, proto, pred) + ce = F.cross_entropy(cls_output, pred, reduction='none') + loss_local = torch.sum(weight_ * ce) / (torch.sum(weight_).item()) + loss = loss_local * (1 - args.lamb) + cl * args.lamb + + loss.backward() + optimizer_model.step() + print('Model: ' + str(loss.item())) + + model.eval() + with torch.no_grad(): + feat_output = model.feat_bottleneck(feat, edge_index, edge_weight) + cls_output = model.feat_classifier(feat_output) + softmax_out = F.softmax(cls_output, dim=1) + outputs_target = softmax_out**2 / ((softmax_out**2).sum(dim=0)) + + mem_cls = (1.0 - args.momentum) * mem_cls + args.momentum * outputs_target.clone() + mem_fea = (1.0 - args.momentum) * mem_fea + args.momentum * feat_output.clone() + + for k,v in model.named_parameters(): + v.requires_grad = False + + perturbed_edge_weight = perturbed_edge_weight.detach() + for loop_feat in range(args.loop_feat): + optimizer_feat.zero_grad() + delta_feat.requires_grad = True + loss = test_time_loss(model, target_data.x + delta_feat, edge_index, mem_fea, mem_cls, edge_weight) + loss.backward() + optimizer_feat.step() + print('Feat: ' + str(loss.item())) + + new_feat = (data.x + delta_feat).detach() + for loop_adj in range(args.loop_adj): + perturbed_edge_weight.requires_grad = True + edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected) + loss = test_time_loss(model, new_feat, edge_index, mem_fea, mem_cls, edge_weight) + print('Adj: ' + str(loss.item())) + + gradient = grad_with_checkpoint(loss, perturbed_edge_weight)[0] + + with torch.no_grad(): + update_edge_weights(gradient) + perturbed_edge_weight = project(n_perturbations, perturbed_edge_weight, 1e-7) + + if args.loop_adj != 0: + edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected) + edge_weight = edge_weight.detach() + + if args.loop_feat != 0: + feat = (target_data.x + delta_feat).detach() + + edge_index, edge_weight = sample_final_edges(n_perturbations, perturbed_edge_weight, target_data, modified_edge_index, mem_fea, mem_cls) + + test_acc, _ = evaluate(target_data.x + delta_feat, edge_index, edge_weight, target_data.y, model) + print('acc : ' + str(test_acc)) + print('Optimization Finished!\n') + + +def instance_proto_alignment(feat, center, pred): + feat_norm = F.normalize(feat, dim=1) + center_norm = F.normalize(center, dim=1) + sim = torch.matmul(feat_norm, center_norm.t()) + + num_nodes = feat.size(0) + weight = sim[range(num_nodes), pred] + sim = torch.exp(sim / args.tau) + pos_sim = sim[range(num_nodes), pred] + + sim_feat = torch.matmul(feat_norm, feat_norm.t()) + sim_feat = torch.exp(sim_feat / args.tau) + ident = sim_feat[range(num_nodes), range(num_nodes)] + + logit = pos_sim / (sim.sum(dim=1) - pos_sim + sim_feat.sum(dim=1) - ident + 1e-8) + loss = - torch.log(logit + 1e-8).mean() + + return loss, weight + + +def update_edge_weights(gradient): + optimizer_adj.zero_grad() + perturbed_edge_weight.grad = gradient + optimizer_adj.step() + perturbed_edge_weight.data[perturbed_edge_weight < 1e-7] = 1e-7 + + +def test_time_loss(model, feat, edge_index, mem_fea, mem_cls, edge_weight=None): + model.eval() + feat_output = model.feat_bottleneck(feat, edge_index, edge_weight) + cls_output = model.feat_classifier(feat_output) + softmax_out = F.softmax(cls_output, dim=1) + _, predict = torch.max(softmax_out, 1) + mean_ent = Entropy(softmax_out) + est_p = (mean_ent n_perturbations: + n_samples = sampled_edges.sum() + if args.debug ==2: + print(f'{i}-th sampling: too many samples {n_samples}') + continue + + perturbed_edge_weight = sampled_edges + + edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected) + with torch.no_grad(): + loss = test_time_loss(model, feat, edge_index, mem_fea, mem_cls, edge_weight) + + # Save best sample + if best_loss > loss: + best_loss = loss + print('best_loss:', best_loss.item()) + best_edges = perturbed_edge_weight.clone().cpu() + + # Recover best sample + perturbed_edge_weight.data.copy_(best_edges.to(args.device)) + + edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected) + edge_mask = edge_weight == 1 + make_undirected = args.make_undirected + + allowed_perturbations = 2 * n_perturbations if make_undirected else n_perturbations + edges_after_attack = edge_mask.sum() + clean_edges = edge_index.shape[1] + assert (edges_after_attack >= clean_edges - allowed_perturbations + and edges_after_attack <= clean_edges + allowed_perturbations), \ + f'{edges_after_attack} out of range with {clean_edges} clean edges and {n_perturbations} pertutbations' + + return edge_index[:, edge_mask], edge_weight[edge_mask] + +if __name__ == '__main__': + train_target(data, perturbed_edge_weight) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..f6f2add --- /dev/null +++ b/utils.py @@ -0,0 +1,132 @@ +import os.path as osp +import torch +import torch.nn.functional as F +from torch import Tensor +from torch_sparse import coalesce + +import warnings +warnings.filterwarnings('ignore', category=DeprecationWarning) + + +def grad_with_checkpoint(outputs, inputs): + inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs) + for input in inputs: + if not input.is_leaf: + input.retain_grad() + torch.autograd.backward(outputs) + + grad_outputs = [] + for input in inputs: + grad_outputs.append(input.grad.clone()) + input.grad.zero_() + return grad_outputs + + +def linear_to_triu_idx(n: int, lin_idx: torch.Tensor) -> torch.Tensor: + row_idx = ( + n + - 2 + - torch.floor(torch.sqrt(-8 * lin_idx.double() + 4 * n * (n - 1) - 7) / 2.0 - 0.5) + ).long() + col_idx = ( + lin_idx + + row_idx + + 1 - n * (n - 1) // 2 + + (n - row_idx) * ((n - row_idx) - 1) // 2 + ) + return torch.stack((row_idx, col_idx)) + + +def bisection(edge_weights, a, b, n_perturbations, epsilon=1e-5, iter_max=1e5): + def func(x): + return torch.clamp(edge_weights - x, 0, 1).sum() - n_perturbations + + miu = a + for i in range(int(iter_max)): + miu = (a + b) / 2 + # Check if middle point is root + if (func(miu) == 0.0): + break + # Decide the side to repeat the steps + if (func(miu) * func(a) < 0): + b = miu + else: + a = miu + if ((b - a) <= epsilon): + break + return miu + + +def project(n_perturbations, values, eps, inplace=False): + if not inplace: + values = values.clone() + + if torch.clamp(values, 0, 1).sum() > n_perturbations: + left = (values - 1).min() + right = values.max() + miu = bisection(values, left, right, n_perturbations) + values.data.copy_(torch.clamp( + values - miu, min=eps, max=1 - eps + )) + else: + values.data.copy_(torch.clamp(values, min=eps, max=1 - eps)) + + return values + + +def get_modified_adj(modified_edge_index, perturbed_edge_weight, n, device, edge_index, edge_weight, make_undirected=False): + if make_undirected: + modified_edge_index, modified_edge_weight = to_symmetric(modified_edge_index, perturbed_edge_weight, n) + else: + modified_edge_index, modified_edge_weight = modified_edge_index, perturbed_edge_weight + edge_index = torch.cat((edge_index.to(device), modified_edge_index), dim=-1) + edge_weight = torch.cat((edge_weight.to(device), modified_edge_weight)) + edge_index, edge_weight = coalesce(edge_index, edge_weight, m=n, n=n, op='sum') + + # Allow removal of edges + edge_weight[edge_weight > 1] = 2 - edge_weight[edge_weight > 1] + return edge_index, edge_weight + + +def to_symmetric(edge_index, edge_weight, n, op='mean'): + symmetric_edge_index = torch.cat( + (edge_index, edge_index.flip(0)), dim=-1 + ) + + symmetric_edge_weight = edge_weight.repeat(2) + + symmetric_edge_index, symmetric_edge_weight = coalesce( + symmetric_edge_index, + symmetric_edge_weight, + m=n, + n=n, + op=op + ) + return symmetric_edge_index, symmetric_edge_weight + + +def compute_test(mask, model, data): + model.eval() + output = model(data.x, data.edge_index) + loss = F.nll_loss(output[mask], data.y[mask]) + pred = output[mask].max(dim=1)[1] + correct = pred.eq(data.y[mask]).sum().item() + acc = correct * 1.0 / (mask.sum().item()) + + return acc, loss + + +def evaluate(x, edge_index, edge_weight, y, model): + model.eval() + output = model(x, edge_index, edge_weight) + loss = F.nll_loss(output, y) + pred = output.max(dim=1)[1] + correct = pred.eq(y).sum().item() + acc = correct * 1.0 / len(y) + + return acc, loss + +def Entropy(input_): + entropy = -input_ * torch.log(input_ + 1e-8) + entropy = torch.sum(entropy, dim=1) + return entropy