-
Notifications
You must be signed in to change notification settings - Fork 5
/
base_attack.py
119 lines (95 loc) · 3.27 KB
/
base_attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os.path as osp
import numpy as np
import scipy.sparse as sp
import torch
from torch.nn.modules.module import Module
import utils
class BaseAttack(Module):
"""Abstract base class for target attack classes.
Parameters
----------
model :
model to attack
nnodes : int
number of nodes in the input graph
attack_structure : bool
whether to attack graph structure
attack_features : bool
whether to attack node features
device: str
'cpu' or 'cuda'
"""
def __init__(self, model, nnodes, attack_structure=True, attack_features=False, device='cpu'):
super(BaseAttack, self).__init__()
self.surrogate = model
self.nnodes = nnodes
self.attack_structure = attack_structure
self.attack_features = attack_features
self.device = device
self.modified_adj = None
self.modified_features = None
if model is not None:
self.nclass = model.nclass
self.nfeat = model.nfeat
self.hidden_sizes = model.hidden_sizes
def attack(self, ori_adj, n_perturbations, **kwargs):
"""Generate attacks on the input graph.
Parameters
----------
ori_adj : scipy.sparse.csr_matrix
Original (unperturbed) adjacency matrix.
n_perturbations : int
Number of edge removals/additions.
Returns
-------
None.
"""
pass
def check_adj(self, adj):
"""Check if the modified adjacency is symmetric and unweighted.
"""
assert np.abs(adj - adj.T).sum() == 0, "Input graph is not symmetric"
assert adj.tocsr().max() == 1, "Max value should be 1!"
assert adj.tocsr().min() == 0, "Min value should be 0!"
def save_adj(self, root=r'/tmp/', name='mod_adj'):
"""Save attacked adjacency matrix.
Parameters
----------
root :
root directory where the variable should be saved
name : str
saved file name
Returns
-------
None.
"""
assert self.modified_adj is not None, \
'modified_adj is None! Please perturb the graph first.'
name = name + '.npz'
modified_adj = self.modified_adj
if type(modified_adj) is torch.Tensor:
sparse_adj = utils.to_scipy(modified_adj)
sp.save_npz(osp.join(root, name), sparse_adj)
else:
sp.save_npz(osp.join(root, name), modified_adj)
def save_features(self, root=r'/tmp/', name='mod_features'):
"""Save attacked node feature matrix.
Parameters
----------
root :
root directory where the variable should be saved
name : str
saved file name
Returns
-------
None.
"""
assert self.modified_features is not None, \
'modified_features is None! Please perturb the graph first.'
name = name + '.npz'
modified_features = self.modified_features
if type(modified_features) is torch.Tensor:
sparse_features = utils.to_scipy(modified_features)
sp.save_npz(osp.join(root, name), sparse_features)
else:
sp.save_npz(osp.join(root, name), modified_features)