Skip to content

Commit

Permalink
fix infer bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Thinksky5124 committed Jan 14, 2022
1 parent 0afe024 commit 89e02e9
Show file tree
Hide file tree
Showing 12 changed files with 349 additions and 81 deletions.
200 changes: 200 additions & 0 deletions applications/BaseballAction/transform_segmentation_label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
'''
Author: Thyssen Wen
Date: 2022-01-10 15:57:12
LastEditors: Thyssen Wen
LastEditTime: 2022-01-11 10:15:13
Description: file content
FilePath: \TAS\transform_label.py
'''
import json
import numpy as np
import argparse
import os

from tqdm import tqdm


def generate_mapping_list_txt(action_dict, out_path):
out_txt_file_path = os.path.join(out_path, "mapping.txt")
f = open(out_txt_file_path, "w", encoding='utf-8')
for key, action_name in action_dict.items():
str_str = str(key) + " " + action_name + "\n"
f.write(str_str)
# add None
str_str = str(len(action_dict)) + " None" + "\n"
f.write(str_str)
f.close()


def segmentation_convert_localization_label(prefix_data_path, out_path,
action_dict, fps):
label_path = os.path.join(prefix_data_path, "train")
label_txt_name_list = os.listdir(label_path)

labels_dict = {}
labels_dict["fps"] = fps
labels_list = []
for label_name in tqdm(label_txt_name_list, desc='label convert:'):
label_dict = {}
label_dict["url"] = label_name.split(".")[0] + ".mp4"
label_txt_path = os.path.join(prefix_data_path, "train", label_name)

with open(label_txt_path, "r", encoding='utf-8') as f:
gt = f.read().split("\n")[:-1]
label_dict["total_frames"] = len(gt)

boundary_index_list = [0]
before_action_name = gt[0]
for index in range(1, len(gt)):
if before_action_name != gt[index]:
boundary_index_list.append(index)
actions_list = []
for index in range(len(boundary_index_list) - 1):
if gt[index] != "None":
action_name = gt[boundary_index_list[index]]
start_sec = float(boundary_index_list[index]) / float(fps)
end_sec = float(boundary_index_list[index + 1]) / float(fps)
action_id = list(action_dict.keys())[list(
action_dict.values()).index(action_name)]
label_action_dict = {}
label_action_dict["label_names"] = action_name
label_action_dict["start_id"] = start_sec
label_action_dict["end_id"] = end_sec
label_action_dict["label_ids"] = action_id
actions_list.append(label_action_dict)

label_dict["actions"] = actions_list
labels_list.append(label_dict)
labels_dict["gts"] = labels_list
output_path = os.path.join(out_path, "output.json")
f = open(output_path, "w", encoding='utf-8')
f.write(json.dumps(labels_dict, indent=4))
f.close()


def generate_action_dict(label):
action_dict = {}
for gt in label["gts"]:
for action in gt["actions"]:
label_id = action["label_ids"][0]
label_name = action["label_names"][0]
action_dict[label_id] = label_name

return action_dict


def load_action_dict(data_path):
mapping_txt_path = os.path.join(data_path, "mapping.txt")
with open(mapping_txt_path, "r", encoding='utf-8') as f:
actions = f.read().split("\n")[:-1]

class2id_map = dict()
for a in actions:
class2id_map[a.split()[1]] = int(a.split()[0])

return class2id_map


def localization_convert_segmentation_label(label, prefix_data_path, out_path):
path = os.path.join(out_path, "groundTruth")
isExists = os.path.exists(path)
if not isExists:
os.makedirs(path)
print(path + ' 创建成功')
else:
print(path + ' 目录已存在')

fps = float(label["fps"])
video_list = []
for gt in tqdm(label["gts"], desc='label convert:'):
video_name = gt["url"].split(".")[0]
data_path = os.path.join(prefix_data_path, video_name + ".pkl")
video_list.append(video_name + ".txt")
feature = np.load(data_path, allow_pickle=True)["image_feature"]

num_feture = feature.shape[0]
seg_label = ["None"] * (num_feture)
for action in gt["actions"]:
start_id = action["start_id"]
end_id = action["end_id"]

label_name = action["label_names"]

start_index = int(np.floor(start_id * fps))
end_index = int(np.floor(end_id * fps)) + 1

if end_index < num_feture - 1:
seg_label[start_index:end_index] = label_name * (end_index -
start_index)
elif start_index < num_feture - 1:
seg_label[start_index:] = label_name * (end_index - start_index)
else:
pass

out_txt_file_path = os.path.join(out_path, "groundTruth",
video_name + ".txt")
str = '\n'
f = open(out_txt_file_path, "w", encoding='utf-8')
f.write(str.join(seg_label) + str)
f.close()
out_txt_file_path = os.path.join(out_path, "train_list.txt")
str = '\n'
f = open(out_txt_file_path, "w", encoding='utf-8')
f.write(str.join(video_list) + str)
f.close()


def main():
args = get_arguments()

if args.mode in ["segmentation", "localization"]:
if args.mode == "segmentation":
with open(args.label_path, 'r', encoding='utf-8') as json_file:
label = json.load(json_file)
action_dict = generate_action_dict(label)
generate_mapping_list_txt(action_dict, args.out_path)
localization_convert_segmentation_label(label, args.data_path,
args.out_path)

elif args.mode == "localization":
action_dict = load_action_dict(args.data_path)
segmentation_convert_localization_label(args.data_path,
args.out_path,
action_dict,
fps=25.0)

else:
raise NotImplementedError


def get_arguments():
"""
parse all the arguments from command line inteface
return a list of parsed arguments
"""

parser = argparse.ArgumentParser(
description="convert segmentation and localization label")
parser.add_argument("label_path", type=str, help="path of a label file")
parser.add_argument(
"data_path",
type=str,
help="path of video feature or segmentation label txt.",
)
parser.add_argument(
"out_path",
type=str,
help="path of output file.",
)
parser.add_argument(
"--mode",
type=str,
default="segmentation",
help="Convert segmentation label or localization label.",
)

return parser.parse_args()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions configs/segmentation/asrf/asrf_50salads.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ INFERENCE:
actions_map_file_path: "./data/50salads/mapping.txt"
postprocessing_method: "refinement_with_boundary"
boundary_threshold: 0.5
feature_path: "./data/50salads/features"


model_name: "ASRF"
Expand Down
1 change: 1 addition & 0 deletions configs/segmentation/asrf/asrf_GTEA.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ INFERENCE:
actions_map_file_path: "./data/gtea/mapping.txt"
postprocessing_method: "refinement_with_boundary"
boundary_threshold: 0.5
feature_path: "./data/gtea/features"


model_name: "ASRF"
Expand Down
1 change: 1 addition & 0 deletions configs/segmentation/asrf/asrf_breakfast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ INFERENCE:
actions_map_file_path: "./data/breakfast/mapping.txt"
postprocessing_method: "refinement_with_boundary"
boundary_threshold: 0.5
feature_path: "./data/breakfast/features"


model_name: "ASRF"
Expand Down
1 change: 1 addition & 0 deletions configs/segmentation/ms_tcn/ms_tcn_50salads.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ INFERENCE:
name: 'MSTCN_Inference_helper'
num_channels: 2048
actions_map_file_path: "./data/50salads/mapping.txt"
feature_path: "./data/50salads/features"


model_name: "MSTCN"
Expand Down
1 change: 1 addition & 0 deletions configs/segmentation/ms_tcn/ms_tcn_GTEA.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ INFERENCE:
name: 'MSTCN_Inference_helper'
num_channels: 2048
actions_map_file_path: "./data/gtea/mapping.txt"
feature_path: "./data/gtea/features"


model_name: "MSTCN"
Expand Down
1 change: 1 addition & 0 deletions configs/segmentation/ms_tcn/ms_tcn_breakfast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ INFERENCE:
name: 'MSTCN_Inference_helper'
num_channels: 2048
actions_map_file_path: "./data/breakfast/mapping.txt"
feature_path: "./data/breakfast/features"


model_name: "MSTCN"
Expand Down
18 changes: 18 additions & 0 deletions docs/zh-CN/model_zoo/segmentation/Temporal_action_segmentation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[English](../../../en/model_zoo/segmentation/Trmporal_action_segmentation.md) | 简体中文

本仓库提供经典和热门时序动作分割模型的性能和精度对比

| Model | Metrics | Value | Flops(M) |Params(M) | test time(ms) bs=1 | test time(ms) bs=2 | inference time(ms) bs=1 | inference time(ms) bs=2 |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| MS-TCN | F1@0.5 | 38.8% | 791.360 | 0.8 | 170 | - | 10.68 | - |
| ASRF | F1@0.5 | 55.7% | 1,283.328 | 1.3 | 190 | - | 16.34 | - |

* 模型名称:填写模型的具体名字,比如PP-TSM
* Metrics:填写模型测试时所用的指标,使用的数据集为**breakfast**
* Value:填写Metrics指标对应的数值,一般保留小数点后两位
* Flops:模型一次前向运算所需的浮点运算量,可以调用PaddleVideo/tools/summary.py脚本计算(不同模型可能需要稍作修改),保留小数点后一位,使用数据**输入形状为(1, 2048, 1000)的张量**测得
* Params(M):模型参数量,和Flops一起会被脚本计算出来,保留小数点后一位
* test time(ms) bs=1:python脚本开batchsize=1测试时,一个样本所需的耗时,保留小数点后两位。测试使用的数据集为**breakfast**
* test time(ms) bs=2:python脚本开batchsize=2测试时,一个样本所需的耗时,保留小数点后两位。时序动作分割模型一般是全卷积网络,所以训练、测试和推理的batch_size都是1。测试使用的数据集为**breakfast**
* inference time(ms) bs=1:推理模型用GPU(默认V100)开batchsize=1测试时,一个样本所需的耗时,保留小数点后两位。推理使用的数据集为**breakfast**
* inference time(ms) bs=2:推理模型用GPU(默认V100)开batchsize=1测试时,一个样本所需的耗时,保留小数点后两位。时序动作分割模型一般是全卷积网络,所以训练、测试和推理的batch_size都是1。推理使用的数据集为**breakfast**
9 changes: 9 additions & 0 deletions paddlevideo/tasks/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import paddle
import time
from paddlevideo.utils import get_logger, load

from ..loader.builder import build_dataloader, build_dataset
Expand Down Expand Up @@ -76,6 +77,7 @@ def test_model(cfg, weights, parallel=True):
if cfg.MODEL.framework == "FastRCNN":
Metric.set_dataset_info(dataset.info, len(dataset))

warmup_num = 20
for batch_id, data in enumerate(data_loader):
if cfg.model_name in [
'CFBI'
Expand All @@ -84,4 +86,11 @@ def test_model(cfg, weights, parallel=True):
else:
outputs = model(data, mode='test')
Metric.update(batch_id, data, outputs)
if batch_id == warmup_num:
clock = time.time()
test_cost = time.time() - clock
test_num = len(data_loader) - warmup_num
Metric.accumulate()
print(
f"#Test examples={test_num}, times cost={test_cost}, avg_cost={test_cost / test_num:.2f}s"
)
21 changes: 3 additions & 18 deletions tools/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,22 +162,7 @@ def main():

# Post process output
InferenceHelper.postprocess(outputs)
elif model_name in ['MSTCN', 'ASRF']:
for file in files:
inputs = InferenceHelper.preprocess(file)
outputs = []
for input in inputs:
# Run inference
for i in range(len(input_tensor_list)):
input_tensor_list[i].copy_from_cpu(input)
predictor.run()
output = []
for j in range(len(output_tensor_list)):
output.append(output_tensor_list[j].copy_to_cpu())
outputs.append(output)

# Post process output
InferenceHelper.postprocess(outputs)

elif model_name == 'AVA_SlowFast_FastRcnn':
for file in files: # for videos
inputs = InferenceHelper.preprocess(file)
Expand All @@ -202,7 +187,7 @@ def main():
InferenceHelper.postprocess(outputs)
else:
if args.enable_benchmark:
test_video_num = 300
test_video_num = 50
num_warmup = 10

# instantiate auto log
Expand Down Expand Up @@ -257,7 +242,7 @@ def main():
if args.enable_benchmark:
autolog.times.stamp()

InferenceHelper.postprocess(batched_outputs,
InferenceHelper.postprocess([batched_outputs],
not args.enable_benchmark)

# get post process time cost
Expand Down
8 changes: 5 additions & 3 deletions tools/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _trim(cfg, args):
"""
model_name = cfg.model_name
cfg = cfg.MODEL
cfg.backbone.pretrained = ""
# cfg.backbone.pretrained = ""

if 'num_seg' in cfg.backbone:
cfg.backbone.num_seg = args.num_seg
Expand All @@ -69,11 +69,13 @@ def main():
img_size = args.img_size
num_seg = args.num_seg
#NOTE: only support tsm now, will refine soon
params_info = paddle.summary(model, (1, 1, num_seg, 3, img_size, img_size))
# params_info = paddle.summary(model, (1, 1, num_seg, 3, img_size, img_size))
params_info = paddle.summary(model, (1, 1, 2048, 1000))
print(params_info)

if args.FLOPs:
flops_info = paddleslim.analysis.flops(model, [1, 1, num_seg, 3, img_size, img_size])
# flops_info = paddleslim.analysis.flops(model, [1, 1, num_seg, 3, img_size, img_size])
flops_info = paddleslim.analysis.flops(model, [1, 1, 2048, 1000])
print(flops_info)


Expand Down
Loading

0 comments on commit 89e02e9

Please sign in to comment.