Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
cr committed Jan 14, 2019
0 parents commit 37d6e90
Show file tree
Hide file tree
Showing 6 changed files with 537 additions and 0 deletions.
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SlowFastNetworks
PyTorch implementation of ["SlowFast Networks for Video Recognition"](https://arxiv.org/abs/1812.03982).
## Train
1. Dataset should be orgnized as:
dataset(e.g. UCF-101)
├── train/training
│   ├── ApplyEyeMakeup
│   ├── ApplyLipstick
│   ├── ...
└── validation
│   ├── ApplyEyeMakeup
│   ├── ApplyLipstick
│   ├── ...

2. Modify the params in config.py and `mode` of `train_dataloader` or `val_dataloader` in train.py.

## Requirements
python 3
PyTorch >= 0.4.1
tensorboardX
OpenCV

## Code Reference:
[1] https://github.com/Guocode/SlowFast-Networks/
[2] https://github.com/jfzhang95/pytorch-video-recognition
[3] https://github.com/irhumshafkat/R2Plus1D-PyTorch
20 changes: 20 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
params = dict()

params['num_classes'] = 101

params['dataset'] = '/disk/dataset/UCF-101'

params['epoch_num'] = 40
params['batch_size'] = 16
params['step'] = 10
params['num_workers'] = 4
params['learning_rate'] = 1e-2
params['momentum'] = 0.9
params['weight_decay'] = 1e-5
params['display'] = 10
params['pretrained'] = None
params['gpu'] = [0]
params['log'] = 'log'
params['save_path'] = 'UCF101'
params['clip_len'] = 64
params['frame_sample_rate'] = 1
Empty file added lib/__init__.py
Empty file.
132 changes: 132 additions & 0 deletions lib/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
from pathlib import Path

import cv2
import numpy as np
from torch.utils.data import DataLoader, Dataset


class VideoDataset(Dataset):

def __init__(self, directory, mode='train', clip_len=8, frame_sample_rate=1):
folder = Path(directory)/mode # get the directory of the specified split
print("folder",folder)
self.clip_len = clip_len

self.resize_height = 128
self.resize_width = 171
self.crop_size = 112
self.frame_sample_rate = frame_sample_rate
self.mode = mode


self.fnames, labels = [], []
for label in sorted(os.listdir(folder)):
for fname in os.listdir(os.path.join(folder, label)):
self.fnames.append(os.path.join(folder, label, fname))
labels.append(label)
# prepare a mapping between the label names (strings) and indices (ints)
self.label2index = {label:index for index, label in enumerate(sorted(set(labels)))}
# convert the list of label names into an array of label indices
self.label_array = np.array([self.label2index[label] for label in labels], dtype=int)

with open('labels.txt', 'w') as f:
for id, label in enumerate(sorted(self.label2index)):
f.writelines(str(id + 1) + ' ' + label + '\n')

def __getitem__(self, index):
# loading and preprocessing. TODO move them to transform classes
buffer = self.loadvideo(self.fnames[index])

while buffer.shape[0]<self.clip_len+2 :
index = np.random.randint(self.__len__())
buffer= self.loadvideo(self.fnames[index])

if self.mode == 'train' or self.mode == 'training':
buffer = self.randomflip(buffer)
buffer = self.crop(buffer, self.clip_len, self.crop_size)
buffer = self.normalize(buffer)
buffer = self.to_tensor(buffer)

return buffer, self.label_array[index]

def to_tensor(self, buffer):
# convert from [D, H, W, C] format to [C, D, H, W] (what PyTorch uses)
# D = Depth (in this case, time), H = Height, W = Width, C = Channels
return buffer.transpose((3, 0, 1, 2))

def loadvideo(self, fname):
# initialize a VideoCapture object to read video data into a numpy array
capture = cv2.VideoCapture(fname)
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# create a buffer. Must have dtype float, so it gets converted to a FloatTensor by Pytorch later
frame_count_sample = frame_count // self.frame_sample_rate
buffer = np.empty((frame_count_sample, self.resize_height, self.resize_width, 3), np.dtype('float32'))

count = 0
retaining = True
sample_count = 0

# read in each frame, one at a time into the numpy buffer array
while (count < frame_count and retaining):
retaining, frame = capture.read()
if retaining is False:
break
if count%self.frame_sample_rate == 0 and sample_count < frame_count_sample:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# will resize frames if not already final size

if (frame_height != self.resize_height) or (frame_width != self.resize_width):
frame = cv2.resize(frame, (self.resize_width, self.resize_height))
buffer[sample_count] = frame
sample_count = sample_count + 1
count += 1
capture.release()
return buffer

def crop(self, buffer, clip_len, crop_size):
# randomly select time index for temporal jittering
time_index = np.random.randint(buffer.shape[0] - clip_len)
# Randomly select start indices in order to crop the video
height_index = np.random.randint(buffer.shape[1] - crop_size)
width_index = np.random.randint(buffer.shape[2] - crop_size)

# crop and jitter the video using indexing. The spatial crop is performed on
# the entire array, so each frame is cropped in the same location. The temporal
# jitter takes place via the selection of consecutive frames
buffer = buffer[time_index:time_index + clip_len,
height_index:height_index + crop_size,
width_index:width_index + crop_size, :]

return buffer

def normalize(self, buffer):
# Normalize the buffer
# buffer = (buffer - 128)/128.0
for i, frame in enumerate(buffer):
frame = (frame - np.array([[[128.0, 128.0, 128.0]]]))/128.0
buffer[i] = frame
return buffer

def randomflip(self, buffer):
"""Horizontally flip the given image and ground truth randomly with a probability of 0.5."""
if np.random.random() < 0.5:
for i, frame in enumerate(buffer):
frame = cv2.flip(buffer[i], flipCode=1)
buffer[i] = cv2.flip(frame, flipCode=1)

return buffer

def __len__(self):
return len(self.fnames)


if __name__ == '__main__':

#datapath='/home/cr/workspace/disk/data/VideoRecognition/UCF-101'
train_dataloader = \
DataLoader( VideoDataset(datapathli, mode='train'), batch_size=10, shuffle=True, num_workers=0)
for step, (buffer, label) in enumerate(train_dataloader):
print("label: ", label)
217 changes: 217 additions & 0 deletions lib/slowfastnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

__all__ = ['resnet50', 'resnet101','resnet152', 'resnet200']



class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, head_conv=1):
super(Bottleneck, self).__init__()
if head_conv == 1:
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
elif head_conv == 3:
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3, 1, 1), bias=False, padding=(1, 0, 0))
self.bn1 = nn.BatchNorm3d(planes)
else:
raise ValueError("Unsupported head_conv!")
self.conv2 = nn.Conv3d(
planes, planes, kernel_size=(1, 3, 3), stride=(1,stride,stride), padding=(0, 1, 1), bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm3d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)

return out


class SlowFast(nn.Module):
def __init__(self, block=Bottleneck, layers=[3, 4, 6, 3], class_num=10, dropout=0.5 ):
super(SlowFast, self).__init__()

self.fast_inplanes = 8
self.fast_conv1 = nn.Conv3d(3, 8, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False)
self.fast_bn1 = nn.BatchNorm3d(8)
self.fast_relu = nn.ReLU(inplace=True)
self.fast_maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
self.fast_res2 = self._make_layer_fast(block, 8, layers[0], head_conv=3)
self.fast_res3 = self._make_layer_fast(
block, 16, layers[1], stride=2, head_conv=3)
self.fast_res4 = self._make_layer_fast(
block, 32, layers[2], stride=2, head_conv=3)
self.fast_res5 = self._make_layer_fast(
block, 64, layers[3], stride=2, head_conv=3)

self.lateral_p1 = nn.Conv3d(8, 8*2, kernel_size=(5, 1, 1), stride=(8, 1 ,1), bias=False, padding=(2, 0, 0))
self.lateral_res2 = nn.Conv3d(32,32*2, kernel_size=(5, 1, 1), stride=(8, 1 ,1), bias=False, padding=(2, 0, 0))
self.lateral_res3 = nn.Conv3d(64,64*2, kernel_size=(5, 1, 1), stride=(8, 1 ,1), bias=False, padding=(2, 0, 0))
self.lateral_res4 = nn.Conv3d(128,128*2, kernel_size=(5, 1, 1), stride=(8, 1 ,1), bias=False, padding=(2, 0, 0))

self.slow_inplanes = 64+64//8*2
self.slow_conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
self.slow_bn1 = nn.BatchNorm3d(64)
self.slow_relu = nn.ReLU(inplace=True)
self.slow_maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
self.slow_res2 = self._make_layer_slow(block, 64, layers[0], head_conv=1)
self.slow_res3 = self._make_layer_slow(
block, 128, layers[1], stride=2, head_conv=1)
self.slow_res4 = self._make_layer_slow(
block, 256, layers[2], stride=2, head_conv=1)
self.slow_res5 = self._make_layer_slow(
block, 512, layers[3], stride=2, head_conv=1)
self.dp = nn.Dropout(dropout)
self.fc = nn.Linear(self.fast_inplanes+2048, class_num, bias=False)
def forward(self, input):
fast, lateral = self.FastPath(input[:, :, ::2, :, :])
slow = self.SlowPath(input[:, :, ::16, :, :], lateral)
x = torch.cat([slow, fast], dim=1)
x = self.dp(x)
x = self.fc(x)
return x



def SlowPath(self, input, lateral):
x = self.slow_conv1(input)
x = self.slow_bn1(x)
x = self.slow_relu(x)
x = self.slow_maxpool(x)
x = torch.cat([x, lateral[0]],dim=1)
x = self.slow_res2(x)
x = torch.cat([x, lateral[1]],dim=1)
x = self.slow_res3(x)
x = torch.cat([x, lateral[2]],dim=1)
x = self.slow_res4(x)
x = torch.cat([x, lateral[3]],dim=1)
x = self.slow_res5(x)
x = nn.AdaptiveAvgPool3d(1)(x)
x = x.view(-1, x.size(1))
return x

def FastPath(self, input):
lateral = []
x = self.fast_conv1(input)
x = self.fast_bn1(x)
x = self.fast_relu(x)
pool1 = self.fast_maxpool(x)
lateral_p = self.lateral_p1(pool1)
lateral.append(lateral_p)

res2 = self.fast_res2(pool1)
lateral_res2 = self.lateral_res2(res2)
lateral.append(lateral_res2)

res3 = self.fast_res3(res2)
lateral_res3 = self.lateral_res3(res3)
lateral.append(lateral_res3)

res4 = self.fast_res4(res3)
lateral_res4 = self.lateral_res4(res4)
lateral.append(lateral_res4)

res5 = self.fast_res5(res4)
x = nn.AdaptiveAvgPool3d(1)(res5)
x = x.view(-1, x.size(1))

return x, lateral

def _make_layer_fast(self, block, planes, blocks, stride=1, head_conv=1):
downsample = None
if stride != 1 or self.fast_inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv3d(
self.fast_inplanes,
planes * block.expansion,
kernel_size=1,
stride=(1,stride,stride),
bias=False), nn.BatchNorm3d(planes * block.expansion))

layers = []
layers.append(block(self.fast_inplanes, planes, stride, downsample, head_conv=head_conv))
self.fast_inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.fast_inplanes, planes, head_conv=head_conv))
return nn.Sequential(*layers)

def _make_layer_slow(self, block, planes, blocks, stride=1, head_conv=1):
downsample = None
if stride != 1 or self.slow_inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv3d(
self.slow_inplanes,
planes * block.expansion,
kernel_size=1,
stride=(1,stride,stride),
bias=False), nn.BatchNorm3d(planes * block.expansion))

layers = []
layers.append(block(self.slow_inplanes, planes, stride, downsample, head_conv=head_conv))
self.slow_inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.slow_inplanes, planes, head_conv=head_conv))

self.slow_inplanes = planes * block.expansion + planes * block.expansion//8*2
return nn.Sequential(*layers)




def resnet50(**kwargs):
"""Constructs a ResNet-50 model.
"""
model = SlowFast(Bottleneck, [3, 4, 6, 3], **kwargs)
return model


def resnet101(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = SlowFast(Bottleneck, [3, 4, 23, 3], **kwargs)
return model


def resnet152(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = SlowFast(Bottleneck, [3, 8, 36, 3], **kwargs)
return model


def resnet200(**kwargs):
"""Constructs a ResNet-101 model.
"""
model = SlowFast(Bottleneck, [3, 24, 36, 3], **kwargs)
return model

if __name__ == "__main__":
num_classes = 101
input_tensor = torch.autograd.Variable(torch.rand(1, 3, 64, 224, 224))
model = resnet50(class_num=num_classes)
output = model(input_tensor)
print(output.size())
Loading

0 comments on commit 37d6e90

Please sign in to comment.