-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1 parent
8862a11
commit d721f51
Showing
10 changed files
with
1,455 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,31 @@ | ||
# GraphCTA | ||
# GraphCTA | ||
Collaborate to Adapt: Source-Free Graph Domain Adaptation via Bi-directional Adaptation (WWW 2024) | ||
|
||
![](https://github.com/cszhangzhen/GraphCTA/blob/main/fig/model.png) | ||
|
||
This is a PyTorch implementation of the GraphCTA algorithm, which tries to address the domain adaptation problem without accessing the labelled source graph. It performs model adaptation and graph adaptation collaboratively through a series of procedures: (1) conduct model adaptation based on node's neighborhood predictions in target graph considering both local and global information; (2) perform graph adaptation by updating graph structure and node attributes via neighborhood constrastive learning; and (3) the updated graph serves as an input to facilitate the subsequent iteration of model adaptation, thereby establishing a collaborative loop between model adaptation and graph adaptation. | ||
|
||
|
||
## Requirements | ||
* python3.8 | ||
* pytorch==2.0.0 | ||
* torch-scatter==2.1.1+pt20cu118 | ||
* torch-sparse==0.6.17+pt20cu118 | ||
* torch-cluster==1.6.1+pt20cu118 | ||
* torch-geometric==2.3.1 | ||
* numpy==1.24.3 | ||
* scipy==1.10.1 | ||
* tqdm==4.65.0 | ||
|
||
## Datasets | ||
Datasets used in the paper are all publicly available datasets. You can find [Elliptic](https://www.kaggle.com/datasets/ellipticco/elliptic-data-set), [Twitch](https://github.com/benedekrozemberczki/datasets#twitch-social-networks) and [Citation](https://github.com/yuntaodu/ASN/tree/main/data) via the links. | ||
|
||
## Quick Start: | ||
Just execuate the following command for source model pre-training: | ||
``` | ||
python train_source.py | ||
``` | ||
Then, execuate the following command for adaptation: | ||
``` | ||
python train_target.py | ||
``` |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
import os.path as osp | ||
import torch | ||
import numpy as np | ||
from torch_geometric.data import InMemoryDataset, Data | ||
from torch_geometric.io import read_txt_array | ||
import torch.nn.functional as F | ||
|
||
import scipy | ||
import pickle as pkl | ||
import csv | ||
import json | ||
|
||
import warnings | ||
warnings.filterwarnings('ignore', category=DeprecationWarning) | ||
|
||
|
||
class CitationDataset(InMemoryDataset): | ||
def __init__(self, | ||
root, | ||
name, | ||
transform=None, | ||
pre_transform=None, | ||
pre_filter=None): | ||
self.name = name | ||
self.root = root | ||
super(CitationDataset, self).__init__(root, transform, pre_transform, pre_filter) | ||
|
||
self.data, self.slices = torch.load(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self): | ||
return ["docs.txt", "edgelist.txt", "labels.txt"] | ||
|
||
@property | ||
def processed_file_names(self): | ||
return ['data.pt'] | ||
|
||
def download(self): | ||
pass | ||
|
||
def process(self): | ||
edge_path = osp.join(self.raw_dir, '{}_edgelist.txt'.format(self.name)) | ||
edge_index = read_txt_array(edge_path, sep=',', dtype=torch.long).t() | ||
|
||
docs_path = osp.join(self.raw_dir, '{}_docs.txt'.format(self.name)) | ||
f = open(docs_path, 'rb') | ||
content_list = [] | ||
for line in f.readlines(): | ||
line = str(line, encoding="utf-8") | ||
content_list.append(line.split(",")) | ||
x = np.array(content_list, dtype=float) | ||
x = torch.from_numpy(x).to(torch.float) | ||
|
||
label_path = osp.join(self.raw_dir, '{}_labels.txt'.format(self.name)) | ||
f = open(label_path, 'rb') | ||
content_list = [] | ||
for line in f.readlines(): | ||
line = str(line, encoding="utf-8") | ||
line = line.replace("\r", "").replace("\n", "") | ||
content_list.append(line) | ||
y = np.array(content_list, dtype=int) | ||
y = torch.from_numpy(y).to(torch.int64) | ||
|
||
data_list = [] | ||
data = Data(edge_index=edge_index, x=x, y=y) | ||
|
||
random_node_indices = np.random.permutation(y.shape[0]) | ||
training_size = int(len(random_node_indices) * 0.8) | ||
val_size = int(len(random_node_indices) * 0.1) | ||
train_node_indices = random_node_indices[:training_size] | ||
val_node_indices = random_node_indices[training_size:training_size + val_size] | ||
test_node_indices = random_node_indices[training_size + val_size:] | ||
|
||
train_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
train_masks[train_node_indices] = 1 | ||
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
val_masks[val_node_indices] = 1 | ||
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
test_masks[test_node_indices] = 1 | ||
|
||
data.train_mask = train_masks | ||
data.val_mask = val_masks | ||
data.test_mask = test_masks | ||
|
||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
data_list.append(data) | ||
|
||
data, slices = self.collate([data]) | ||
|
||
torch.save((data, slices), self.processed_paths[0]) | ||
|
||
|
||
class EllipticDataset(InMemoryDataset): | ||
def __init__(self, | ||
root, | ||
name, | ||
transform=None, | ||
pre_transform=None, | ||
pre_filter=None): | ||
self.name = name | ||
self.root = root | ||
super(EllipticDataset, self).__init__(root, transform, pre_transform, pre_filter) | ||
|
||
self.data, self.slices = torch.load(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self): | ||
return [".pkl"] | ||
|
||
@property | ||
def processed_file_names(self): | ||
return ['data.pt'] | ||
|
||
def download(self): | ||
pass | ||
|
||
def process(self): | ||
path = osp.join(self.raw_dir, '{}.pkl'.format(self.name)) | ||
result = pkl.load(open(path, 'rb')) | ||
A, label, features = result | ||
label = label + 1 | ||
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long) | ||
features = np.array(features) | ||
x = torch.from_numpy(features).to(torch.float) | ||
y = torch.tensor(label).to(torch.int64) | ||
|
||
data_list = [] | ||
data = Data(edge_index=edge_index, x=x, y=y) | ||
|
||
random_node_indices = np.random.permutation(y.shape[0]) | ||
training_size = int(len(random_node_indices) * 0.8) | ||
val_size = int(len(random_node_indices) * 0.1) | ||
train_node_indices = random_node_indices[:training_size] | ||
val_node_indices = random_node_indices[training_size:training_size + val_size] | ||
test_node_indices = random_node_indices[training_size + val_size:] | ||
|
||
train_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
train_masks[train_node_indices] = 1 | ||
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
val_masks[val_node_indices] = 1 | ||
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
test_masks[test_node_indices] = 1 | ||
|
||
data.train_mask = train_masks | ||
data.val_mask = val_masks | ||
data.test_mask = test_masks | ||
|
||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
data_list.append(data) | ||
|
||
data, slices = self.collate([data]) | ||
|
||
torch.save((data, slices), self.processed_paths[0]) | ||
|
||
|
||
class TwitchDataset(InMemoryDataset): | ||
def __init__(self, | ||
root, | ||
name, | ||
transform=None, | ||
pre_transform=None, | ||
pre_filter=None): | ||
self.name = name | ||
self.root = root | ||
super(TwitchDataset, self).__init__(root, transform, pre_transform, pre_filter) | ||
|
||
self.data, self.slices = torch.load(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self): | ||
return ["edges.csv, features.json, target.csv"] | ||
|
||
@property | ||
def processed_file_names(self): | ||
return ['data.pt'] | ||
|
||
def download(self): | ||
pass | ||
|
||
def load_twitch(self, lang): | ||
assert lang in ('DE', 'EN', 'FR'), 'Invalid dataset' | ||
filepath = self.raw_dir | ||
label = [] | ||
node_ids = [] | ||
src = [] | ||
targ = [] | ||
uniq_ids = set() | ||
with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f: | ||
reader = csv.reader(f) | ||
next(reader) | ||
for row in reader: | ||
node_id = int(row[5]) | ||
# handle FR case of non-unique rows | ||
if node_id not in uniq_ids: | ||
uniq_ids.add(node_id) | ||
label.append(int(row[2]=="True")) | ||
node_ids.append(int(row[5])) | ||
|
||
node_ids = np.array(node_ids, dtype=np.int32) | ||
|
||
with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f: | ||
reader = csv.reader(f) | ||
next(reader) | ||
for row in reader: | ||
src.append(int(row[0])) | ||
targ.append(int(row[1])) | ||
|
||
with open(f"{filepath}/musae_{lang}_features.json", 'r') as f: | ||
j = json.load(f) | ||
|
||
src = np.array(src) | ||
targ = np.array(targ) | ||
label = np.array(label) | ||
|
||
inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)} | ||
reorder_node_ids = np.zeros_like(node_ids) | ||
for i in range(label.shape[0]): | ||
reorder_node_ids[i] = inv_node_ids[i] | ||
|
||
n = label.shape[0] | ||
A = scipy.sparse.csr_matrix((np.ones(len(src)), (np.array(src), np.array(targ))), shape=(n,n)) | ||
features = np.zeros((n,3170)) | ||
for node, feats in j.items(): | ||
if int(node) >= n: | ||
continue | ||
features[int(node), np.array(feats, dtype=int)] = 1 | ||
new_label = label[reorder_node_ids] | ||
label = new_label | ||
|
||
return A, label, features | ||
|
||
def process(self): | ||
A, label, features = self.load_twitch(self.name) | ||
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long) | ||
features = np.array(features) | ||
x = torch.from_numpy(features).to(torch.float) | ||
y = torch.from_numpy(label).to(torch.int64) | ||
|
||
data_list = [] | ||
data = Data(edge_index=edge_index, x=x, y=y) | ||
|
||
random_node_indices = np.random.permutation(y.shape[0]) | ||
training_size = int(len(random_node_indices) * 0.8) | ||
val_size = int(len(random_node_indices) * 0.1) | ||
train_node_indices = random_node_indices[:training_size] | ||
val_node_indices = random_node_indices[training_size:training_size + val_size] | ||
test_node_indices = random_node_indices[training_size + val_size:] | ||
|
||
train_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
train_masks[train_node_indices] = 1 | ||
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
val_masks[val_node_indices] = 1 | ||
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool) | ||
test_masks[test_node_indices] = 1 | ||
|
||
data.train_mask = train_masks | ||
data.val_mask = val_masks | ||
data.test_mask = test_masks | ||
|
||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
data_list.append(data) | ||
|
||
data, slices = self.collate([data]) | ||
|
||
torch.save((data, slices), self.processed_paths[0]) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from layer import * | ||
|
||
|
||
class Model(torch.nn.Module): | ||
def __init__(self, args): | ||
super(Model, self).__init__() | ||
self.args = args | ||
|
||
if args.gnn == 'gcn': | ||
self.gnn = GCN(args) | ||
elif args.gnn == 'sage': | ||
self.gnn = SAGE(args) | ||
elif args.gnn == 'gat': | ||
self.gnn = GAT(args) | ||
elif args.gnn == 'gin': | ||
self.gnn = GIN(args) | ||
else: | ||
assert args.gnn in ('gcn', 'sage', 'gat', 'gin'), 'Invalid gnn' | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.gnn.reset_parameters() | ||
|
||
def forward(self, x, edge_index, edge_weight=None): | ||
x = self.feat_bottleneck(x, edge_index, edge_weight) | ||
x = self.feat_classifier(x) | ||
|
||
return F.log_softmax(x, dim=1) | ||
|
||
def feat_bottleneck(self, x, edge_index, edge_weight=None): | ||
x = self.gnn.feat_bottleneck(x, edge_index, edge_weight) | ||
return x | ||
|
||
def feat_classifier(self, x): | ||
x = self.gnn.feat_classifier(x) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import argparse | ||
import glob | ||
import os | ||
import time | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from model import * | ||
from utils import * | ||
from datasets import * | ||
import numpy as np | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--seed', type=int, default=200, help='random seed') | ||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | ||
parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay') | ||
parser.add_argument('--nhid', type=int, default=128, help='hidden size') | ||
parser.add_argument('--dropout_ratio', type=float, default=0.1, help='dropout ratio') | ||
parser.add_argument('--device', type=str, default='cuda:2', help='specify cuda devices') | ||
parser.add_argument('--source', type=str, default='Citationv1', help='source domain data') | ||
parser.add_argument('--target', type=str, default='DBLPv7', help='target domain data') | ||
parser.add_argument('--epochs', type=int, default=1000, help='maximum number of epochs') | ||
parser.add_argument('--patience', type=int, default=100, help='patience for early stopping') | ||
parser.add_argument('--num_layers', type=int, default=2, help='number of gnn layers') | ||
parser.add_argument('--gnn', type=str, default='gcn', help='different types of gnns') | ||
parser.add_argument('--use_bn', type=bool, default=False, help='do not use batchnorm') | ||
|
||
args = parser.parse_args() | ||
|
||
if args.source in {'DBLPv7', 'ACMv9', 'Citationv1'}: | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Citation', args.source) | ||
source_dataset = CitationDataset(path, args.source) | ||
elif args.source in {'S10', 'M10', 'E10'}: | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Elliptic', args.source) | ||
source_dataset = EllipticDataset(path, args.source) | ||
elif args.source in {'DE', 'EN', 'FR'}: | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Twitch', args.source) | ||
source_dataset = TwitchDataset(path, args.source) | ||
|
||
if args.target in {'DBLPv7', 'ACMv9', 'Citationv1'}: | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Citation', args.target) | ||
target_dataset = CitationDataset(path, args.target) | ||
elif args.target in {'S10', 'M10', 'E10'}: | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Elliptic', args.target) | ||
target_dataset = EllipticDataset(path, args.target) | ||
elif args.target in {'DE', 'EN', 'FR'}: | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Twitch', args.target) | ||
target_dataset = TwitchDataset(path, args.target) | ||
|
||
target_data = target_dataset[0] | ||
data = source_dataset[0] | ||
|
||
args.num_classes = len(np.unique(data.y.numpy())) | ||
args.num_features = data.x.size(1) | ||
|
||
print(args) | ||
|
||
model = Model(args).to(args.device) | ||
data = data.to(args.device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | ||
|
||
def train_source(): | ||
min_loss = 1e10 | ||
patience_cnt = 0 | ||
val_loss_values = [] | ||
best_epoch = 0 | ||
|
||
t = time.time() | ||
for epoch in range(args.epochs): | ||
model.train() | ||
optimizer.zero_grad() | ||
correct = 0 | ||
output = model(data.x, data.edge_index) | ||
train_loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask]) | ||
train_loss.backward() | ||
optimizer.step() | ||
pred = output[data.train_mask].max(dim=1)[1] | ||
correct = pred.eq(data.y[data.train_mask]).sum().item() | ||
train_acc = correct * 1.0 / (data.train_mask).sum().item() | ||
|
||
val_acc, val_loss = compute_test(data.val_mask, model, data) | ||
|
||
print('Epoch: {:04d}'.format(epoch + 1), 'train_loss: {:.6f}'.format(train_loss), | ||
'train_acc: {:.6f}'.format(train_acc), 'loss_val: {:.6f}'.format(val_loss), | ||
'acc_val: {:.6f}'.format(val_acc), 'time: {:.6f}s'.format(time.time() - t)) | ||
|
||
val_loss_values.append(val_loss) | ||
torch.save(model.state_dict(), '{}.pth'.format(epoch)) | ||
|
||
if val_loss_values[-1] < min_loss: | ||
min_loss = val_loss_values[-1] | ||
best_epoch = epoch | ||
patience_cnt = 0 | ||
else: | ||
patience_cnt += 1 | ||
|
||
if patience_cnt == args.patience: | ||
break | ||
|
||
files = glob.glob('*.pth') | ||
for f in files: | ||
epoch_nb = int(f.split('.')[0]) | ||
if epoch_nb < best_epoch: | ||
os.remove(f) | ||
|
||
files = glob.glob('*.pth') | ||
for f in files: | ||
epoch_nb = int(f.split('.')[0]) | ||
if epoch_nb > best_epoch: | ||
os.remove(f) | ||
print('Optimization Finished! Total time elapsed: {:.6f}'.format(time.time() - t)) | ||
|
||
return best_epoch | ||
|
||
|
||
if __name__ == '__main__': | ||
if os.path.exists('model.pth'): | ||
os.remove('model.pth') | ||
# Model training | ||
best_model = train_source() | ||
# Restore best model for test set | ||
model.load_state_dict(torch.load('{}.pth'.format(best_model))) | ||
test_acc, test_loss = compute_test(data.test_mask, model, data) | ||
print('Source {} test set results, loss = {:.6f}, accuracy = {:.6f}'.format(args.source, test_loss, test_acc)) | ||
|
||
target_data = target_data.to(args.device) | ||
test_acc, test_loss = evaluate(target_data.x, target_data.edge_index, target_data.edge_weight, target_data.y, model) | ||
print('Target {} test results, loss = {:.6f}, accuracy = {:.6f}'.format(args.target, test_loss, test_acc)) | ||
|
||
# Save model for target domain adaptation | ||
torch.save(model.state_dict(), 'model.pth') | ||
os.remove('{}.pth'.format(best_model)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
import argparse | ||
import glob | ||
import os | ||
import time | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from model import * | ||
from utils import * | ||
from layer import * | ||
from datasets import * | ||
import numpy as np | ||
from torch_geometric.transforms import Constant | ||
from torch.nn.parameter import Parameter | ||
from torch_geometric.utils import dropout_adj | ||
from tqdm import tqdm | ||
import random | ||
|
||
from torch import Tensor | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument('--seed', type=int, default=200, help='random seed') | ||
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') | ||
parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay') | ||
parser.add_argument('--nhid', type=int, default=128, help='hidden size') | ||
parser.add_argument('--dropout_ratio', type=float, default=0.1, help='dropout ratio') | ||
parser.add_argument('--device', type=str, default='cuda:2', help='specify cuda devices') | ||
parser.add_argument('--source', type=str, default='Citationv1', help='source domain data') | ||
parser.add_argument('--target', type=str, default='DBLPv7', help='target domain data') | ||
parser.add_argument('--epochs', type=int, default=1000, help='maximum number of epochs') | ||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum') | ||
parser.add_argument('--tau', type=float, default=0.2, help='tau') | ||
parser.add_argument('--lamb', type=float, default=0.2, help='trade-off parameter lambda') | ||
parser.add_argument('--num_layers', type=int, default=2, help='number of gnn layers') | ||
parser.add_argument('--gnn', type=str, default='gcn', help='different types of gnns') | ||
parser.add_argument('--use_bn', type=bool, default=False, help='do not use batchnorm') | ||
parser.add_argument('--make_undirected', type=bool, default=True, help='directed graph or not') | ||
|
||
parser.add_argument('--ratio', type=float, default=0.2, help='structure perturbation budget') | ||
parser.add_argument('--loop_adj', type=int, default=1, help='inner loop for adjacent update') | ||
parser.add_argument('--loop_feat', type=int, default=2, help='inner loop for feature update') | ||
parser.add_argument('--loop_model', type=int, default=3, help='inner loop for model update') | ||
parser.add_argument('--debug', type=int, default=1, help='whether output intermediate results') | ||
parser.add_argument("--K", type=int, default=5, help='number of k-nearest neighbors') | ||
|
||
args = parser.parse_args() | ||
|
||
|
||
if args.target in {'DBLPv7', 'ACMv9', 'Citationv1'}: | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Citation', args.target) | ||
target_dataset = CitationDataset(path, args.target) | ||
elif args.target in {'S10', 'M10', 'E10'}: | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Elliptic', args.target) | ||
target_dataset = EllipticDataset(path, args.target) | ||
elif args.target in {'DE', 'EN', 'FR'}: | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Twitch', args.target) | ||
target_dataset = TwitchDataset(path, args.target) | ||
|
||
data = target_dataset[0] | ||
|
||
args.num_classes = len(np.unique(data.y.numpy())) | ||
args.num_features = data.x.size(1) | ||
|
||
print(args) | ||
|
||
model = Model(args).to(args.device) | ||
data = data.to(args.device) | ||
|
||
neighprop = NeighborPropagate() | ||
|
||
model.load_state_dict(torch.load('model.pth')) | ||
|
||
delta_feat = Parameter(torch.FloatTensor(data.x.size(0), data.x.size(1)).to(args.device)) | ||
delta_feat.data.fill_(1e-7) | ||
optimizer_feat = torch.optim.Adam([delta_feat], lr=0.0001, weight_decay=0.0001) | ||
|
||
modified_edge_index = data.edge_index.clone() | ||
modified_edge_index = modified_edge_index[:, modified_edge_index[0] < modified_edge_index[1]] | ||
row, col = modified_edge_index[0], modified_edge_index[1] | ||
edge_index_id = (2 * data.x.size(0) - row - 1) * row // 2 + col - row - 1 | ||
edge_index_id = edge_index_id.long() | ||
modified_edge_index = linear_to_triu_idx(data.x.size(0), edge_index_id) | ||
perturbed_edge_weight = torch.full_like(edge_index_id, 1e-7, dtype=torch.float32, requires_grad=True).to(args.device) | ||
|
||
optimizer_adj = torch.optim.Adam([perturbed_edge_weight], lr=0.0001, weight_decay=0.0001) | ||
|
||
n_perturbations = int(args.ratio * data.edge_index.shape[1] // 2) | ||
|
||
n = data.x.size(0) | ||
|
||
def train_target(target_data, perturbed_edge_weight): | ||
optimizer_model = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | ||
t = time.time() | ||
edge_index = target_data.edge_index | ||
edge_weight = torch.ones(edge_index.shape[1]).to(args.device) | ||
feat = target_data.x | ||
|
||
mem_fea = torch.rand(target_data.x.size(0), args.nhid).to(args.device) | ||
mem_cls = torch.ones(target_data.x.size(0), args.num_classes).to(args.device) / args.num_classes | ||
|
||
for it in tqdm(range(args.epochs//(args.loop_feat+args.loop_adj))): | ||
for loop_model in range(args.loop_model): | ||
for k,v in model.named_parameters(): | ||
v.requires_grad = True | ||
model.train() | ||
feat = feat.detach() | ||
edge_weight = edge_weight.detach() | ||
|
||
optimizer_model.zero_grad() | ||
feat_output = model.feat_bottleneck(feat, edge_index, edge_weight) | ||
cls_output = model.feat_classifier(feat_output) | ||
|
||
onehot = torch.nn.functional.one_hot(cls_output.argmax(1), num_classes=args.num_classes).float() | ||
proto = (torch.mm(mem_fea.t(), onehot) / (onehot.sum(dim=0) + 1e-8)).t() | ||
|
||
prob = neighprop(mem_cls, edge_index) | ||
weight, pred = torch.max(prob, dim=1) | ||
cl, weight_ = instance_proto_alignment(feat_output, proto, pred) | ||
ce = F.cross_entropy(cls_output, pred, reduction='none') | ||
loss_local = torch.sum(weight_ * ce) / (torch.sum(weight_).item()) | ||
loss = loss_local * (1 - args.lamb) + cl * args.lamb | ||
|
||
loss.backward() | ||
optimizer_model.step() | ||
print('Model: ' + str(loss.item())) | ||
|
||
model.eval() | ||
with torch.no_grad(): | ||
feat_output = model.feat_bottleneck(feat, edge_index, edge_weight) | ||
cls_output = model.feat_classifier(feat_output) | ||
softmax_out = F.softmax(cls_output, dim=1) | ||
outputs_target = softmax_out**2 / ((softmax_out**2).sum(dim=0)) | ||
|
||
mem_cls = (1.0 - args.momentum) * mem_cls + args.momentum * outputs_target.clone() | ||
mem_fea = (1.0 - args.momentum) * mem_fea + args.momentum * feat_output.clone() | ||
|
||
for k,v in model.named_parameters(): | ||
v.requires_grad = False | ||
|
||
perturbed_edge_weight = perturbed_edge_weight.detach() | ||
for loop_feat in range(args.loop_feat): | ||
optimizer_feat.zero_grad() | ||
delta_feat.requires_grad = True | ||
loss = test_time_loss(model, target_data.x + delta_feat, edge_index, mem_fea, mem_cls, edge_weight) | ||
loss.backward() | ||
optimizer_feat.step() | ||
print('Feat: ' + str(loss.item())) | ||
|
||
new_feat = (data.x + delta_feat).detach() | ||
for loop_adj in range(args.loop_adj): | ||
perturbed_edge_weight.requires_grad = True | ||
edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected) | ||
loss = test_time_loss(model, new_feat, edge_index, mem_fea, mem_cls, edge_weight) | ||
print('Adj: ' + str(loss.item())) | ||
|
||
gradient = grad_with_checkpoint(loss, perturbed_edge_weight)[0] | ||
|
||
with torch.no_grad(): | ||
update_edge_weights(gradient) | ||
perturbed_edge_weight = project(n_perturbations, perturbed_edge_weight, 1e-7) | ||
|
||
if args.loop_adj != 0: | ||
edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected) | ||
edge_weight = edge_weight.detach() | ||
|
||
if args.loop_feat != 0: | ||
feat = (target_data.x + delta_feat).detach() | ||
|
||
edge_index, edge_weight = sample_final_edges(n_perturbations, perturbed_edge_weight, target_data, modified_edge_index, mem_fea, mem_cls) | ||
|
||
test_acc, _ = evaluate(target_data.x + delta_feat, edge_index, edge_weight, target_data.y, model) | ||
print('acc : ' + str(test_acc)) | ||
print('Optimization Finished!\n') | ||
|
||
|
||
def instance_proto_alignment(feat, center, pred): | ||
feat_norm = F.normalize(feat, dim=1) | ||
center_norm = F.normalize(center, dim=1) | ||
sim = torch.matmul(feat_norm, center_norm.t()) | ||
|
||
num_nodes = feat.size(0) | ||
weight = sim[range(num_nodes), pred] | ||
sim = torch.exp(sim / args.tau) | ||
pos_sim = sim[range(num_nodes), pred] | ||
|
||
sim_feat = torch.matmul(feat_norm, feat_norm.t()) | ||
sim_feat = torch.exp(sim_feat / args.tau) | ||
ident = sim_feat[range(num_nodes), range(num_nodes)] | ||
|
||
logit = pos_sim / (sim.sum(dim=1) - pos_sim + sim_feat.sum(dim=1) - ident + 1e-8) | ||
loss = - torch.log(logit + 1e-8).mean() | ||
|
||
return loss, weight | ||
|
||
|
||
def update_edge_weights(gradient): | ||
optimizer_adj.zero_grad() | ||
perturbed_edge_weight.grad = gradient | ||
optimizer_adj.step() | ||
perturbed_edge_weight.data[perturbed_edge_weight < 1e-7] = 1e-7 | ||
|
||
|
||
def test_time_loss(model, feat, edge_index, mem_fea, mem_cls, edge_weight=None): | ||
model.eval() | ||
feat_output = model.feat_bottleneck(feat, edge_index, edge_weight) | ||
cls_output = model.feat_classifier(feat_output) | ||
softmax_out = F.softmax(cls_output, dim=1) | ||
_, predict = torch.max(softmax_out, 1) | ||
mean_ent = Entropy(softmax_out) | ||
est_p = (mean_ent<mean_ent.mean()).sum().item() / mean_ent.size(0) | ||
value = mean_ent | ||
|
||
predict = predict.cpu().numpy() | ||
train_idx = np.zeros(predict.shape) | ||
|
||
cls_k = args.num_classes | ||
for c in range(cls_k): | ||
c_idx = np.where(predict==c) | ||
c_idx = c_idx[0] | ||
c_value = value[c_idx] | ||
|
||
_, idx_ = torch.sort(c_value) | ||
c_num = len(idx_) | ||
c_num_s = int(c_num * est_p / 5) | ||
|
||
for ei in range(0, c_num_s): | ||
ee = c_idx[idx_[ei]] | ||
train_idx[ee] = 1 | ||
|
||
train_idx = np.array(train_idx, dtype=bool) | ||
pred_label = predict[train_idx] | ||
pseudo_label = torch.from_numpy(pred_label).to(args.device) | ||
|
||
pred_output = cls_output[train_idx] | ||
loss = F.cross_entropy(pred_output, pseudo_label) | ||
|
||
distance = feat_output @ mem_fea.T | ||
_, idx_near = torch.topk(distance, dim=-1, largest=True, k=args.K + 1) | ||
idx_near = idx_near[:, 1:] # batch x K | ||
|
||
mem_near = mem_fea[idx_near] # batch x K x d | ||
feat_output_un = feat_output.unsqueeze(1).expand(-1, args.K, -1) # batch x K x d | ||
loss -= torch.mean((feat_output_un * mem_near).sum(-1).sum(1)/args.K) * 0.1 | ||
|
||
_, pred_mem = torch.max(mem_cls, dim=1) | ||
_, pred = torch.max(softmax_out, dim=1) | ||
idx = pred.unsqueeze(-1) == pred_mem | ||
neg_num = torch.sum(~idx, dim=1) | ||
dis = (distance * ~idx).sum(1)/neg_num | ||
loss += dis.mean() * 0.1 | ||
|
||
return loss | ||
|
||
|
||
@torch.no_grad() | ||
def sample_final_edges(n_perturbations, perturbed_edge_weight, data, modified_edge_index, mem_fea, mem_cls): | ||
best_loss = float('Inf') | ||
perturbed_edge_weight = perturbed_edge_weight.detach() | ||
# TODO: potentially convert to assert | ||
perturbed_edge_weight[perturbed_edge_weight <= 1e-7] = 0 | ||
|
||
# _, feat, labels = self.data.edge_index, self.data.x, self.data.y | ||
feat = data.x.to(args.device) | ||
edge_index = data.edge_index.to(args.device) | ||
edge_weight = torch.ones(edge_index.shape[1]).to(args.device) | ||
# self.edge_index = data.graph['edge_index'].to(self.device) | ||
|
||
for i in range(20): | ||
if best_loss == float('Inf'): | ||
# In first iteration employ top k heuristic instead of sampling | ||
sampled_edges = torch.zeros_like(perturbed_edge_weight).to(args.device) | ||
sampled_edges[torch.topk(perturbed_edge_weight, n_perturbations).indices] = 1 | ||
else: | ||
sampled_edges = torch.bernoulli(perturbed_edge_weight).float() | ||
|
||
if sampled_edges.sum() > n_perturbations: | ||
n_samples = sampled_edges.sum() | ||
if args.debug ==2: | ||
print(f'{i}-th sampling: too many samples {n_samples}') | ||
continue | ||
|
||
perturbed_edge_weight = sampled_edges | ||
|
||
edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected) | ||
with torch.no_grad(): | ||
loss = test_time_loss(model, feat, edge_index, mem_fea, mem_cls, edge_weight) | ||
|
||
# Save best sample | ||
if best_loss > loss: | ||
best_loss = loss | ||
print('best_loss:', best_loss.item()) | ||
best_edges = perturbed_edge_weight.clone().cpu() | ||
|
||
# Recover best sample | ||
perturbed_edge_weight.data.copy_(best_edges.to(args.device)) | ||
|
||
edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected) | ||
edge_mask = edge_weight == 1 | ||
make_undirected = args.make_undirected | ||
|
||
allowed_perturbations = 2 * n_perturbations if make_undirected else n_perturbations | ||
edges_after_attack = edge_mask.sum() | ||
clean_edges = edge_index.shape[1] | ||
assert (edges_after_attack >= clean_edges - allowed_perturbations | ||
and edges_after_attack <= clean_edges + allowed_perturbations), \ | ||
f'{edges_after_attack} out of range with {clean_edges} clean edges and {n_perturbations} pertutbations' | ||
|
||
return edge_index[:, edge_mask], edge_weight[edge_mask] | ||
|
||
if __name__ == '__main__': | ||
train_target(data, perturbed_edge_weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import os.path as osp | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
from torch_sparse import coalesce | ||
|
||
import warnings | ||
warnings.filterwarnings('ignore', category=DeprecationWarning) | ||
|
||
|
||
def grad_with_checkpoint(outputs, inputs): | ||
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs) | ||
for input in inputs: | ||
if not input.is_leaf: | ||
input.retain_grad() | ||
torch.autograd.backward(outputs) | ||
|
||
grad_outputs = [] | ||
for input in inputs: | ||
grad_outputs.append(input.grad.clone()) | ||
input.grad.zero_() | ||
return grad_outputs | ||
|
||
|
||
def linear_to_triu_idx(n: int, lin_idx: torch.Tensor) -> torch.Tensor: | ||
row_idx = ( | ||
n | ||
- 2 | ||
- torch.floor(torch.sqrt(-8 * lin_idx.double() + 4 * n * (n - 1) - 7) / 2.0 - 0.5) | ||
).long() | ||
col_idx = ( | ||
lin_idx | ||
+ row_idx | ||
+ 1 - n * (n - 1) // 2 | ||
+ (n - row_idx) * ((n - row_idx) - 1) // 2 | ||
) | ||
return torch.stack((row_idx, col_idx)) | ||
|
||
|
||
def bisection(edge_weights, a, b, n_perturbations, epsilon=1e-5, iter_max=1e5): | ||
def func(x): | ||
return torch.clamp(edge_weights - x, 0, 1).sum() - n_perturbations | ||
|
||
miu = a | ||
for i in range(int(iter_max)): | ||
miu = (a + b) / 2 | ||
# Check if middle point is root | ||
if (func(miu) == 0.0): | ||
break | ||
# Decide the side to repeat the steps | ||
if (func(miu) * func(a) < 0): | ||
b = miu | ||
else: | ||
a = miu | ||
if ((b - a) <= epsilon): | ||
break | ||
return miu | ||
|
||
|
||
def project(n_perturbations, values, eps, inplace=False): | ||
if not inplace: | ||
values = values.clone() | ||
|
||
if torch.clamp(values, 0, 1).sum() > n_perturbations: | ||
left = (values - 1).min() | ||
right = values.max() | ||
miu = bisection(values, left, right, n_perturbations) | ||
values.data.copy_(torch.clamp( | ||
values - miu, min=eps, max=1 - eps | ||
)) | ||
else: | ||
values.data.copy_(torch.clamp(values, min=eps, max=1 - eps)) | ||
|
||
return values | ||
|
||
|
||
def get_modified_adj(modified_edge_index, perturbed_edge_weight, n, device, edge_index, edge_weight, make_undirected=False): | ||
if make_undirected: | ||
modified_edge_index, modified_edge_weight = to_symmetric(modified_edge_index, perturbed_edge_weight, n) | ||
else: | ||
modified_edge_index, modified_edge_weight = modified_edge_index, perturbed_edge_weight | ||
edge_index = torch.cat((edge_index.to(device), modified_edge_index), dim=-1) | ||
edge_weight = torch.cat((edge_weight.to(device), modified_edge_weight)) | ||
edge_index, edge_weight = coalesce(edge_index, edge_weight, m=n, n=n, op='sum') | ||
|
||
# Allow removal of edges | ||
edge_weight[edge_weight > 1] = 2 - edge_weight[edge_weight > 1] | ||
return edge_index, edge_weight | ||
|
||
|
||
def to_symmetric(edge_index, edge_weight, n, op='mean'): | ||
symmetric_edge_index = torch.cat( | ||
(edge_index, edge_index.flip(0)), dim=-1 | ||
) | ||
|
||
symmetric_edge_weight = edge_weight.repeat(2) | ||
|
||
symmetric_edge_index, symmetric_edge_weight = coalesce( | ||
symmetric_edge_index, | ||
symmetric_edge_weight, | ||
m=n, | ||
n=n, | ||
op=op | ||
) | ||
return symmetric_edge_index, symmetric_edge_weight | ||
|
||
|
||
def compute_test(mask, model, data): | ||
model.eval() | ||
output = model(data.x, data.edge_index) | ||
loss = F.nll_loss(output[mask], data.y[mask]) | ||
pred = output[mask].max(dim=1)[1] | ||
correct = pred.eq(data.y[mask]).sum().item() | ||
acc = correct * 1.0 / (mask.sum().item()) | ||
|
||
return acc, loss | ||
|
||
|
||
def evaluate(x, edge_index, edge_weight, y, model): | ||
model.eval() | ||
output = model(x, edge_index, edge_weight) | ||
loss = F.nll_loss(output, y) | ||
pred = output.max(dim=1)[1] | ||
correct = pred.eq(y).sum().item() | ||
acc = correct * 1.0 / len(y) | ||
|
||
return acc, loss | ||
|
||
def Entropy(input_): | ||
entropy = -input_ * torch.log(input_ + 1e-8) | ||
entropy = torch.sum(entropy, dim=1) | ||
return entropy |