diff --git a/example_usage2.py b/example_usage2.py index 7c9006c..9a32478 100644 --- a/example_usage2.py +++ b/example_usage2.py @@ -6,28 +6,27 @@ import antialiased_cnns -for force in [False, True]: - model = antialiased_cnns.resnet18(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.resnet34(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.resnet50(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.resnet101(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.resnet152(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.wide_resnet50_2(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.wide_resnet101_2(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.resnext50_32x4d(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.resnext101_32x8d(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.alexnet(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.vgg11(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.vgg11_bn(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.vgg13(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.vgg13_bn(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.vgg16(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.vgg16_bn(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.vgg19(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.vgg19_bn(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.densenet121(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.densenet169(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.densenet201(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.densenet161(pretrained=True, _force_nonfinetuned=force) - model = antialiased_cnns.mobilenet_v2(pretrained=True, _force_nonfinetuned=force) +model = antialiased_cnns.resnet18(pretrained=True) +model = antialiased_cnns.resnet34(pretrained=True) +model = antialiased_cnns.resnet50(pretrained=True) +model = antialiased_cnns.resnet101(pretrained=True) +model = antialiased_cnns.resnet152(pretrained=True) +model = antialiased_cnns.wide_resnet50_2(pretrained=True) +model = antialiased_cnns.wide_resnet101_2(pretrained=True) +model = antialiased_cnns.resnext50_32x4d(pretrained=True) +model = antialiased_cnns.resnext101_32x8d(pretrained=True) +model = antialiased_cnns.alexnet(pretrained=True) +model = antialiased_cnns.vgg11(pretrained=True) +model = antialiased_cnns.vgg11_bn(pretrained=True) +model = antialiased_cnns.vgg13(pretrained=True) +model = antialiased_cnns.vgg13_bn(pretrained=True) +model = antialiased_cnns.vgg16(pretrained=True) +model = antialiased_cnns.vgg16_bn(pretrained=True) +model = antialiased_cnns.vgg19(pretrained=True) +model = antialiased_cnns.vgg19_bn(pretrained=True) +model = antialiased_cnns.densenet121(pretrained=True) +model = antialiased_cnns.densenet169(pretrained=True) +model = antialiased_cnns.densenet201(pretrained=True) +model = antialiased_cnns.densenet161(pretrained=True) +model = antialiased_cnns.mobilenet_v2(pretrained=True) diff --git a/main.py b/main.py index f00f679..45c3449 100644 --- a/main.py +++ b/main.py @@ -99,6 +99,8 @@ metavar='N', help='print frequency (default: 10)') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') +parser.add_argument('--force_nonfinetuned', dest='force_nonfinetuned', action='store_true', + help='if pretrained, load the model that is pretrained from scratch (if available)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', @@ -211,7 +213,9 @@ def main_worker(gpu, ngpus_per_node, args): # create model print("=> creating model '{}'".format(args.arch)) if(args.arch.split('_')[-1][:-1]=='lpf'): # antialiased model - model = antialiased_cnns.__dict__[args.arch[:-5]](pretrained=args.pretrained, filter_size=int(args.arch[-1])) + model = antialiased_cnns.__dict__[args.arch[:-5]](pretrained=args.pretrained, + filter_size=int(args.arch[-1], + _force_nonfinetuned=args.force_nonfinetuned)) else: # baseline model model = models.__dict__[args.arch](pretrained=args.pretrained)