Skip to content

Commit

Permalink
add video support for tsm and pptsm
Browse files Browse the repository at this point in the history
  • Loading branch information
huangjun12 committed May 24, 2021
1 parent 8e3e132 commit 71329cf
Show file tree
Hide file tree
Showing 7 changed files with 402 additions and 30 deletions.
120 changes: 120 additions & 0 deletions configs/recognition/pptsm/pptsm_k400_videos_uniform.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
MODEL: #MODEL field
framework: "Recognizer2D" #Mandatory, indicate the type of network, associate to the 'paddlevideo/modeling/framework/' .
backbone: #Mandatory, indicate the type of backbone, associate to the 'paddlevideo/modeling/backbones/' .
name: "ResNetTweaksTSM" #Mandatory, The name of backbone.
pretrained: "data/ResNet50_vd_ssld_v2_pretrained.pdparams" #Optional, pretrained model path.
depth: 50 #Optional, the depth of backbone architecture.
head:
name: "ppTSMHead" #Mandatory, indicate the type of head, associate to the 'paddlevideo/modeling/heads'
num_classes: 400 #101 #Optional, the number of classes to be classified.
in_channels: 2048 #input channel of the extracted feature.
drop_ratio: 0.5 #the ratio of dropout
std: 0.01 #std value in params initialization
ls_eps: 0.1

DATASET: #DATASET field
batch_size: 16 #Mandatory, bacth size
num_workers: 4 #Mandatory, the number of subprocess on each GPU.
test_batch_size: 1
train:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
file_path: "data/k400/train.list" #Mandatory, train data index file path
valid:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
file_path: "data/k400/val.list" #Mandatory, valid data index file path
test:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
file_path: "data/k400/val.list" #Mandatory, valid data index file path

PIPELINE: #PIPELINE field
train: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
decode:
name: "VideoDecoder"
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: False
transform: #Mandotary, image transfrom operator
- Scale:
short_size: 256
- MultiScaleCrop:
target_size: 256
- RandomCrop:
target_size: 224
- RandomFlip:
- Image2Array:
- Normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
valid: #Mandatory, indicate the pipeline to deal with the validing data. associate to the 'paddlevideo/loader/pipelines/'
decode:
name: "VideoDecoder"
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: True
transform:
- Scale:
short_size: 256
- CenterCrop:
target_size: 224
- Image2Array:
- Normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
test: #Mandatory, indicate the pipeline to deal with the validing data. associate to the 'paddlevideo/loader/pipelines/'
decode:
name: "VideoDecoder"
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: True
transform:
- Scale:
short_size: 256
- CenterCrop:
target_size: 224
- Image2Array:
- Normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

OPTIMIZER: #OPTIMIZER field
name: 'Momentum'
momentum: 0.9
learning_rate:
iter_step: True
name: 'CustomWarmupCosineDecay'
max_epoch: 80
warmup_epochs: 10
warmup_start_lr: 0.005
cosine_base_lr: 0.01
weight_decay:
name: 'L2'
value: 1e-4
use_nesterov: True

MIX:
name: "Mixup"
alpha: 0.2

PRECISEBN:
preciseBN_interval: 5 # epoch interval to do preciseBN, default 1.
num_iters_preciseBN: 200 # how many batches used to do preciseBN, default 200.


METRIC:
name: 'CenterCropMetric'

INFERENCE:
name: 'ppTSM_Inference_helper'
num_seg: 8
target_size: 224

model_name: "ppTSM"
log_interval: 10 #Optional, the interal of logger, default:10
epochs: 80 #Mandatory, total epoch
log_level: "INFO" #Optional, the logger level. default: "INFO"
117 changes: 117 additions & 0 deletions configs/recognition/tsm/tsm_k400_videos.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
MODEL: #MODEL field
framework: "Recognizer2D" #Mandatory, indicate the type of network, associate to the 'paddlevideo/modeling/framework/' .
backbone: #Mandatory, indicate the type of backbone, associate to the 'paddlevideo/modeling/backbones/' .
name: "ResNetTSM" #Mandatory, The name of backbone.
pretrained: "data/ResNet50_pretrain.pdparams" #Optional, pretrained model path.
num_seg: 8
depth: 50 #Optional, the depth of backbone architecture.
head:
name: "TSMHead" #Mandatory, indicate the type of head, associate to the 'paddlevideo/modeling/heads'
num_classes: 400 #Optional, the number of classes to be classified.
in_channels: 2048 #input channel of the extracted feature.
drop_ratio: 0.5 #the ratio of dropout
std: 0.001 #std value in params initialization


DATASET: #DATASET field
batch_size: 16 #Mandatory, bacth size
num_workers: 4 #Mandatory, XXX the number of subprocess on each GPU.
train:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
file_path: "data/k400/train.list" #Mandatory, train data index file path
valid:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
file_path: "data/k400/val.list" #Mandatory, valid data index file path
test:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
file_path: "data/k400/val.list" #Mandatory, valid data index file path


PIPELINE: #PIPELINE field
train: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
decode:
name: "VideoDecoder"
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: False
select_left: True
transform: #Mandotary, image transform operator.
- MultiScaleCrop:
target_size: 224
allow_duplication: True
- RandomFlip:
- Image2Array:
- Normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

valid: #Mandatory, indicate the pipeline to deal with the validing data. associate to the 'paddlevideo/loader/pipelines/'
decode:
name: "VideoDecoder"
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: True
select_left: True
transform:
- Scale:
short_size: 256
fixed_ratio: False
- CenterCrop:
target_size: 224
- Image2Array:
- Normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

test:
decode:
name: "VideoDecoder"
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: True
select_left: True
transform:
- Scale:
short_size: 256
fixed_ratio: False
- CenterCrop:
target_size: 224
- Image2Array:
- Normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

OPTIMIZER: #OPTIMIZER field
name: 'Momentum' #Mandatory, the type of optimizer, associate to the 'paddlevideo/solver/'
momentum: 0.9
learning_rate: #Mandatory, the type of learning rate scheduler, associate to the 'paddlevideo/solver/'
name: 'PiecewiseDecay'
boundaries: [20, 40]
values: [0.02, 0.002, 0.0002] #8 cards * 16 batch size
weight_decay:
name: 'L2'
value: 0.0001
grad_clip:
name: 'ClipGradByGlobalNorm'
value: 20.0


METRIC:
name: 'CenterCropMetric'

INFERENCE:
name: 'ppTSM_Inference_helper'
num_seg: 8
target_size: 224

model_name: "TSM"
log_interval: 20 #Optional, the interal of logger, default:10
save_interval: 10
epochs: 50 #Mandatory, total epoch
log_level: "INFO" #Optional, the logger level. default: "INFO"
120 changes: 120 additions & 0 deletions configs/recognition/tsm/tsm_ucf101_videos.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
MODEL: #MODEL field
framework: "Recognizer2D" #Mandatory, indicate the type of network, associate to the 'paddlevideo/modeling/framework/' .
backbone: #Mandatory, indicate the type of backbone, associate to the 'paddlevideo/modeling/backbones/' .
name: "ResNetTSM" #Mandatory, The name of backbone.
pretrained: "data/TSM_k400.pdparams" #Optional, pretrained model path.
num_seg: 8
depth: 50 #Optional, the depth of backbone architecture.
head:
name: "TSMHead" #Mandatory, indicate the type of head, associate to the 'paddlevideo/modeling/heads'
num_classes: 101 #Optional, the number of classes to be classified.
in_channels: 2048 #input channel of the extracted feature.
drop_ratio: 0.8 #the ratio of dropout
std: 0.001 #std value in params initialization


DATASET: #DATASET field
batch_size: 16 #Mandatory, bacth size
num_workers: 4 #Mandatory, XXX the number of subprocess on each GPU.
train:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
data_prefix: "" #Mandatory, train data root path
file_path: "data/ucf101/ucf101_train_split_1_videos.txt" #Mandatory, train data index file path
suffix: '.avi'
valid:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
data_prefix: "" #Mandatory, valid data root path
file_path: "data/ucf101/ucf101_val_split_1_videos.txt" #Mandatory, valid data index file path
suffix: '.avi'
test:
format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset'
data_prefix: "" #Mandatory, valid data root path
file_path: "data/ucf101/ucf101_val_split_1_videos.txt" #Mandatory, valid data index file path
suffix: '.avi'


PIPELINE: #PIPELINE field
train: #Mandotary, indicate the pipeline to deal with the training data, associate to the 'paddlevideo/loader/pipelines/'
decode:
name: "VideoDecoder"
backend: "cv2"
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: False
select_left: True
transform: #Mandotary, image transform operator.
- MultiScaleCrop:
target_size: 224
allow_duplication: True
- RandomFlip:
- Image2Array:
- Normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

valid: #Mandatory, indicate the pipeline to deal with the validing data. associate to the 'paddlevideo/loader/pipelines/'
decode:
name: "VideoDecoder"
backend: "cv2"
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: True
select_left: True
transform:
- Scale:
short_size: 256
fixed_ratio: False
- CenterCrop:
target_size: 224
- Image2Array:
- Normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

test:
decode:
name: "VideoDecoder"
backend: "cv2"
sample:
name: "Sampler"
num_seg: 8
seg_len: 1
valid_mode: True
select_left: True
transform:
- Scale:
short_size: 256
fixed_ratio: False
- CenterCrop:
target_size: 224
- Image2Array:
- Normalization:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]


OPTIMIZER: #OPTIMIZER field
name: 'Momentum' #Mandatory, the type of optimizer, associate to the 'paddlevideo/solver/'
momentum: 0.9
learning_rate: #Mandatory, the type of learning rate scheduler, associate to the 'paddlevideo/solver/'
name: 'PiecewiseDecay'
boundaries: [10, 20]
values: [0.001, 0.0001, 0.00001] #4 cards * 16 batch size
grad_clip:
name: 'ClipGradByGlobalNorm'
value: 20.0


METRIC:
name: 'CenterCropMetric'


model_name: "TSM"
log_interval: 20 #Optional, the interal of logger, default:10
save_interval: 10
epochs: 25 #Mandatory, total epoch
log_level: "INFO" #Optional, the logger level. default: "INFO"
2 changes: 1 addition & 1 deletion paddlevideo/loader/dataset/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def prepare_test(self, idx):
results = copy.deepcopy(self.info[idx])
results = self.pipeline(results)
except Exception as e:
logger.info(e)
#logger.info(e)
if ir < self.num_retries - 1:
logger.info(
"Error when loading {}, have {} trys, will try again".
Expand Down
7 changes: 4 additions & 3 deletions paddlevideo/loader/dataset/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ class VideoDataset(BaseDataset):
pipeline(XXX): A sequence of data transforms.
**kwargs: Keyword arguments for ```BaseDataset```.
"""
def __init__(self, file_path, pipeline, num_retries=5, **kwargs):
def __init__(self, file_path, pipeline, num_retries=5, suffix='', **kwargs):
self.num_retries = num_retries
self.suffix = suffix
super().__init__(file_path, pipeline, **kwargs)

def load_file(self):
Expand All @@ -53,7 +54,7 @@ def load_file(self):
line_split = line.strip().split()
filename, labels = line_split
#TODO(hj): Required suffix format: may mp4/avi/wmv
filename = filename + '.avi'
filename = filename + self.suffix
if self.data_prefix is not None:
filename = osp.join(self.data_prefix, filename)
info.append(dict(filename=filename, labels=int(labels)))
Expand All @@ -67,7 +68,7 @@ def prepare_train(self, idx):
results = copy.deepcopy(self.info[idx])
results = self.pipeline(results)
except Exception as e:
logger.info(e)
#logger.info(e)
if ir < self.num_retries - 1:
logger.info(
"Error when loading {}, have {} trys, will try again".
Expand Down
Loading

0 comments on commit 71329cf

Please sign in to comment.