Skip to content

Commit c08905e

Browse files
committed
Add image pre-training from PyTorch's examples
1 parent 2cf9be8 commit c08905e

File tree

2 files changed

+341
-0
lines changed

2 files changed

+341
-0
lines changed

image-pretraining/README.md

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Image pre-training
2+
3+
Find the original code at [PyTorch's example](https://github.com/pytorch/examples/tree/master/imagenet).
4+
This adaptation trains the discriminative branch of CortexNet for TempoNet.
5+
6+
## Training
7+
8+
To train a model, run `main.py` with the desired model architecture and the path to the ImageNet dataset:
9+
10+
```bash
11+
python main.py -a resnet18 [imagenet-folder with train and val folders]
12+
```
13+
14+
The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. This is appropriate for ResNet and models with batch normalization, but too high for AlexNet and VGG. Use 0.01 as the initial learning rate for AlexNet or VGG:
15+
16+
```bash
17+
python main.py -a alexnet --lr 0.01 [imagenet-folder with train and val folders]
18+
```
19+
20+
## Usage
21+
22+
```
23+
usage: main.py [-h] [--arch ARCH] [-j N] [--epochs N] [--start-epoch N] [-b N]
24+
[--lr LR] [--momentum M] [--weight-decay W] [--print-freq N]
25+
[--resume PATH] [-e] [--pretrained]
26+
DIR
27+
28+
PyTorch ImageNet Training
29+
30+
positional arguments:
31+
DIR path to dataset
32+
33+
optional arguments:
34+
-h, --help show this help message and exit
35+
--arch ARCH, -a ARCH model architecture: alexnet | resnet | resnet101 |
36+
resnet152 | resnet18 | resnet34 | resnet50 | vgg |
37+
vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn
38+
| vgg19 | vgg19_bn (default: resnet18)
39+
-j N, --workers N number of data loading workers (default: 4)
40+
--epochs N number of total epochs to run
41+
--start-epoch N manual epoch number (useful on restarts)
42+
-b N, --batch-size N mini-batch size (default: 256)
43+
--lr LR, --learning-rate LR
44+
initial learning rate
45+
--momentum M momentum
46+
--weight-decay W, --wd W
47+
weight decay (default: 1e-4)
48+
--print-freq N, -p N print frequency (default: 10)
49+
--resume PATH path to latest checkpoint (default: none)
50+
-e, --evaluate evaluate model on validation set
51+
--pretrained use pre-trained model
52+
```

image-pretraining/main.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
import argparse
2+
import os
3+
import shutil
4+
import time
5+
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.parallel
9+
import torch.backends.cudnn as cudnn
10+
import torch.optim
11+
import torch.utils.data
12+
import torchvision.transforms as transforms
13+
import torchvision.datasets as datasets
14+
import torchvision.models as models
15+
16+
17+
model_names = sorted(name for name in models.__dict__
18+
if name.islower() and not name.startswith("__")
19+
and callable(models.__dict__[name]))
20+
21+
22+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
23+
parser.add_argument('data', metavar='DIR',
24+
help='path to dataset')
25+
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
26+
choices=model_names,
27+
help='model architecture: ' +
28+
' | '.join(model_names) +
29+
' (default: resnet18)')
30+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
31+
help='number of data loading workers (default: 4)')
32+
parser.add_argument('--epochs', default=90, type=int, metavar='N',
33+
help='number of total epochs to run')
34+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
35+
help='manual epoch number (useful on restarts)')
36+
parser.add_argument('-b', '--batch-size', default=256, type=int,
37+
metavar='N', help='mini-batch size (default: 256)')
38+
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
39+
metavar='LR', help='initial learning rate')
40+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
41+
help='momentum')
42+
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
43+
metavar='W', help='weight decay (default: 1e-4)')
44+
parser.add_argument('--print-freq', '-p', default=10, type=int,
45+
metavar='N', help='print frequency (default: 10)')
46+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
47+
help='path to latest checkpoint (default: none)')
48+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
49+
help='evaluate model on validation set')
50+
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
51+
help='use pre-trained model')
52+
53+
best_prec1 = 0
54+
55+
56+
def main():
57+
global args, best_prec1
58+
args = parser.parse_args()
59+
60+
# create model
61+
if args.pretrained:
62+
print("=> using pre-trained model '{}'".format(args.arch))
63+
model = models.__dict__[args.arch](pretrained=True)
64+
else:
65+
print("=> creating model '{}'".format(args.arch))
66+
model = models.__dict__[args.arch]()
67+
68+
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
69+
model.features = torch.nn.DataParallel(model.features)
70+
model.cuda()
71+
else:
72+
model = torch.nn.DataParallel(model).cuda()
73+
74+
# define loss function (criterion) and optimizer
75+
criterion = nn.CrossEntropyLoss().cuda()
76+
77+
optimizer = torch.optim.SGD(model.parameters(), args.lr,
78+
momentum=args.momentum,
79+
weight_decay=args.weight_decay)
80+
81+
# optionally resume from a checkpoint
82+
if args.resume:
83+
if os.path.isfile(args.resume):
84+
print("=> loading checkpoint '{}'".format(args.resume))
85+
checkpoint = torch.load(args.resume)
86+
args.start_epoch = checkpoint['epoch']
87+
best_prec1 = checkpoint['best_prec1']
88+
model.load_state_dict(checkpoint['state_dict'])
89+
optimizer.load_state_dict(checkpoint['optimizer'])
90+
print("=> loaded checkpoint '{}' (epoch {})"
91+
.format(args.resume, checkpoint['epoch']))
92+
else:
93+
print("=> no checkpoint found at '{}'".format(args.resume))
94+
95+
cudnn.benchmark = True
96+
97+
# Data loading code
98+
traindir = os.path.join(args.data, 'train')
99+
valdir = os.path.join(args.data, 'val')
100+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
101+
std=[0.229, 0.224, 0.225])
102+
103+
train_loader = torch.utils.data.DataLoader(
104+
datasets.ImageFolder(traindir, transforms.Compose([
105+
transforms.RandomSizedCrop(224),
106+
transforms.RandomHorizontalFlip(),
107+
transforms.ToTensor(),
108+
normalize,
109+
])),
110+
batch_size=args.batch_size, shuffle=True,
111+
num_workers=args.workers, pin_memory=True)
112+
113+
val_loader = torch.utils.data.DataLoader(
114+
datasets.ImageFolder(valdir, transforms.Compose([
115+
transforms.Scale(256),
116+
transforms.CenterCrop(224),
117+
transforms.ToTensor(),
118+
normalize,
119+
])),
120+
batch_size=args.batch_size, shuffle=False,
121+
num_workers=args.workers, pin_memory=True)
122+
123+
if args.evaluate:
124+
validate(val_loader, model, criterion)
125+
return
126+
127+
for epoch in range(args.start_epoch, args.epochs):
128+
adjust_learning_rate(optimizer, epoch)
129+
130+
# train for one epoch
131+
train(train_loader, model, criterion, optimizer, epoch)
132+
133+
# evaluate on validation set
134+
prec1 = validate(val_loader, model, criterion)
135+
136+
# remember best prec@1 and save checkpoint
137+
is_best = prec1 > best_prec1
138+
best_prec1 = max(prec1, best_prec1)
139+
save_checkpoint({
140+
'epoch': epoch + 1,
141+
'arch': args.arch,
142+
'state_dict': model.state_dict(),
143+
'best_prec1': best_prec1,
144+
'optimizer' : optimizer.state_dict(),
145+
}, is_best)
146+
147+
148+
def train(train_loader, model, criterion, optimizer, epoch):
149+
batch_time = AverageMeter()
150+
data_time = AverageMeter()
151+
losses = AverageMeter()
152+
top1 = AverageMeter()
153+
top5 = AverageMeter()
154+
155+
# switch to train mode
156+
model.train()
157+
158+
end = time.time()
159+
for i, (input, target) in enumerate(train_loader):
160+
# measure data loading time
161+
data_time.update(time.time() - end)
162+
163+
target = target.cuda(async=True)
164+
input_var = torch.autograd.Variable(input)
165+
target_var = torch.autograd.Variable(target)
166+
167+
# compute output
168+
output = model(input_var)
169+
loss = criterion(output, target_var)
170+
171+
# measure accuracy and record loss
172+
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
173+
losses.update(loss.data[0], input.size(0))
174+
top1.update(prec1[0], input.size(0))
175+
top5.update(prec5[0], input.size(0))
176+
177+
# compute gradient and do SGD step
178+
optimizer.zero_grad()
179+
loss.backward()
180+
optimizer.step()
181+
182+
# measure elapsed time
183+
batch_time.update(time.time() - end)
184+
end = time.time()
185+
186+
if i % args.print_freq == 0:
187+
print('Epoch: [{0}][{1}/{2}]\t'
188+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
189+
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
190+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
191+
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
192+
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
193+
epoch, i, len(train_loader), batch_time=batch_time,
194+
data_time=data_time, loss=losses, top1=top1, top5=top5))
195+
196+
197+
def validate(val_loader, model, criterion):
198+
batch_time = AverageMeter()
199+
losses = AverageMeter()
200+
top1 = AverageMeter()
201+
top5 = AverageMeter()
202+
203+
# switch to evaluate mode
204+
model.eval()
205+
206+
end = time.time()
207+
for i, (input, target) in enumerate(val_loader):
208+
target = target.cuda(async=True)
209+
input_var = torch.autograd.Variable(input, volatile=True)
210+
target_var = torch.autograd.Variable(target, volatile=True)
211+
212+
# compute output
213+
output = model(input_var)
214+
loss = criterion(output, target_var)
215+
216+
# measure accuracy and record loss
217+
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
218+
losses.update(loss.data[0], input.size(0))
219+
top1.update(prec1[0], input.size(0))
220+
top5.update(prec5[0], input.size(0))
221+
222+
# measure elapsed time
223+
batch_time.update(time.time() - end)
224+
end = time.time()
225+
226+
if i % args.print_freq == 0:
227+
print('Test: [{0}/{1}]\t'
228+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
229+
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
230+
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
231+
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
232+
i, len(val_loader), batch_time=batch_time, loss=losses,
233+
top1=top1, top5=top5))
234+
235+
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
236+
.format(top1=top1, top5=top5))
237+
238+
return top1.avg
239+
240+
241+
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
242+
torch.save(state, filename)
243+
if is_best:
244+
shutil.copyfile(filename, 'model_best.pth.tar')
245+
246+
247+
class AverageMeter(object):
248+
"""Computes and stores the average and current value"""
249+
def __init__(self):
250+
self.reset()
251+
252+
def reset(self):
253+
self.val = 0
254+
self.avg = 0
255+
self.sum = 0
256+
self.count = 0
257+
258+
def update(self, val, n=1):
259+
self.val = val
260+
self.sum += val * n
261+
self.count += n
262+
self.avg = self.sum / self.count
263+
264+
265+
def adjust_learning_rate(optimizer, epoch):
266+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
267+
lr = args.lr * (0.1 ** (epoch // 30))
268+
for param_group in optimizer.param_groups:
269+
param_group['lr'] = lr
270+
271+
272+
def accuracy(output, target, topk=(1,)):
273+
"""Computes the precision@k for the specified values of k"""
274+
maxk = max(topk)
275+
batch_size = target.size(0)
276+
277+
_, pred = output.topk(maxk, 1, True, True)
278+
pred = pred.t()
279+
correct = pred.eq(target.view(1, -1).expand_as(pred))
280+
281+
res = []
282+
for k in topk:
283+
correct_k = correct[:k].view(-1).float().sum(0)
284+
res.append(correct_k.mul_(100.0 / batch_size))
285+
return res
286+
287+
288+
if __name__ == '__main__':
289+
main()

0 commit comments

Comments
 (0)