Skip to content

Commit 4a36249

Browse files
author
Keivan Rezaei
committed
update certificates to get version as well
1 parent 73b98c9 commit 4a36249

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

dpa_roe_certify.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def dp(g1, g2):
3232
parser = argparse.ArgumentParser(description='Certification')
3333
parser.add_argument('--evaluations', type=str, help='name of evaluations directory')
3434
parser.add_argument('--num_classes', type=int, default=10, help='Number of classes')
35+
parser.add_argument('--version', required=True, type=int, help='version of base classifiers')
3536

3637
args = parser.parse_args()
3738
if not os.path.exists('./certs'):
@@ -41,7 +42,7 @@ def dp(g1, g2):
4142
# print(device)
4243

4344

44-
filein = torch.load('evaluations/'+args.evaluations + '.pth', map_location=torch.device(device))
45+
filein = torch.load('evaluations/'+args.evaluations + '_v' + str(args.version) + '.pth', map_location=torch.device(device))
4546

4647
labels = filein['labels']
4748
scores = filein['scores']
@@ -158,7 +159,7 @@ def dp(g1, g2):
158159
certs = cert_dpa
159160
torchidx = idx_dpa
160161
certs[torchidx != labels] = -1
161-
torch.save(certs,'./certs/dpa_v2_'+args.evaluations+'.pth')
162+
torch.save(certs,'./certs/v_dpa_'+args.evaluations+ '_v' + str(args.version) + '.pth')
162163
a = certs.cpu().sort()[0].numpy()
163164

164165
dpa_accs = np.array([(i <= a).sum() for i in np.arange(np.amax(a)+1)])/num_of_samples
@@ -172,7 +173,8 @@ def dp(g1, g2):
172173
certs = cert_dpa_roe
173174
torchidx = idx_dpa_roe
174175
certs[torchidx != labels] = -1
175-
torch.save(certs,'./certs/dpa_roe_'+args.evaluations+'.pth')
176+
torch.save(certs,'./certs/v_dpa_roe_'+args.evaluations+ '_v' + str(args.version) + '.pth')
177+
176178
a = certs.cpu().sort()[0].numpy()
177179

178180
roe_dpa_accs = np.array([(i <= a).sum() for i in np.arange(np.amax(a)+1)])/num_of_samples

fa_certify.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
parser.add_argument('--num_classes', type=int, default=10, help='number of classes')
2020
parser.add_argument('--k', default = 50, type=int, help='number of partitions')
2121
parser.add_argument('--d', default = 1, type=int, help='number of partitions that each model is trained on')
22-
22+
parser.add_argument('--version', required=True, type=int, help='version of base classifiers')
2323

2424
args = parser.parse_args()
2525

@@ -34,7 +34,7 @@
3434

3535
device = 'cpu'
3636

37-
filein = torch.load('evaluations/'+args.evaluations + '.pth', map_location=torch.device(device))
37+
filein = torch.load('evaluations/'+args.evaluations + '_v' + str(args.version) + '.pth', map_location=torch.device(device))
3838
labels = filein['labels']
3939
scores = filein['scores']
4040

@@ -97,7 +97,7 @@
9797

9898
base_acc = 100 * (max_classes == labels.unsqueeze(1)).sum().item() / (max_classes.shape[0] * max_classes.shape[1])
9999
print('Base classifier accuracy: ' + str(base_acc))
100-
torch.save(certs,'./certs/fa_'+args.evaluations+'.pth')
100+
torch.save(certs,'./certs/v_fa_'+args.evaluations+ '_v' + str(args.version) + '.pth')
101101
a = certs.cpu().sort()[0].numpy()
102102
accs = numpy.array([(i <= a).sum() for i in numpy.arange(numpy.amax(a)+1)])/predictions.shape[0]
103103
print('Smoothed classifier accuracy: ' + str(accs[0] * 100.) + '%')

fa_roe_certify.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_sample_cert(gap, gap_reducers):
3838
parser.add_argument('--num_classes', type=int, default=10, help='number of classes')
3939
parser.add_argument('--k', default = 50, type=int, help='number of partitions')
4040
parser.add_argument('--d', default = 1, type=int, help='number of partitions that each model is trained on')
41+
parser.add_argument('--version', required=True, type=int, help='version of base classifiers')
4142

4243

4344
args = parser.parse_args()
@@ -53,7 +54,7 @@ def get_sample_cert(gap, gap_reducers):
5354

5455
device = 'cpu'
5556

56-
filein = torch.load('evaluations/'+args.evaluations + '.pth', map_location=torch.device(device))
57+
filein = torch.load('evaluations/'+args.evaluations + '_v' + str(args.version) + '.pth', map_location=torch.device(device))
5758
labels = filein['labels']
5859
scores = filein['scores']
5960

@@ -203,7 +204,7 @@ def get_sample_cert(gap, gap_reducers):
203204

204205
base_acc = 100 * (max_classes[:, :, 0] == labels.unsqueeze(1)).sum().item() / (num_of_samples * num_of_models)
205206
print('Base classifier accuracy: ' + str(base_acc))
206-
torch.save(certs,'./certs/fa_roe_'+args.evaluations+'.pth')
207+
torch.save(certs,'./certs/v_fa_roe_'+args.evaluations+ '_v' + str(args.version) + '.pth')
207208
a = certs.cpu().sort()[0].numpy()
208209
accs = np.array([(i <= a).sum() for i in np.arange(np.amax(a)+1)])/predictions.shape[0]
209210
print('Smoothed classifier accuracy: ' + str(accs[0] * 100.) + '%')

0 commit comments

Comments
 (0)