Skip to content

Commit

Permalink
support tsp features; add visualization code
Browse files Browse the repository at this point in the history
  • Loading branch information
ttengwang committed Nov 18, 2021
1 parent c7a1dae commit 9789580
Show file tree
Hide file tree
Showing 85 changed files with 69,778 additions and 92 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ data/anet/features/c3d
data/anet/features/resnet_bn
data/yc2/features/resnet_bn
data/densevid_eval3
*.tmp

*.Ink
.idea/
Expand Down
103 changes: 66 additions & 37 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,33 +1,22 @@
# PDVC
Official implementation for End-to-End Dense Video Captioning with Parallel Decoding (ICCV 2021) [[arxiv](https://arxiv.org/abs/2108.07781)].
Official implementation for End-to-End Dense Video Captioning with Parallel Decoding (ICCV 2021). [[paper]](https://arxiv.org/abs/2108.07781) [[code]](https://github.com/ttengwang/PDVC)

![pdvc.jpg](pdvc.jpg)
This repo supports:
* two video captioning task: dense video captioning and video paragraph captioning
* two datasets: ActivityNet Captions and YouCook2
* video features containing C3D, TSN, and TSP.
* visualization of the generated captions of your own videos

Dense video captioning aims to generate multiple associated captions with their temporal locations from the video. Previous methods follow a sophisticated "localize-then-describe" scheme, which heavily relies on numerous hand-crafted components. In this paper, we proposed a simple yet effective framework for end-to-end dense video captioning with parallel decoding (PDVC), by formulating the dense caption generation as a set prediction task. In practice, through stacking a newly proposed event counter on the top of a transformer decoder, the PDVC precisely segments the video into a number of event pieces under the holistic understanding of the video content, which effectively increases the coherence and readability of predicted captions. Compared with prior arts, the PDVC has several appealing advantages: (1) Without relying on heuristic non-maximum suppression or a recurrent event sequence selection network to remove redundancy, PDVC directly produces an event set with an appropriate size; (2) In contrast to adopting the two-stage scheme, we feed the enhanced representations of event queries into the localization head and caption head in parallel, making these two sub-tasks deeply interrelated and mutually promoted through the optimization; (3) Without bells and whistles, extensive experiments on ActivityNet Captions and YouCook2 show that PDVC is capable of producing high-quality captioning results, surpassing the state-of-the-art two-stage methods when its localization accuracy is on par with them.
# Updates
- (2021.11.19) **add code for running PDVC on raw videos and visualize the generated captions (support Chinese and other non-English languages)**
- (2021.11.19) add pretrained models with TSP features. It achieves 9.06 METEOR(2021) and 5.84 SODA_c, a very competitive results on ActivityNet Captions without self-critical training.
- (2021.08.29) add TSN pretrained models and support YouCook2

# Introduction
PDVC is a simple yet effective framework for end-to-end dense video captioning with parallel decoding (PDVC), by formulating the dense caption generation as a set prediction task. Without bells and whistles, extensive experiments on ActivityNet Captions and YouCook2 show that PDVC is capable of producing high-quality captioning results, surpassing the state-of-the-art methods when its localization accuracy is on par with them.

# Performance
### Dense video captioning

| Model | Features | config_path | Url | Recall | Precision | BLEU4 | METEOR2018 | METEOR2021 | CIDEr | SODA_c |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| PDVC_light | C3D | cfgs/anet_c3d_pdvcl.yml | [Google Drive](https://drive.google.com/drive/folders/1JKOJrm5QMAkso-VJnzGnksIVqNYt8BSI?usp=sharing) | 55.30 | 58.42 | 1.55 | 7.13 | 7.66 | 24.80 | 5.23 |
| PDVC_light | TSN | cfgs/anet_tsn_pdvcl.yml | [Google Drive](https://drive.google.com/drive/folders/1hImJ7sXABzS-ycErruLFCE_pkWEHzFSV?usp=sharing) | 55.34 | 57.97 | 1.66 | 7.41 | 7.97 | 27.23 | 5.51 |
| PDVC | C3D | cfgs/anet_c3d_pdvc.yml | [Google Drive](https://drive.google.com/drive/folders/1I77miVvThdMenmprgozfRsXDVoc-9TxY?usp=sharing) | 55.20 | 57.36 | 1.82 | 7.48 | 8.09 | 28.16 | 5.47 |
| PDVC | TSN | cfgs/anet_tsn_pdvc.yml | [Google Drive](https://drive.google.com/drive/folders/1v2Xj0Qjt3Te_SgVyySKEofRaZsSw_rjs?usp=sharing) | 56.21 | 57.46 | 1.92 | 8.00 | 8.63 | 29.00 | 5.68 |

Notes:
* In the paper, we follow the most previous methods to use the [evaluation toolkit in ActivityNet Challenge 2018](https://github.com/ranjaykrishna/densevid_eval/tree/deba7d7e83012b218a4df888f6c971e21cfeea33). Note that the latest [evluation tookit](https://github.com/ranjaykrishna/densevid_eval/tree/9d4045aced3d827834a5d2da3c9f0692e3f33c1c) (METEOR2021) gives the same CIDEr/BLEU4 but a higher METEOR score.
* In the paper, we use an [old version of SODA_c implementation](https://github.com/fujiso/SODA/tree/22671b3570e088217139bcb1e4de7a3499c30294), while here we use an [updated version](https://github.com/fujiso/SODA/tree/9cb3e2c5a73c4e320a38c72f320b63bbef4aa798) for convenience.

### Video paragraph captioning
| Model | Features | config_path | BLEU4 | METEOR | CIDEr |
| ---- | ---- | ---- | ---- | ---- | ---- |
| PDVC | C3D | cfgs/anet_c3d_pdvc.yml | 9.67 | 14.74 | 16.43 |
| PDVC | TSN | cfgs/anet_tsn_pdvc.yml | 10.18 | 15.96 | 20.66 |

Notes:
* Paragraph-level scores are evaluated on the ActivityNet Entity ae-val set.
![pdvc.jpg](pdvc.jpg)

# Preparation
Environment: Linux, GCC>=5.4, CUDA >= 9.2, Python>=3.7, PyTorch>=1.5.1
Expand All @@ -52,18 +41,31 @@ cd data/anet/features
bash download_anet_c3d.sh
# bash download_anet_tsn.sh
# bash download_i3d_vggish_features.sh

# bash download_tsp_features.sh
```

4. Compile the deformable attention layer (requires GCC >= 5.4).
```bash
cd models/ops
cd pdvc/ops
sh make.sh
```

# Running PDVC on Your Own Videos
Download a pretrained model ([GoogleDrive](https://drive.google.com/drive/folders/1sX5wTk1hBgR_a5YUzpxFCrzwkZQXiIab?usp=sharing)) with [TSP](https://github.com/HumamAlwassel/TSP) features and put it into `./save`. Then run:
```
video_folder=visualization/videos # path to video folder (only support mp4 files)
output_folder=visualization/videos # path to save results
pdvc_model_path=save/anet_tsp_pdvc/model-best.pth # path of pretrained model
output_language=en # 'zh-cn' for chinese (simplied), for other language, find the abbreviation of your language at https://github.com/lushan88a/google_trans_new/blob/main/constant.py
bash test_and_visualize.sh $video_folder $output_folder $pdvc_model_path $pdvc_model_path $output_language # to generate new captioning
```
check the `$output_folder`, you will see a new video with embedded captions. Not that we geneate non-English Captions by translating the English captions by GoogleTranslater.

![demo.gif](visualization/xukun.gif)

# Usage
### Dense Video Captioning
1. PDVC with learnt proposal
### Dense Video Captioning Task
1. PDVC with learnt proposals
```
# Training
config_path=cfgs/anet_c3d_pdvc.yml
Expand All @@ -87,9 +89,9 @@ python eval.py --eval_folder ${eval_folder} --eval_transformer_input_type gt_pro
```


### Video Paragraph Captioning
### Video Paragraph Captioning Task

1. PDVC with learnt proposal
1. PDVC with learnt proposals
```bash
# Training
config_path=cfgs/anet_c3d_pdvc.yml
Expand All @@ -99,7 +101,7 @@ python train.py --cfg_path ${config_path} --criteria_for_best_ckpt pc --gpu_id $
eval_folder=anet_c3d_pdvc # specify the folder to be evaluated
python eval.py --eval_folder ${eval_folder} --eval_transformer_input_type queries --gpu_id ${GPU_ID}
```
2. PDVC with ground-truth proposal
2. PDVC with ground-truth proposals
```
# Training
config_path=cfgs/anet_c3d_pdvc.yml
Expand All @@ -110,18 +112,45 @@ eval_folder=anet_c3d_pdvc_gt
python eval.py --eval_folder ${eval_folder} --eval_transformer_input_type gt_proposals --gpu_id ${GPU_ID}
```

# TODO
- [x] add more pretrained models
- [x] support youcook2
# Performance
### Dense video captioning

| Model | Features | config_path | Url | Recall | Precision | BLEU4 | METEOR2018 | METEOR2021 | CIDEr | SODA_c |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| PDVC_light | C3D | cfgs/anet_c3d_pdvcl.yml | [Google Drive](https://drive.google.com/drive/folders/1JKOJrm5QMAkso-VJnzGnksIVqNYt8BSI?usp=sharing) | 55.30 | 58.42 | 1.55 | 7.13 | 7.66 | 24.80 | 5.23 |
| PDVC | C3D | cfgs/anet_c3d_pdvc.yml | [Google Drive](https://drive.google.com/drive/folders/1I77miVvThdMenmprgozfRsXDVoc-9TxY?usp=sharing) | 55.20 | 57.36 | 1.82 | 7.48 | 8.09 | 28.16 | 5.47 |
| PDVC_light | TSN | cfgs/anet_tsn_pdvcl.yml | [Google Drive](https://drive.google.com/drive/folders/1hImJ7sXABzS-ycErruLFCE_pkWEHzFSV?usp=sharing) | 55.34 | 57.97 | 1.66 | 7.41 | 7.97 | 27.23 | 5.51 |
| PDVC | TSN | cfgs/anet_tsn_pdvc.yml | [Google Drive](https://drive.google.com/drive/folders/1v2Xj0Qjt3Te_SgVyySKEofRaZsSw_rjs?usp=sharing) | 56.21 | 57.46 | 1.92 | 8.00 | 8.63 | 29.00 | 5.68 |
| PDVC_light | TSP | cfgs/anet_tsn_pdvcl.yml | [Google Drive](https://drive.google.com/drive/folders/1Ei8lnBs9Nn2SsFVd7WGe2iJERo46izv8?usp=sharing) | 55.24 | 57.78 | 1.77 | 7.94 | 8.55 | 28.25 | 5.95 |
| PDVC | TSP | cfgs/anet_tsn_pdvc.yml | [Google Drive](https://drive.google.com/drive/folders/1sX5wTk1hBgR_a5YUzpxFCrzwkZQXiIab?usp=sharing) | 55.22 | 57.17 | 2.21 | 8.42 | 9.06 | 30.35 | 5.84 |


Notes:
* In the paper, we follow the most previous methods to use the [evaluation toolkit in ActivityNet Challenge 2018](https://github.com/ranjaykrishna/densevid_eval/tree/deba7d7e83012b218a4df888f6c971e21cfeea33). Note that the latest [evluation tookit](https://github.com/ranjaykrishna/densevid_eval/tree/9d4045aced3d827834a5d2da3c9f0692e3f33c1c) (METEOR2021) gives the same CIDEr/BLEU4 but a higher METEOR score.
* In the paper, we use an [old version of SODA_c implementation](https://github.com/fujiso/SODA/tree/22671b3570e088217139bcb1e4de7a3499c30294), while here we use an [updated version](https://github.com/fujiso/SODA/tree/9cb3e2c5a73c4e320a38c72f320b63bbef4aa798) for convenience.

### Video paragraph captioning
| Model | Features | config_path | BLEU4 | METEOR | CIDEr |
| ---- | ---- | ---- | ---- | ---- | ---- |
| PDVC | C3D | cfgs/anet_c3d_pdvc.yml | 9.67 | 14.74 | 16.43 |
| PDVC | TSN | cfgs/anet_tsn_pdvc.yml | 10.18 | 15.96 | 20.66 |
| PDVC | TSP | cfgs/anet_tsn_pdvc.yml | 10.09 | 16.43 | 19.41 |
Notes:
* Paragraph-level scores are evaluated on the ActivityNet Entity ae-val set.




# Citation
If you find this repo helpful, please consider citing:
```
@article{wang2021end,
@inproceedings{wang2021end,
title={End-to-End Dense Video Captioning with Parallel Decoding},
author={Wang, Teng and Zhang, Ruimao and Lu, Zhichao and Zheng, Feng and Cheng, Ran and Luo, Ping},
journal={arXiv preprint},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={6847--6857},
year={2021}
}
```
```
@ARTICLE{wang2021echr,
Expand All @@ -138,5 +167,5 @@ If you find this repo helpful, please consider citing:
# Acknowledgement

The implementation of Deformable Transformer is mainly based on [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR).
The implementation of the captioning head is based on [ImageCaptioning.pyotrch](https://github.com/ruotianluo/ImageCaptioning.pytorch).
The implementation of the captioning head is based on [ImageCaptioning.pytorch](https://github.com/ruotianluo/ImageCaptioning.pytorch).
We thanks the authors for their efforts.
6 changes: 6 additions & 0 deletions cfgs/anet_tsp_pdvc.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
id: anet_tsp_pdvc
base_cfg_path: cfgs/anet_c3d_pdvc.yml
visual_feature_type: ['tsp']
visual_feature_folder: ['data/anet/features/tsp']
invalid_video_json: []
feature_dim: 512
6 changes: 6 additions & 0 deletions cfgs/anet_tsp_pdvcl.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
id: anet_tsp_pdvcl
base_cfg_path: cfgs/anet_c3d_pdvcl.yml
visual_feature_type: ['tsp']
visual_feature_folder: ['data/anet/features/tsp']
invalid_video_json: []
feature_dim: 512
5 changes: 5 additions & 0 deletions data/anet/features/download_tsp_features.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# TSP features from https://github.com/HumamAlwassel/TSP
# download the following files and reformat them into data/features/tsp/VIDEO_ID.npy where VIDEO_ID starts with 'v_'
wget https://github.com/HumamAlwassel/TSP/releases/download/activitynet_features/r2plus1d_34-tsp_on_activitynet-train_features.h5
wget https://github.com/HumamAlwassel/TSP/releases/download/activitynet_features/r2plus1d_34-tsp_on_activitynet-valid_features.h5
wget https://github.com/HumamAlwassel/TSP/releases/download/activitynet_features/r2plus1d_34-tsp_on_activitynet-test_features.h5
81 changes: 56 additions & 25 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,35 @@
# print(sys.path)

from eval_utils import evaluate
from models.pdvc import build
from pdvc.pdvc import build
from misc.utils import create_logger
from data.video_dataset import PropSeqDataset, collate_fn
from torch.utils.data import DataLoader

from os.path import basename
import pandas as pd

def create_fake_test_caption_file(metadata_csv_path):
out = {}
df = pd.read_csv(metadata_csv_path)
for i, row in df.iterrows():
out[basename(row['filename']).split('.')[0]] = {'duration': row['video-duration'], "timestamps": [[0, 0.5]], "sentences":["None"]}
fake_test_json = '.fake_test_json.tmp'
json.dump(out, open(fake_test_json, 'w'))
return fake_test_json

def main(opt):
folder_path = os.path.join(opt.eval_save_dir, opt.eval_folder)
infos_path = os.path.join(folder_path, 'info.json')
if opt.eval_mode == 'test':
if not os.path.exists(folder_path):
os.makedirs(folder_path)
logger = create_logger(folder_path, 'val.log')
if opt.eval_model_path:
model_path = opt.eval_model_path
infos_path = os.path.join('/'.join(opt.eval_model_path.split('/')[:-1]), 'info.json')
else:
model_path = os.path.join(folder_path, 'model-best.pth')
infos_path = os.path.join(folder_path, 'info.json')

logger.info(vars(opt))

with open(infos_path, 'rb') as f:
Expand All @@ -43,6 +62,11 @@ def main(opt):
if not torch.cuda.is_available():
opt.nthreads = 0
# Create the Data Loader instance

if opt.eval_mode == 'test':
opt.eval_caption_file = create_fake_test_caption_file(opt.test_video_meta_data_csv_path)
opt.visual_feature_folder = opt.test_video_feature_folder

val_dataset = PropSeqDataset(opt.eval_caption_file,
opt.visual_feature_folder,
opt.dict_file, False, opt.eval_proposal_type,
Expand All @@ -54,58 +78,65 @@ def main(opt):
model, criterion, postprocessors = build(opt)
model.translator = val_dataset.translator

if opt.eval_model_path:
model_path = opt.eval_model_path
else:
model_path = os.path.join(folder_path, 'model-best.pth')


while not os.path.exists(model_path):
raise AssertionError('File {} does not exist'.format(model_path))

logger.debug('Loading model from {}'.format(model_path))
loaded_pth = torch.load(model_path, map_location=opt.device)
loaded_pth = torch.load(model_path, map_location=opt.eval_device)
epoch = loaded_pth['epoch']

# loaded_pth = transfer(model, loaded_pth, model_path+'.transfer.pth')
model.load_state_dict(loaded_pth['model'], strict=True)
model.eval()

model.to(opt.device)
model.to(opt.eval_device)

out_json_path = os.path.join(folder_path, '{}_epoch{}_num{}_alpha{}.json'.format(
time.strftime("%Y-%m-%d-%H-%M-%S_", time.localtime()) + str(opt.id), epoch, len(loader.dataset), opt.ec_alpha))
logger.info('saving reults json to {}'.format(out_json_path))
caption_scores,eval_loss = evaluate(model, criterion, postprocessors, loader, out_json_path,
logger, alpha=opt.ec_alpha, dvc_eval_version=opt.eval_tool_version, device=opt.device, debug=False)
if opt.eval_mode == 'test':
out_json_path = os.path.join(folder_path, 'dvc_results.json')
evaluate(model, criterion, postprocessors, loader, out_json_path,
logger, alpha=opt.ec_alpha, dvc_eval_version=opt.eval_tool_version, device=opt.eval_device, debug=False, skip_lang_eval=True)

avg_eval_score = {key: np.array(value).mean() for key, value in caption_scores.items() if key !='tiou'}
avg_eval_score2 = {key: np.array(value).mean() * 4917 / len(loader.dataset) for key, value in caption_scores.items() if key != 'tiou'}

logger.info(
'\nValidation result based on all 4917 val videos:\n {}\n avg_score:\n{}'.format(
caption_scores.items(),
avg_eval_score))

logger.info(
'\nValidation result based on {} available val videos:\n avg_score:\n{}'.format(len(loader.dataset),
avg_eval_score2))
else:
out_json_path = os.path.join(folder_path, '{}_epoch{}_num{}_alpha{}.json'.format(
time.strftime("%Y-%m-%d-%H-%M-%S_", time.localtime()) + str(opt.id), epoch, len(loader.dataset),
opt.ec_alpha))
caption_scores, eval_loss = evaluate(model, criterion, postprocessors, loader, out_json_path,
logger, alpha=opt.ec_alpha, dvc_eval_version=opt.eval_tool_version, device=opt.eval_device, debug=False, skip_lang_eval=False)
avg_eval_score = {key: np.array(value).mean() for key, value in caption_scores.items() if key !='tiou'}
avg_eval_score2 = {key: np.array(value).mean() * 4917 / len(loader.dataset) for key, value in caption_scores.items() if key != 'tiou'}

logger.info(
'\nValidation result based on all 4917 val videos:\n {}\n avg_score:\n{}'.format(
caption_scores.items(),
avg_eval_score))

logger.info(
'\nValidation result based on {} available val videos:\n avg_score:\n{}'.format(len(loader.dataset),
avg_eval_score2))

logger.info('saving reults json to {}'.format(out_json_path))

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--eval_save_dir', type=str, default='save')
parser.add_argument('--eval_mode', type=str, default='eval', choices=['eval', 'test'])
parser.add_argument('--test_video_feature_folder', type=str, nargs='+', default=None)
parser.add_argument('--test_video_meta_data_csv_path', type=str, default=None)
parser.add_argument('--eval_folder', type=str, required=True)
parser.add_argument('--eval_model_path', type=str, default='')
parser.add_argument('--eval_tool_version', type=str, default='2018', choices=['2018', '2021'])
parser.add_argument('--eval_caption_file', type=str, default='data/anet/captiondata/val_1.json')
parser.add_argument('--eval_proposal_type', type=str, default='gt')
parser.add_argument('--eval_transformer_input_type', type=str, default='queries', choices=['gt_proposals', 'queries'])
parser.add_argument('--gpu_id', type=str, nargs='+', default=['0'])
parser.add_argument('--eval_device', type=str, default='cuda')
opt = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in opt.gpu_id])
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

if True:
torch.backends.cudnn.enabled = False
main(opt)
Loading

0 comments on commit 9789580

Please sign in to comment.