|
22 | 22 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
23 | 23 | parser.add_argument('data', metavar='DIR',
|
24 | 24 | help='path to dataset')
|
25 |
| -parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', |
26 |
| - choices=model_names, |
27 |
| - help='model architecture: ' + |
28 |
| - ' | '.join(model_names) + |
29 |
| - ' (default: resnet18)') |
30 | 25 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
31 | 26 | help='number of data loading workers (default: 4)')
|
32 | 27 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
|
49 | 44 | help='evaluate model on validation set')
|
50 | 45 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
51 | 46 | help='use pre-trained model')
|
| 47 | +parser.add_argument('--size', type=int, default=(3, 32, 64, 128, 256, 256, 256), nargs='*', |
| 48 | + help='number and size of hidden layers', metavar='S') |
52 | 49 |
|
53 | 50 | best_prec1 = 0
|
54 | 51 |
|
55 | 52 |
|
56 | 53 | def main():
|
57 | 54 | global args, best_prec1
|
58 | 55 | args = parser.parse_args()
|
| 56 | + args.size = tuple(args.size) |
59 | 57 |
|
60 | 58 | # create model
|
61 |
| - if args.pretrained: |
62 |
| - print("=> using pre-trained model '{}'".format(args.arch)) |
63 |
| - model = models.__dict__[args.arch](pretrained=True) |
64 |
| - else: |
65 |
| - print("=> creating model '{}'".format(args.arch)) |
66 |
| - model = models.__dict__[args.arch]() |
67 |
| - |
68 |
| - if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): |
69 |
| - model.features = torch.nn.DataParallel(model.features) |
70 |
| - model.cuda() |
71 |
| - else: |
72 |
| - model = torch.nn.DataParallel(model).cuda() |
| 59 | + from model.Model02 import Model02 as Model |
73 | 60 |
|
74 |
| - # define loss function (criterion) and optimizer |
75 |
| - criterion = nn.CrossEntropyLoss().cuda() |
| 61 | + class Capsule(nn.Module): |
76 | 62 |
|
77 |
| - optimizer = torch.optim.SGD(model.parameters(), args.lr, |
78 |
| - momentum=args.momentum, |
79 |
| - weight_decay=args.weight_decay) |
| 63 | + def __init__(self): |
| 64 | + super().__init__() |
| 65 | + nb_of_classes = 33 # 970 (vid) or 35 (vid obj) or 33 (imgs) |
| 66 | + self.inner_model = Model(args.size + (nb_of_classes,), (256, 256)) |
80 | 67 |
|
81 |
| - # optionally resume from a checkpoint |
82 |
| - if args.resume: |
83 |
| - if os.path.isfile(args.resume): |
84 |
| - print("=> loading checkpoint '{}'".format(args.resume)) |
85 |
| - checkpoint = torch.load(args.resume) |
86 |
| - args.start_epoch = checkpoint['epoch'] |
87 |
| - best_prec1 = checkpoint['best_prec1'] |
88 |
| - model.load_state_dict(checkpoint['state_dict']) |
89 |
| - optimizer.load_state_dict(checkpoint['optimizer']) |
90 |
| - print("=> loaded checkpoint '{}' (epoch {})" |
91 |
| - .format(args.resume, checkpoint['epoch'])) |
92 |
| - else: |
93 |
| - print("=> no checkpoint found at '{}'".format(args.resume)) |
| 68 | + def forward(self, x): |
| 69 | + (_, _), (_, video_index) = self.inner_model(x, None) |
| 70 | + return video_index |
| 71 | + |
| 72 | + model = Capsule() |
| 73 | + |
| 74 | + model = torch.nn.DataParallel(model).cuda() |
94 | 75 |
|
95 | 76 | cudnn.benchmark = True
|
96 | 77 |
|
97 | 78 | # Data loading code
|
98 | 79 | traindir = os.path.join(args.data, 'train')
|
99 | 80 | valdir = os.path.join(args.data, 'val')
|
100 |
| - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
101 |
| - std=[0.229, 0.224, 0.225]) |
| 81 | +# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| 82 | +# std=[0.229, 0.224, 0.225]) |
102 | 83 |
|
103 |
| - train_loader = torch.utils.data.DataLoader( |
104 |
| - datasets.ImageFolder(traindir, transforms.Compose([ |
105 |
| - transforms.RandomSizedCrop(224), |
106 |
| - transforms.RandomHorizontalFlip(), |
| 84 | + train_data = datasets.ImageFolder(traindir, transforms.Compose([ |
| 85 | + transforms.CenterCrop(256), |
107 | 86 | transforms.ToTensor(),
|
108 |
| - normalize, |
109 |
| - ])), |
| 87 | + ])) |
| 88 | + train_loader = torch.utils.data.DataLoader( |
| 89 | + train_data, |
110 | 90 | batch_size=args.batch_size, shuffle=True,
|
111 |
| - num_workers=args.workers, pin_memory=True) |
| 91 | + num_workers=args.workers, pin_memory=True |
| 92 | + ) |
112 | 93 |
|
| 94 | + val_data = datasets.ImageFolder(valdir, transforms.Compose([transforms.CenterCrop(256), transforms.ToTensor(), ])) |
113 | 95 | val_loader = torch.utils.data.DataLoader(
|
114 |
| - datasets.ImageFolder(valdir, transforms.Compose([ |
115 |
| - transforms.Scale(256), |
116 |
| - transforms.CenterCrop(224), |
117 |
| - transforms.ToTensor(), |
118 |
| - normalize, |
119 |
| - ])), |
| 96 | + val_data, |
120 | 97 | batch_size=args.batch_size, shuffle=False,
|
121 |
| - num_workers=args.workers, pin_memory=True) |
| 98 | + num_workers=args.workers, pin_memory=True |
| 99 | + ) |
| 100 | + |
| 101 | + # define loss function (criterion) and optimizer |
| 102 | + class_count = [0] * len(train_data.classes) |
| 103 | + for i in train_data.imgs: class_count[i[1]] += 1 |
| 104 | + train_crit_weight = torch.Tensor(class_count) |
| 105 | + train_crit_weight.div_(train_crit_weight.mean()).pow_(-1) |
| 106 | + train_criterion = nn.CrossEntropyLoss(train_crit_weight).cuda() |
| 107 | + |
| 108 | + class_count = [0] * len(val_data.classes) |
| 109 | + for i in val_data.imgs: class_count[i[1]] += 1 |
| 110 | + val_crit_weight = torch.Tensor(class_count) |
| 111 | + val_crit_weight.div_(val_crit_weight.mean()).pow_(-1) |
| 112 | + val_criterion = nn.CrossEntropyLoss(val_crit_weight).cuda() |
| 113 | + |
| 114 | + optimizer = torch.optim.SGD(model.parameters(), args.lr, |
| 115 | + momentum=args.momentum, |
| 116 | + weight_decay=args.weight_decay) |
122 | 117 |
|
123 | 118 | if args.evaluate:
|
124 |
| - validate(val_loader, model, criterion) |
| 119 | + validate(val_loader, model, val_criterion) |
125 | 120 | return
|
126 | 121 |
|
127 | 122 | for epoch in range(args.start_epoch, args.epochs):
|
128 | 123 | adjust_learning_rate(optimizer, epoch)
|
129 | 124 |
|
130 | 125 | # train for one epoch
|
131 |
| - train(train_loader, model, criterion, optimizer, epoch) |
| 126 | + train(train_loader, model, train_criterion, optimizer, epoch) |
132 | 127 |
|
133 | 128 | # evaluate on validation set
|
134 |
| - prec1 = validate(val_loader, model, criterion) |
| 129 | + prec1 = validate(val_loader, model, val_criterion) |
135 | 130 |
|
136 | 131 | # remember best prec@1 and save checkpoint
|
137 | 132 | is_best = prec1 > best_prec1
|
138 | 133 | best_prec1 = max(prec1, best_prec1)
|
139 | 134 | save_checkpoint({
|
140 | 135 | 'epoch': epoch + 1,
|
141 |
| - 'arch': args.arch, |
142 | 136 | 'state_dict': model.state_dict(),
|
143 | 137 | 'best_prec1': best_prec1,
|
144 |
| - 'optimizer' : optimizer.state_dict(), |
145 | 138 | }, is_best)
|
146 | 139 |
|
147 | 140 |
|
|
0 commit comments