-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
cr
committed
Jan 14, 2019
0 parents
commit 37d6e90
Showing
6 changed files
with
537 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# SlowFastNetworks | ||
PyTorch implementation of ["SlowFast Networks for Video Recognition"](https://arxiv.org/abs/1812.03982). | ||
## Train | ||
1. Dataset should be orgnized as: | ||
dataset(e.g. UCF-101) | ||
├── train/training | ||
│ ├── ApplyEyeMakeup | ||
│ ├── ApplyLipstick | ||
│ ├── ... | ||
└── validation | ||
│ ├── ApplyEyeMakeup | ||
│ ├── ApplyLipstick | ||
│ ├── ... | ||
|
||
2. Modify the params in config.py and `mode` of `train_dataloader` or `val_dataloader` in train.py. | ||
|
||
## Requirements | ||
python 3 | ||
PyTorch >= 0.4.1 | ||
tensorboardX | ||
OpenCV | ||
|
||
## Code Reference: | ||
[1] https://github.com/Guocode/SlowFast-Networks/ | ||
[2] https://github.com/jfzhang95/pytorch-video-recognition | ||
[3] https://github.com/irhumshafkat/R2Plus1D-PyTorch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
params = dict() | ||
|
||
params['num_classes'] = 101 | ||
|
||
params['dataset'] = '/disk/dataset/UCF-101' | ||
|
||
params['epoch_num'] = 40 | ||
params['batch_size'] = 16 | ||
params['step'] = 10 | ||
params['num_workers'] = 4 | ||
params['learning_rate'] = 1e-2 | ||
params['momentum'] = 0.9 | ||
params['weight_decay'] = 1e-5 | ||
params['display'] = 10 | ||
params['pretrained'] = None | ||
params['gpu'] = [0] | ||
params['log'] = 'log' | ||
params['save_path'] = 'UCF101' | ||
params['clip_len'] = 64 | ||
params['frame_sample_rate'] = 1 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
import cv2 | ||
import numpy as np | ||
from torch.utils.data import DataLoader, Dataset | ||
|
||
|
||
class VideoDataset(Dataset): | ||
|
||
def __init__(self, directory, mode='train', clip_len=8, frame_sample_rate=1): | ||
folder = Path(directory)/mode # get the directory of the specified split | ||
print("folder",folder) | ||
self.clip_len = clip_len | ||
|
||
self.resize_height = 128 | ||
self.resize_width = 171 | ||
self.crop_size = 112 | ||
self.frame_sample_rate = frame_sample_rate | ||
self.mode = mode | ||
|
||
|
||
self.fnames, labels = [], [] | ||
for label in sorted(os.listdir(folder)): | ||
for fname in os.listdir(os.path.join(folder, label)): | ||
self.fnames.append(os.path.join(folder, label, fname)) | ||
labels.append(label) | ||
# prepare a mapping between the label names (strings) and indices (ints) | ||
self.label2index = {label:index for index, label in enumerate(sorted(set(labels)))} | ||
# convert the list of label names into an array of label indices | ||
self.label_array = np.array([self.label2index[label] for label in labels], dtype=int) | ||
|
||
with open('labels.txt', 'w') as f: | ||
for id, label in enumerate(sorted(self.label2index)): | ||
f.writelines(str(id + 1) + ' ' + label + '\n') | ||
|
||
def __getitem__(self, index): | ||
# loading and preprocessing. TODO move them to transform classes | ||
buffer = self.loadvideo(self.fnames[index]) | ||
|
||
while buffer.shape[0]<self.clip_len+2 : | ||
index = np.random.randint(self.__len__()) | ||
buffer= self.loadvideo(self.fnames[index]) | ||
|
||
if self.mode == 'train' or self.mode == 'training': | ||
buffer = self.randomflip(buffer) | ||
buffer = self.crop(buffer, self.clip_len, self.crop_size) | ||
buffer = self.normalize(buffer) | ||
buffer = self.to_tensor(buffer) | ||
|
||
return buffer, self.label_array[index] | ||
|
||
def to_tensor(self, buffer): | ||
# convert from [D, H, W, C] format to [C, D, H, W] (what PyTorch uses) | ||
# D = Depth (in this case, time), H = Height, W = Width, C = Channels | ||
return buffer.transpose((3, 0, 1, 2)) | ||
|
||
def loadvideo(self, fname): | ||
# initialize a VideoCapture object to read video data into a numpy array | ||
capture = cv2.VideoCapture(fname) | ||
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) | ||
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) | ||
# create a buffer. Must have dtype float, so it gets converted to a FloatTensor by Pytorch later | ||
frame_count_sample = frame_count // self.frame_sample_rate | ||
buffer = np.empty((frame_count_sample, self.resize_height, self.resize_width, 3), np.dtype('float32')) | ||
|
||
count = 0 | ||
retaining = True | ||
sample_count = 0 | ||
|
||
# read in each frame, one at a time into the numpy buffer array | ||
while (count < frame_count and retaining): | ||
retaining, frame = capture.read() | ||
if retaining is False: | ||
break | ||
if count%self.frame_sample_rate == 0 and sample_count < frame_count_sample: | ||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | ||
# will resize frames if not already final size | ||
|
||
if (frame_height != self.resize_height) or (frame_width != self.resize_width): | ||
frame = cv2.resize(frame, (self.resize_width, self.resize_height)) | ||
buffer[sample_count] = frame | ||
sample_count = sample_count + 1 | ||
count += 1 | ||
capture.release() | ||
return buffer | ||
|
||
def crop(self, buffer, clip_len, crop_size): | ||
# randomly select time index for temporal jittering | ||
time_index = np.random.randint(buffer.shape[0] - clip_len) | ||
# Randomly select start indices in order to crop the video | ||
height_index = np.random.randint(buffer.shape[1] - crop_size) | ||
width_index = np.random.randint(buffer.shape[2] - crop_size) | ||
|
||
# crop and jitter the video using indexing. The spatial crop is performed on | ||
# the entire array, so each frame is cropped in the same location. The temporal | ||
# jitter takes place via the selection of consecutive frames | ||
buffer = buffer[time_index:time_index + clip_len, | ||
height_index:height_index + crop_size, | ||
width_index:width_index + crop_size, :] | ||
|
||
return buffer | ||
|
||
def normalize(self, buffer): | ||
# Normalize the buffer | ||
# buffer = (buffer - 128)/128.0 | ||
for i, frame in enumerate(buffer): | ||
frame = (frame - np.array([[[128.0, 128.0, 128.0]]]))/128.0 | ||
buffer[i] = frame | ||
return buffer | ||
|
||
def randomflip(self, buffer): | ||
"""Horizontally flip the given image and ground truth randomly with a probability of 0.5.""" | ||
if np.random.random() < 0.5: | ||
for i, frame in enumerate(buffer): | ||
frame = cv2.flip(buffer[i], flipCode=1) | ||
buffer[i] = cv2.flip(frame, flipCode=1) | ||
|
||
return buffer | ||
|
||
def __len__(self): | ||
return len(self.fnames) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
#datapath='/home/cr/workspace/disk/data/VideoRecognition/UCF-101' | ||
train_dataloader = \ | ||
DataLoader( VideoDataset(datapathli, mode='train'), batch_size=10, shuffle=True, num_workers=0) | ||
for step, (buffer, label) in enumerate(train_dataloader): | ||
print("label: ", label) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Variable | ||
|
||
__all__ = ['resnet50', 'resnet101','resnet152', 'resnet200'] | ||
|
||
|
||
|
||
class Bottleneck(nn.Module): | ||
expansion = 4 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None, head_conv=1): | ||
super(Bottleneck, self).__init__() | ||
if head_conv == 1: | ||
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) | ||
self.bn1 = nn.BatchNorm3d(planes) | ||
elif head_conv == 3: | ||
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1), bias=False, padding=(1, 0, 0)) | ||
self.bn1 = nn.BatchNorm3d(planes) | ||
else: | ||
raise ValueError("Unsupported head_conv!") | ||
self.conv2 = nn.Conv3d( | ||
planes, planes, kernel_size=(1, 3, 3), stride=(1,stride,stride), padding=(0, 1, 1), bias=False) | ||
self.bn2 = nn.BatchNorm3d(planes) | ||
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) | ||
self.bn3 = nn.BatchNorm3d(planes * 4) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class SlowFast(nn.Module): | ||
def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], class_num=10, dropout=0.5 ): | ||
super(SlowFast, self).__init__() | ||
|
||
self.fast_inplanes = 8 | ||
self.fast_conv1 = nn.Conv3d(3, 8, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False) | ||
self.fast_bn1 = nn.BatchNorm3d(8) | ||
self.fast_relu = nn.ReLU(inplace=True) | ||
self.fast_maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) | ||
self.fast_res2 = self._make_layer_fast(block, 8, layers[0], head_conv=3) | ||
self.fast_res3 = self._make_layer_fast( | ||
block, 16, layers[1], stride=2, head_conv=3) | ||
self.fast_res4 = self._make_layer_fast( | ||
block, 32, layers[2], stride=2, head_conv=3) | ||
self.fast_res5 = self._make_layer_fast( | ||
block, 64, layers[3], stride=2, head_conv=3) | ||
|
||
self.lateral_p1 = nn.Conv3d(8, 8*2, kernel_size=(5, 1, 1), stride=(8, 1 ,1), bias=False, padding=(2, 0, 0)) | ||
self.lateral_res2 = nn.Conv3d(32,32*2, kernel_size=(5, 1, 1), stride=(8, 1 ,1), bias=False, padding=(2, 0, 0)) | ||
self.lateral_res3 = nn.Conv3d(64,64*2, kernel_size=(5, 1, 1), stride=(8, 1 ,1), bias=False, padding=(2, 0, 0)) | ||
self.lateral_res4 = nn.Conv3d(128,128*2, kernel_size=(5, 1, 1), stride=(8, 1 ,1), bias=False, padding=(2, 0, 0)) | ||
|
||
self.slow_inplanes = 64+64//8*2 | ||
self.slow_conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False) | ||
self.slow_bn1 = nn.BatchNorm3d(64) | ||
self.slow_relu = nn.ReLU(inplace=True) | ||
self.slow_maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) | ||
self.slow_res2 = self._make_layer_slow(block, 64, layers[0], head_conv=1) | ||
self.slow_res3 = self._make_layer_slow( | ||
block, 128, layers[1], stride=2, head_conv=1) | ||
self.slow_res4 = self._make_layer_slow( | ||
block, 256, layers[2], stride=2, head_conv=1) | ||
self.slow_res5 = self._make_layer_slow( | ||
block, 512, layers[3], stride=2, head_conv=1) | ||
self.dp = nn.Dropout(dropout) | ||
self.fc = nn.Linear(self.fast_inplanes+2048, class_num, bias=False) | ||
def forward(self, input): | ||
fast, lateral = self.FastPath(input[:, :, ::2, :, :]) | ||
slow = self.SlowPath(input[:, :, ::16, :, :], lateral) | ||
x = torch.cat([slow, fast], dim=1) | ||
x = self.dp(x) | ||
x = self.fc(x) | ||
return x | ||
|
||
|
||
|
||
def SlowPath(self, input, lateral): | ||
x = self.slow_conv1(input) | ||
x = self.slow_bn1(x) | ||
x = self.slow_relu(x) | ||
x = self.slow_maxpool(x) | ||
x = torch.cat([x, lateral[0]],dim=1) | ||
x = self.slow_res2(x) | ||
x = torch.cat([x, lateral[1]],dim=1) | ||
x = self.slow_res3(x) | ||
x = torch.cat([x, lateral[2]],dim=1) | ||
x = self.slow_res4(x) | ||
x = torch.cat([x, lateral[3]],dim=1) | ||
x = self.slow_res5(x) | ||
x = nn.AdaptiveAvgPool3d(1)(x) | ||
x = x.view(-1, x.size(1)) | ||
return x | ||
|
||
def FastPath(self, input): | ||
lateral = [] | ||
x = self.fast_conv1(input) | ||
x = self.fast_bn1(x) | ||
x = self.fast_relu(x) | ||
pool1 = self.fast_maxpool(x) | ||
lateral_p = self.lateral_p1(pool1) | ||
lateral.append(lateral_p) | ||
|
||
res2 = self.fast_res2(pool1) | ||
lateral_res2 = self.lateral_res2(res2) | ||
lateral.append(lateral_res2) | ||
|
||
res3 = self.fast_res3(res2) | ||
lateral_res3 = self.lateral_res3(res3) | ||
lateral.append(lateral_res3) | ||
|
||
res4 = self.fast_res4(res3) | ||
lateral_res4 = self.lateral_res4(res4) | ||
lateral.append(lateral_res4) | ||
|
||
res5 = self.fast_res5(res4) | ||
x = nn.AdaptiveAvgPool3d(1)(res5) | ||
x = x.view(-1, x.size(1)) | ||
|
||
return x, lateral | ||
|
||
def _make_layer_fast(self, block, planes, blocks, stride=1, head_conv=1): | ||
downsample = None | ||
if stride != 1 or self.fast_inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv3d( | ||
self.fast_inplanes, | ||
planes * block.expansion, | ||
kernel_size=1, | ||
stride=(1,stride,stride), | ||
bias=False), nn.BatchNorm3d(planes * block.expansion)) | ||
|
||
layers = [] | ||
layers.append(block(self.fast_inplanes, planes, stride, downsample, head_conv=head_conv)) | ||
self.fast_inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.fast_inplanes, planes, head_conv=head_conv)) | ||
return nn.Sequential(*layers) | ||
|
||
def _make_layer_slow(self, block, planes, blocks, stride=1, head_conv=1): | ||
downsample = None | ||
if stride != 1 or self.slow_inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv3d( | ||
self.slow_inplanes, | ||
planes * block.expansion, | ||
kernel_size=1, | ||
stride=(1,stride,stride), | ||
bias=False), nn.BatchNorm3d(planes * block.expansion)) | ||
|
||
layers = [] | ||
layers.append(block(self.slow_inplanes, planes, stride, downsample, head_conv=head_conv)) | ||
self.slow_inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.slow_inplanes, planes, head_conv=head_conv)) | ||
|
||
self.slow_inplanes = planes * block.expansion + planes * block.expansion//8*2 | ||
return nn.Sequential(*layers) | ||
|
||
|
||
|
||
|
||
def resnet50(**kwargs): | ||
"""Constructs a ResNet-50 model. | ||
""" | ||
model = SlowFast(Bottleneck, [3, 4, 6, 3], **kwargs) | ||
return model | ||
|
||
|
||
def resnet101(**kwargs): | ||
"""Constructs a ResNet-101 model. | ||
""" | ||
model = SlowFast(Bottleneck, [3, 4, 23, 3], **kwargs) | ||
return model | ||
|
||
|
||
def resnet152(**kwargs): | ||
"""Constructs a ResNet-101 model. | ||
""" | ||
model = SlowFast(Bottleneck, [3, 8, 36, 3], **kwargs) | ||
return model | ||
|
||
|
||
def resnet200(**kwargs): | ||
"""Constructs a ResNet-101 model. | ||
""" | ||
model = SlowFast(Bottleneck, [3, 24, 36, 3], **kwargs) | ||
return model | ||
|
||
if __name__ == "__main__": | ||
num_classes = 101 | ||
input_tensor = torch.autograd.Variable(torch.rand(1, 3, 64, 224, 224)) | ||
model = resnet50(class_num=num_classes) | ||
output = model(input_tensor) | ||
print(output.size()) |
Oops, something went wrong.