-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathimportance.py
85 lines (66 loc) · 3.14 KB
/
importance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# training preparation functions and importance calculations
import os.path
import torch
import util
import numpy as np
import distance
from torch_geometric.data import Data, Batch, DataLoader
def prepare_and_get(graphs, gnn_model, original_graph_indices, alpha, theta, device1, device2, dataset_name):
original_graphs = graphs[original_graph_indices.tolist()]
neurosed_folder = f'data/{dataset_name}/neurosed'
if not os.path.exists(neurosed_folder):
os.makedirs(neurosed_folder)
neurosed_model_path = os.path.join(neurosed_folder, 'best_model.pt')
neurosed_model = distance.load_neurosed(original_graphs, neurosed_model_path=neurosed_model_path, device=device2)
original_graphs_elements_counts = util.graph_element_counts(original_graphs)
return {
'gnn_model': gnn_model,
'alpha': alpha,
'neurosed_model': neurosed_model,
'original_graphs': original_graphs,
'original_graphs_element_counts': original_graphs_elements_counts,
'distance_threshold': theta,
'gnn_device': device1,
'neurosed_device': device2
}
def call(graphs, wargs):
try:
preds, graph_embeddings = prediction(wargs['gnn_model'], Batch.from_data_list(graphs).to(wargs['gnn_device']))
preds = preds.cpu().numpy()
graph_embeddings = graph_embeddings.cpu().numpy()
except RuntimeError as re:
loader = DataLoader(graphs, batch_size=128)
preds, graph_embeddings = [], []
for batch in loader:
pred, graph_embedding = prediction(wargs['gnn_model'], batch.to(wargs['gnn_device']))
preds.append(pred)
graph_embeddings.append(graph_embedding)
preds = torch.cat(preds).cpu().numpy()
graph_embeddings = torch.cat(graph_embeddings).cpu().numpy()
torch.cuda.set_device(wargs['gnn_device'])
torch.cuda.empty_cache()
coverage = np.ones(shape=preds.shape) # .to(preds.device)
coverage_matrix = neurosed_threshold_coverage_estimation(wargs['neurosed_model'], graphs, wargs['original_graphs_element_counts'], wargs['distance_threshold'])
coverage_matrix = coverage_matrix.cpu()
torch.cuda.set_device(wargs['neurosed_model'].device)
torch.cuda.empty_cache()
return np.stack([preds, coverage]).T, graph_embeddings, coverage_matrix
@torch.no_grad()
def prediction(model, graphs):
node_embeddings, graph_embeddings, preds = model(graphs) # .to(model.device))
preds = torch.exp(preds)
return preds[:, [1]].sum(axis=1), graph_embeddings
@torch.no_grad()
def neurosed_threshold_coverage_estimation(neurosed_model, dataset, original_graphs_element_counts, threshold):
gras_element_counts = util.graph_element_counts(dataset)
batch_size = len(dataset)
while True:
try:
d = neurosed_model.predict_outer_with_queries(dataset, batch_size=batch_size).cpu()
break
except RuntimeError as e:
batch_size = batch_size // 2
s = torch.cartesian_prod(gras_element_counts, original_graphs_element_counts).sum(dim=1).view(len(dataset), len(original_graphs_element_counts))
d = d / s
selected = d <= threshold
return selected.float()