-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
279 lines (219 loc) · 9.08 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
import os
import torch
from torchmetrics import AveragePrecision, AUROC, MeanMetric
import numpy as np
from torch_geometric.utils import (to_scipy_sparse_matrix, scatter, )
from torch_scatter import scatter_sum
import re
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
def sentence_accuracy(func, output, batch):
pred = output.logits.reshape(-1, output.logits.size()[-1])
target = output.answer_id.view(-1)
return func(pred, target)
def extract_numbers(text):
""" Extracts all numbers from a given text and returns them as a list of floats. """
numbers = re.findall(r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?", text)
if len(numbers) == 0:
# default value
return [0.0]
return [float(num) for num in numbers]
def sentence_base(func, output, batch):
pred_text = output.pred_text
answer = batch.label[batch.label_map.cpu().numpy()].tolist()
return func(pred_text, answer)
def sentence_mae(func, output, batch):
pred_text = output.pred_text
answer = output.answer
# print(pred_text)
# print(answer)
pred_values = [np.mean(extract_numbers(pred)) for pred in pred_text]
true_values = [np.mean(extract_numbers(true)) for true in answer]
print("pred", pred_values, "target", true_values)
return func(torch.tensor(pred_values), torch.tensor(true_values))
def sentence_perplexity(func, output, batch):
pred = output.logits.unsqueeze(0)
target = output.answer_id.unsqueeze(0)
return func(pred, target)
def auc_word(func, output, batch):
pred = output.logits.reshape(-1, 2, output.logits.size()[-1])[:, 0]
pos = pred[:, [1939, 3869]]
pos = torch.softmax(pos, dim=-1)[:, -1]
target = output.answer_id.view(-1, 2)[:, 0] == 3869
return func(pos, target.to(torch.long))
class NormalizedLossFactory:
def __init__(self, batch_size, seq_len):
self.denom = batch_size * seq_len
def __call__(self, func, output, batch):
sentence_size = len(batch.node_map)
pred = output.logits.reshape(-1, output.logits.size()[-1])
target = output.answer_id.view(-1)
# print(target, pred.sort()[1][:, -1])
numer = target.ne(-100).sum()
original_loss = func(pred, target)
norm_loss = original_loss * numer / (self.denom * sentence_size)
return original_loss
def normalized_loss_factory(batch_size, seq_len):
denom = batch_size * seq_len
def normalized_loss(func, output, batch):
sentence_size = len(batch.node_map)
pred = output.logits.reshape(-1, output.logits.size()[-1])
target = output.answer_id.view(-1)
# print(target, pred.sort()[1][:, -1])
numer = target.ne(-100).sum()
original_loss = func(pred, target)
norm_loss = original_loss * numer / (denom * sentence_size)
return original_loss
return normalized_loss
def auc_func(func, output, batch):
pred = output
label = batch.y.to(pred)
return func(pred.view(-1), label.view(-1))
def scatter_reg_func(func, output, batch):
repr_rep = output.repr
target = output.target[batch.bin_labels[batch.true_nodes_mask].to(torch.bool)]
return func(repr_rep, target)
class MatCLFunc:
def __init__(self, temp=0.1, hard_mine_ratio=0.8, sim="mae"):
self.temp = temp
if sim == "mae":
self.loss = torch.nn.L1Loss()
elif sim == "mse":
self.loss = torch.nn.MSELoss()
else:
raise NotImplementedError("unknown fidelity loss")
self.hard_mine_ratio = hard_mine_ratio
def __call__(self, output, batch):
n_classes = batch.num_classes
repr_rep = output.repr.repeat_interleave(n_classes, dim=0)
fidelity_loss = self.loss(output.repr,
output.target[batch.bin_labels[batch.true_nodes_mask].to(torch.bool)].to(torch.float))
# sim = torch.nn.functional.cosine_similarity(repr_rep.view(repr_rep.size()[0], -1), output.target.view(
# output.target.size()[0], -1), dim=-1)/self.temp
sim = torch.abs(repr_rep - output.target).mean(dim=(-1, -2)) / self.temp
sim = torch.exp(sim)
class_ind = torch.arange(len(n_classes), device=n_classes.device).repeat_interleave(n_classes, dim=0)
sim_loss = -torch.log(
sim[batch.bin_labels[batch.true_nodes_mask].to(torch.bool)] / scatter_sum(sim, class_ind, dim=0)).mean()
return fidelity_loss
def cl_wrap_func(proc_func, temp=0.1):
def wrap_func(func, output, batch):
n_classes = batch.num_classes[0]
repr_rep = output.repr.repeat_interleave(n_classes, dim=0)
sim = torch.nn.functional.cosine_similarity(repr_rep.view(repr_rep.size()[0], -1),
output.target.view(output.target.size()[0], -1), dim=-1) / temp
sim = sim.view(-1, 1)
return proc_func(func, sim, batch)
return wrap_func
class SimAnyAuc(torch.nn.Module):
def __init__(self, sim_metric="mse"):
super().__init__()
if sim_metric == "mse":
self.loss = torch.nn.MSELoss()
elif sim_metric == "mae":
self.loss = torch.nn.L1Loss()
else:
raise NotImplementedError("unknown fidelity measure")
self.metric = MeanMetric()
def __call__(self, *args, **kwargs):
self.update(*args, **kwargs)
def update(self, output, batch):
l1loss = self.loss(output.repr, output.target[batch.bin_labels[batch.true_nodes_mask].to(torch.bool)])
return self.metric(l1loss)
def compute(self):
return self.metric.compute()
def reset(self):
self.metric.reset()
def mean_func(func, output, batch):
return func(output, batch)
class MultiApr(torch.nn.Module):
def __init__(self, num_labels=1):
super().__init__()
self.metrics = torch.nn.ModuleList([AveragePrecision(task="binary") for i in range(num_labels)])
def update(self, preds, targets):
for i, met in enumerate(self.metrics):
pred = preds[:, i]
target = targets[:, i]
valid_idx = target == target
# print(pred[valid_idx])
# print(target[valid_idx])
met.update(pred[valid_idx], target[valid_idx].to(torch.long))
def compute(self):
full_val = []
for met in self.metrics:
try:
res = met.compute()
if res == res:
full_val.append(res)
except BaseException:
pass
return torch.tensor(full_val).mean()
def reset(self):
for met in self.metrics:
met.reset()
class MultiAuc(torch.nn.Module):
def __init__(self, num_labels=1):
super().__init__()
self.metrics = torch.nn.ModuleList([AUROC(task="binary") for i in range(num_labels)])
def update(self, preds, targets):
for i, met in enumerate(self.metrics):
pred = preds[:, i]
target = targets[:, i]
valid_idx = target == target
# print(pred[valid_idx])
# print(target[valid_idx])
met.update(pred[valid_idx], target[valid_idx].to(torch.long))
def compute(self):
full_val = []
for met in self.metrics:
try:
res = met.compute()
if res == res:
full_val.append(res)
except BaseException:
pass
return torch.tensor(full_val).mean()
def reset(self):
for met in self.metrics:
met.reset()
def scipy_rwpe(data, walk_length):
row, col = data.edge_index
N = data.num_nodes
value = data.edge_weight
if value is None:
value = torch.ones(data.num_edges, device=row.device)
value = scatter(value, row, dim_size=N, reduce="sum").clamp(min=1)[row]
value = 1.0 / value
adj = to_scipy_sparse_matrix(data.edge_index, edge_attr=value, num_nodes=data.num_nodes)
out = adj
pe_list = [out.diagonal()]
for _ in range(walk_length - 1):
out = out @ adj
pe_list.append(out.diagonal())
pe = torch.tensor(np.stack(pe_list, axis=-1))
return pe
def get_available_devices():
r"""Get IDs of all available GPUs.
Returns:
device (torch.device): Main device (GPU 0 or CPU).
gpu_ids (list): List of IDs of all GPUs that are available.
"""
gpu_ids = []
if torch.cuda.is_available():
gpu_ids += [gpu_id for gpu_id in range(torch.cuda.device_count())]
device = torch.device(f'cuda:{gpu_ids[0]}')
torch.cuda.set_device(device)
else:
device = torch.device('cpu')
return device, gpu_ids
def get_label_texts(labels):
label_texts = [None] * int(len(labels) * 2)
for entry in labels:
label_texts[labels[entry][0]] = (
"The molecule is effective to the following assay. " + labels[entry][1][0][:-41])
label_texts[labels[entry][0] + len(labels)] = (
"The molecule is not effective to the following assay. " + labels[entry][1][0][:-41])
return label_texts
def set_mask(data, name, index, dtype=torch.bool):
mask = torch.zeros(data.num_nodes, dtype=dtype)
mask[index] = True
setattr(data, name, mask)