Skip to content

增加返回json结果的参数 #153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 176 additions & 95 deletions wired_table_rec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@
import logging
import time
import traceback
from dataclasses import dataclass, asdict
from enum import Enum
from pathlib import Path
from typing import List, Optional, Union, Dict, Any
from typing import List, Optional, Tuple, Union, Dict, Any
import numpy as np
import cv2

from wired_table_rec.table_structure_cycle_center_net import TSRCycleCenterNet
from wired_table_rec.table_structure_unet import TSRUnet
from wired_table_rec.utils.download_model import DownloadModel
from wired_table_rec.table_line_rec import TableLineRecognition
from wired_table_rec.table_line_rec_plus import TableLineRecognitionPlus
from .table_recover import TableRecover
from .utils.utils import InputType, LoadImage
from wired_table_rec.utils.utils_table_recover import (
from .utils import InputType, LoadImage
from .utils_table_recover import (
match_ocr_cell,
plot_html_table,
box_4_2_poly_to_box_4_1,
Expand All @@ -27,73 +24,54 @@
gather_ocr_list_by_row,
)


class ModelType(Enum):
CYCLE_CENTER_NET = "cycle_center_net"
UNET = "unet"


ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
KEY_TO_MODEL_URL = {
ModelType.CYCLE_CENTER_NET.value: f"{ROOT_URL}/cycle_center_net.onnx",
ModelType.UNET.value: f"{ROOT_URL}/unet.onnx",
}


@dataclass
class WiredTableInput:
model_type: Optional[str] = ModelType.UNET.value
model_path: Union[str, Path, None, Dict[str, str]] = None
use_cuda: bool = False
device: str = "cpu"


@dataclass
class WiredTableOutput:
pred_html: Optional[str] = None
cell_bboxes: Optional[np.ndarray] = None
logic_points: Optional[np.ndarray] = None
elapse: Optional[float] = None
cur_dir = Path(__file__).resolve().parent
default_model_path = cur_dir / "models" / "cycle_center_net_v1.onnx"
default_model_path_v2 = cur_dir / "models" / "cycle_center_net_v2.onnx"


class WiredTableRecognition:
def __init__(self, config: WiredTableInput):
self.model_type = config.model_type
if self.model_type not in KEY_TO_MODEL_URL:
model_list = ",".join(KEY_TO_MODEL_URL)
raise ValueError(
f"{self.model_type} is not supported. The currently supported models are {model_list}."
)

config.model_path = self.get_model_path(config.model_type, config.model_path)
if self.model_type == ModelType.CYCLE_CENTER_NET.value:
self.table_structure = TSRCycleCenterNet(asdict(config))
else:
self.table_structure = TSRUnet(asdict(config))

def __init__(self, table_model_path: Union[str, Path] = None, version="v2"):
self.load_img = LoadImage()
if version == "v2":
model_path = table_model_path if table_model_path else default_model_path_v2
self.table_line_rec = TableLineRecognitionPlus(str(model_path))
else:
model_path = table_model_path if table_model_path else default_model_path
self.table_line_rec = TableLineRecognition(str(model_path))

self.table_recover = TableRecover()

try:
self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
except ModuleNotFoundError:
self.ocr = None

def __call__(
self,
img: InputType,
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
**kwargs,
) -> WiredTableOutput:
) -> Tuple[str, float, Any, Any, Any]:
if self.ocr is None and ocr_result is None:
raise ValueError(
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
)

s = time.perf_counter()
rec_again = True
need_ocr = True
col_threshold = 15
row_threshold = 10
if kwargs:
rec_again = kwargs.get("rec_again", True)
need_ocr = kwargs.get("need_ocr", True)
col_threshold = kwargs.get("col_threshold", 15)
row_threshold = kwargs.get("row_threshold", 10)
img = self.load_img(img)
polygons, rotated_polygons = self.table_structure(img, **kwargs)
polygons, rotated_polygons = self.table_line_rec(img, **kwargs)
if polygons is None:
logging.warning("polygons is None.")
return WiredTableOutput("", None, None, 0.0)
return "", 0.0, None, None, None

try:
table_res, logi_points = self.table_recover(
Expand All @@ -108,34 +86,52 @@ def __call__(
sorted_polygons, idx_list = sorted_ocr_boxes(
[box_4_2_poly_to_box_4_1(box) for box in polygons]
)
return WiredTableOutput(
return (
"",
time.perf_counter() - s,
sorted_polygons,
logi_points[idx_list],
time.perf_counter() - s,
[],
)
if ocr_result is None and need_ocr:
ocr_result, _ = self.ocr(img)
cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
# 如果有识别框没有ocr结果,直接进行rec补充
cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again)
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
t_rec_ocr_list_dict = self.transform_res(cell_box_det_map, polygons, logi_points)
# 第一行或者第一列为空时,调整代码
#adjust_dict = self.adjust_table_cells(t_rec_ocr_list_dict)
adjust_dict = self.process_ocr_result(t_rec_ocr_list_dict)
# 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list_dict)
# cell_box_map =
logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list]
cell_box_det_map = {
i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]]
for i, t_box_ocr in enumerate(t_rec_ocr_list)
}
pred_html = plot_html_table(logi_points, cell_box_det_map)
polygons = np.array(polygons).reshape(-1, 8)
logi_points = np.array(logi_points)
elapse = time.perf_counter() - s
table_str = plot_html_table(logi_points, cell_box_det_map)
ocr_boxes_res = [
box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result
]
sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res)
sorted_polygons = [box_4_2_poly_to_box_4_1(box) for box in polygons]
sorted_logi_points = logi_points
table_elapse = time.perf_counter() - s

except Exception:
logging.warning(traceback.format_exc())
return WiredTableOutput("", None, None, 0.0)
return WiredTableOutput(pred_html, polygons, logi_points, elapse)
return "", 0.0, None, None, None
return (
table_str,
table_elapse,
sorted_polygons,
sorted_logi_points,
sorted_ocr_boxes_res,
adjust_dict

)

def transform_res(
self,
Expand Down Expand Up @@ -166,6 +162,102 @@ def transform_res(
res.append(dict_res)
return res

def process_ocr_result(self, ocr_result):
# 删除第一行的字典,并调整其余字典的行数
first_row_empty = [entry for entry in ocr_result if
entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0 and entry['t_ocr_res'][0][
1] == '']

if len(first_row_empty) == len(
[entry for entry in ocr_result if entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0]):
# 如果第一行的所有单元格都为空,删除第一行
ocr_result = [entry for entry in ocr_result if entry['t_logic_box'][0] != 0 or entry['t_logic_box'][1] != 0]
# 调整剩余字典的行数
for entry in ocr_result:
entry['t_logic_box'][0] -= 1
entry['t_logic_box'][1] -= 1

# 删除第一列的字典,并调整其余字典的列数
first_col_empty = [entry for entry in ocr_result if
entry['t_logic_box'][2] == 0 and entry['t_logic_box'][3] == 0 and entry['t_ocr_res'][0][
1] == '']

if len(first_col_empty) == len(
[entry for entry in ocr_result if entry['t_logic_box'][2] == 0 and entry['t_logic_box'][3] == 0]):
# 如果第一列的所有单元格都为空,删除第一列
ocr_result = [entry for entry in ocr_result if entry['t_logic_box'][2] != 0 or entry['t_logic_box'][3] != 0]
# 调整剩余字典的列数
for entry in ocr_result:
entry['t_logic_box'][2] -= 1
entry['t_logic_box'][3] -= 1

return ocr_result

def adjust_table_cells(self, t_rec_ocr_list_dict):
"""
调整表格单元格,去掉第一行和/或第一列的单元格,
并更新剩余单元格的行列起始和结束位置。

参数:
t_rec_ocr_list_dict (list): 原始表格单元格识别结果,格式为
[
{
"t_box": [xmin, ymin, xmax, ymax],
"t_logic_box": [row_start, row_end, col_start, col_end],
"t_ocr_res": [[box, text], ...]
},
...
]

返回:
list: 调整后的表格单元格识别结果,格式与输入相同。
"""
# 新的结果列表
adjusted_result = []

# 记录是否第一行和第一列的单元格已被删除
remove_first_row = False
remove_first_col = False

# 检查并移除第一行
if all(cell and not cell[1] for cell in t_rec_ocr_list_dict[0].get("t_ocr_res", [])):
remove_first_row = True

# 检查并移除第一列
if all(row.get("t_ocr_res") and not row["t_ocr_res"][0][1] for row in t_rec_ocr_list_dict):
remove_first_col = True

# 遍历原始结果进行调整
for i, row in enumerate(t_rec_ocr_list_dict):
adjusted_row = []

# 如果是第一行并且需要删除,跳过这行
if remove_first_row and i == 0:
continue

for j, cell in enumerate(row.get("t_ocr_res", [])):
# 如果是第一列并且需要删除,跳过这一列
if remove_first_col and j == 0:
continue

# 更新当前单元格的逻辑位置
adjusted_cell = {
"t_box": row.get("t_box"),
"t_logic_box": [
row["t_logic_box"][0] - 1 if i > 0 else row["t_logic_box"][0],
row["t_logic_box"][1] - 1 if i > 0 else row["t_logic_box"][1],
row["t_logic_box"][2] - 1 if j > 0 else row["t_logic_box"][2],
row["t_logic_box"][3] - 1 if j > 0 else row["t_logic_box"][3]
],
"t_ocr_res": cell
}
adjusted_row.append(adjusted_cell)

if adjusted_row:
adjusted_result.append(adjusted_row)

return adjusted_result

def sort_and_gather_ocr_res(self, res):
for i, dict_res in enumerate(res):
_, sorted_idx = sorted_ocr_boxes(
Expand All @@ -177,19 +269,30 @@ def sort_and_gather_ocr_res(self, res):
)
return res

def fill_blank_rec(
def re_rec(
self,
img: np.ndarray,
sorted_polygons: np.ndarray,
cell_box_map: Dict[int, List[str]],
rec_again=True,
) -> Dict[int, List[Any]]:
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
for i in range(sorted_polygons.shape[0]):
if cell_box_map.get(i):
continue
if not rec_again:
box = sorted_polygons[i]
cell_box_map[i] = [[box, "", 1]]
continue
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
pad_img = cv2.copyMakeBorder(
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
)
rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True)
box = sorted_polygons[i]
cell_box_map[i] = [[box, "", 1]]
continue
text = [rec[0] for rec in rec_res]
scores = [rec[1] for rec in rec_res]
cell_box_map[i] = [[box, "".join(text), min(scores)]]
return cell_box_map

def re_rec_high_precise(
Expand Down Expand Up @@ -222,46 +325,24 @@ def re_rec_high_precise(
]
return cell_box_map

@staticmethod
def get_model_path(
model_type: str, model_path: Union[str, Path, None]
) -> Union[str, Dict[str, str]]:
if model_path is not None:
return model_path

model_url = KEY_TO_MODEL_URL.get(model_type, None)
if isinstance(model_url, str):
model_path = DownloadModel.download(model_url)
return model_path

if isinstance(model_url, dict):
model_paths = {}
for k, url in model_url.items():
model_paths[k] = DownloadModel.download(
url, save_model_name=f"{model_type}_{Path(url).name}"
)
return model_paths

raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-img", "--img_path", type=str, required=True)
args = parser.parse_args()

try:
ocr_engine = importlib.import_module("rapidocr").RapidOCR()
ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Please install the rapidocr by pip install rapidocr."
"Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
) from exc
input_args = WiredTableInput()
table_rec = WiredTableRecognition(input_args)

table_rec = WiredTableRecognition()
ocr_result, _ = ocr_engine(args.img_path)
table_results = table_rec(args.img_path, ocr_result)
print(table_results.pred_html)
print(f"cost: {table_results.elapse:.5f}")
table_str, elapse = table_rec(args.img_path, ocr_result)
print(table_str)
print(f"cost: {elapse:.5f}")


if __name__ == "__main__":
Expand Down