diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py index caaa4006302d..3b9f64e1220f 100644 --- a/python/mxnet/image/detection.py +++ b/python/mxnet/image/detection.py @@ -803,7 +803,7 @@ def check_label_shape(self, label_shape): raise ValueError(msg) def draw_next(self, color=None, thickness=2, mean=None, std=None, clip=True, - waitKey=None, window_name='draw_next'): + waitKey=None, window_name='draw_next', id2labels=None): """Display next image with bounding boxes drawn. Parameters @@ -822,6 +822,8 @@ def draw_next(self, color=None, thickness=2, mean=None, std=None, clip=True, Hold the window for waitKey milliseconds if set, skip ploting if None window_name : str Plot window name if waitKey is set. + id2labels : dict + Mapping of labels id to labels name. Returns ------- @@ -889,6 +891,17 @@ def draw_next(self, color=None, thickness=2, mean=None, std=None, clip=True, y2 = int(label[i, 4] * height) bc = np.random.rand(3) * 255 if not color else color cv2.rectangle(image, (x1, y1), (x2, y2), bc, thickness) + if id2labels is not None: + cls_id = int(label[i, 0]) + if cls_id in id2labels: + cls_name = id2labels[cls_id] + text = "{:s}".format(cls_name) + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.5 + text_height = cv2.getTextSize(text, font, font_scale, 2)[0][1] + tc = (255, 255, 255) + tpos = (x1 + 5, y1 + text_height + 5) + cv2.putText(image, text, tpos, font, font_scale, tc, 2) if waitKey is not None: cv2.imshow(window_name, image) cv2.waitKey(waitKey)