Skip to content

Commit

Permalink
[Fix] Refactor demo (#8423)
Browse files Browse the repository at this point in the history
* add enable auto_scale_lr in train.py

* add auto_scale_lr setting in default_runtime

* move auto_scale_lr to schedule

* replace warning with error

* update cascase rcnn metafile

* show = False when  args.out_file is not none

* show = False when  args.out_file is not none

* refactor webcam_demo

* fix typos

* rafactor video_demo

* refactor video_gpuaccel_demo

* fix typos

* fix typos

* add todo

* add todo

* update inference_demo

* update get_started

* update faq

* remove redundant code

* revert inference_demo

* fix bug in webcam_demo

* revert docs

* add todo

* fix bug in webcam_demo

* add 2x8 in config file name for lad

* remove comments

* add batch_size

* fix typos

* add test_pipelint as arg in inference_model

* build test_pipeline outside inference_pipeline

* add comments that dataset_meta is from checkpoint

* add assert that async inference is not supported yet
  • Loading branch information
chhluo authored and ZwwWayne committed Aug 17, 2022
1 parent 4dd2c3d commit 5f54d7e
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 55 deletions.
1 change: 0 additions & 1 deletion configs/faster_rcnn/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ Models:
Dataset: COCO
Metrics:
box AP: 37.9
# re-release
Weights: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_iou_1x_coco/faster_rcnn_r50_fpn_iou_1x_coco_20200506_095954-938e81f0.pth

- Name: faster_rcnn_r50_fpn_giou_1x_coco
Expand Down
13 changes: 7 additions & 6 deletions configs/lad/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ Distillation.

### PAA with LAD

| Teacher | Student | Training schedule | AP (val) | Config | Download |
| :-----: | :-----: | :---------------: | :------: | :---------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| -- | R-50 | 1x | 40.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/paa/paa_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_coco/paa_r50_fpn_1x_coco_20200821-936edec3.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_coco/paa_r50_fpn_1x_coco_20200821-936edec3.log.json) |
| -- | R-101 | 1x | 42.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/paa/paa_r101_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_coco/paa_r101_fpn_1x_coco_20200821-0a1825a4.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_coco/paa_r101_fpn_1x_coco_20200821-0a1825a4.log.json) |
| R-101 | R-50 | 1x | 41.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/lad/lad_r50_paa_r101_fpn_coco_1x.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/lad/lad_r50_paa_r101_fpn_coco_1x/lad_r50_paa_r101_fpn_coco_1x_20220708_124246-74c76ff0.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/lad/lad_r50_paa_r101_fpn_coco_1x/lad_r50_paa_r101_fpn_coco_1x_20220708_124246.log.json) |
| R-50 | R-101 | 1x | 43.2 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/lad/lad_r101_paa_r50_fpn_coco_1x.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/lad/lad_r101_paa_r50_fpn_coco_1x/lad_r101_paa_r50_fpn_coco_1x_20220708_124357-9407ac54.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/lad/lad_r101_paa_r50_fpn_coco_1x/lad_r101_paa_r50_fpn_coco_1x_20220708_124357.log.json) |
| Teacher | Student | Training schedule | AP (val) | Config | Download |
| :-----: | :-----: | :---------------: | :------: | :-------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| -- | R-50 | 1x | 40.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/paa/paa_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_coco/paa_r50_fpn_1x_coco_20200821-936edec3.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_coco/paa_r50_fpn_1x_coco_20200821-936edec3.log.json) |
| -- | R-101 | 1x | 42.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/paa/paa_r101_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_coco/paa_r101_fpn_1x_coco_20200821-0a1825a4.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_coco/paa_r101_fpn_1x_coco_20200821-0a1825a4.log.json) |
| R-101 | R-50 | 1x | 41.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/lad/lad_r50_paa_r101_fpn_2x8_coco_1x.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/lad/lad_r50_paa_r101_fpn_coco_1x/lad_r50_paa_r101_fpn_coco_1x_20220708_124246-74c76ff0.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/lad/lad_r50_paa_r101_fpn_coco_1x/lad_r50_paa_r101_fpn_coco_1x_20220708_124246.log.json) |
| R-50 | R-101 | 1x | 43.2 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/lad/lad_r101_paa_r50_fpn_2x8_coco_1x.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/lad/lad_r101_paa_r50_fpn_coco_1x/lad_r101_paa_r50_fpn_coco_1x_20220708_124357-9407ac54.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/lad/lad_r101_paa_r50_fpn_coco_1x/lad_r101_paa_r50_fpn_coco_1x_20220708_124357.log.json) |

## Note

- Meaning of Config name: lad_r50(student model)\_paa(based on paa)\_r101(teacher model)\_fpn(neck)\_coco(dataset)\_1x(12 epoch).py
- Results may fluctuate by about 0.2 mAP.
- 2 GPUs are used, 8 samples per GPU.

## Citation

Expand Down
File renamed without changes.
File renamed without changes.
8 changes: 4 additions & 4 deletions configs/lad/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ Collections:
Version: v2.19.0

Models:
- Name: lad_r101_paa_r50_fpn_coco_1x
- Name: lad_r101_paa_r50_fpn_2x8_coco_1x
In Collection: Label Assignment Distillation
Config: configs/lad/lad_r101_paa_r50_fpn_coco_1x.py
Config: configs/lad/lad_r101_paa_r50_fpn_2x8_coco_1x.py
Metadata:
Training Memory (GB): 12.4
Epochs: 12
Expand All @@ -31,9 +31,9 @@ Models:
Metrics:
box AP: 43.2
Weights: https://download.openmmlab.com/mmdetection/v2.0/lad/lad_r101_paa_r50_fpn_coco_1x/lad_r101_paa_r50_fpn_coco_1x_20220708_124357-9407ac54.pth
- Name: lad_r50_paa_r101_fpn_coco_1x
- Name: lad_r50_paa_r101_fpn_2x8_coco_1x
In Collection: Label Assignment Distillation
Config: configs/lad/lad_r50_paa_r101_fpn_coco_1x.py
Config: configs/lad/lad_r50_paa_r101_fpn_2x8_coco_1x.py
Metadata:
Training Memory (GB): 8.9
Epochs: 12
Expand Down
1 change: 1 addition & 0 deletions demo/create_result_gif.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
imageio = None


# TODO verify after refactoring analyze_results.py
def parse_args():
parser = argparse.ArgumentParser(description='Create GIF for demo')
parser.add_argument(
Expand Down
7 changes: 5 additions & 2 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def main(args):

# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint and
# then pass to the model in init_detector
visualizer.dataset_meta = model.dataset_meta

# test a single image
Expand All @@ -56,7 +58,7 @@ def main(args):
'result',
img,
pred_sample=result,
show=True,
show=args.out_file is None,
wait_time=0,
out_file=args.out_file,
pred_score_thr=args.score_thr)
Expand All @@ -80,14 +82,15 @@ async def async_main(args):
'result',
img,
pred_sample=result[0],
show=True,
show=args.out_file is None,
wait_time=0,
out_file=args.out_file,
pred_score_thr=args.score_thr)


if __name__ == '__main__':
args = parse_args()
assert not args.async_test, 'async inference is not supported yet.'
if args.async_test:
asyncio.run(async_main(args))
else:
Expand Down
28 changes: 26 additions & 2 deletions demo/video_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@

import cv2
import mmcv
from mmcv.transforms import Compose

from mmdet.apis import inference_detector, init_detector
from mmdet.registry import VISUALIZERS
from mmdet.utils import register_all_modules


def parse_args():
Expand Down Expand Up @@ -33,8 +36,22 @@ def main():
('Please specify at least one operation (save/show the '
'video) with the argument "--out" or "--show"')

# register all modules in mmdet into the registries
register_all_modules()

# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)

# build test pipeline
model.cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
test_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline)

# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint and
# then pass to the model in init_detector
visualizer.dataset_meta = model.dataset_meta

video_reader = mmcv.VideoReader(args.video)
video_writer = None
if args.out:
Expand All @@ -44,8 +61,15 @@ def main():
(video_reader.width, video_reader.height))

for frame in mmcv.track_iter_progress(video_reader):
result = inference_detector(model, frame)
frame = model.show_result(frame, result, score_thr=args.score_thr)
result = inference_detector(model, frame, test_pipeline=test_pipeline)
visualizer.add_datasample(
name='video',
image=frame,
pred_sample=result,
show=False,
pred_score_thr=args.score_thr)
frame = visualizer.get_image()

if args.show:
cv2.namedWindow('video', 0)
mmcv.imshow(frame, 'video', args.wait_time)
Expand Down
89 changes: 60 additions & 29 deletions demo/video_gpuaccel_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.transforms import Compose
from torchvision.transforms import functional as F

from mmdet.apis import init_detector
from mmdet.registry import VISUALIZERS
from mmdet.structures import DetDataSample
from mmdet.utils import register_all_modules
from mmdet.utils.typing import Tuple

try:
import ffmpegcv
Expand Down Expand Up @@ -40,25 +44,33 @@ def parse_args():
return args


def prefetch_img_metas(cfg, ori_wh):
def prefetch_batch_input_shape(model: nn.Module, ori_wh: Tuple[int,
int]) -> dict:
cfg = model.cfg
w, h = ori_wh
cfg.data.test.pipeline[0].type = 'LoadImageFromNDArray'
test_pipeline = Compose(cfg.data.test.pipeline)
data = {'img': np.zeros((h, w, 3), dtype=np.uint8)}
cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
data = {'img': np.zeros((h, w, 3), dtype=np.uint8), 'img_id': 0}
data = test_pipeline(data)
img_metas = data['img_metas'][0].data
return img_metas


def process_img(frame_resize, img_metas, device):
assert frame_resize.shape == img_metas['pad_shape']
frame_cuda = torch.from_numpy(frame_resize).to(device).float()
frame_cuda = frame_cuda.permute(2, 0, 1) # HWC to CHW
mean = torch.from_numpy(img_metas['img_norm_cfg']['mean']).to(device)
std = torch.from_numpy(img_metas['img_norm_cfg']['std']).to(device)
frame_cuda = F.normalize(frame_cuda, mean=mean, std=std, inplace=True)
frame_cuda = frame_cuda[None, :, :, :] # NCHW
data = {'img': [frame_cuda], 'img_metas': [[img_metas]]}
_, data_sample = model.data_preprocessor([data], False)
batch_input_shape = data_sample[0].batch_input_shape
return batch_input_shape


def pack_data(frame_resize: np.ndarray, batch_input_shape: Tuple[int, int],
ori_shape: Tuple[int, int]) -> dict:
assert frame_resize.shape[:2] == batch_input_shape
data_sample = DetDataSample()
data_sample.set_metainfo({
'img_shape':
batch_input_shape,
'ori_shape':
ori_shape,
'scale_factor': (batch_input_shape[0] / ori_shape[0],
batch_input_shape[1] / ori_shape[1])
})
frame_resize = torch.from_numpy(frame_resize).permute((2, 0, 1))
data = {'inputs': frame_resize, 'data_sample': data_sample}
return data


Expand All @@ -68,33 +80,52 @@ def main():
('Please specify at least one operation (save/show the '
'video) with the argument "--out" or "--show"')

# register all modules in mmdet into the registries
register_all_modules()

model = init_detector(args.config, args.checkpoint, device=args.device)

# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint and
# then pass to the model in init_detector
visualizer.dataset_meta = model.dataset_meta

if args.nvdecode:
VideoCapture = ffmpegcv.VideoCaptureNV
else:
VideoCapture = ffmpegcv.VideoCapture
video_origin = VideoCapture(args.video)
img_metas = prefetch_img_metas(model.cfg,
(video_origin.width, video_origin.height))
resize_wh = img_metas['pad_shape'][1::-1]

batch_input_shape = prefetch_batch_input_shape(
model, (video_origin.width, video_origin.height))
ori_shape = (video_origin.height, video_origin.width)
resize_wh = batch_input_shape[::-1]
video_resize = VideoCapture(
args.video,
resize=resize_wh,
resize_keepratio=True,
resize_keepratioalign='topleft',
pix_fmt='rgb24')
resize_keepratioalign='topleft')

video_writer = None
if args.out:
video_writer = ffmpegcv.VideoWriter(args.out, fps=video_origin.fps)

with torch.no_grad():
for frame_resize, frame_origin in zip(
mmcv.track_iter_progress(video_resize), video_origin):
data = process_img(frame_resize, img_metas, args.device)
result = model(return_loss=False, rescale=True, **data)[0]
frame_mask = model.show_result(
frame_origin, result, score_thr=args.score_thr)
for i, (frame_resize, frame_origin) in enumerate(
zip(mmcv.track_iter_progress(video_resize), video_origin)):
data = pack_data(frame_resize, batch_input_shape, ori_shape)
result = model.test_step([data])[0]

visualizer.add_datasample(
name='video',
image=frame_origin,
pred_sample=result,
show=False,
pred_score_thr=args.score_thr)

frame_mask = visualizer.get_image()

if args.show:
cv2.namedWindow('video', 0)
mmcv.imshow(frame_mask, 'video', args.wait_time)
Expand Down
29 changes: 25 additions & 4 deletions demo/webcam_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import argparse

import cv2
import mmcv
import torch

from mmdet.apis import inference_detector, init_detector
from mmdet.registry import VISUALIZERS
from mmdet.utils import register_all_modules


def parse_args():
Expand All @@ -24,24 +27,42 @@ def parse_args():
def main():
args = parse_args()

device = torch.device(args.device)
# register all modules in mmdet into the registries
register_all_modules()

# build the model from a config file and a checkpoint file
device = torch.device(args.device)
model = init_detector(args.config, args.checkpoint, device=device)

# init visualizer
visualizer = VISUALIZERS.build(model.cfg.visualizer)
# the dataset_meta is loaded from the checkpoint and
# then pass to the model in init_detector
visualizer.dataset_meta = model.dataset_meta

camera = cv2.VideoCapture(args.camera_id)

print('Press "Esc", "q" or "Q" to exit.')
while True:
ret_val, img = camera.read()
result = inference_detector(model, img)

img = mmcv.imconvert(img, 'bgr', 'rgb')
visualizer.add_datasample(
name='result',
image=img,
pred_sample=result,
pred_score_thr=args.score_thr,
show=False)

img = visualizer.get_image()
img = mmcv.imconvert(img, 'bgr', 'rgb')
cv2.imshow('result', img)

ch = cv2.waitKey(1)
if ch == 27 or ch == ord('q') or ch == ord('Q'):
break

model.show_result(
img, result, score_thr=args.score_thr, wait_time=1, show=True)


if __name__ == '__main__':
main()
20 changes: 13 additions & 7 deletions mmdet/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,18 @@ def init_detector(
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]


def inference_detector(model: nn.Module,
imgs: ImagesType) -> Union[DetDataSample, SampleList]:
def inference_detector(
model: nn.Module,
imgs: ImagesType,
test_pipeline: Optional[Compose] = None
) -> Union[DetDataSample, SampleList]:
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str, ndarray, Sequence[str/ndarray]):
Either image files or loaded images.
test_pipeline (:obj:`Compose`): Test pipeline.
Returns:
:obj:`DetDataSample` or list[:obj:`DetDataSample`]:
Expand All @@ -107,12 +111,14 @@ def inference_detector(model: nn.Module,

cfg = model.cfg

if isinstance(imgs[0], np.ndarray):
cfg = cfg.copy()
# set loading pipeline type
cfg.test_dataloader.dataset.pipeline[0].type = 'LoadImageFromNDArray'
if test_pipeline is None:
if isinstance(imgs[0], np.ndarray):
cfg = cfg.copy()
# set loading pipeline type
cfg.test_dataloader.dataset.pipeline[
0].type = 'LoadImageFromNDArray'

test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)
test_pipeline = Compose(cfg.test_dataloader.dataset.pipeline)

data = []
for img in imgs:
Expand Down

0 comments on commit 5f54d7e

Please sign in to comment.