-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_ood_vit.py
540 lines (425 loc) · 19.9 KB
/
test_ood_vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
import sys
from utils import log
import resnetv2
import torch
import torchvision as tv
import time
import torchvision.models as models
import numpy as np
import torch.nn as nn
from utils.test_utils import arg_parser, get_measures
import os
import math
import timm
from sklearn.linear_model import LogisticRegressionCV
from torch.autograd import Variable
from utils.mahalanobis_lib import get_Mahalanobis_score
def vit_intermediate_forward(model, x, layer_index):
"""
Performs intermediate forward pass through the Vision Transformer (ViT) model.
The function returns the output from the specified layer for Mahalanobis score computation.
Args:
model (nn.Module): The Vision Transformer model.
x (torch.Tensor): The input tensor.
layer_index (int or str): The index of the layer to return the output from.
If 'all', returns the output from all layers.
If None, returns the final output of the model.
If an integer between 0 and 3, returns the output from the corresponding layer.
Returns:
torch.Tensor or tuple: The output tensor(s) from the specified layer(s).
"""
out_list = []
x = model.module._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = model.module.class_token.expand(n, -1, -1)
out = torch.cat([batch_class_token, x], dim=1)
# x = model.module.encoder(x)
for (i,blk) in enumerate(model.module.encoder.layers):
out = blk(out)
if (i+1)%3==0:
out_list.append(out)
out = model.module.encoder.ln(out)
out = out[:, 0]
out = model.module.heads(out)
if layer_index == 'all':
return out, out_list
elif layer_index is None:
return out
elif layer_index == 0:
return out_list[0]
elif layer_index == 1:
return out_list[1]
elif layer_index == 2:
return out_list[2]
elif layer_index == 3:
return out_list[3]
def _l2normalize(v, eps=1e-10):
return v / (torch.norm(v,dim=2,keepdim=True) + eps)
# Power Iteration for acceleration
def power_iteration_plus(A, iter=50):
u = torch.FloatTensor(1, A.size(1)).normal_(0, 1).view(1,1,A.size(1)).repeat(A.size(0),1,1).to(A)
v = torch.FloatTensor(A.size(2),1).normal_(0, 1).view(1,A.size(2),1).repeat(A.size(0),1,1).to(A)
for _ in range(iter):
v = _l2normalize(u.bmm(A)).transpose(1,2)
u = _l2normalize(A.bmm(v).transpose(1,2))
sigma = u.bmm(A).bmm(v)
sub = sigma * u.transpose(1,2).bmm(v.transpose(1,2))
return sub
def text_save(filename, data):
file = open(filename,'a')
for i in range(len(data)):
s = str(data[i]).replace('[','').replace(']','')
s = s.replace("'",'').replace(',','') +'\n'
file.write(s)
file.close()
def make_id_ood(args, logger):
"""Returns train and validation datasets."""
crop = 224 #The resolution is fixed for transformers
val_tx = tv.transforms.Compose([
tv.transforms.Resize((crop, crop)),
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
in_set = tv.datasets.ImageFolder(args.in_datadir, val_tx)
out_set = tv.datasets.ImageFolder(args.out_datadir, val_tx)
logger.info(f"Using an in-distribution set with {len(in_set)} images.")
logger.info(f"Using an out-of-distribution set with {len(out_set)} images.")
in_loader = torch.utils.data.DataLoader(
in_set, batch_size=args.batch, shuffle=False,
num_workers=args.workers, pin_memory=True, drop_last=False)
out_loader = torch.utils.data.DataLoader(
out_set, batch_size=args.batch, shuffle=False,
num_workers=args.workers, pin_memory=True, drop_last=False)
return in_set, out_set, in_loader, out_loader
#MSP Score
def iterate_data_msp(data_loader, model):
confs = []
m = torch.nn.Softmax(dim=-1).cuda()
for b, (x, y) in enumerate(data_loader):
with torch.no_grad():
x = x.cuda()
logits = model(x)
conf, _ = torch.max(m(logits), dim=-1)
confs.extend(conf.data.cpu().numpy())
return np.array(confs)
#ODIN Score
def iterate_data_odin(data_loader, model, epsilon, temper, logger):
criterion = torch.nn.CrossEntropyLoss().cuda()
confs = []
for b, (x, y) in enumerate(data_loader):
x = Variable(x.cuda(), requires_grad=True)
outputs = model(x)
maxIndexTemp = np.argmax(outputs.data.cpu().numpy(), axis=1)
outputs = outputs / temper
labels = Variable(torch.LongTensor(maxIndexTemp).cuda())
loss = criterion(outputs, labels)
loss.backward()
# Normalizing the gradient to binary in {0, 1}
gradient = torch.ge(x.grad.data, 0)
gradient = (gradient.float() - 0.5) * 2
# Adding small perturbations to images
tempInputs = torch.add(x.data, -epsilon, gradient)
outputs = model(Variable(tempInputs))
outputs = outputs / temper
# Calculating the confidence after adding perturbations
nnOutputs = outputs.data.cpu()
nnOutputs = nnOutputs.numpy()
nnOutputs = nnOutputs - np.max(nnOutputs, axis=1, keepdims=True)
nnOutputs = np.exp(nnOutputs) / np.sum(np.exp(nnOutputs), axis=1, keepdims=True)
confs.extend(np.max(nnOutputs, axis=1))
if b % 100 == 0:
logger.info('{} batches processed'.format(b))
return np.array(confs)
#Energy Score
def iterate_data_energy(data_loader, model, temper):
confs = []
for b, (x, y) in enumerate(data_loader):
with torch.no_grad():
x = x.cuda()
logits = model(x)
conf = temper * torch.logsumexp(logits / temper, dim=1)
confs.extend(conf.data.cpu().numpy())
return np.array(confs)
#Mahalanobis Score
def iterate_data_mahalanobis(data_loader, model, num_classes, sample_mean, precision,
num_output, magnitude, regressor, logger):
confs = []
for b, (x, y) in enumerate(data_loader):
if b % 10 == 0:
logger.info('{} batches processed'.format(b))
x = x.cuda()
Mahalanobis_scores = get_Mahalanobis_score(x, model, num_classes, sample_mean, precision, num_output, magnitude)
scores = -regressor.predict_proba(Mahalanobis_scores)[:, 1]
confs.extend(scores)
return np.array(confs)
#GradNorm Score
def iterate_data_gradnorm(data_loader, model, temperature, num_classes):
confs = []
logsoftmax = torch.nn.LogSoftmax(dim=-1).cuda()
for b, (x, y) in enumerate(data_loader):
if b % 10000 == 0:
print('{} batches processed'.format(b))
inputs = Variable(x.cuda(), requires_grad=True)
model.zero_grad()
outputs = model(inputs)
targets = torch.ones((inputs.shape[0], num_classes)).cuda()
outputs = outputs / temperature
loss = torch.mean(torch.sum(-targets * logsoftmax(outputs), dim=-1))
loss.backward()
layer_grad = model.heads.head.weight.grad.data
layer_grad_norm = torch.sum(torch.abs(layer_grad)).cpu().numpy()
confs.append(layer_grad_norm)
return np.array(confs)
# Our proposed RankFeat Score
def iterate_data_rankfeat(data_loader, model, temperature):
confs = []
for b, (x, y) in enumerate(data_loader):
if b % 100 == 0:
print('{} batches processed'.format(b))
x = x.cuda()
with torch.no_grad():
# Reshape and permute the input tensor
x = model._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = model.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = x + model.encoder.pos_embedding
x = model.encoder.dropout(x)
block_lens = 12
svd_block_index = 11
for i in range(svd_block_index):
x = model.encoder.layers[i](x)
feat = x
u, s, v = torch.linalg.svd(feat)
feat = feat - s[:, 0:1].unsqueeze(2) * u[:, :, 0:1].bmm(v[:, 0:1, :])
#if you want to use PI for acceleration, comment the above 2 lines and uncomment the line below
# feat = feat - power_iteration_plus(feat, iter=20)
for i in range(svd_block_index, block_lens):
feat = model.encoder.layers[i](feat)
feat = model.encoder.ln(feat)
feat = feat[:, 0]
logits = model.heads(feat)
conf = temperature * torch.logsumexp(logits / temperature, dim=1)
confs.extend(conf.data.cpu().numpy())
return np.array(confs)
# Our proposed RankFeat+RankWeight Score
def iterate_data_rankfeatweight(data_loader, model, temperature):
confs = []
weight = model.encoder.layers[10].mlp[3].weight.data
B, C = weight.size()
weight_sub = power_iteration_plus(weight.unsqueeze(0), iter=100)
weight = weight - weight_sub.squeeze()
model.encoder.layers[10].mlp[3].weight.data = weight
for b, (x, y) in enumerate(data_loader):
if b % 100 == 0:
print('{} batches processed'.format(b))
x = x.cuda()
with torch.no_grad():
# Reshape and permute the input tensor
x = model._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = model.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
x = x + model.encoder.pos_embedding
x = model.encoder.dropout(x)
block_lens = 12
svd_block_index = 11
for i in range(svd_block_index):
x = model.encoder.layers[i](x)
feat = x
u, s, v = torch.linalg.svd(feat)
feat = feat - s[:, 0:1].unsqueeze(2) * u[:, :, 0:1].bmm(v[:, 0:1, :])
#if you want to use PI for acceleration, comment the above 2 lines and uncomment the line below
# feat = feat - power_iteration_plus(feat, iter=20)
for i in range(svd_block_index, block_lens):
feat = model.encoder.layers[i](feat)
feat = model.encoder.ln(feat)
feat = feat[:, 0]
logits = model.heads(feat)
conf = temperature * torch.logsumexp(logits / temperature, dim=1)
confs.extend(conf.data.cpu().numpy())
return np.array(confs)
def iterate_data_rankweight(data_loader, model, temperature):
confs = []
weight = model.encoder.layers[10].mlp[3].weight.data
B, C = weight.size()
weight_sub = power_iteration_plus(weight.unsqueeze(0), iter=100)
weight = weight - weight_sub.squeeze()
model.encoder.layers[10].mlp[3].weight.data = weight
for b, (x, y) in enumerate(data_loader):
if b % 100 == 0:
print('{} batches processed'.format(b))
x = x.cuda()
with torch.no_grad():
x = x.cuda()
logits = model(x)
conf = temperature * torch.logsumexp(logits / temperature, dim=1)
confs.extend(conf.data.cpu().numpy())
return np.array(confs)
def iterate_data_react(data_loader, model, temperature):
confs = []
for b, (x, y) in enumerate(data_loader):
if b % 100 == 0:
print('{} batches processed'.format(b))
inputs = x.cuda()
# Reshape and permute the input tensor
inputs = model._process_input(inputs)
n = inputs.shape[0]
# Expand the class token to the full batch
batch_class_token = model.class_token.expand(n, -1, -1)
inputs = torch.cat([batch_class_token, inputs], dim=1)
feat = model.encoder(inputs)
# feat = model.forward_features(inputs)
# feat= model.norm(feat)
feat = torch.clip(feat,max=0.483) #threshold computed by 90% percentile of activations
logits = model.heads(feat[:, 0])
conf = temperature * torch.logsumexp(logits / temperature, dim=1)
confs.extend(conf.data.cpu().numpy())
return np.array(confs)
def compute_threshold(data_loader, model):
"""
Compute the threshold of activation values for React.
Args:
data_loader (torch.utils.data.DataLoader): The data loader for loading the input data.
model: The model used for computing the activation values.
Returns:
float: The computed threshold value.
"""
model.eval()
activation_list = []
for b, (x, y) in enumerate(data_loader):
with torch.no_grad():
x = x.cuda()
# Reshape and permute the input tensor
x = model._process_input(x)
n = x.shape[0]
# Expand the class token to the full batch
batch_class_token = model.class_token.expand(n, -1, -1)
x = torch.cat([batch_class_token, x], dim=1)
feat = model.encoder(x)
dim = feat.shape[1]
curr_batch_size = feat.shape[0]
activation_list.append(feat.data.cpu().numpy().reshape(curr_batch_size, dim, -1).mean(1))
activation_list = np.concatenate(activation_list, axis=0)
print(np.percentile(activation_list.flatten(), 60))
print(np.percentile(activation_list.flatten(), 90))
def run_eval(model, in_loader, out_loader, logger, args, num_classes):
# switch to evaluate mode
model.eval()
logger.info("Running test...")
logger.flush()
if args.score == 'MSP':
logger.info("Processing in-distribution data...")
in_scores = iterate_data_msp(in_loader, model)
logger.info("Processing out-of-distribution data...")
out_scores = iterate_data_msp(out_loader, model)
elif args.score == 'ODIN':
logger.info("Processing in-distribution data...")
in_scores = iterate_data_odin(in_loader, model, args.epsilon_odin, args.temperature_odin, logger)
logger.info("Processing out-of-distribution data...")
out_scores = iterate_data_odin(out_loader, model, args.epsilon_odin, args.temperature_odin, logger)
elif args.score == 'Energy':
logger.info("Processing in-distribution data...")
in_scores = iterate_data_energy(in_loader, model, args.temperature_energy)
logger.info("Processing out-of-distribution data...")
out_scores = iterate_data_energy(out_loader, model, args.temperature_energy)
elif args.score == 'Mahalanobis':
sample_mean, precision, lr_weights, lr_bias, magnitude = np.load(
os.path.join(args.mahalanobis_param_path, 'results.npy'), allow_pickle=True)
sample_mean = [s.cuda() for s in sample_mean]
precision = [p.cuda() for p in precision]
regressor = LogisticRegressionCV(cv=2).fit([[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]],
[0, 0, 1, 1])
regressor.coef_ = lr_weights
regressor.intercept_ = lr_bias
temp_x = torch.rand(2, 3, 224, 224)
temp_x = Variable(temp_x).cuda()
# temp_list = model.intermediate_forward(x=temp_x, layer_index='all')[1]
temp_list = vit_intermediate_forward(model=model, x=temp_x, layer_index='all')[1]
num_output = len(temp_list)
logger.info("Processing in-distribution data...")
in_scores = iterate_data_mahalanobis(in_loader, model, num_classes, sample_mean, precision,
num_output, magnitude, regressor, logger)
logger.info("Processing out-of-distribution data...")
out_scores = iterate_data_mahalanobis(out_loader, model, num_classes, sample_mean, precision,
num_output, magnitude, regressor, logger)
elif args.score == 'GradNorm':
logger.info("Processing in-distribution data...")
in_scores = iterate_data_gradnorm(in_loader, model, args.temperature_gradnorm, num_classes)
logger.info("Processing out-of-distribution data...")
out_scores = iterate_data_gradnorm(out_loader, model, args.temperature_gradnorm, num_classes)
elif args.score == 'RankFeat':
logger.info("Processing in-distribution data...")
in_scores = iterate_data_rankfeat(in_loader, model, args.temperature_rankfeat)
logger.info("Processing out-of-distribution data...")
out_scores = iterate_data_rankfeat(out_loader, model, args.temperature_rankfeat)
elif args.score == 'RankFeatWeight':
logger.info("Processing in-distribution data...")
in_scores = iterate_data_rankfeatweight(in_loader, model, args.temperature_rankfeat)
logger.info("Processing out-of-distribution data...")
out_scores = iterate_data_rankfeatweight(out_loader, model, args.temperature_rankfeat)
elif args.score == 'React':
logger.info("Processing in-distribution data...")
in_scores = iterate_data_react(in_loader, model, args.temperature_react)
logger.info("Processing out-of-distribution data...")
out_scores = iterate_data_react(out_loader, model, args.temperature_react)
else:
raise ValueError("Unknown score type {}".format(args.score))
in_examples = in_scores.reshape((-1, 1))
out_examples = out_scores.reshape((-1, 1))
auroc, aupr_in, aupr_out, fpr95 = get_measures(in_examples, out_examples)
logger.info('============Results for {}============'.format(args.score))
logger.info('AUROC: {}'.format(auroc))
logger.info('AUPR (In): {}'.format(aupr_in))
logger.info('AUPR (Out): {}'.format(aupr_out))
logger.info('FPR95: {}'.format(fpr95))
logger.flush()
def main(args):
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_id)
logger = log.setup_logger(args)
torch.backends.cudnn.benchmark = True
if args.score == 'GradNorm':
args.batch = 1
in_set, out_set, in_loader, out_loader = make_id_ood(args, logger)
logger.info(f"Loading model from {args.model_path}")
# load the pre-trained model
model = tv.models.vit_b_16(weights='IMAGENET1K_V1')
# for mahalanobis
# model = torch.nn.DataParallel(model)
model = model.cuda()
start_time = time.time()
run_eval(model, in_loader, out_loader, logger, args, num_classes=len(in_set.classes))
end_time = time.time()
logger.info("Total running time: {}".format(end_time - start_time))
if __name__ == "__main__":
parser = arg_parser()
parser.add_argument("--in_datadir", help="Path to the in-distribution data folder.")
parser.add_argument("--out_datadir", help="Path to the out-of-distribution data folder.")
parser.add_argument('--score', choices=['MSP', 'ODIN', 'Energy', 'Mahalanobis', 'GradNorm', 'RankFeat', 'RankFeatWeight', 'React'], default='RankFeatWeight')
# arguments for ODIN
parser.add_argument('--temperature_odin', default=1000, type=int,
help='temperature scaling for odin')
parser.add_argument('--epsilon_odin', default=0.0, type=float,
help='perturbation magnitude for odin')
# arguments for Energy
parser.add_argument('--temperature_energy', default=1, type=int,
help='temperature scaling for energy')
# arguments for Mahalanobis
parser.add_argument('--mahalanobis_param_path', default='checkpoints/finetune/tune_mahalanobis',
help='path to tuned mahalanobis parameters')
# arguments for GradNorm
parser.add_argument('--temperature_gradnorm', default=1, type=float,
help='temperature scaling for GradNorm')
# arguments for RankFeat
parser.add_argument('--temperature_rankfeat', default=1, type=float,
help='temperature scaling for RankFeat')
# arguments for ReAct
parser.add_argument('--temperature_react', default=1, type=float,
help='temperature scaling for React')
# arguments for CUDA device index
parser.add_argument('--cuda_id', default=0, type=int,
help='cuda index for the test')
main(parser.parse_args())