@@ -32,6 +32,7 @@ def dp(g1, g2):
32
32
parser = argparse .ArgumentParser (description = 'Certification' )
33
33
parser .add_argument ('--evaluations' , type = str , help = 'name of evaluations directory' )
34
34
parser .add_argument ('--num_classes' , type = int , default = 10 , help = 'Number of classes' )
35
+ parser .add_argument ('--version' , required = True , type = int , help = 'version of base classifiers' )
35
36
36
37
args = parser .parse_args ()
37
38
if not os .path .exists ('./certs' ):
@@ -41,7 +42,7 @@ def dp(g1, g2):
41
42
# print(device)
42
43
43
44
44
- filein = torch .load ('evaluations/' + args .evaluations + '.pth' , map_location = torch .device (device ))
45
+ filein = torch .load ('evaluations/' + args .evaluations + '_v' + str ( args . version ) + ' .pth' , map_location = torch .device (device ))
45
46
46
47
labels = filein ['labels' ]
47
48
scores = filein ['scores' ]
@@ -158,7 +159,7 @@ def dp(g1, g2):
158
159
certs = cert_dpa
159
160
torchidx = idx_dpa
160
161
certs [torchidx != labels ] = - 1
161
- torch .save (certs ,'./certs/dpa_v2_ ' + args .evaluations + '.pth' )
162
+ torch .save (certs ,'./certs/v_dpa_ ' + args .evaluations + '_v' + str ( args . version ) + '.pth' )
162
163
a = certs .cpu ().sort ()[0 ].numpy ()
163
164
164
165
dpa_accs = np .array ([(i <= a ).sum () for i in np .arange (np .amax (a )+ 1 )])/ num_of_samples
@@ -172,7 +173,8 @@ def dp(g1, g2):
172
173
certs = cert_dpa_roe
173
174
torchidx = idx_dpa_roe
174
175
certs [torchidx != labels ] = - 1
175
- torch .save (certs ,'./certs/dpa_roe_' + args .evaluations + '.pth' )
176
+ torch .save (certs ,'./certs/v_dpa_roe_' + args .evaluations + '_v' + str (args .version ) + '.pth' )
177
+
176
178
a = certs .cpu ().sort ()[0 ].numpy ()
177
179
178
180
roe_dpa_accs = np .array ([(i <= a ).sum () for i in np .arange (np .amax (a )+ 1 )])/ num_of_samples
0 commit comments