Skip to content

Support rectangle in draw_boxes and relocate draw_ser_results #755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 107 additions & 14 deletions mindocr/utils/visualize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import math
import os
from typing import List, Union

import cv2
Expand All @@ -9,7 +10,16 @@

from ..data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

__all__ = ["show_img", "show_imgs", "draw_boxes", "draw_texts_with_boxes", "recover_image", "visualize"]
__all__ = [
"show_img",
"show_imgs",
"draw_boxes",
"draw_texts_with_boxes",
"recover_image",
"visualize",
"draw_ser_results",
"trans_poly_to_bbox",
]
_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -73,12 +83,22 @@ def draw_boxes(
thickness=1,
is_bgr_img=False,
color_random=False,
): # , to_rgb=False):
"""image can be str or np.array for image in 'BGR' colorm mode.
color: list for color of each box, or tuple for color of all boxes with the same color. in RGB order

draw_type="polygon",
):
"""
Draw boxes (polygons or rectangles) on the image.
Args:
image: The image to draw boxes on. It can be a path to the image or a numpy array.
bboxes: The list of boxes to draw.
For polygon, each box is a list of points [[x1, y1], [x2, y2], ...].
For rectangle, each box is a list of 4 integers [x1, y1, x2, y2].
color: The color of the boxes. Default is (255, 0, 0) in RGB order.
thickness: The thickness of the lines.
is_bgr_img: Whether the image is in BGR format.
color_random: Whether to use random color for each box.
draw_type: The type of the boxes to draw. It can be "polygon" or "rectangle".
return:
np.array, image draw with boxes in RGB color order
np.array: The image with boxes drawn in RGB color order.
"""
# load image and convert to BGR format
if isinstance(image, str):
Expand All @@ -88,14 +108,23 @@ def draw_boxes(
if not is_bgr_img:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

for i, box in enumerate(bboxes):
for _, box in enumerate(bboxes):
box = box.astype(int)

if isinstance(color, tuple):
color_bgr = color[::-1]
if color_random:
color_bgr = np.random.randint(0, 255, 3, dtype=np.int32).tolist()
elif isinstance(color, tuple):
color_bgr = color[::-1] # Convert RGB to BGR
else:
color_bgr = (0, 0, 255) # Default color in BGR

if draw_type == "polygon":
cv2.polylines(image, [box], True, color_bgr, thickness)
elif draw_type == "rectangle":
x1, y1, x2, y2 = box
cv2.rectangle(image, (x1, y1), (x2, y2), color_bgr, thickness)
else:
color_bgr = np.randint(0, 255, 3, dtype=np.int32)
cv2.polylines(image, [box], True, color_bgr, thickness)
raise ValueError(f"Unsupported draw type: {draw_type}")

# convert to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
Expand All @@ -117,9 +146,16 @@ def draw_texts_with_boxes(
text_inside_box: bool = True,
):
"""
font_path: path to ttf font file. If None, use a "better than nothing" defalut font in PIL (font_size can't be set
in this case)
font_size: font size. if None, font_size will be computed automatically according to the box size and image size.
Draw texts with boxes on the image.
Args:
image: The image to draw boxes on. It can be a path to the image or a numpy array.
bboxes: The list of boxes to draw. each box is a list of points [[x1, y1], [x2, y2], ...].
texts: The list of texts to draw.
box_color: The color of the boxes. Default is (255, 0, 0) in RGB order.
thickness: The thickness of the lines.
text_color: The color of the texts. Default is (0, 0, 0) in RGB order.
font_path: The path to the font file. If None, use a "better than nothing" default font in PIL.
font_size: The font size. If None, the font size will be computed automatically by box size and image size.
"""
if hide_boxes:
if is_bgr_img:
Expand Down Expand Up @@ -237,3 +273,60 @@ def recover_image(img, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, is_
img = img.astype(np.uint8)

return img


def draw_ser_results(
image: Union[str, np.array], ocr_results: List[dict], font_path="docs/fonts/simfang.ttf", font_size=14
):
np.random.seed(2021)
color = (np.random.permutation(range(255)), np.random.permutation(range(255)), np.random.permutation(range(255)))
color_map = {idx: (color[0][idx], color[1][idx], color[2][idx]) for idx in range(1, 255)}

if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str) and os.path.isfile(image):
image = Image.open(image).convert("RGB")
else:
raise ValueError("Invalid image input. Must be a file path or numpy array.")

img_new = image.copy()
draw = ImageDraw.Draw(img_new)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")

for ocr_info in ocr_results:
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
text = f"{ocr_info['pred']}: {ocr_info['transcription']}"
bbox = ocr_info.get("bbox", trans_poly_to_bbox(ocr_info["points"]))
draw_box_txt(bbox, text, draw, font, color)

img_new = Image.blend(image, img_new, 0.7)
return np.array(img_new)


def trans_poly_to_bbox(poly: list):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]


def draw_box_txt(
bbox: list,
text: str,
draw: ImageDraw.Draw,
font: ImageFont.FreeTypeFont,
color: Union[tuple, str] = (255, 0, 0),
):
# draw ocr results outline
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
draw.rectangle(bbox, fill=color)

# draw ocr results
left, top, right, bottom = font.getbbox(text)
tw, th = right - left, bottom - top
start_y = max(0, bbox[0][1] - th)
draw.rectangle([(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + th)], fill=(0, 0, 255))
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
62 changes: 6 additions & 56 deletions tools/infer/text/predict_ser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import cv2
import numpy as np
from config import parse_args
from PIL import Image, ImageDraw, ImageFont
from postprocess import Postprocessor
from predict_system import TextSystem
from preprocess import Preprocessor
Expand All @@ -23,6 +22,7 @@

from mindocr import build_model # noqa
from mindocr.utils.logger import set_logger # noqa
from mindocr.utils.visualize import draw_ser_results

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../")))
Expand All @@ -37,7 +37,7 @@
logger = logging.getLogger("mindocr")


class SemanticEntityRecognition(object):
class SemanticEntityRecognition:
"""
Infer model for semantic entity recognition
"""
Expand Down Expand Up @@ -157,9 +157,10 @@ def _parse_annotation(self, data_line: str):
return img_name, annot_str

def get_from_file(self, label_file_list):
"""Load data list from label_file which contains infomation of image paths and annotations
"""
Load data list from label_file_list which contains information image path and annotation.
Args:
label_file: annotation file path(s)
label_file_list: annotation file path(s)
Returns:
data_list (List[dict]): A list of annotation dict, which contains keys: img_path, annot...
"""
Expand Down Expand Up @@ -255,7 +256,7 @@ def run_single(self, ocr_info_list):
"""
ser_res = []
# preprocess
for i, img in enumerate(ocr_info_list):
for _, img in enumerate(ocr_info_list):
data = self.preprocess(img)
input_ids = data["input_ids"]
bbox = data["bbox"]
Expand Down Expand Up @@ -310,57 +311,6 @@ def __call__(self, img_path, ocr_path=None):
return results_ser, time_report


def draw_ser_results(image, ocr_results, font_path="docs/fonts/simfang.ttf", font_size=14):
np.random.seed(2021)
color = (np.random.permutation(range(255)), np.random.permutation(range(255)), np.random.permutation(range(255)))
color_map = {idx: (color[0][idx], color[1][idx], color[2][idx]) for idx in range(1, 255)}
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str) and os.path.isfile(image):
image = Image.open(image).convert("RGB")
img_new = image.copy()
draw = ImageDraw.Draw(img_new)

font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
for ocr_info in ocr_results:
if ocr_info["pred_id"] not in color_map:
continue
color = color_map[ocr_info["pred_id"]]
text = "{}: {}".format(ocr_info["pred"], ocr_info["transcription"])

if "bbox" in ocr_info:
# draw with ocr engine
bbox = ocr_info["bbox"]
else:
# draw with ocr groundtruth
bbox = trans_poly_to_bbox(ocr_info["points"])
draw_box_txt(bbox, text, draw, font, font_size, color)

img_new = Image.blend(image, img_new, 0.7)
return np.array(img_new)


def trans_poly_to_bbox(poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
return [x1, y1, x2, y2]


def draw_box_txt(bbox, text, draw, font, font_size, color):
# draw ocr results outline
bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
draw.rectangle(bbox, fill=color)

# draw ocr results
left, top, right, bottom = font.getbbox(text)
tw, th = right - left, bottom - top
start_y = max(0, bbox[0][1] - th)
draw.rectangle([(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + th)], fill=(0, 0, 255))
draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)


if __name__ == "__main__":
args = parse_args()
set_logger(name="mindocr")
Expand Down
Loading