Skip to content

Latest commit

 

History

History
108 lines (64 loc) · 4.59 KB

File metadata and controls

108 lines (64 loc) · 4.59 KB

简体中文 | English

TSN

简介

Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。本代码实现的模型为基于单路RGB图像的TSN网络结构,Backbone采用ResNet-50结构。

详细内容请参考ECCV 2016年论文Temporal Segment Networks: Towards Good Practices for Deep Action Recognition

数据准备

PaddleVide提供了在K400和UCF101两种数据集上训练TSN的训练脚本。

K400数据下载及准备请参考K400数据准备

UCF101数据下载及准备请参考UCF101数据准备

模型训练

  • 加载在ImageNet1000上训练好的ResNet50权重作为Backbone初始化参数,请下载此模型参数, 或是通过命令行下载
wget https://videotag.bj.bcebos.com/PaddleVideo/PretrainModel/ResNet50_pretrain.pdparams

并将路径添加到configs中backbone字段下

MODEL:
framework: "Recognizer2D"
    backbone:
        name: "ResNet"
        pretrained: 将路径填写到此处

或用-o 参数在run.sh或命令行中进行添加 -o MODEL.framework.backbone.pretrained="" -o 参数用法请参考conifg

  • 如若进行finetune或者模型测试等,请下载PaddleVideo的已发布模型modelcoming soon, 通过--weights指定权重存放路径。 --weights 参数用法请参考config

K400 video格式训练

K400 frames格式训练

UCF101 video格式训练

UCF101 frames格式训练

实现细节

数据处理: 模型读取Kinetics-400数据集中的mp4数据,每条数据抽取seg_num段,每段抽取1帧图像,对每帧图像做随机增强后,缩放至target_size

训练策略:

  • 采用Momentum优化算法训练,momentum值设定为0.9。
  • l2_decay权重衰减系数为1e-4。
  • 学习率在训练的总epoch数的15(1/3)和30(2/3)时分别做0.1倍的衰减。

参数初始化

TSN模型的卷积层和BN层参数采用Paddle默认的KaimingNormalConstant初始化方法。而实际,在真正训练过程中,指定了pretrained参数后会以保存的权重进行参数初始化。 源代码可参考TSN的参数初始化

Linear(FC)层的参数采用mean=0,std默认0.01的Normal初始化,关于Normal初始化方法可以参考初始化官方文档

模型测试

TSN采用CenterCrop的测试Mertics

python3 main.py --test --weights=""
  • 指定--weights参数为下载已发布模型进行模型测试。

当取如下参数时,在Kinetics400的validation数据集下评估精度如下:

seg_num target_size Top-1
3 224 0.66
7 224 0.67

模型推理

首先导出模型,这里加载默认路径为output/TSN下的参数。并将预测模型导出至inference下。

    python3.7 tools/export_model.py -c configs/recognition/tsn/tsn.yaml -p output/TSN/TSN_best.pdparams -o ./inference

之后,进行模型推理

python3 tools/predict.py -v data/example.avi --model_file "./inference/TSN.pdmodel" --params_file "./inference/TSN.pdiparams" --enable_benchmark=False --model="TSN"

更多关于预测部署功能介绍请参考[../../tutorials/deployment.md]

参考论文