Skip to content

Commit

Permalink
Release TimeSformer model
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Jun 28, 2021
1 parent dbedc5f commit 220cd07
Show file tree
Hide file tree
Showing 15 changed files with 1,094 additions and 40 deletions.
143 changes: 143 additions & 0 deletions configs/recognition/timesformer/timesformer_k400_videos.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
MODEL: #MODEL field
framework: "Recognizer2D" #Mandatory, indicate the type of network, associate to the 'paddlevideo/modeling/framework/' .
backbone: #Mandatory, indicate the type of backbone, associate to the 'paddlevideo/modeling/backbones/' .
name: "VisionTransformer" #Mandatory, The name of backbone.
pretrained: "data/ViT_base_patch16_224_pretrained.pdparams" #Optional, pretrained model path.
img_size: 224
patch_size: 16
in_channels: 3
embed_dim: 768
depth: 12
num_heads: 12
mlp_ratio: 4
qkv_bias: True
epsilon: 1e-6
seg_num: 8
attention_type: 'divided_space_time'
head:
name: "TimeSformerHead" #Mandatory, indicate the type of head, associate to the 'paddlevideo/modeling/heads'
num_classes: 400 #Optional, the number of classes to be classified.
in_channels: 768 #input channel of the extracted feature.
std: 0.01 #std value in params initialization


DATASET: #DATASET field
batch_size: 1 #Mandatory, bacth size
num_workers: 4 #Mandatory, XXX the number of subprocess on each GPU.
train:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
file_path: "/workspace/huangjun12/PaddleProject/PaddleVideo/ppTSM_ACC/Distill/Stage3/E2.r101.Dense/data/k400/train.list" #Mandatory, train data index file path
valid:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
file_path: "/workspace/huangjun12/PaddleProject/PaddleVideo/ppTSM_ACC/Distill/Stage3/E2.r101.Dense/data/k400/val.list" #Mandatory, valid data index file path
test:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
file_path: "/workspace/huangjun12/PaddleProject/PaddleVideo/ppTSM_ACC/Distill/Stage3/E2.r101.Dense/data/k400/val.list" #Mandatory, valid data index file path

PIPELINE: #PIPELINE field TODO.....
train: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
decode:
name: "VideoDecoder"
backend: 'pyav'
mode: 'train'
num_seg: 8
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: False
linspace_sample: True
transform: #Mandotary, image transform operator.
- JitterScale:
min_size: 256
max_size: 320
- RandomCrop:
target_size: 224
- RandomFlip:
- Image2Array:
data_format: 'cthw'
- Normalization:
mean: [0.45, 0.45, 0.45]
std: [0.225, 0.225, 0.225]
tensor_shape: [3, 1, 1, 1]

valid: #Mandatory, indicate the pipeline to deal with the validing data. associate to the 'paddlevideo/loader/pipelines/'
decode:
name: "VideoDecoder"
backend: 'pyav'
mode: 'valid'
num_seg: 8
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: False
linspace_sample: True
transform:
- JitterScale:
min_size: 256
max_size: 320
- RandomCrop:
target_size: 224
- RandomFlip:
- Image2Array:
data_format: 'cthw'
- Normalization:
mean: [0.45, 0.45, 0.45]
std: [0.225, 0.225, 0.225]
tensor_shape: [3, 1, 1, 1]
test:
decode:
name: "VideoDecoder"
backend: 'pyav'
mode: 'test'
num_seg: 8
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: True
linspace_sample: True
transform:
- JitterScale:
min_size: 224
max_size: 224
- UniformCrop:
target_size: 224
- Image2Array:
data_format: 'cthw'
- Normalization:
mean: [0.45, 0.45, 0.45]
std: [0.225, 0.225, 0.225]
tensor_shape: [3, 1, 1, 1]

OPTIMIZER: #OPTIMIZER field
name: 'Momentum' #Mandatory, the type of optimizer, associate to the 'paddlevideo/solver/'
momentum: 0.9
learning_rate: #Mandatory, the type of learning rate scheduler, associate to the 'paddlevideo/solver/'
learning_rate: 0.005 # 8 cards * 4 batch size
name: 'MultiStepDecay'
milestones: [11, 14]
gamma: 0.1
weight_decay:
name: 'L2'
value: 0.0001
use_nesterov: True

METRIC:
name: 'UniformCropMetric'

GRADIENT_ACCUMULATION:
global_batch_size: 64 # Specify the sum of batches to be calculated by all GPUs
num_gpus: 8 # Number of GPUs

# INFERENCE:
# name: 'ppTSM_Inference_helper'
# num_seg: 8
# target_size: 224

model_name: "TimeSformer"
log_interval: 20 #Optional, the interal of logger, default:10
save_interval: 3
epochs: 15 #Mandatory, total epoch
log_level: "INFO" #Optional, the logger level. default: "INFO"
3 changes: 2 additions & 1 deletion paddlevideo/loader/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .augmentations import (Scale, RandomCrop, CenterCrop, RandomFlip,
Image2Array, Normalization, JitterScale, MultiCrop,
PackOutput, TenCrop)
PackOutput, TenCrop, UniformCrop)

from .compose import Compose
from .decode import VideoDecoder, FrameDecoder
Expand All @@ -40,6 +40,7 @@
'MultiCrop',
'PackOutput',
'TenCrop',
'UniformCrop',
'DecodeSampler',
'LoadFeat',
'GetMatchMap',
Expand Down
65 changes: 59 additions & 6 deletions paddlevideo/loader/pipelines/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,17 @@ def __call__(self, results):
if self.fixed_ratio:
oh = int(self.short_size * 4.0 / 3.0)
else:
oh = int(round(h * self.short_size / w)) if self.do_round else int(h * self.short_size / w)
oh = int(round(h * self.short_size /
w)) if self.do_round else int(
h * self.short_size / w)
else:
oh = self.short_size
if self.fixed_ratio:
ow = int(self.short_size * 4.0 / 3.0)
else:
ow = int(round(w * self.short_size / h)) if self.do_round else int(w * self.short_size / h)
ow = int(round(w * self.short_size /
h)) if self.do_round else int(
w * self.short_size / h)
if self.backend == 'pillow':
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
else:
Expand Down Expand Up @@ -322,8 +326,12 @@ class Image2Array(object):
Args:
transpose: whether to transpose or not, default True, False for slowfast.
"""
def __init__(self, transpose=True):
def __init__(self, transpose=True, data_format='tchw'):
assert data_format in [
'tchw', 'cthw'
], f"Target format must in ['tchw', 'cthw'], but got {data_format}"
self.transpose = transpose
self.data_format = data_format

def __call__(self, results):
"""
Expand All @@ -337,7 +345,10 @@ def __call__(self, results):
imgs = results['imgs']
np_imgs = (np.stack(imgs)).astype('float32')
if self.transpose:
np_imgs = np_imgs.transpose(0, 3, 1, 2) # nchw
if self.data_format == 'tchw':
np_imgs = np_imgs.transpose(0, 3, 1, 2) # tchw
else:
np_imgs = np_imgs.transpose(3, 0, 1, 2) # cthw
results['imgs'] = np_imgs
return results

Expand Down Expand Up @@ -617,8 +628,9 @@ def __call__(self, results):
img_crops = list()
for x_offset, y_offset in offsets:
crop = [
img.crop((x_offset, y_offset, x_offset + crop_w,
y_offset + crop_h)) for img in imgs
img.crop(
(x_offset, y_offset, x_offset + crop_w, y_offset + crop_h))
for img in imgs
]
crop_fliped = [
timg.transpose(Image.FLIP_LEFT_RIGHT) for timg in crop
Expand All @@ -628,3 +640,44 @@ def __call__(self, results):

results['imgs'] = img_crops
return results


@PIPELINES.register()
class UniformCrop:
"""
Perform uniform spatial sampling on the images.
and then flip the cropping result to get 10 cropped images, which can make the prediction result more robust.
Args:
target_size(int | tuple[int]): (w, h) of target size for crop.
"""
def __init__(self, target_size):
self.target_size = (target_size, target_size)

def __call__(self, results):

imgs = results['imgs']
img_w, img_h = imgs[0].size
crop_w, crop_h = self.target_size
if img_h > img_w:
offsets = [
(0, 0),
(0, (img_h - crop_h + 1) // 2), # ceil
(0, img_h - crop_h)
]
else:
offsets = [
(0, 0),
((img_w - crop_w + 1) // 2, 0), # ceil
(img_w - crop_w, 0)
]
img_crops = []
for x_offset, y_offset in offsets:
crop = [
img.crop(
(x_offset, y_offset, x_offset + crop_w, y_offset + crop_h))
for img in imgs
]
img_crops.extend(
crop) # [I0_left, ..., ITleft, ...I0right, ..., ITright]
results['imgs'] = img_crops
return results
Loading

0 comments on commit 220cd07

Please sign in to comment.