Skip to content

Commit aad1571

Browse files
committed
Add image pre-training code and docs
1 parent c08905e commit aad1571

File tree

4 files changed

+66
-78
lines changed

4 files changed

+66
-78
lines changed

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ Check the [project website](https://engineering.purdue.edu/elab/CortexNet/) for
88
The project consists of the following folders and files:
99

1010
- [`data/`](data): contains *Bash* scripts and a *Python* class definition inherent video data loading;
11+
- [`image-pretraining/`](image-pretraining/): hosts the code for pre-training TempoNet's discriminative branch;
1112
- [`model/`](model): stores several network architectures, including [*PredNet*](https://coxlab.github.io/prednet/), an additive feedback *Model01*, and a modulatory feedback *Model02* ([*CortexNet*](https://engineering.purdue.edu/elab/CortexNet/));
12-
- [`notebook/`](notebook): collection of *Jupyter Notebook*s for data exploration and results visualisation (best view with [this](https://userstyles.org/styles/98208/jupyter-notebook-dark-originally-from-ipython) and [this](https://userstyles.org/styles/37035/github-dark) dark styles);
13-
- [`utils/`](utils): scripts for
13+
- [`notebook/`](notebook): collection of *Jupyter Notebook*s for data exploration and results visualisation (best view with [this](https://userstyles.org/styles/98208/jupyter-notebook-dark-originally-from-ipython) and [this](https://userstyles.org/styles/37035/github-dark) dark styles);
14+
- [`utils/`](utils): scripts for
1415
- (current or former) training error plotting,
1516
- experiments `diff`,
1617
- multi-node synchronisation,
1718
- generative predictions visualisation,
18-
- network architecture graphing;
19+
- network architecture graphing;
1920
- `results@`: link to the location where experimental results will be saved within 3-digit folders;
2021
- [`new_experiment.sh*`](new_experiment.sh): creates a new experiment folder, updates `last@`, prints a memo about last used settings;
2122
- `last@`: symbolic link pointing to a new results sub-directory created by `new_experiment.sh`;
@@ -68,7 +69,7 @@ Therefore, type `CUDA_VISIBLE_DEVICES=n` just before `python ...` in the followi
6869
+ Use [`data/resize_and_split.sh`](data/resize_and_split.sh) to prepare your (video) data for training.
6970
It resizes videos present in folders of folders (*i.e.* directory of classes) and may split them into training and validation set.
7071
May also skip short videos and trim longer ones.
71-
Check [`data/README.md`](data/README.md) for more details.
72+
Check [`data/README.md`](data/README.md#matchnet-mode) for more details.
7273
+ Run the [`main.py`](main.py) script to start training.
7374
Use `-h` to print the command line interface (CLI) arguments help.
7475

@@ -79,16 +80,17 @@ python -u main.py --mode MatchNet <CLI arguments> | tee last/train.log
7980
## Train *TempoNet*
8081

8182
+ Download *e-VDS35* (*e.g.* `e-VDS35-May17.tar`) from [here](https://engineering.purdue.edu/elab/eVDS/).
83+
+ Pre-train the forward branch (see [`image-pretraining/`](image-pretraining)) on an image data set (*e.g.* `33-image-set.tar` from [here](https://engineering.purdue.edu/elab/eVDS/));
8284
+ Use [`data/resize_and_sample.sh`](data/resize_and_sample.sh) to prepare your (video) data for training.
8385
It resizes videos present in folders of folders (*i.e.* directory of classes) and samples them.
8486
Videos are then distributed across training and validation set.
8587
May also skip short videos and trim longer ones.
86-
Check [`data/README.md`](data/README.md) for more details.
88+
Check [`data/README.md`](data/README.md#temponet-mode) for more details.
8789
+ Run the [`main.py`](main.py) script to start training.
8890
Use `-h` to print the CLI arguments help.
8991

9092
```bash
91-
python -u main.py --mode MatchNet <CLI arguments> | tee last/train.log
93+
python -u main.py --mode TempoNet --pre-trained <path> <CLI args> | tee last/train.log
9294
```
9395

9496
## GPU selection

image-pretraining/README.md

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
11
# Image pre-training
22

3-
Find the original code at [PyTorch's example](https://github.com/pytorch/examples/tree/master/imagenet).
3+
Find the original code at [PyTorch ImageNet example](https://github.com/pytorch/examples/tree/master/imagenet).
44
This adaptation trains the discriminative branch of CortexNet for TempoNet.
55

66
## Training
77

8-
To train a model, run `main.py` with the desired model architecture and the path to the ImageNet dataset:
8+
To train the discriminative branch of CortexNet, run `main.py` with the path to an image data set:
99

1010
```bash
11-
python main.py -a resnet18 [imagenet-folder with train and val folders]
11+
python main.py <image data path> | tee train.log
1212
```
1313

14-
The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. This is appropriate for ResNet and models with batch normalization, but too high for AlexNet and VGG. Use 0.01 as the initial learning rate for AlexNet or VGG:
15-
16-
```bash
17-
python main.py -a alexnet --lr 0.01 [imagenet-folder with train and val folders]
18-
```
14+
The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs.
1915

2016
## Usage
2117

2218
```
23-
usage: main.py [-h] [--arch ARCH] [-j N] [--epochs N] [--start-epoch N] [-b N]
24-
[--lr LR] [--momentum M] [--weight-decay W] [--print-freq N]
25-
[--resume PATH] [-e] [--pretrained]
19+
usage: main.py [-h] [-j N] [--epochs N] [--start-epoch N] [-b N] [--lr LR]
20+
[--momentum M] [--weight-decay W] [--print-freq N]
21+
[--resume PATH] [-e] [--pretrained] [--size [S [S ...]]]
2622
DIR
2723
2824
PyTorch ImageNet Training
@@ -32,10 +28,6 @@ positional arguments:
3228
3329
optional arguments:
3430
-h, --help show this help message and exit
35-
--arch ARCH, -a ARCH model architecture: alexnet | resnet | resnet101 |
36-
resnet152 | resnet18 | resnet34 | resnet50 | vgg |
37-
vgg11 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn
38-
| vgg19 | vgg19_bn (default: resnet18)
3931
-j N, --workers N number of data loading workers (default: 4)
4032
--epochs N number of total epochs to run
4133
--start-epoch N manual epoch number (useful on restarts)
@@ -49,4 +41,5 @@ optional arguments:
4941
--resume PATH path to latest checkpoint (default: none)
5042
-e, --evaluate evaluate model on validation set
5143
--pretrained use pre-trained model
44+
--size [S [S ...]] number and size of hidden layers
5245
```

image-pretraining/main.py

Lines changed: 49 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,6 @@
2222
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
2323
parser.add_argument('data', metavar='DIR',
2424
help='path to dataset')
25-
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
26-
choices=model_names,
27-
help='model architecture: ' +
28-
' | '.join(model_names) +
29-
' (default: resnet18)')
3025
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
3126
help='number of data loading workers (default: 4)')
3227
parser.add_argument('--epochs', default=90, type=int, metavar='N',
@@ -49,99 +44,97 @@
4944
help='evaluate model on validation set')
5045
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
5146
help='use pre-trained model')
47+
parser.add_argument('--size', type=int, default=(3, 32, 64, 128, 256, 256, 256), nargs='*',
48+
help='number and size of hidden layers', metavar='S')
5249

5350
best_prec1 = 0
5451

5552

5653
def main():
5754
global args, best_prec1
5855
args = parser.parse_args()
56+
args.size = tuple(args.size)
5957

6058
# create model
61-
if args.pretrained:
62-
print("=> using pre-trained model '{}'".format(args.arch))
63-
model = models.__dict__[args.arch](pretrained=True)
64-
else:
65-
print("=> creating model '{}'".format(args.arch))
66-
model = models.__dict__[args.arch]()
67-
68-
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
69-
model.features = torch.nn.DataParallel(model.features)
70-
model.cuda()
71-
else:
72-
model = torch.nn.DataParallel(model).cuda()
59+
from model.Model02 import Model02 as Model
7360

74-
# define loss function (criterion) and optimizer
75-
criterion = nn.CrossEntropyLoss().cuda()
61+
class Capsule(nn.Module):
7662

77-
optimizer = torch.optim.SGD(model.parameters(), args.lr,
78-
momentum=args.momentum,
79-
weight_decay=args.weight_decay)
63+
def __init__(self):
64+
super().__init__()
65+
nb_of_classes = 33 # 970 (vid) or 35 (vid obj) or 33 (imgs)
66+
self.inner_model = Model(args.size + (nb_of_classes,), (256, 256))
8067

81-
# optionally resume from a checkpoint
82-
if args.resume:
83-
if os.path.isfile(args.resume):
84-
print("=> loading checkpoint '{}'".format(args.resume))
85-
checkpoint = torch.load(args.resume)
86-
args.start_epoch = checkpoint['epoch']
87-
best_prec1 = checkpoint['best_prec1']
88-
model.load_state_dict(checkpoint['state_dict'])
89-
optimizer.load_state_dict(checkpoint['optimizer'])
90-
print("=> loaded checkpoint '{}' (epoch {})"
91-
.format(args.resume, checkpoint['epoch']))
92-
else:
93-
print("=> no checkpoint found at '{}'".format(args.resume))
68+
def forward(self, x):
69+
(_, _), (_, video_index) = self.inner_model(x, None)
70+
return video_index
71+
72+
model = Capsule()
73+
74+
model = torch.nn.DataParallel(model).cuda()
9475

9576
cudnn.benchmark = True
9677

9778
# Data loading code
9879
traindir = os.path.join(args.data, 'train')
9980
valdir = os.path.join(args.data, 'val')
100-
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
101-
std=[0.229, 0.224, 0.225])
81+
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
82+
# std=[0.229, 0.224, 0.225])
10283

103-
train_loader = torch.utils.data.DataLoader(
104-
datasets.ImageFolder(traindir, transforms.Compose([
105-
transforms.RandomSizedCrop(224),
106-
transforms.RandomHorizontalFlip(),
84+
train_data = datasets.ImageFolder(traindir, transforms.Compose([
85+
transforms.CenterCrop(256),
10786
transforms.ToTensor(),
108-
normalize,
109-
])),
87+
]))
88+
train_loader = torch.utils.data.DataLoader(
89+
train_data,
11090
batch_size=args.batch_size, shuffle=True,
111-
num_workers=args.workers, pin_memory=True)
91+
num_workers=args.workers, pin_memory=True
92+
)
11293

94+
val_data = datasets.ImageFolder(valdir, transforms.Compose([transforms.CenterCrop(256), transforms.ToTensor(), ]))
11395
val_loader = torch.utils.data.DataLoader(
114-
datasets.ImageFolder(valdir, transforms.Compose([
115-
transforms.Scale(256),
116-
transforms.CenterCrop(224),
117-
transforms.ToTensor(),
118-
normalize,
119-
])),
96+
val_data,
12097
batch_size=args.batch_size, shuffle=False,
121-
num_workers=args.workers, pin_memory=True)
98+
num_workers=args.workers, pin_memory=True
99+
)
100+
101+
# define loss function (criterion) and optimizer
102+
class_count = [0] * len(train_data.classes)
103+
for i in train_data.imgs: class_count[i[1]] += 1
104+
train_crit_weight = torch.Tensor(class_count)
105+
train_crit_weight.div_(train_crit_weight.mean()).pow_(-1)
106+
train_criterion = nn.CrossEntropyLoss(train_crit_weight).cuda()
107+
108+
class_count = [0] * len(val_data.classes)
109+
for i in val_data.imgs: class_count[i[1]] += 1
110+
val_crit_weight = torch.Tensor(class_count)
111+
val_crit_weight.div_(val_crit_weight.mean()).pow_(-1)
112+
val_criterion = nn.CrossEntropyLoss(val_crit_weight).cuda()
113+
114+
optimizer = torch.optim.SGD(model.parameters(), args.lr,
115+
momentum=args.momentum,
116+
weight_decay=args.weight_decay)
122117

123118
if args.evaluate:
124-
validate(val_loader, model, criterion)
119+
validate(val_loader, model, val_criterion)
125120
return
126121

127122
for epoch in range(args.start_epoch, args.epochs):
128123
adjust_learning_rate(optimizer, epoch)
129124

130125
# train for one epoch
131-
train(train_loader, model, criterion, optimizer, epoch)
126+
train(train_loader, model, train_criterion, optimizer, epoch)
132127

133128
# evaluate on validation set
134-
prec1 = validate(val_loader, model, criterion)
129+
prec1 = validate(val_loader, model, val_criterion)
135130

136131
# remember best prec@1 and save checkpoint
137132
is_best = prec1 > best_prec1
138133
best_prec1 = max(prec1, best_prec1)
139134
save_checkpoint({
140135
'epoch': epoch + 1,
141-
'arch': args.arch,
142136
'state_dict': model.state_dict(),
143137
'best_prec1': best_prec1,
144-
'optimizer' : optimizer.state_dict(),
145138
}, is_best)
146139

147140

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def main():
135135

136136
if args.pre_trained:
137137
print('Load pre-trained weights')
138-
# args.pre_trained = 'model/model02D-33IS/model_best.pth.tar'
138+
# args.pre_trained = 'image-pretraining/model02D-33IS/model_best.pth.tar'
139139
dict_33 = torch.load(args.pre_trained)['state_dict']
140140

141141
def load_state_dict(new_model, state_dict):

0 commit comments

Comments
 (0)