Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference API for multiple object tracking #51

Merged
merged 4 commits into from
Jan 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions demo/demo_mot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
import os.path as osp
import tempfile
from argparse import ArgumentParser

import mmcv

from mmtrack.apis import inference_mot, init_model


def main():
parser = ArgumentParser()
parser.add_argument('config', help='config file')
parser.add_argument('-i', '--input', help='input video file or folder')
parser.add_argument(
'-o', '--output', help='output video file (mp4 format) or folder')
parser.add_argument('--checkpoint', help='checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='device used for inference')
parser.add_argument(
'--show',
action='store_true',
help='whether show the results on the fly')
parser.add_argument(
'--backend',
choices=['cv2', 'plt'],
default='cv2',
help='the backend to visualize the results')
parser.add_argument('--fps', help='FPS of the output video')
args = parser.parse_args()
assert args.output or args.show
# load images
if osp.isdir(args.input):
imgs = sorted(os.listdir(args.input))
IN_VIDEO = False
else:
imgs = mmcv.VideoReader(args.input)
IN_VIDEO = True
# define output
if args.output is not None:
if args.output.endswith('.mp4'):
OUT_VIDEO = True
if (not IN_VIDEO) and (not args.fps):
raise ValueError('Please set the FPS for the output video.')
fps = args.fps if args.fps else imgs.fps
out_dir = tempfile.TemporaryDirectory()
out_path = out_dir.name
os.makedirs(args.output.rsplit('/', 1)[0], exist_ok=True)
else:
OUT_VIDEO = False
out_path = args.output
os.makedirs(out_path, exist_ok=True)

# build the model from a config file and a checkpoint file
model = init_model(args.config, args.checkpoint, device=args.device)

prog_bar = mmcv.ProgressBar(len(imgs))
# test and show/save the images
for i, img in enumerate(imgs):
if isinstance(img, str):
img = osp.join(args.input, img)
result = inference_mot(model, img, frame_id=i)
result = result['track_results']
if args.output is not None:
if IN_VIDEO or OUT_VIDEO:
out_file = osp.join(out_path, f'{i:06d}.jpg')
else:
out_file = osp.join(out_path, img.rsplit('/', 1)[-1])
model.show_result(
img,
result,
show=args.show,
out_file=out_file,
backend=args.backend)
prog_bar.update()

if OUT_VIDEO:
print(f'making the output video at {args.output} with a FPS of {fps}')
mmcv.frames2video(out_path, args.output, fps=int(fps))
out_dir.cleanup()


if __name__ == '__main__':
main()
6 changes: 3 additions & 3 deletions demo/sot_demo.py → demo/demo_sot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import cv2

from mmtrack.apis import init_sot_model, sot_inference
from mmtrack.apis import inference_sot, init_model


def main():
Expand All @@ -30,7 +30,7 @@ def main():
args = parser.parse_args()

# build the model from a config file and a checkpoint file
model = init_sot_model(args.config, args.checkpoint, device=args.device)
model = init_model(args.config, args.checkpoint, device=args.device)

cap = cv2.VideoCapture(args.video)

Expand Down Expand Up @@ -61,7 +61,7 @@ def main():
init_bbox[3] += init_bbox[1]

# test a single image
result = sot_inference(model, frame, init_bbox, frame_id)
result = inference_sot(model, frame, init_bbox, frame_id)

track_bbox = result['bbox']
cv2.rectangle(
Expand Down
6 changes: 3 additions & 3 deletions demo/vid_demo.py → demo/demo_vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import cv2

from mmtrack.apis import init_vid_model, vid_inference
from mmtrack.apis import inference_vid, init_model


def main():
Expand All @@ -28,7 +28,7 @@ def main():
args = parser.parse_args()

# build the model from a config file and a checkpoint file
model = init_vid_model(args.config, args.checkpoint, device=args.device)
model = init_model(args.config, args.checkpoint, device=args.device)

cap = cv2.VideoCapture(args.video)

Expand All @@ -53,7 +53,7 @@ def main():
break

# test a single image
result = vid_inference(model, frame, frame_id)
result = inference_vid(model, frame, frame_id)
vis_frame = model.show_result(
frame, result, score_thr=args.score_thr, show=False)

Expand Down
7 changes: 3 additions & 4 deletions mmtrack/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .sot_inference import init_sot_model, sot_inference
from .inference import inference_mot, inference_sot, inference_vid, init_model
from .test import multi_gpu_test, single_gpu_test
from .train import train_model
from .vid_inference import init_vid_model, vid_inference

__all__ = [
'multi_gpu_test', 'single_gpu_test', 'train_model', 'init_sot_model',
'sot_inference', 'init_vid_model', 'vid_inference'
'init_model', 'multi_gpu_test', 'single_gpu_test', 'train_model',
'inference_mot', 'inference_sot', 'inference_vid'
]
116 changes: 97 additions & 19 deletions mmtrack/apis/vid_inference.py → mmtrack/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,26 @@
import warnings

import mmcv
import numpy as np
import torch
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from mmdet.core import get_classes
from mmdet.datasets.pipelines import Compose

from mmtrack.models import build_model


def init_vid_model(config, checkpoint=None, device='cuda:0', cfg_options=None):
"""Initialize a video object detector from config file.
def init_model(config, checkpoint=None, device='cuda:0', cfg_options=None):
"""Initialize a model from config file.

Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights. Defaults to None.
device (str, optional): The device where the model is put on. Defaults
to 'cuda:0'
cfg_options (dict, optional): Options to override some settings in the
used config. Defaults to None.
checkpoint (str, optional): Checkpoint path. Default as None.
cfg_options (dict, optional): Options to override some settings in
the used config. Default to None.

Returns:
nn.Module: The constructed video object detector.
nn.Module: The constructed detector.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
Expand All @@ -35,26 +29,110 @@ def init_vid_model(config, checkpoint=None, device='cuda:0', cfg_options=None):
f'but got {type(config)}')
if cfg_options is not None:
config.merge_from_dict(cfg_options)
config.model.pretrains = None
config.model.detector.pretrained = None
if 'detector' in config.model:
config.model.detector.pretrained = None
model = build_model(config.model)
if checkpoint is not None:
map_loc = 'cpu' if device == 'cpu' else None
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
if not hasattr(model, 'CLASSES'):
if hasattr(model.detector, 'CLASSES'):
model.CLASSES = model.detector.CLASSES
else:
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use ImageNet VID classes by default.')
model.CLASSES = get_classes('imagenet_vid')
raise KeyError('The classes must be defined.')
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model


def vid_inference(model,
def inference_mot(model, img, frame_id):
"""Inference image(s) with the mot model.

Args:
model (nn.Module): The loaded mot model.
img (str | ndarray): Either image name or loaded image.
frame_id (int): frame id.

Returns:
dict[str : ndarray]: The tracking results.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# prepare data
if isinstance(img, np.ndarray):
# directly add img
data = dict(img=img, img_info=dict(frame_id=frame_id), img_prefix=None)
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
else:
# add information into dict
data = dict(
img_info=dict(filename=img, frame_id=frame_id), img_prefix=None)
# build the data pipeline
test_pipeline = Compose(cfg.data.test.pipeline)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
for m in model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'
# just get the actual data from DataContainer
data['img_metas'] = data['img_metas'][0].data
# forward the model
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result


def inference_sot(model, image, init_bbox, frame_id):
"""Inference image with the single object tracker.

Args:
model (nn.Module): The loaded tracker.
image (ndarray): Loaded images.
init_bbox (ndarray): The target needs to be tracked.
frame_id (int): frame id.

Returns:
dict[str : ndarray]: The tracking results.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device

data = dict(
img=image.astype(np.float32),
gt_bboxes=np.array(init_bbox).astype(np.float32),
img_info=dict(frame_id=frame_id))
# remove the "LoadImageFromFile" and "LoadAnnotations" in pipeline
test_pipeline = Compose(cfg.data.test.pipeline[2:])
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
for m in model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'
# just get the actual data from DataContainer
data['img_metas'] = data['img_metas'][0].data

# forward the model
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result


def inference_vid(model,
image,
frame_id,
ref_img_sampler=dict(frame_stride=10, num_left_ref_imgs=10)):
Expand Down
83 changes: 0 additions & 83 deletions mmtrack/apis/sot_inference.py

This file was deleted.

3 changes: 2 additions & 1 deletion mmtrack/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .image import crop_image
from .visualization import imshow_tracks

__all__ = ['crop_image']
__all__ = ['crop_image', 'imshow_tracks']
Loading