Skip to content

Commit

Permalink
Switch imshow_det_bboxes visualization backend from opencv to matplot…
Browse files Browse the repository at this point in the history
…lib (#4389)

* FPN deprecated warning

* FPN deprecated warning

* Replace imshow_det_bboxes visualization backend

* Add bbox_vis unit tests

* Encapsulate color_val_matplotlib function

* Add fun input parameters

* Deprecate block

* Add mask display in image

* Putting the text inner left corner the bbox

* Add a filling color for text regions.

* Update color docs

* Fix color docs

* Update color docs and Fix default param
  • Loading branch information
hhaAndroid authored Jan 13, 2021
1 parent 9b2c208 commit 308f0d7
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 38 deletions.
26 changes: 17 additions & 9 deletions mmdet/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import torch
Expand Down Expand Up @@ -164,7 +163,8 @@ def show_result_pyplot(model,
score_thr=0.3,
fig_size=(15, 10),
title='result',
block=True):
block=True,
wait_time=0):
"""Visualize the detection results on the image.
Args:
Expand All @@ -175,13 +175,21 @@ def show_result_pyplot(model,
score_thr (float): The threshold to visualize the bboxes and masks.
fig_size (tuple): Figure size of the pyplot figure.
title (str): Title of the pyplot figure.
block (bool): Whether to block GUI.
block (bool): Whether to block GUI. Default: True
wait_time (float): Value of waitKey param.
Default: 0.
"""
warnings.warn('"block" will be deprecated in v2.9.0,'
'Please use "wait_time"')
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, score_thr=score_thr, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.title(title)
plt.tight_layout()
plt.show(block=block)
model.show_result(
img,
result,
score_thr=score_thr,
show=True,
wait_time=wait_time,
fig_size=fig_size,
win_name=title,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241))
3 changes: 3 additions & 0 deletions mmdet/core/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .image import color_val_matplotlib, imshow_det_bboxes

__all__ = ['imshow_det_bboxes', 'color_val_matplotlib']
169 changes: 169 additions & 0 deletions mmdet/core/visualization/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import os.path as osp
import warnings

import matplotlib.pyplot as plt
import mmcv
import numpy as np
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon


def color_val_matplotlib(color):
"""Convert various input in BGR order to normalized RGB matplotlib color
tuples,
Args:
color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
Returns:
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
"""
color = mmcv.color_val(color)
color = [color / 255 for color in color[::-1]]
return tuple(color)


def imshow_det_bboxes(img,
bboxes,
labels,
segms=None,
class_names=None,
score_thr=0,
bbox_color='green',
text_color='green',
mask_color=None,
thickness=2,
font_scale=0.5,
font_size=13,
win_name='',
fig_size=(15, 10),
show=True,
wait_time=0,
out_file=None):
"""Draw bboxes and class labels (with scores) on an image.
Args:
img (str or ndarray): The image to be displayed.
bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
(n, 5).
labels (ndarray): Labels of bboxes.
segms (ndarray or None): Masks, shaped (n,h,w) or None
class_names (list[str]): Names of each classes.
score_thr (float): Minimum score of bboxes to be shown. Default: 0
bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
The tuple of color should be in BGR order. Default: 'green'
text_color (str or tuple(int) or :obj:`Color`):Color of texts.
The tuple of color should be in BGR order. Default: 'green'
mask_color (None or str or tuple(int) or :obj:`Color`):
Color of masks. The tuple of color should be in BGR order.
Default: None
thickness (int): Thickness of lines. Default: 2
font_scale (float): Font scales of texts. Default: 0.5
font_size (int): Font size of texts. Default: 13
show (bool): Whether to show the image. Default: True
win_name (str): The window name. Default: ''
fig_size (tuple): Figure size of the pyplot figure. Default: (15, 10)
wait_time (float): Value of waitKey param. Default: 0.
out_file (str or None): The filename to write the image. Default: None
Returns:
ndarray: The image with bboxes drawn on it.
"""
warnings.warn('"font_scale" will be deprecated in v2.9.0,'
'Please use "font_size"')
assert bboxes.ndim == 2, \
f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
assert labels.ndim == 1, \
f' labels ndim should be 1, but its ndim is {labels.ndim}.'
assert bboxes.shape[0] == labels.shape[0], \
'bboxes.shape[0] and labels.shape[0] should have the same length.'
assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5,\
f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.'
img = mmcv.imread(img).copy()

if score_thr > 0:
assert bboxes.shape[1] == 5
scores = bboxes[:, -1]
inds = scores > score_thr
bboxes = bboxes[inds, :]
labels = labels[inds]
if segms is not None:
segms = segms[inds, ...]

mask_colors = []
if labels.shape[0] > 0:
if mask_color is None:
# random color
np.random.seed(42)
mask_colors = [
np.random.randint(0, 256, (1, 3), dtype=np.uint8)
for _ in range(max(labels) + 1)
]
else:
# specify color
mask_colors = [
np.array(mmcv.color_val(mask_color)[::-1], dtype=np.uint8)
] * (
max(labels) + 1)

bbox_color = color_val_matplotlib(bbox_color)
text_color = color_val_matplotlib(text_color)

img = mmcv.bgr2rgb(img)
img = np.ascontiguousarray(img)

plt.figure(figsize=fig_size)
plt.title(win_name)
plt.axis('off')
ax = plt.gca()

polygons = []
color = []
for i, (bbox, label) in enumerate(zip(bboxes, labels)):
bbox_int = bbox.astype(np.int32)
poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
[bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
np_poly = np.array(poly).reshape((4, 2))
polygons.append(Polygon(np_poly))
color.append(bbox_color)
label_text = class_names[
label] if class_names is not None else f'class {label}'
if len(bbox) > 4:
label_text += f'|{bbox[-1]:.02f}'
ax.text(
bbox_int[0],
bbox_int[1],
f'{label_text}',
bbox={
'facecolor': 'black',
'alpha': 0.8,
'pad': 0.7,
'edgecolor': 'none'
},
color=text_color,
fontsize=font_size,
verticalalignment='top',
horizontalalignment='left')
if segms is not None:
color_mask = mask_colors[labels[i]]
mask = segms[i].astype(bool)
img[mask] = img[mask] * 0.5 + color_mask * 0.5

plt.imshow(img)

p = PatchCollection(
polygons, facecolor='none', edgecolors=color, linewidths=thickness)
ax.add_collection(p)

if out_file is not None:
dir_name = osp.abspath(osp.dirname(out_file))
mmcv.mkdir_or_exist(dir_name)
plt.savefig(out_file)
if show:
if wait_time == 0:
plt.show()
else:
plt.show(block=False)
plt.pause(wait_time)
plt.close()
return mmcv.rgb2bgr(img)
55 changes: 31 additions & 24 deletions mmdet/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mmcv.runner import auto_fp16
from mmcv.utils import print_log

from mmdet.core.visualization import imshow_det_bboxes
from mmdet.utils import get_root_logger


Expand Down Expand Up @@ -270,11 +271,14 @@ def show_result(self,
img,
result,
score_thr=0.3,
bbox_color='green',
text_color='green',
thickness=1,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241),
mask_color=None,
thickness=2,
font_scale=0.5,
font_size=13,
win_name='',
fig_size=(15, 10),
show=False,
wait_time=0,
out_file=None):
Expand All @@ -286,12 +290,20 @@ def show_result(self,
bbox_result or (bbox_result, segm_result).
score_thr (float, optional): Minimum score of bboxes to be shown.
Default: 0.3.
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
text_color (str or tuple or :obj:`Color`): Color of texts.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
The tuple of color should be in BGR order. Default: 'green'
text_color (str or tuple(int) or :obj:`Color`):Color of texts.
The tuple of color should be in BGR order. Default: 'green'
mask_color (None or str or tuple(int) or :obj:`Color`):
Color of masks. The tuple of color should be in BGR order.
Default: None
thickness (int): Thickness of lines. Default: 2
font_scale (float): Font scales of texts. Default: 0.5
font_size (int): Font size of texts. Default: 13
win_name (str): The window name. Default: ''
fig_size (tuple): Figure size of the pyplot figure.
Default: (15, 10)
wait_time (float): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
Expand All @@ -316,37 +328,32 @@ def show_result(self,
]
labels = np.concatenate(labels)
# draw segmentation masks
segms = None
if segm_result is not None and len(labels) > 0: # non empty
segms = mmcv.concat_list(segm_result)
inds = np.where(bboxes[:, -1] > score_thr)[0]
np.random.seed(42)
color_masks = [
np.random.randint(0, 256, (1, 3), dtype=np.uint8)
for _ in range(max(labels) + 1)
]
for i in inds:
i = int(i)
color_mask = color_masks[labels[i]]
sg = segms[i]
if isinstance(sg, torch.Tensor):
sg = sg.detach().cpu().numpy()
mask = sg.astype(bool)
img[mask] = img[mask] * 0.5 + color_mask * 0.5
if isinstance(segms[0], torch.Tensor):
segms = torch.stack(segms, dim=0).detach().cpu().numpy()
else:
segms = np.stack(segms, axis=0)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
# draw bounding boxes
mmcv.imshow_det_bboxes(
imshow_det_bboxes(
img,
bboxes,
labels,
segms,
class_names=self.CLASSES,
score_thr=score_thr,
bbox_color=bbox_color,
text_color=text_color,
mask_color=mask_color,
thickness=thickness,
font_scale=font_scale,
font_size=font_size,
win_name=win_name,
fig_size=fig_size,
show=show,
wait_time=wait_time,
out_file=out_file)
Expand Down
61 changes: 61 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import tempfile

import mmcv
import numpy as np
import pytest
import torch

from mmdet.core import visualization as vis


def test_color():
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
assert vis.color_val_matplotlib('green') == (0., 1., 0.)
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
assert vis.color_val_matplotlib(np.zeros(3, dtype=np.int)) == (0., 0., 0.)
# forbid white color
with pytest.raises(TypeError):
vis.color_val_matplotlib([255, 255, 255])
# forbid float
with pytest.raises(TypeError):
vis.color_val_matplotlib(1.0)
# overflowed
with pytest.raises(AssertionError):
vis.color_val_matplotlib((0, 0, 500))


def test_imshow_det_bboxes():
tmp_filename = osp.join(tempfile.gettempdir(), 'det_bboxes_image',
'image.jpg')
image = np.ones((10, 10, 3), np.uint8)
bbox = np.array([[2, 1, 3, 3], [3, 4, 6, 6]])
label = np.array([0, 1])
vis.imshow_det_bboxes(
image, bbox, label, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)

# test shaped (0,)
image = np.ones((10, 10, 3), np.uint8)
bbox = np.ones((0, 4))
label = np.ones((0, ))
vis.imshow_det_bboxes(
image, bbox, label, out_file=tmp_filename, show=False)

# test mask
image = np.ones((10, 10, 3), np.uint8)
bbox = np.array([[2, 1, 3, 3], [3, 4, 6, 6]])
label = np.array([0, 1])
segms = np.random.random((2, 10, 10)) > 0.5
segms = np.array(segms, np.int32)
vis.imshow_det_bboxes(
image, bbox, label, segms, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)

# test tensor mask type error
with pytest.raises(AttributeError):
segms = torch.tensor(segms)
vis.imshow_det_bboxes(
image, bbox, label, segms, out_file=tmp_filename, show=False)
Loading

0 comments on commit 308f0d7

Please sign in to comment.