forked from hazimehh/robust-models-transfer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
289 lines (243 loc) · 12.7 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import argparse
import os
import cox.store
import numpy as np
import torch as ch
from cox import utils
from robustness import datasets, defaults, model_utils, train
from robustness.tools import helpers
from torch import nn
from torchvision import models
from utils import constants as cs
from utils import fine_tunify, transfer_datasets
parser = argparse.ArgumentParser(description='Transfer learning via pretrained Imagenet models',
conflict_handler='resolve')
parser = defaults.add_args_to_parser(defaults.CONFIG_ARGS, parser)
parser = defaults.add_args_to_parser(defaults.MODEL_LOADER_ARGS, parser)
parser = defaults.add_args_to_parser(defaults.TRAINING_ARGS, parser)
parser = defaults.add_args_to_parser(defaults.PGD_ARGS, parser)
# Custom arguments
parser.add_argument('--dataset', type=str, default='cifar',
help='Dataset (Overrides the one in robustness.defaults)')
parser.add_argument('--model-path', type=str, default='')
parser.add_argument('--resume', action='store_true',
help='Whether to resume or not (Overrides the one in robustness.defaults)')
parser.add_argument('--pytorch-pretrained', action='store_true',
help='If True, loads a Pytorch pretrained model.')
parser.add_argument('--cifar10-cifar10', action='store_true',
help='cifar10 to cifar10 transfer')
parser.add_argument('--subset', type=int, default=None,
help='number of training data to use from the dataset')
parser.add_argument('--no-tqdm', type=int, default=1,
choices=[0, 1], help='Do not use tqdm.')
parser.add_argument('--no-replace-last-layer', action='store_true',
help='Whether to avoid replacing the last layer')
parser.add_argument('--freeze-level', type=int, default=-1,
help='Up to what layer to freeze in the pretrained model (assumes a resnet architectures)')
parser.add_argument('--additional-hidden', type=int, default=0,
help='How many hidden layers to add on top of pretrained network + classification layer')
parser.add_argument('--per-class-accuracy', action='store_true', help='Report the per-class accuracy. '
'Can be used only with pets, caltech101, caltech256, aircraft, and flowers.')
def main(args, store):
'''Given arguments and a cox store, trains as a model. Check out the
argparse object in this file for argument options.
'''
ds, train_loader, validation_loader = get_dataset_and_loaders(args)
if args.per_class_accuracy:
assert args.dataset in ['pets', 'caltech101', 'caltech256', 'flowers', 'aircraft'], \
f'Per-class accuracy not supported for the {args.dataset} dataset.'
# VERY IMPORTANT
# We report the per-class accuracy using the validation
# set distribution. So ignore the training accuracy (as you will see it go
# beyond 100. Don't freak out, it doesn't really capture anything),
# just look at the validation accuarcy
args.custom_accuracy = get_per_class_accuracy(args, validation_loader)
model, checkpoint = get_model(args, ds)
if args.eval_only:
return train.eval_model(args, model, validation_loader, store=store)
update_params = freeze_model(model, freeze_level=args.freeze_level)
print(f"Dataset: {args.dataset} | Model: {args.arch}")
train.train_model(args, model, (train_loader, validation_loader), store=store,
checkpoint=checkpoint, update_params=update_params)
def get_per_class_accuracy(args, loader):
'''Returns the custom per_class_accuracy function. When using this custom function
look at only the validation accuracy. Ignore trainig set accuracy.
'''
def _get_class_weights(args, loader):
'''Returns the distribution of classes in a given dataset.
'''
if args.dataset in ['pets', 'flowers']:
targets = loader.dataset.targets
elif args.dataset in ['caltech101', 'caltech256']:
targets = np.array([loader.dataset.ds.dataset.y[idx]
for idx in loader.dataset.ds.indices])
elif args.dataset == 'aircraft':
targets = [s[1] for s in loader.dataset.samples]
counts = np.unique(targets, return_counts=True)[1]
class_weights = counts.sum()/(counts*len(counts))
return ch.Tensor(class_weights)
class_weights = _get_class_weights(args, loader)
def custom_acc(logits, labels):
'''Returns the top1 accuracy, weighted by the class distribution.
This is important when evaluating an unbalanced dataset.
'''
batch_size = labels.size(0)
maxk = min(5, logits.shape[-1])
prec1, _ = helpers.accuracy(
logits, labels, topk=(1, maxk), exact=True)
normal_prec1 = prec1.sum(0, keepdim=True).mul_(100/batch_size)
weighted_prec1 = prec1 * class_weights[labels.cpu()].cuda()
weighted_prec1 = weighted_prec1.sum(
0, keepdim=True).mul_(100/batch_size)
return weighted_prec1.item(), normal_prec1.item()
return custom_acc
def get_dataset_and_loaders(args):
'''Given arguments, returns a datasets object and the train and validation loaders.
'''
if args.dataset in ['imagenet', 'stylized_imagenet']:
ds = datasets.ImageNet(args.data)
train_loader, validation_loader = ds.make_loaders(
only_val=args.eval_only, batch_size=args.batch_size, workers=8)
elif args.cifar10_cifar10:
ds = datasets.CIFAR('/tmp')
train_loader, validation_loader = ds.make_loaders(
only_val=args.eval_only, batch_size=args.batch_size, workers=8)
else:
ds, (train_loader, validation_loader) = transfer_datasets.make_loaders(
args.dataset, args.batch_size, 8, args.subset)
if type(ds) == int:
new_ds = datasets.CIFAR("/tmp")
new_ds.num_classes = ds
new_ds.mean = ch.tensor([0., 0., 0.])
new_ds.std = ch.tensor([1., 1., 1.])
ds = new_ds
return ds, train_loader, validation_loader
def resume_finetuning_from_checkpoint(args, ds, finetuned_model_path):
'''Given arguments, dataset object and a finetuned model_path, returns a model
with loaded weights and returns the checkpoint necessary for resuming training.
'''
print('[Resuming finetuning from a checkpoint...]')
if args.dataset in list(transfer_datasets.DS_TO_FUNC.keys()) and not args.cifar10_cifar10:
model, _ = model_utils.make_and_restore_model(
arch=pytorch_models[args.arch](
args.pytorch_pretrained) if args.arch in pytorch_models.keys() else args.arch,
dataset=datasets.ImageNet(''), add_custom_forward=args.arch in pytorch_models.keys())
while hasattr(model, 'model'):
model = model.model
model = fine_tunify.ft(
args.arch, model, ds.num_classes, args.additional_hidden)
model, checkpoint = model_utils.make_and_restore_model(arch=model, dataset=ds, resume_path=finetuned_model_path,
add_custom_forward=args.additional_hidden > 0 or args.arch in pytorch_models.keys())
else:
model, checkpoint = model_utils.make_and_restore_model(
arch=args.arch, dataset=ds, resume_path=finetuned_model_path)
return model, checkpoint
def get_model(args, ds):
'''Given arguments and a dataset object, returns an ImageNet model (with appropriate last layer changes to
fit the target dataset) and a checkpoint.The checkpoint is set to None if noe resuming training.
'''
finetuned_model_path = os.path.join(
args.out_dir, args.exp_name, 'checkpoint.pt.latest')
if args.resume and os.path.isfile(finetuned_model_path):
model, checkpoint = resume_finetuning_from_checkpoint(
args, ds, finetuned_model_path)
else:
if args.dataset in list(transfer_datasets.DS_TO_FUNC.keys()) and not args.cifar10_cifar10:
model, _ = model_utils.make_and_restore_model(
arch=pytorch_models[args.arch](
args.pytorch_pretrained) if args.arch in pytorch_models.keys() else args.arch,
dataset=datasets.ImageNet(''), resume_path=args.model_path, pytorch_pretrained=args.pytorch_pretrained,
add_custom_forward=args.arch in pytorch_models.keys())
checkpoint = None
else:
model, _ = model_utils.make_and_restore_model(arch=args.arch, dataset=ds,
resume_path=args.model_path, pytorch_pretrained=args.pytorch_pretrained)
checkpoint = None
if not args.no_replace_last_layer and not args.eval_only:
print(f'[Replacing the last layer with {args.additional_hidden} '
f'hidden layers and 1 classification layer that fits the {args.dataset} dataset.]')
while hasattr(model, 'model'):
model = model.model
model = fine_tunify.ft(
args.arch, model, ds.num_classes, args.additional_hidden)
model, checkpoint = model_utils.make_and_restore_model(arch=model, dataset=ds,
add_custom_forward=args.additional_hidden > 0 or args.arch in pytorch_models.keys())
else:
print('[NOT replacing the last layer]')
return model, checkpoint
def freeze_model(model, freeze_level):
'''
Freezes up to args.freeze_level layers of the model (assumes a resnet model)
'''
# Freeze layers according to args.freeze-level
update_params = None
if freeze_level != -1:
# assumes a resnet architecture
assert len([name for name, _ in list(model.named_parameters())
if f"layer{freeze_level}" in name]), "unknown freeze level (only {1,2,3,4} for ResNets)"
update_params = []
freeze = True
for name, param in model.named_parameters():
print(name, param.size())
if not freeze and f'layer{freeze_level}' not in name:
print(f"[Appending the params of {name} to the update list]")
update_params.append(param)
else:
param.requires_grad = False
if freeze and f'layer{freeze_level}' in name:
# if the freeze level is detected stop freezing onwards
freeze = False
return update_params
def args_preprocess(args):
'''
Fill the args object with reasonable defaults, and also perform a sanity check to make sure no
args are missing.
'''
if args.adv_train and eval(args.eps) == 0:
print('[Switching to standard training since eps = 0]')
args.adv_train = 0
if args.pytorch_pretrained:
assert not args.model_path, 'You can either specify pytorch_pretrained or model_path, not together.'
# CIFAR10 to CIFAR10 assertions
if args.cifar10_cifar10:
assert args.dataset == 'cifar10'
if args.data != '':
cs.CALTECH101_PATH = cs.CALTECH256_PATH = cs.PETS_PATH = cs.CARS_PATH = args.data
cs.FGVC_PATH = cs.FLOWERS_PATH = cs.DTD_PATH = cs.SUN_PATH = cs.FOOD_PATH = cs.BIRDS_PATH = args.data
ALL_DS = list(transfer_datasets.DS_TO_FUNC.keys()) + \
['imagenet', 'breeds_living_9', 'stylized_imagenet']
assert args.dataset in ALL_DS
# Important for automatic job retries on the cluster in case of premptions. Avoid uuids.
assert args.exp_name != None
# Preprocess args
args = defaults.check_and_fill_args(args, defaults.CONFIG_ARGS, None)
if not args.eval_only:
args = defaults.check_and_fill_args(args, defaults.TRAINING_ARGS, None)
if args.adv_train or args.adv_eval:
args = defaults.check_and_fill_args(args, defaults.PGD_ARGS, None)
args = defaults.check_and_fill_args(args, defaults.MODEL_LOADER_ARGS, None)
return args
if __name__ == "__main__":
args = parser.parse_args()
args = args_preprocess(args)
pytorch_models = {
'alexnet': models.alexnet,
'vgg16': models.vgg16,
'vgg16_bn': models.vgg16_bn,
'squeezenet': models.squeezenet1_0,
'densenet': models.densenet161,
'shufflenet': models.shufflenet_v2_x1_0,
'mobilenet': models.mobilenet_v2,
'resnext50_32x4d': models.resnext50_32x4d,
'mnasnet': models.mnasnet1_0,
}
# Create store and log the args
store = cox.store.Store(args.out_dir, args.exp_name)
if 'metadata' not in store.keys:
args_dict = args.__dict__
schema = cox.store.schema_from_dict(args_dict)
store.add_table('metadata', schema)
store['metadata'].append_row(args_dict)
else:
print('[Found existing metadata in store. Skipping this part.]')
main(args, store)