Skip to content

Commit

Permalink
fix infer postprocess init (#442)
Browse files Browse the repository at this point in the history
  • Loading branch information
liangxhao authored Jun 26, 2023
1 parent b6605aa commit 2b95dc0
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
box_type: str = "quad",
rescale_fields=["polys"],
):
super().__init__(box_type, rescale_fields)
super().__init__(rescale_fields=rescale_fields, box_type=box_type)
self.fourier_degree = fourier_degree
self.num_reconstr_points = num_reconstr_points
self.scales = scales
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
rescale_fields=["polys"],
**kwargs
):
super().__init__(box_type, rescale_fields)
super().__init__(rescale_fields=rescale_fields, box_type=box_type)

self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
Expand Down
9 changes: 5 additions & 4 deletions deploy/py_infer/src/utils/visual_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ def vis_bbox_text(image, box_list, text_list, font_path):
:param font_path: path to font file
:return: image with box and text
"""
_font_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../docs/fonts/simfang.ttf"))
if not os.path.isfile(_font_path):
raise ValueError(f"font_path must be a file, but got {font_path}.")
font_path = _font_path
if font_path is None:
_font_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../docs/fonts/simfang.ttf"))
if os.path.isfile(_font_path):
font_path = _font_path

image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
h, w = image.height, image.width
img_left = image.copy()
Expand Down
20 changes: 20 additions & 0 deletions tests/ut/test_infer_postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import sys

import pytest

py_infer_path = "deploy/py_infer"
sys.path.insert(0, py_infer_path)

from src.data_process import build_postprocess

configs_list = [
"configs/det/dbnet/db_r50_icdar15.yaml",
"configs/rec/crnn/crnn_icdar15.yaml",
"deploy/py_infer/src/configs/det/ppocr/det_r50_vd_sast_icdar15.yaml",
"deploy/py_infer/src/configs/rec/mmocr/nrtr_resnet31-1by8-1by4_6e_st_mj.yaml",
]


@pytest.mark.parametrize("config_file", configs_list)
def test_build_postprocess(config_file):
build_postprocess(config_file)

0 comments on commit 2b95dc0

Please sign in to comment.