From bd290d3aa78db4074f9226b99dea9382574eeced Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 22 Nov 2021 22:26:26 +0800 Subject: [PATCH] Release PPTSN, PPTSM, TSN, TSM base tpic --- .../pptsm/pptsm_k400_videos_uniform.yaml | 3 + .../recognition/pptsn/pptsn_k400_videos.yaml | 3 + configs/recognition/tsm/tsm_k400_frames.yaml | 6 +- configs/recognition/tsm/tsm_k400_videos.yaml | 3 + main.py | 35 ++++- paddlevideo/loader/pipelines/sample.py | 5 +- paddlevideo/modeling/backbones/vit.py | 96 ++++++------- .../recognizers/recognizer_transformer.py | 53 +++++-- tools/export_model.py | 7 +- tools/predict.py | 132 +++++++++++++----- tools/utils.py | 75 +++++----- 11 files changed, 273 insertions(+), 145 deletions(-) diff --git a/configs/recognition/pptsm/pptsm_k400_videos_uniform.yaml b/configs/recognition/pptsm/pptsm_k400_videos_uniform.yaml index 8e66182a7..0188edf59 100644 --- a/configs/recognition/pptsm/pptsm_k400_videos_uniform.yaml +++ b/configs/recognition/pptsm/pptsm_k400_videos_uniform.yaml @@ -18,12 +18,15 @@ DATASET: #DATASET field test_batch_size: 1 train: format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' + data_prefix: "data/k400" #Mandatory, train data root path file_path: "data/k400/train.list" #Mandatory, train data index file path valid: format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' + data_prefix: "data/k400" #Mandatory, train data root path file_path: "data/k400/val.list" #Mandatory, valid data index file path test: format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' + data_prefix: "data/k400" #Mandatory, train data root path file_path: "data/k400/val.list" #Mandatory, valid data index file path PIPELINE: #PIPELINE field diff --git a/configs/recognition/pptsn/pptsn_k400_videos.yaml b/configs/recognition/pptsn/pptsn_k400_videos.yaml index 0122b5aa1..7a825524d 100644 --- a/configs/recognition/pptsn/pptsn_k400_videos.yaml +++ b/configs/recognition/pptsn/pptsn_k400_videos.yaml @@ -20,12 +20,15 @@ DATASET: #DATASET field num_workers: 4 #Mandatory, XXX the number of subprocess on each GPU. train: format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' + data_prefix: "data/k400" #Mandatory, train data root path file_path: "data/k400/train.list" #Mandatory, train data index file path valid: format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' + data_prefix: "data/k400" #Mandatory, train data root path file_path: "data/k400/val.list" #Mandatory, valid data index file path test: format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' + data_prefix: "data/k400" #Mandatory, train data root path file_path: "data/k400/val.list" #Mandatory, valid data index file path diff --git a/configs/recognition/tsm/tsm_k400_frames.yaml b/configs/recognition/tsm/tsm_k400_frames.yaml index cb3b83d7a..b2d38bf77 100644 --- a/configs/recognition/tsm/tsm_k400_frames.yaml +++ b/configs/recognition/tsm/tsm_k400_frames.yaml @@ -18,17 +18,17 @@ DATASET: #DATASET field num_workers: 4 #Mandatory, XXX the number of subprocess on each GPU. train: format: "FrameDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' - data_prefix: "" #Mandatory, train data root path + data_prefix: "data/k400/rawframes" #Mandatory, train data root path file_path: "data/k400_frames/train.list" #Mandatory, train data index file path suffix: 'img_{:05}.jpg' valid: format: "FrameDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' - data_prefix: "" #Mandatory, valid data root path + data_prefix: "data/k400/rawframes" #Mandatory, valid data root path file_path: "data/k400_frames/val.list" #Mandatory, valid data index file path suffix: 'img_{:05}.jpg' test: format: "FrameDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' - data_prefix: "" #Mandatory, valid data root path + data_prefix: "data/k400/rawframes" #Mandatory, valid data root path file_path: "data/k400_frames/val.list" #Mandatory, valid data index file path suffix: 'img_{:05}.jpg' diff --git a/configs/recognition/tsm/tsm_k400_videos.yaml b/configs/recognition/tsm/tsm_k400_videos.yaml index 94ae70f15..0e3253e56 100644 --- a/configs/recognition/tsm/tsm_k400_videos.yaml +++ b/configs/recognition/tsm/tsm_k400_videos.yaml @@ -18,12 +18,15 @@ DATASET: #DATASET field num_workers: 4 #Mandatory, XXX the number of subprocess on each GPU. train: format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' + data_prefix: "data/k400/videos" #Mandatory, train data root path file_path: "data/k400/train.list" #Mandatory, train data index file path valid: format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' + data_prefix: "data/k400/videos" #Mandatory, train data root path file_path: "data/k400/val.list" #Mandatory, valid data index file path test: format: "VideoDataset" #Mandatory, indicate the type of dataset, associate to the 'paddlevidel/loader/dateset' + data_prefix: "data/k400/videos" #Mandatory, train data root path file_path: "data/k400/val.list" #Mandatory, valid data index file path diff --git a/main.py b/main.py index a1927e60a..daa97f201 100644 --- a/main.py +++ b/main.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import paddle import argparse -from paddlevideo.utils import get_config -from paddlevideo.tasks import train_model, train_model_multigrid, test_model, train_dali -from paddlevideo.utils import get_dist_info +import random + +import numpy as np +import paddle + +from paddlevideo.tasks import (test_model, train_dali, train_model, + train_model_multigrid) +from paddlevideo.utils import get_config, get_dist_info def parse_args(): @@ -53,6 +57,16 @@ def parse_args(): '--validate', action='store_true', help='whether to evaluate the checkpoint during training') + parser.add_argument( + '--seed', + type=int, + default=None, + help='fixed all random seeds when the program is running') + parser.add_argument( + '--seed', + type=int, + default=None, + help='fixed all random seeds when the program is running') parser.add_argument( '-p', '--profiler_options', @@ -69,6 +83,15 @@ def main(): args = parse_args() cfg = get_config(args.config, overrides=args.override) + # set seed if specified + seed = args.seed + if seed is not None: + assert isinstance( + seed, int), f"seed must be a integer when specified, but got {seed}" + paddle.seed(seed) + np.random.seed(seed) + random.seed(seed) + _, world_size = get_dist_info() parallel = world_size != 1 if parallel: @@ -79,7 +102,9 @@ def main(): elif args.train_dali: train_dali(cfg, weights=args.weights, parallel=parallel) elif args.multigrid: - train_model_multigrid(cfg, world_size=world_size, validate=args.validate) + train_model_multigrid(cfg, + world_size=world_size, + validate=args.validate) else: train_model(cfg, weights=args.weights, diff --git a/paddlevideo/loader/pipelines/sample.py b/paddlevideo/loader/pipelines/sample.py index 7fc7e64fe..c289d51da 100644 --- a/paddlevideo/loader/pipelines/sample.py +++ b/paddlevideo/loader/pipelines/sample.py @@ -101,10 +101,11 @@ def __call__(self, results): frames_idx = [] if self.linspace_sample: if 'start_idx' in results and 'end_idx' in results: - offsets = np.linspace(results['start_idx'], results['end_idx'], self.num_seg) + offsets = np.linspace(results['start_idx'], results['end_idx'], + self.num_seg) else: offsets = np.linspace(0, frames_len - 1, self.num_seg) - offsets = np.clip(offsets, 0, frames_len - 1).astype(np.long) + offsets = np.clip(offsets, 0, frames_len - 1).astype(np.int64) if results['format'] == 'video': frames_idx = list(offsets) frames_idx = [x % frames_len for x in frames_idx] diff --git a/paddlevideo/modeling/backbones/vit.py b/paddlevideo/modeling/backbones/vit.py index 7aaaf5a2f..06de0a353 100644 --- a/paddlevideo/modeling/backbones/vit.py +++ b/paddlevideo/modeling/backbones/vit.py @@ -17,16 +17,15 @@ import numpy as np import paddle import paddle.nn as nn -from paddle.nn.initializer import TruncatedNormal, Constant, Normal import paddle.nn.functional as F -from ..registry import BACKBONES +from paddle.nn.initializer import Constant + from ...utils import load_ckpt +from ..registry import BACKBONES from ..weight_init import trunc_normal_ - __all__ = ['VisionTransformer'] - zeros_ = Constant(value=0.) ones_ = Constant(value=1.) @@ -77,7 +76,7 @@ def __init__(self, hidden_features=None, out_features=None, act_layer=nn.GELU, - drop=0.): + drop=0.0): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -101,8 +100,8 @@ def __init__(self, num_heads=8, qkv_bias=False, qk_scale=None, - attn_drop=0., - proj_drop=0.): + attn_drop=0.0, + proj_drop=0.0): super().__init__() self.num_heads = num_heads @@ -151,7 +150,7 @@ def __init__(self, if isinstance(norm_layer, str): self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) elif isinstance(norm_layer, Callable): - self.norm1 = norm_layer(dim) + self.norm1 = norm_layer(dim, epsilon=epsilon) else: raise TypeError( "The norm_layer must be str or paddle.nn.layer.Layer class") @@ -185,7 +184,7 @@ def __init__(self, if isinstance(norm_layer, str): self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) elif isinstance(norm_layer, Callable): - self.norm2 = norm_layer(dim) + self.norm2 = norm_layer(dim, epsilon=epsilon) else: raise TypeError( "The norm_layer must be str or paddle.nn.layer.Layer class") @@ -205,14 +204,14 @@ def forward(self, x, B, T, W): elif self.attention_type == 'divided_space_time': ########## Temporal ########## xt = x[:, 1:, :] - _b, _h, _w, _t, _m = B, H, W, T, xt.shape[-1] - xt = xt.reshape([_b * _h * _w if _b > 0 else -1, _t, _m]) + _, _, _, _t, _m = B, H, W, T, xt.shape[-1] + xt = xt.reshape([-1, _t, _m]) res_temporal = self.drop_path( self.temporal_attn(self.temporal_norm1(xt))) - _b, _h, _w, _t, _m = B, H, W, T, res_temporal.shape[-1] - res_temporal = res_temporal.reshape([_b, _h * _w * _t, _m]) + _, _h, _w, _t, _m = B, H, W, T, res_temporal.shape[-1] + res_temporal = res_temporal.reshape([-1, _h * _w * _t, _m]) res_temporal = self.temporal_fc(res_temporal) xt = x[:, 1:, :] + res_temporal @@ -221,26 +220,26 @@ def forward(self, x, B, T, W): init_cls_token = x[:, 0, :].unsqueeze(1) cls_token = init_cls_token.tile((1, T, 1)) _b, _t, _m = cls_token.shape - cls_token = cls_token.reshape([_b * _t, _m]).unsqueeze(1) + cls_token = cls_token.reshape([-1, _m]).unsqueeze(1) xs = xt - _b, _h, _w, _t, _m = B, H, W, T, xs.shape[-1] - xs = xs.reshape([_b, _h, _w, _t, _m]).transpose( - (0, 3, 1, 2, 4)).reshape([_b * _t if _b > 0 else -1, _h * _w, _m]) + _, _h, _w, _t, _m = B, H, W, T, xs.shape[-1] + xs = xs.reshape([-1, _h, _w, _t, _m]).transpose( + (0, 3, 1, 2, 4)).reshape([-1, _h * _w, _m]) xs = paddle.concat((cls_token, xs), axis=1) res_spatial = self.drop_path(self.attn(self.norm1(xs))) # Taking care of CLS token cls_token = res_spatial[:, 0, :] - _b, _t, _m = B, T, cls_token.shape[-1] - cls_token = cls_token.reshape([_b, _t, _m]) + _, _t, _m = B, T, cls_token.shape[-1] + cls_token = cls_token.reshape([-1, _t, _m]) # averaging for every frame cls_token = paddle.mean(cls_token, axis=1, keepdim=True) res_spatial = res_spatial[:, 1:, :] - _b, _t, _h, _w, _m = B, T, H, W, res_spatial.shape[-1] - res_spatial = res_spatial.reshape([_b, _t, _h, _w, _m]).transpose( - (0, 2, 3, 1, 4)).reshape([_b, _h * _w * _t, _m]) + _, _t, _h, _w, _m = B, T, H, W, res_spatial.shape[-1] + res_spatial = res_spatial.reshape([-1, _t, _h, _w, _m]).transpose( + (0, 2, 3, 1, 4)).reshape([-1, _h * _w * _t, _m]) res = res_spatial x = xt @@ -282,7 +281,7 @@ def forward(self, x): assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = x.transpose((0, 2, 1, 3, 4)) - x = x.reshape([B * T if B > 0 else -1, C, H, W]) + x = x.reshape([-1, C, H, W]) x = self.proj(x) W = x.shape[-1] x = x.flatten(2).transpose((0, 2, 1)) @@ -316,7 +315,6 @@ def __init__(self, self.pretrained = pretrained self.seg_num = seg_num self.attention_type = attention_type - self.num_features = self.embed_dim = embed_dim self.patch_embed = PatchEmbed(img_size=img_size, @@ -375,32 +373,30 @@ def init_weights(self): zeros_(m.temporal_fc.weight) zeros_(m.temporal_fc.bias) i += 1 - """Second, if provide pretrained ckpt, load it""" if isinstance( self.pretrained, str ) and self.pretrained.strip() != "": # load pretrained weights - load_ckpt(self, self.pretrained, num_patches=self.patch_embed.num_patches, - seg_num=self.seg_num, attention_type=self.attention_type) - elif self.pretrained is None or self.pretrained.strip() == "": - pass - else: - raise NotImplementedError + load_ckpt(self, + self.pretrained, + num_patches=self.patch_embed.num_patches, + seg_num=self.seg_num, + attention_type=self.attention_type) def _init_fn(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight) - if isinstance(m, nn.Linear) and m.bias is not None: + if m.bias is not None: zeros_(m.bias) elif isinstance(m, nn.LayerNorm): ones_(m.weight) zeros_(m.bias) def forward_features(self, x): - B = x.shape[0] - # B = paddle.shape(x)[0] - x, T, W = self.patch_embed(x) - cls_tokens = self.cls_token.expand((x.shape[0] if B > 0 else 3 * T, -1, -1)) + # B = x.shape[0] + B = paddle.shape(x)[0] + x, T, W = self.patch_embed(x) # [BT,nH*nW,F] + cls_tokens = self.cls_token.expand((B * T, -1, -1)) # [1,1,F]->[BT,1,F] x = paddle.concat((cls_tokens, x), axis=1) pos_interp = (x.shape[1] != self.pos_embed.shape[1]) if pos_interp: @@ -426,14 +422,14 @@ def forward_features(self, x): # Time Embeddings if self.attention_type != 'space_only': - # cls_tokens = x[:B, 0, :].unsqueeze(1) - cls_tokens = x[:B, 0, :].unsqueeze(1) if B > 0 else x.split(T)[0].index_select(paddle.to_tensor([0]), axis=1) + cls_tokens = x[:B, 0, :].unsqueeze(1) if B > 0 else x.split( + T)[0].index_select(paddle.to_tensor([0]), axis=1) x = x[:, 1:] - _bt, _n, _m = x.shape - _b = B - _t = _bt // _b if _b != -1 else T - x = x.reshape([_b, _t, _n, _m]).transpose( - (0, 2, 1, 3)).reshape([_b * _n if _b > 0 else -1, _t, _m]) + _, _n, _m = x.shape + # _b = B + _t = T + x = x.reshape([-1, _t, _n, _m]).transpose( + (0, 2, 1, 3)).reshape([-1, _t, _m]) # Resizing time embeddings in case they don't match time_interp = (T != self.time_embed.shape[1]) if time_interp: # T' != T @@ -447,9 +443,9 @@ def forward_features(self, x): x = x + self.time_embed x = self.time_drop(x) - _bn, _t, _m = x.shape - _b = B - x = x.reshape([_b, _n * _t, _m] if _n > 0 else [_b, W * W * T, _m]) + _, _t, _m = x.shape + # _b = B + x = x.reshape([-1, W * W * T, _m]) x = paddle.concat((cls_tokens, x), axis=1) # Attention blocks @@ -458,14 +454,14 @@ def forward_features(self, x): # Predictions for space-only baseline if self.attention_type == 'space_only': - _bt, _n, _m = x.shape - _b = B + _, _n, _m = x.shape + # _b = B _t = T - x = x.reshape([_b, _t, _n, _m]) + x = x.reshape([-1, _t, _n, _m]) x = paddle.mean(x, 1) # averaging predictions for every frame x = self.norm(x) - return x[:, 0] # [B, 1, embed_dim] + return x[:, 0] # [B, embed_dim] def forward(self, x): x = self.forward_features(x) diff --git a/paddlevideo/modeling/framework/recognizers/recognizer_transformer.py b/paddlevideo/modeling/framework/recognizers/recognizer_transformer.py index ba30753b7..26ca73282 100644 --- a/paddlevideo/modeling/framework/recognizers/recognizer_transformer.py +++ b/paddlevideo/modeling/framework/recognizers/recognizer_transformer.py @@ -10,11 +10,13 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from ...registry import RECOGNIZERS -from .base import BaseRecognizer import paddle +import paddle.nn.functional as F from paddlevideo.utils import get_logger +from ...registry import RECOGNIZERS +from .base import BaseRecognizer + logger = get_logger("paddlevideo") @@ -52,21 +54,42 @@ def val_step(self, data_batch): return loss_metrics def test_step(self, data_batch): - """Define how the model is going to test, from input to output.""" - # NOTE: (shipping) when testing, the net won't call head.loss, we deal with the test processing in /paddlevideo/metrics - clips_list = paddle.split( - data_batch[0], num_or_sections=3, - axis=2 - ) # [N, 3, T, H, W], [N, 3, T, H, W], [N, 3, T, H, W] - cls_score = [ - self.forward_net(imgs) - for imgs in clips_list - ] # [N, C], [N, C], [N, C] - cls_score = paddle.add_n(cls_score) # [N, C] in [0,1] + """Define how the model is going to infer, from input to output.""" + imgs = data_batch[0] + num_views = imgs.shape[2] // self.backbone.seg_num + cls_score = [] + for i in range(num_views): + view = imgs[:, :, i * self.backbone.seg_num:(i + 1) * + self.backbone.seg_num] + cls_score.append(self.forward_net(view)) + cls_score = self.average_view(cls_score) return cls_score def infer_step(self, data_batch): - """Define how the model is going to test, from input to output.""" + """Define how the model is going to infer, from input to output.""" imgs = data_batch[0] - cls_score = self.forward_net(imgs) + num_views = imgs.shape[2] // self.backbone.seg_num + cls_score = [] + for i in range(num_views): + view = imgs[:, :, i * self.backbone.seg_num:(i + 1) * + self.backbone.seg_num] + cls_score.append(self.forward_net(view)) + cls_score = self.average_view(cls_score) return cls_score + + def average_view(self, cls_score, average_type='score'): + """Combine the scores of different views + + Args: + cls_score (list): Scores of multiple views + average_type (str, optional): Average calculation method. Defaults to 'score'. + """ + assert average_type in ['score', 'prob'], \ + f"Currently only the average of 'score' or 'prob' is supported, but got {average_type}" + if average_type == 'score': + return paddle.add_n(cls_score) / len(cls_score) + elif average_type == 'avg': + return paddle.add_n([F.softmax(score) + for score in cls_score]) / len(cls_score) + else: + raise NotImplementedError diff --git a/tools/export_model.py b/tools/export_model.py index 3e8c17463..cf411c876 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -77,9 +77,10 @@ def get_input_spec(cfg, model_name): ]] elif model_name in ['TimeSformer']: input_spec = [[ - InputSpec( - shape=[None, 3, cfg.num_seg, cfg.target_size, cfg.target_size], - dtype='float32'), + InputSpec(shape=[ + None, 3, cfg.num_seg * 3, cfg.target_size, cfg.target_size + ], + dtype='float32'), ]] elif model_name in ['AttentionLSTM']: input_spec = [[ diff --git a/tools/predict.py b/tools/predict.py index 485ef8b1e..64a23f177 100644 --- a/tools/predict.py +++ b/tools/predict.py @@ -13,12 +13,14 @@ # limitations under the License. import argparse -import numpy as np +import os import time +import numpy as np +from paddle import inference +from paddle.inference import Config, create_predictor + from utils import build_inference_helper -from paddle.inference import Config -from paddle.inference import create_predictor from paddlevideo.utils import get_config @@ -40,12 +42,13 @@ def str2bool(v): # params for predict parser.add_argument("-b", "--batch_size", type=int, default=1) parser.add_argument("--use_gpu", type=str2bool, default=True) - parser.add_argument("--use_fp16", type=str2bool, default=False) + parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--gpu_mem", type=int, default=8000) parser.add_argument("--enable_benchmark", type=str2bool, default=False) parser.add_argument("--enable_mkldnn", type=bool, default=False) + parser.add_argument("--cpu_threads", type=int) # parser.add_argument("--hubserving", type=str2bool, default=False) #TODO return parser.parse_args() @@ -58,42 +61,46 @@ def create_paddle_predictor(args): config.enable_use_gpu(args.gpu_mem, 0) else: config.disable_gpu() + if hasattr(args, "cpu_threads"): + config.set_cpu_math_library_num_threads(args.cpu_threads) if args.enable_mkldnn: # cache 10 different shapes for mkldnn to avoid memory leak config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() + if args.precision == "fp16": + config.enable_mkldnn_bfloat16() #config.disable_glog_info() config.switch_ir_optim(args.ir_optim) # default true if args.use_tensorrt: - config.enable_tensorrt_engine( - precision_mode=Config.Precision.Half - if args.use_fp16 else Config.Precision.Float32, - max_batch_size=args.batch_size) + # choose precision + if args.precision == "fp16": + precision = inference.PrecisionType.Half + elif args.precision == "int8": + precision = inference.PrecisionType.Int8 + else: + precision = inference.PrecisionType.Float32 + + config.enable_tensorrt_engine(precision_mode=precision, + max_batch_size=args.batch_size) config.enable_memory_optim() # use zero copy config.switch_use_feed_fetch_ops(False) predictor = create_predictor(config) - return predictor + return config, predictor def main(): args = parse_args() cfg = get_config(args.config, show=False) + model_name = cfg.model_name print(f"Inference model({model_name})...") InferenceHelper = build_inference_helper(cfg.INFERENCE) - if args.enable_benchmark: - assert args.use_gpu is True - - # HALF precission predict only work when using tensorrt - if args.use_fp16 is True: - assert args.use_tensorrt is True - - predictor = create_paddle_predictor(args) + inference_config, predictor = create_paddle_predictor(args) # get input_tensor and output_tensor input_names = predictor.get_input_names() @@ -105,10 +112,8 @@ def main(): for item in output_names: output_tensor_list.append(predictor.get_output_handle(item)) - test_num = 500 - test_time = 0.0 if not args.enable_benchmark: - # Prepare input + # Pre process input inputs = InferenceHelper.preprocess(args.input_file) # Run inference @@ -121,15 +126,68 @@ def main(): # Post process output InferenceHelper.postprocess(output) - else: # benchmark only for ppTSM - for i in range(0, test_num + 10): - inputs = [] - inputs.append( - np.random.rand(args.batch_size, 8, 3, 224, - 224).astype(np.float32)) + + # fp_message = "FP16" if args.use_fp16 else "FP32" + # trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt" + # print("{0}\t{1}\t{2}\tbatch size: {3}\ttime(ms): {4}".format( + # model_name, trt_msg, fp_message, args.batch_size, + # 1000 * test_time / test_num)) + else: + test_num = 500 + test_time = 0.0 + log_interval = 20 + num_warmup = 10 + + # instantiate auto log + import auto_log + pid = os.getpid() + autolog = auto_log.AutoLogger( + model_name=cfg.model_name, + model_precision=args.precision, + batch_size=args.batch_size, + data_shape="dynamic", + save_path="./output/auto_log.lpg", + inference_config=inference_config, + pids=pid, + process_name=None, + gpu_ids=0, + time_keys=['preprocess_time', 'inference_time', 'postprocess_time'], + warmup=num_warmup) + + for i in range(0, test_num + num_warmup): + if (i + 1) % log_interval == 0 or (i + 1) == test_num + num_warmup: + print(f"Benchmark process {i + 1}/{test_num + num_warmup}") + input_list = [] + start_time = time.time() + # auto log start + if args.enable_benchmark: + autolog.times.start() + + # Pre process input + batched_inputs_list = [] + batch_count = 0 + while batch_count < args.batch_size: + inputs = InferenceHelper.preprocess(args.input_file) + batched_inputs_list.append(inputs) + if 'tsm' in cfg.model_name.lower(): + batch_count += (inputs[0].shape[1] * 1) # centercrop + elif 'tsn' in cfg.model_name.lower(): + batch_count += (inputs[0].shape[1] * 10) # tencrop + elif 'timesformer' in cfg.model_name.lower(): + batch_count += (inputs[0].shape[2] * 3) # threecrop + else: + batch_count += inputs[0].shape[0] + + batched_inputs = np.concatenate(batched_inputs_list, axis=0) + input_list.extend(batched_inputs) + + # get pre process time cost + if args.enable_benchmark: + autolog.times.stamp() + for j in range(len(input_tensor_list)): - input_tensor_list[j].copy_from_cpu(inputs[j]) + input_tensor_list[j].copy_from_cpu(input_list[j]) predictor.run() @@ -137,15 +195,23 @@ def main(): for j in range(len(output_tensor_list)): output.append(output_tensor_list[j].copy_to_cpu()) + # get inference process time cost + if args.enable_benchmark: + autolog.times.stamp() + + InferenceHelper.postprocess(output, False) + + # get post process time cost + if args.enable_benchmark: + autolog.times.end(stamp=True) + if i >= 10: test_time += time.time() - start_time - #time.sleep(0.01) # sleep for T4 GPU + # time.sleep(0.01) # sleep for T4 GPU - fp_message = "FP16" if args.use_fp16 else "FP32" - trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt" - print("{0}\t{1}\t{2}\tbatch size: {3}\ttime(ms): {4}".format( - model_name, trt_msg, fp_message, args.batch_size, - 1000 * test_time / test_num)) + # report benchmark log if enabled + if args.enable_benchmark: + autolog.report() if __name__ == "__main__": diff --git a/tools/utils.py b/tools/utils.py index 75c5a9f73..5922698be 100644 --- a/tools/utils.py +++ b/tools/utils.py @@ -35,7 +35,6 @@ from paddlevideo.metrics.bmn_metric import boundary_choose, soft_nms from paddlevideo.utils import Registry, build - INFERENCE = Registry('inference') @@ -141,7 +140,7 @@ def preprocess(self, input_file): res = np.expand_dims(results['imgs'], axis=0).copy() return [res] - def postprocess(self, output): + def postprocess(self, output, print_output=True): """ output: list """ @@ -150,9 +149,10 @@ def postprocess(self, output): classes = np.argpartition(output, -self.top_k)[-self.top_k:] classes = classes[np.argsort(-output[classes])] scores = output[classes] - print("Current video file: {0}".format(self.input_file)) - print("\ttop-1 class: {0}".format(classes[0])) - print("\ttop-1 score: {0}".format(scores[0])) + if print_output: + print("Current video file: {0}".format(self.input_file)) + print("\ttop-1 class: {0}".format(classes[0])) + print("\ttop-1 score: {0}".format(scores[0])) @INFERENCE.register() @@ -200,7 +200,7 @@ def preprocess(self, input_file): res = np.expand_dims(results['imgs'], axis=0).copy() return [res] - def postprocess(self, output): + def postprocess(self, output, print_output=True): """ output: list """ @@ -215,9 +215,10 @@ def postprocess(self, output): classes = np.argpartition(output, -self.top_k)[-self.top_k:] classes = classes[np.argsort(-output[classes])] scores = output[classes] - print("Current video file: {0}".format(self.input_file)) - print("\ttop-1 class: {0}".format(classes[0])) - print("\ttop-1 score: {0}".format(scores[0])) + if print_output: + print("Current video file: {0}".format(self.input_file)) + print("\ttop-1 class: {0}".format(classes[0])) + print("\ttop-1 score: {0}".format(scores[0])) @INFERENCE.register() @@ -252,7 +253,7 @@ def postprocess(self, outputs): pred_bm, pred_start, pred_end = outputs self._gen_props(pred_bm, pred_start[0], pred_end[0]) - def _gen_props(self, pred_bm, pred_start, pred_end): + def _gen_props(self, pred_bm, pred_start, pred_end, print_output=True): snippet_xmins = [1.0 / self.tscale * i for i in range(self.tscale)] snippet_xmaxs = [ 1.0 / self.tscale * i for i in range(1, self.tscale + 1) @@ -294,9 +295,10 @@ def _gen_props(self, pred_bm, pred_start, pred_end): result_dict[self.feat_path] = proposal_list # print top-5 predictions - print("BMN Inference results of {0} :".format(self.feat_path)) - for pred in proposal_list[:5]: - print(pred) + if print_output: + print("BMN Inference results of {0} :".format(self.feat_path)) + for pred in proposal_list[:5]: + print(pred) # save result outfile = open( @@ -344,10 +346,11 @@ def preprocess(self, input_file): for op in ops: results = op(results) + # [N,C,Tx3,H,W] res = np.expand_dims(results['imgs'], axis=0).copy() return [res] - def postprocess(self, output): + def postprocess(self, output, print_output=True): """ output: list """ @@ -362,9 +365,10 @@ def postprocess(self, output): classes = np.argpartition(output, -self.top_k)[-self.top_k:] classes = classes[np.argsort(-output[classes])] scores = output[classes] - print("Current video file: {0}".format(self.input_file)) - print("\ttop-1 class: {0}".format(classes[0])) - print("\ttop-1 score: {0}".format(scores[0])) + if print_output: + print("Current video file: {0}".format(self.input_file)) + print("\ttop-1 class: {0}".format(classes[0])) + print("\ttop-1 score: {0}".format(scores[0])) @INFERENCE.register() @@ -480,13 +484,14 @@ def postprocess(self, output): @INFERENCE.register() class AttentionLSTM_Inference_helper(): - def __init__(self, - num_classes, #Optional, the number of classes to be classified. - feature_num, - feature_dims, - embedding_size, - lstm_size, - top_k=1): + def __init__( + self, + num_classes, #Optional, the number of classes to be classified. + feature_num, + feature_dims, + embedding_size, + lstm_size, + top_k=1): self.num_classes = num_classes self.feature_num = feature_num self.feature_dims = feature_dims @@ -503,20 +508,21 @@ def preprocess(self, input_file): assert os.path.isfile(input_file) is not None, "{0} not exists".format( input_file) results = {'filename': input_file} - ops = [ - FeatureDecoder(num_classes=self.num_classes, has_label=False) - ] + ops = [FeatureDecoder(num_classes=self.num_classes, has_label=False)] for op in ops: results = op(results) res = [] for modality in ['rgb', 'audio']: - res.append(np.expand_dims(results[f'{modality}_data'], axis=0).copy()) - res.append(np.expand_dims(results[f'{modality}_len'], axis=0).copy()) - res.append(np.expand_dims(results[f'{modality}_mask'], axis=0).copy()) + res.append( + np.expand_dims(results[f'{modality}_data'], axis=0).copy()) + res.append( + np.expand_dims(results[f'{modality}_len'], axis=0).copy()) + res.append( + np.expand_dims(results[f'{modality}_mask'], axis=0).copy()) return res - def postprocess(self, output): + def postprocess(self, output, print_output=True): """ output: list """ @@ -531,6 +537,7 @@ def postprocess(self, output): classes = np.argpartition(output, -self.top_k)[-self.top_k:] classes = classes[np.argsort(-output[classes])] scores = output[classes] - print("Current video file: {0}".format(self.input_file)) - print("\ttop-1 class: {0}".format(classes[0])) - print("\ttop-1 score: {0}".format(scores[0])) + if print_output: + print("Current video file: {0}".format(self.input_file)) + print("\ttop-1 class: {0}".format(classes[0])) + print("\ttop-1 score: {0}".format(scores[0]))