Skip to content

Commit

Permalink
Init submit
Browse files Browse the repository at this point in the history
cszhangzhen committed Feb 16, 2024
1 parent 8862a11 commit d721f51
Showing 10 changed files with 1,455 additions and 1 deletion.
Binary file added .DS_Store
Binary file not shown.
32 changes: 31 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,31 @@
# GraphCTA
# GraphCTA
Collaborate to Adapt: Source-Free Graph Domain Adaptation via Bi-directional Adaptation (WWW 2024)

![](https://github.com/cszhangzhen/GraphCTA/blob/main/fig/model.png)

This is a PyTorch implementation of the GraphCTA algorithm, which tries to address the domain adaptation problem without accessing the labelled source graph. It performs model adaptation and graph adaptation collaboratively through a series of procedures: (1) conduct model adaptation based on node's neighborhood predictions in target graph considering both local and global information; (2) perform graph adaptation by updating graph structure and node attributes via neighborhood constrastive learning; and (3) the updated graph serves as an input to facilitate the subsequent iteration of model adaptation, thereby establishing a collaborative loop between model adaptation and graph adaptation.


## Requirements
* python3.8
* pytorch==2.0.0
* torch-scatter==2.1.1+pt20cu118
* torch-sparse==0.6.17+pt20cu118
* torch-cluster==1.6.1+pt20cu118
* torch-geometric==2.3.1
* numpy==1.24.3
* scipy==1.10.1
* tqdm==4.65.0

## Datasets
Datasets used in the paper are all publicly available datasets. You can find [Elliptic](https://www.kaggle.com/datasets/ellipticco/elliptic-data-set), [Twitch](https://github.com/benedekrozemberczki/datasets#twitch-social-networks) and [Citation](https://github.com/yuntaodu/ASN/tree/main/data) via the links.

## Quick Start:
Just execuate the following command for source model pre-training:
```
python train_source.py
```
Then, execuate the following command for adaptation:
```
python train_target.py
```
Binary file added data/Citation.zip
Binary file not shown.
271 changes: 271 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
import os.path as osp
import torch
import numpy as np
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.io import read_txt_array
import torch.nn.functional as F

import scipy
import pickle as pkl
import csv
import json

import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)


class CitationDataset(InMemoryDataset):
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
self.root = root
super(CitationDataset, self).__init__(root, transform, pre_transform, pre_filter)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
return ["docs.txt", "edgelist.txt", "labels.txt"]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
pass

def process(self):
edge_path = osp.join(self.raw_dir, '{}_edgelist.txt'.format(self.name))
edge_index = read_txt_array(edge_path, sep=',', dtype=torch.long).t()

docs_path = osp.join(self.raw_dir, '{}_docs.txt'.format(self.name))
f = open(docs_path, 'rb')
content_list = []
for line in f.readlines():
line = str(line, encoding="utf-8")
content_list.append(line.split(","))
x = np.array(content_list, dtype=float)
x = torch.from_numpy(x).to(torch.float)

label_path = osp.join(self.raw_dir, '{}_labels.txt'.format(self.name))
f = open(label_path, 'rb')
content_list = []
for line in f.readlines():
line = str(line, encoding="utf-8")
line = line.replace("\r", "").replace("\n", "")
content_list.append(line)
y = np.array(content_list, dtype=int)
y = torch.from_numpy(y).to(torch.int64)

data_list = []
data = Data(edge_index=edge_index, x=x, y=y)

random_node_indices = np.random.permutation(y.shape[0])
training_size = int(len(random_node_indices) * 0.8)
val_size = int(len(random_node_indices) * 0.1)
train_node_indices = random_node_indices[:training_size]
val_node_indices = random_node_indices[training_size:training_size + val_size]
test_node_indices = random_node_indices[training_size + val_size:]

train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
train_masks[train_node_indices] = 1
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
val_masks[val_node_indices] = 1
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
test_masks[test_node_indices] = 1

data.train_mask = train_masks
data.val_mask = val_masks
data.test_mask = test_masks

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

data, slices = self.collate([data])

torch.save((data, slices), self.processed_paths[0])


class EllipticDataset(InMemoryDataset):
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
self.root = root
super(EllipticDataset, self).__init__(root, transform, pre_transform, pre_filter)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
return [".pkl"]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
pass

def process(self):
path = osp.join(self.raw_dir, '{}.pkl'.format(self.name))
result = pkl.load(open(path, 'rb'))
A, label, features = result
label = label + 1
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long)
features = np.array(features)
x = torch.from_numpy(features).to(torch.float)
y = torch.tensor(label).to(torch.int64)

data_list = []
data = Data(edge_index=edge_index, x=x, y=y)

random_node_indices = np.random.permutation(y.shape[0])
training_size = int(len(random_node_indices) * 0.8)
val_size = int(len(random_node_indices) * 0.1)
train_node_indices = random_node_indices[:training_size]
val_node_indices = random_node_indices[training_size:training_size + val_size]
test_node_indices = random_node_indices[training_size + val_size:]

train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
train_masks[train_node_indices] = 1
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
val_masks[val_node_indices] = 1
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
test_masks[test_node_indices] = 1

data.train_mask = train_masks
data.val_mask = val_masks
data.test_mask = test_masks

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

data, slices = self.collate([data])

torch.save((data, slices), self.processed_paths[0])


class TwitchDataset(InMemoryDataset):
def __init__(self,
root,
name,
transform=None,
pre_transform=None,
pre_filter=None):
self.name = name
self.root = root
super(TwitchDataset, self).__init__(root, transform, pre_transform, pre_filter)

self.data, self.slices = torch.load(self.processed_paths[0])

@property
def raw_file_names(self):
return ["edges.csv, features.json, target.csv"]

@property
def processed_file_names(self):
return ['data.pt']

def download(self):
pass

def load_twitch(self, lang):
assert lang in ('DE', 'EN', 'FR'), 'Invalid dataset'
filepath = self.raw_dir
label = []
node_ids = []
src = []
targ = []
uniq_ids = set()
with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
node_id = int(row[5])
# handle FR case of non-unique rows
if node_id not in uniq_ids:
uniq_ids.add(node_id)
label.append(int(row[2]=="True"))
node_ids.append(int(row[5]))

node_ids = np.array(node_ids, dtype=np.int32)

with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
src.append(int(row[0]))
targ.append(int(row[1]))

with open(f"{filepath}/musae_{lang}_features.json", 'r') as f:
j = json.load(f)

src = np.array(src)
targ = np.array(targ)
label = np.array(label)

inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)}
reorder_node_ids = np.zeros_like(node_ids)
for i in range(label.shape[0]):
reorder_node_ids[i] = inv_node_ids[i]

n = label.shape[0]
A = scipy.sparse.csr_matrix((np.ones(len(src)), (np.array(src), np.array(targ))), shape=(n,n))
features = np.zeros((n,3170))
for node, feats in j.items():
if int(node) >= n:
continue
features[int(node), np.array(feats, dtype=int)] = 1
new_label = label[reorder_node_ids]
label = new_label

return A, label, features

def process(self):
A, label, features = self.load_twitch(self.name)
edge_index = torch.tensor(np.array(A.nonzero()), dtype=torch.long)
features = np.array(features)
x = torch.from_numpy(features).to(torch.float)
y = torch.from_numpy(label).to(torch.int64)

data_list = []
data = Data(edge_index=edge_index, x=x, y=y)

random_node_indices = np.random.permutation(y.shape[0])
training_size = int(len(random_node_indices) * 0.8)
val_size = int(len(random_node_indices) * 0.1)
train_node_indices = random_node_indices[:training_size]
val_node_indices = random_node_indices[training_size:training_size + val_size]
test_node_indices = random_node_indices[training_size + val_size:]

train_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
train_masks[train_node_indices] = 1
val_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
val_masks[val_node_indices] = 1
test_masks = torch.zeros([y.shape[0]], dtype=torch.bool)
test_masks[test_node_indices] = 1

data.train_mask = train_masks
data.val_mask = val_masks
data.test_mask = test_masks

if self.pre_transform is not None:
data = self.pre_transform(data)

data_list.append(data)

data, slices = self.collate([data])

torch.save((data, slices), self.processed_paths[0])
Binary file added fig/model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
536 changes: 536 additions & 0 deletions layer.py

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.nn.functional as F
from layer import *


class Model(torch.nn.Module):
def __init__(self, args):
super(Model, self).__init__()
self.args = args

if args.gnn == 'gcn':
self.gnn = GCN(args)
elif args.gnn == 'sage':
self.gnn = SAGE(args)
elif args.gnn == 'gat':
self.gnn = GAT(args)
elif args.gnn == 'gin':
self.gnn = GIN(args)
else:
assert args.gnn in ('gcn', 'sage', 'gat', 'gin'), 'Invalid gnn'

self.reset_parameters()

def reset_parameters(self):
self.gnn.reset_parameters()

def forward(self, x, edge_index, edge_weight=None):
x = self.feat_bottleneck(x, edge_index, edge_weight)
x = self.feat_classifier(x)

return F.log_softmax(x, dim=1)

def feat_bottleneck(self, x, edge_index, edge_weight=None):
x = self.gnn.feat_bottleneck(x, edge_index, edge_weight)
return x

def feat_classifier(self, x):
x = self.gnn.feat_classifier(x)

return x
133 changes: 133 additions & 0 deletions train_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import argparse
import glob
import os
import time

import torch
import torch.nn.functional as F
from model import *
from utils import *
from datasets import *
import numpy as np

parser = argparse.ArgumentParser()

parser.add_argument('--seed', type=int, default=200, help='random seed')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay')
parser.add_argument('--nhid', type=int, default=128, help='hidden size')
parser.add_argument('--dropout_ratio', type=float, default=0.1, help='dropout ratio')
parser.add_argument('--device', type=str, default='cuda:2', help='specify cuda devices')
parser.add_argument('--source', type=str, default='Citationv1', help='source domain data')
parser.add_argument('--target', type=str, default='DBLPv7', help='target domain data')
parser.add_argument('--epochs', type=int, default=1000, help='maximum number of epochs')
parser.add_argument('--patience', type=int, default=100, help='patience for early stopping')
parser.add_argument('--num_layers', type=int, default=2, help='number of gnn layers')
parser.add_argument('--gnn', type=str, default='gcn', help='different types of gnns')
parser.add_argument('--use_bn', type=bool, default=False, help='do not use batchnorm')

args = parser.parse_args()

if args.source in {'DBLPv7', 'ACMv9', 'Citationv1'}:
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Citation', args.source)
source_dataset = CitationDataset(path, args.source)
elif args.source in {'S10', 'M10', 'E10'}:
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Elliptic', args.source)
source_dataset = EllipticDataset(path, args.source)
elif args.source in {'DE', 'EN', 'FR'}:
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Twitch', args.source)
source_dataset = TwitchDataset(path, args.source)

if args.target in {'DBLPv7', 'ACMv9', 'Citationv1'}:
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Citation', args.target)
target_dataset = CitationDataset(path, args.target)
elif args.target in {'S10', 'M10', 'E10'}:
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Elliptic', args.target)
target_dataset = EllipticDataset(path, args.target)
elif args.target in {'DE', 'EN', 'FR'}:
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Twitch', args.target)
target_dataset = TwitchDataset(path, args.target)

target_data = target_dataset[0]
data = source_dataset[0]

args.num_classes = len(np.unique(data.y.numpy()))
args.num_features = data.x.size(1)

print(args)

model = Model(args).to(args.device)
data = data.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

def train_source():
min_loss = 1e10
patience_cnt = 0
val_loss_values = []
best_epoch = 0

t = time.time()
for epoch in range(args.epochs):
model.train()
optimizer.zero_grad()
correct = 0
output = model(data.x, data.edge_index)
train_loss = F.nll_loss(output[data.train_mask], data.y[data.train_mask])
train_loss.backward()
optimizer.step()
pred = output[data.train_mask].max(dim=1)[1]
correct = pred.eq(data.y[data.train_mask]).sum().item()
train_acc = correct * 1.0 / (data.train_mask).sum().item()

val_acc, val_loss = compute_test(data.val_mask, model, data)

print('Epoch: {:04d}'.format(epoch + 1), 'train_loss: {:.6f}'.format(train_loss),
'train_acc: {:.6f}'.format(train_acc), 'loss_val: {:.6f}'.format(val_loss),
'acc_val: {:.6f}'.format(val_acc), 'time: {:.6f}s'.format(time.time() - t))

val_loss_values.append(val_loss)
torch.save(model.state_dict(), '{}.pth'.format(epoch))

if val_loss_values[-1] < min_loss:
min_loss = val_loss_values[-1]
best_epoch = epoch
patience_cnt = 0
else:
patience_cnt += 1

if patience_cnt == args.patience:
break

files = glob.glob('*.pth')
for f in files:
epoch_nb = int(f.split('.')[0])
if epoch_nb < best_epoch:
os.remove(f)

files = glob.glob('*.pth')
for f in files:
epoch_nb = int(f.split('.')[0])
if epoch_nb > best_epoch:
os.remove(f)
print('Optimization Finished! Total time elapsed: {:.6f}'.format(time.time() - t))

return best_epoch


if __name__ == '__main__':
if os.path.exists('model.pth'):
os.remove('model.pth')
# Model training
best_model = train_source()
# Restore best model for test set
model.load_state_dict(torch.load('{}.pth'.format(best_model)))
test_acc, test_loss = compute_test(data.test_mask, model, data)
print('Source {} test set results, loss = {:.6f}, accuracy = {:.6f}'.format(args.source, test_loss, test_acc))

target_data = target_data.to(args.device)
test_acc, test_loss = evaluate(target_data.x, target_data.edge_index, target_data.edge_weight, target_data.y, model)
print('Target {} test results, loss = {:.6f}, accuracy = {:.6f}'.format(args.target, test_loss, test_acc))

# Save model for target domain adaptation
torch.save(model.state_dict(), 'model.pth')
os.remove('{}.pth'.format(best_model))
312 changes: 312 additions & 0 deletions train_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
import argparse
import glob
import os
import time

import torch
import torch.nn.functional as F
from model import *
from utils import *
from layer import *
from datasets import *
import numpy as np
from torch_geometric.transforms import Constant
from torch.nn.parameter import Parameter
from torch_geometric.utils import dropout_adj
from tqdm import tqdm
import random

from torch import Tensor

parser = argparse.ArgumentParser()

parser.add_argument('--seed', type=int, default=200, help='random seed')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay')
parser.add_argument('--nhid', type=int, default=128, help='hidden size')
parser.add_argument('--dropout_ratio', type=float, default=0.1, help='dropout ratio')
parser.add_argument('--device', type=str, default='cuda:2', help='specify cuda devices')
parser.add_argument('--source', type=str, default='Citationv1', help='source domain data')
parser.add_argument('--target', type=str, default='DBLPv7', help='target domain data')
parser.add_argument('--epochs', type=int, default=1000, help='maximum number of epochs')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--tau', type=float, default=0.2, help='tau')
parser.add_argument('--lamb', type=float, default=0.2, help='trade-off parameter lambda')
parser.add_argument('--num_layers', type=int, default=2, help='number of gnn layers')
parser.add_argument('--gnn', type=str, default='gcn', help='different types of gnns')
parser.add_argument('--use_bn', type=bool, default=False, help='do not use batchnorm')
parser.add_argument('--make_undirected', type=bool, default=True, help='directed graph or not')

parser.add_argument('--ratio', type=float, default=0.2, help='structure perturbation budget')
parser.add_argument('--loop_adj', type=int, default=1, help='inner loop for adjacent update')
parser.add_argument('--loop_feat', type=int, default=2, help='inner loop for feature update')
parser.add_argument('--loop_model', type=int, default=3, help='inner loop for model update')
parser.add_argument('--debug', type=int, default=1, help='whether output intermediate results')
parser.add_argument("--K", type=int, default=5, help='number of k-nearest neighbors')

args = parser.parse_args()


if args.target in {'DBLPv7', 'ACMv9', 'Citationv1'}:
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Citation', args.target)
target_dataset = CitationDataset(path, args.target)
elif args.target in {'S10', 'M10', 'E10'}:
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Elliptic', args.target)
target_dataset = EllipticDataset(path, args.target)
elif args.target in {'DE', 'EN', 'FR'}:
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data/Twitch', args.target)
target_dataset = TwitchDataset(path, args.target)

data = target_dataset[0]

args.num_classes = len(np.unique(data.y.numpy()))
args.num_features = data.x.size(1)

print(args)

model = Model(args).to(args.device)
data = data.to(args.device)

neighprop = NeighborPropagate()

model.load_state_dict(torch.load('model.pth'))

delta_feat = Parameter(torch.FloatTensor(data.x.size(0), data.x.size(1)).to(args.device))
delta_feat.data.fill_(1e-7)
optimizer_feat = torch.optim.Adam([delta_feat], lr=0.0001, weight_decay=0.0001)

modified_edge_index = data.edge_index.clone()
modified_edge_index = modified_edge_index[:, modified_edge_index[0] < modified_edge_index[1]]
row, col = modified_edge_index[0], modified_edge_index[1]
edge_index_id = (2 * data.x.size(0) - row - 1) * row // 2 + col - row - 1
edge_index_id = edge_index_id.long()
modified_edge_index = linear_to_triu_idx(data.x.size(0), edge_index_id)
perturbed_edge_weight = torch.full_like(edge_index_id, 1e-7, dtype=torch.float32, requires_grad=True).to(args.device)

optimizer_adj = torch.optim.Adam([perturbed_edge_weight], lr=0.0001, weight_decay=0.0001)

n_perturbations = int(args.ratio * data.edge_index.shape[1] // 2)

n = data.x.size(0)

def train_target(target_data, perturbed_edge_weight):
optimizer_model = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
t = time.time()
edge_index = target_data.edge_index
edge_weight = torch.ones(edge_index.shape[1]).to(args.device)
feat = target_data.x

mem_fea = torch.rand(target_data.x.size(0), args.nhid).to(args.device)
mem_cls = torch.ones(target_data.x.size(0), args.num_classes).to(args.device) / args.num_classes

for it in tqdm(range(args.epochs//(args.loop_feat+args.loop_adj))):
for loop_model in range(args.loop_model):
for k,v in model.named_parameters():
v.requires_grad = True
model.train()
feat = feat.detach()
edge_weight = edge_weight.detach()

optimizer_model.zero_grad()
feat_output = model.feat_bottleneck(feat, edge_index, edge_weight)
cls_output = model.feat_classifier(feat_output)

onehot = torch.nn.functional.one_hot(cls_output.argmax(1), num_classes=args.num_classes).float()
proto = (torch.mm(mem_fea.t(), onehot) / (onehot.sum(dim=0) + 1e-8)).t()

prob = neighprop(mem_cls, edge_index)
weight, pred = torch.max(prob, dim=1)
cl, weight_ = instance_proto_alignment(feat_output, proto, pred)
ce = F.cross_entropy(cls_output, pred, reduction='none')
loss_local = torch.sum(weight_ * ce) / (torch.sum(weight_).item())
loss = loss_local * (1 - args.lamb) + cl * args.lamb

loss.backward()
optimizer_model.step()
print('Model: ' + str(loss.item()))

model.eval()
with torch.no_grad():
feat_output = model.feat_bottleneck(feat, edge_index, edge_weight)
cls_output = model.feat_classifier(feat_output)
softmax_out = F.softmax(cls_output, dim=1)
outputs_target = softmax_out**2 / ((softmax_out**2).sum(dim=0))

mem_cls = (1.0 - args.momentum) * mem_cls + args.momentum * outputs_target.clone()
mem_fea = (1.0 - args.momentum) * mem_fea + args.momentum * feat_output.clone()

for k,v in model.named_parameters():
v.requires_grad = False

perturbed_edge_weight = perturbed_edge_weight.detach()
for loop_feat in range(args.loop_feat):
optimizer_feat.zero_grad()
delta_feat.requires_grad = True
loss = test_time_loss(model, target_data.x + delta_feat, edge_index, mem_fea, mem_cls, edge_weight)
loss.backward()
optimizer_feat.step()
print('Feat: ' + str(loss.item()))

new_feat = (data.x + delta_feat).detach()
for loop_adj in range(args.loop_adj):
perturbed_edge_weight.requires_grad = True
edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected)
loss = test_time_loss(model, new_feat, edge_index, mem_fea, mem_cls, edge_weight)
print('Adj: ' + str(loss.item()))

gradient = grad_with_checkpoint(loss, perturbed_edge_weight)[0]

with torch.no_grad():
update_edge_weights(gradient)
perturbed_edge_weight = project(n_perturbations, perturbed_edge_weight, 1e-7)

if args.loop_adj != 0:
edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected)
edge_weight = edge_weight.detach()

if args.loop_feat != 0:
feat = (target_data.x + delta_feat).detach()

edge_index, edge_weight = sample_final_edges(n_perturbations, perturbed_edge_weight, target_data, modified_edge_index, mem_fea, mem_cls)

test_acc, _ = evaluate(target_data.x + delta_feat, edge_index, edge_weight, target_data.y, model)
print('acc : ' + str(test_acc))
print('Optimization Finished!\n')


def instance_proto_alignment(feat, center, pred):
feat_norm = F.normalize(feat, dim=1)
center_norm = F.normalize(center, dim=1)
sim = torch.matmul(feat_norm, center_norm.t())

num_nodes = feat.size(0)
weight = sim[range(num_nodes), pred]
sim = torch.exp(sim / args.tau)
pos_sim = sim[range(num_nodes), pred]

sim_feat = torch.matmul(feat_norm, feat_norm.t())
sim_feat = torch.exp(sim_feat / args.tau)
ident = sim_feat[range(num_nodes), range(num_nodes)]

logit = pos_sim / (sim.sum(dim=1) - pos_sim + sim_feat.sum(dim=1) - ident + 1e-8)
loss = - torch.log(logit + 1e-8).mean()

return loss, weight


def update_edge_weights(gradient):
optimizer_adj.zero_grad()
perturbed_edge_weight.grad = gradient
optimizer_adj.step()
perturbed_edge_weight.data[perturbed_edge_weight < 1e-7] = 1e-7


def test_time_loss(model, feat, edge_index, mem_fea, mem_cls, edge_weight=None):
model.eval()
feat_output = model.feat_bottleneck(feat, edge_index, edge_weight)
cls_output = model.feat_classifier(feat_output)
softmax_out = F.softmax(cls_output, dim=1)
_, predict = torch.max(softmax_out, 1)
mean_ent = Entropy(softmax_out)
est_p = (mean_ent<mean_ent.mean()).sum().item() / mean_ent.size(0)
value = mean_ent

predict = predict.cpu().numpy()
train_idx = np.zeros(predict.shape)

cls_k = args.num_classes
for c in range(cls_k):
c_idx = np.where(predict==c)
c_idx = c_idx[0]
c_value = value[c_idx]

_, idx_ = torch.sort(c_value)
c_num = len(idx_)
c_num_s = int(c_num * est_p / 5)

for ei in range(0, c_num_s):
ee = c_idx[idx_[ei]]
train_idx[ee] = 1

train_idx = np.array(train_idx, dtype=bool)
pred_label = predict[train_idx]
pseudo_label = torch.from_numpy(pred_label).to(args.device)

pred_output = cls_output[train_idx]
loss = F.cross_entropy(pred_output, pseudo_label)

distance = feat_output @ mem_fea.T
_, idx_near = torch.topk(distance, dim=-1, largest=True, k=args.K + 1)
idx_near = idx_near[:, 1:] # batch x K

mem_near = mem_fea[idx_near] # batch x K x d
feat_output_un = feat_output.unsqueeze(1).expand(-1, args.K, -1) # batch x K x d
loss -= torch.mean((feat_output_un * mem_near).sum(-1).sum(1)/args.K) * 0.1

_, pred_mem = torch.max(mem_cls, dim=1)
_, pred = torch.max(softmax_out, dim=1)
idx = pred.unsqueeze(-1) == pred_mem
neg_num = torch.sum(~idx, dim=1)
dis = (distance * ~idx).sum(1)/neg_num
loss += dis.mean() * 0.1

return loss


@torch.no_grad()
def sample_final_edges(n_perturbations, perturbed_edge_weight, data, modified_edge_index, mem_fea, mem_cls):
best_loss = float('Inf')
perturbed_edge_weight = perturbed_edge_weight.detach()
# TODO: potentially convert to assert
perturbed_edge_weight[perturbed_edge_weight <= 1e-7] = 0

# _, feat, labels = self.data.edge_index, self.data.x, self.data.y
feat = data.x.to(args.device)
edge_index = data.edge_index.to(args.device)
edge_weight = torch.ones(edge_index.shape[1]).to(args.device)
# self.edge_index = data.graph['edge_index'].to(self.device)

for i in range(20):
if best_loss == float('Inf'):
# In first iteration employ top k heuristic instead of sampling
sampled_edges = torch.zeros_like(perturbed_edge_weight).to(args.device)
sampled_edges[torch.topk(perturbed_edge_weight, n_perturbations).indices] = 1
else:
sampled_edges = torch.bernoulli(perturbed_edge_weight).float()

if sampled_edges.sum() > n_perturbations:
n_samples = sampled_edges.sum()
if args.debug ==2:
print(f'{i}-th sampling: too many samples {n_samples}')
continue

perturbed_edge_weight = sampled_edges

edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected)
with torch.no_grad():
loss = test_time_loss(model, feat, edge_index, mem_fea, mem_cls, edge_weight)

# Save best sample
if best_loss > loss:
best_loss = loss
print('best_loss:', best_loss.item())
best_edges = perturbed_edge_weight.clone().cpu()

# Recover best sample
perturbed_edge_weight.data.copy_(best_edges.to(args.device))

edge_index, edge_weight = get_modified_adj(modified_edge_index, perturbed_edge_weight, n, args.device, edge_index, edge_weight, args.make_undirected)
edge_mask = edge_weight == 1
make_undirected = args.make_undirected

allowed_perturbations = 2 * n_perturbations if make_undirected else n_perturbations
edges_after_attack = edge_mask.sum()
clean_edges = edge_index.shape[1]
assert (edges_after_attack >= clean_edges - allowed_perturbations
and edges_after_attack <= clean_edges + allowed_perturbations), \
f'{edges_after_attack} out of range with {clean_edges} clean edges and {n_perturbations} pertutbations'

return edge_index[:, edge_mask], edge_weight[edge_mask]

if __name__ == '__main__':
train_target(data, perturbed_edge_weight)
132 changes: 132 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os.path as osp
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_sparse import coalesce

import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)


def grad_with_checkpoint(outputs, inputs):
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
for input in inputs:
if not input.is_leaf:
input.retain_grad()
torch.autograd.backward(outputs)

grad_outputs = []
for input in inputs:
grad_outputs.append(input.grad.clone())
input.grad.zero_()
return grad_outputs


def linear_to_triu_idx(n: int, lin_idx: torch.Tensor) -> torch.Tensor:
row_idx = (
n
- 2
- torch.floor(torch.sqrt(-8 * lin_idx.double() + 4 * n * (n - 1) - 7) / 2.0 - 0.5)
).long()
col_idx = (
lin_idx
+ row_idx
+ 1 - n * (n - 1) // 2
+ (n - row_idx) * ((n - row_idx) - 1) // 2
)
return torch.stack((row_idx, col_idx))


def bisection(edge_weights, a, b, n_perturbations, epsilon=1e-5, iter_max=1e5):
def func(x):
return torch.clamp(edge_weights - x, 0, 1).sum() - n_perturbations

miu = a
for i in range(int(iter_max)):
miu = (a + b) / 2
# Check if middle point is root
if (func(miu) == 0.0):
break
# Decide the side to repeat the steps
if (func(miu) * func(a) < 0):
b = miu
else:
a = miu
if ((b - a) <= epsilon):
break
return miu


def project(n_perturbations, values, eps, inplace=False):
if not inplace:
values = values.clone()

if torch.clamp(values, 0, 1).sum() > n_perturbations:
left = (values - 1).min()
right = values.max()
miu = bisection(values, left, right, n_perturbations)
values.data.copy_(torch.clamp(
values - miu, min=eps, max=1 - eps
))
else:
values.data.copy_(torch.clamp(values, min=eps, max=1 - eps))

return values


def get_modified_adj(modified_edge_index, perturbed_edge_weight, n, device, edge_index, edge_weight, make_undirected=False):
if make_undirected:
modified_edge_index, modified_edge_weight = to_symmetric(modified_edge_index, perturbed_edge_weight, n)
else:
modified_edge_index, modified_edge_weight = modified_edge_index, perturbed_edge_weight
edge_index = torch.cat((edge_index.to(device), modified_edge_index), dim=-1)
edge_weight = torch.cat((edge_weight.to(device), modified_edge_weight))
edge_index, edge_weight = coalesce(edge_index, edge_weight, m=n, n=n, op='sum')

# Allow removal of edges
edge_weight[edge_weight > 1] = 2 - edge_weight[edge_weight > 1]
return edge_index, edge_weight


def to_symmetric(edge_index, edge_weight, n, op='mean'):
symmetric_edge_index = torch.cat(
(edge_index, edge_index.flip(0)), dim=-1
)

symmetric_edge_weight = edge_weight.repeat(2)

symmetric_edge_index, symmetric_edge_weight = coalesce(
symmetric_edge_index,
symmetric_edge_weight,
m=n,
n=n,
op=op
)
return symmetric_edge_index, symmetric_edge_weight


def compute_test(mask, model, data):
model.eval()
output = model(data.x, data.edge_index)
loss = F.nll_loss(output[mask], data.y[mask])
pred = output[mask].max(dim=1)[1]
correct = pred.eq(data.y[mask]).sum().item()
acc = correct * 1.0 / (mask.sum().item())

return acc, loss


def evaluate(x, edge_index, edge_weight, y, model):
model.eval()
output = model(x, edge_index, edge_weight)
loss = F.nll_loss(output, y)
pred = output.max(dim=1)[1]
correct = pred.eq(y).sum().item()
acc = correct * 1.0 / len(y)

return acc, loss

def Entropy(input_):
entropy = -input_ * torch.log(input_ + 1e-8)
entropy = torch.sum(entropy, dim=1)
return entropy

0 comments on commit d721f51

Please sign in to comment.