From 74a5e178b4e61fb5700003507b38e5084b92a522 Mon Sep 17 00:00:00 2001 From: Xiaomeng Zhao Date: Wed, 14 Aug 2024 18:43:23 +0800 Subject: [PATCH] =?UTF-8?q?fix=20&=20refactor=20&=20docs=EF=BC=9Aupdate=20?= =?UTF-8?q?ocr=20logic=20and=20installation=20guides=20(#88)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(extract_pdf): When converting a PDF to a list of images, do not perform a BGR channel conversion upfront. * feat(self_modify): refine text and formula detection box updating logic Update the logic for merging and refining detection boxes in self_modify module. Replace hardcoded checks with dynamic calculations for determining overlapping regions, resulting in more accurate detection box merging when formulae are identified within texts. * fix(pdf_extract): optimize batch size and worker count for DataLoader Reduce the batch size from 128 to 64 and set the number of workers to 0 in the DataLoaderto improve stability and performance on systems with limited resources. refactor(pdf_extract): refactor ocr and table recognition logicRefactor the ocr and table recognition logic to enhance readability and maintainability.This includes the adjustment of formula recognition coordinates relative to the cropped image and streamlining the process for handling OCR results and table recognition. * refactor(pdf_extract): optimize image processing and table recognition - Rename loop variable 'idx' to 'pdf_idx' for clarity.- Adjust image pasting and coordinate handling during OCR processing.- Add comments for improved code understanding.- Ensure proper rendering of images during PDF visualization. - Refactor logging and utility imports in self_modify module. The changes include improvements to image processing routines, better variable naming, and streamlined table recognition logic. Also, the visualization process has been tweaked to handle images more accurately. Additionally, redundant logging and utility importshave been cleaned up in the self_modify module to declutter the codebase. * refactor(pdf_extract): remove hardcoded paste values in crop_img function The crop_img function now accepts `crop_paste_x` and `crop_paste_y` as parameters instead of using hardcoded values. This change makes the function more flexible andeasier to adjust for different use cases. * fix(extract_pdf): prevent overscaling of large images Adjust the condition to prevent images from being enlarged beyond a width or height of 9000 pixels, ensuring large images do not become overly large when processed. This change avoids unnecessary resource consumption and potential performance issues when handling scaled images. * docs: update installation guides and requirements - Update the installation guides for macOS and Windows with new commands and simplified dependency installation. - Add new installation guide for Linux. - Modify requirements for CPU and GPU environments, including updates to `unimernet`, `matplotlib`, and `paddlepaddle`. - Provide precompiled wheels for `detectron2` in the installation process. * docs(windows_en): update config guidance for windows * Update func description in self_modify.py * change parameter name in pdf_extract.py, update padding size in ocr * update some instructions in Install_in_Windows_en.md * update some instructions in Install_in_Windows_zh_cn.md * Update README.md * Update README-zh_CN.md --------- Co-authored-by: Fan Wu <34300920+wufan-tb@users.noreply.github.com> --- README-zh_CN.md | 12 +- README.md | 14 +- ...etectron2-0.6-cp310-cp310-linux_x86_64.whl | Bin 0 -> 902088 bytes docs/Install_in_Windows_en.md | 26 ++-- docs/Install_in_Windows_zh_cn.md | 26 ++-- docs/Install_in_macOS_en.md | 3 +- docs/Install_in_macOS_zh_cn.md | 3 +- modules/extract_pdf.py | 15 +- modules/self_modify.py | 135 +++++++++++------ pdf_extract.py | 136 ++++++++++++------ requirements+cpu.txt | 6 +- requirements-without-unimernet+cpu.txt | 6 - requirements.txt | 2 +- 13 files changed, 239 insertions(+), 145 deletions(-) create mode 100644 assets/whl/detectron2-0.6-cp310-cp310-linux_x86_64.whl delete mode 100644 requirements-without-unimernet+cpu.txt diff --git a/README-zh_CN.md b/README-zh_CN.md index 0e948c0..b719d9e 100644 --- a/README-zh_CN.md +++ b/README-zh_CN.md @@ -229,16 +229,10 @@ conda create -n pipeline python=3.10 pip install -r requirements.txt -pip install --extra-index-url https://miropsota.github.io/torch_packages_builder detectron2==0.6+pt2.3.1cu121 +pip install https://github.com/opendatalab/PDF-Extract-Kit/raw/main/assets/whl/detectron2-0.6-cp310-cp310-linux_x86_64.whl ``` -安装完环境后,可能会遇到一些版本冲突导致版本变更,如果遇到了版本相关的报错,可以尝试下面的命令重新安装指定版本的库。 - -```bash -pip install pillow==8.4.0 -``` - -除了版本冲突外,可能还会遇到torch无法调用的错误,可以先把下面的库卸载,然后重新安装cuda12和cudnn。 +安装完环境后,可能还会遇到torch无法调用的错误,可以先把下面的库卸载,然后重新安装cuda12和cudnn。 ```bash pip uninstall nvidia-cusparse-cu12 @@ -260,7 +254,7 @@ pip uninstall nvidia-cusparse-cu12 ## 运行提取脚本 ```bash -python pdf_extract.py --pdf data/pdfs/ocr_1.pdf +python pdf_extract.py --pdf assets/examples/example.pdf ``` 相关参数解释: diff --git a/README.md b/README.md index 3a4e508..4e4a289 100644 --- a/README.md +++ b/README.md @@ -214,23 +214,17 @@ The formula recognition we used is based on the weights downloaded from [UniMERN The table recognition we used is based on the weights downloaded from [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy), a solution that converts images of Table into LaTeX. Compared to the table recognition capability of PP-StructureV2, StructEqTable demonstrates stronger recognition performance, delivering good results even with complex tables, which may currently be best suited for data within research papers. There is also significant room for improvement in terms of speed, and we are continuously iterating and optimizing. Within a week, we will update the table recognition capability to [MinerU](https://github.com/opendatalab/MinerU). -## Installation Guide +## Installation Guide(Linux) ```bash conda create -n pipeline python=3.10 pip install -r requirements.txt -pip install --extra-index-url https://miropsota.github.io/torch_packages_builder detectron2==0.6+pt2.3.1cu121 +pip install https://github.com/opendatalab/PDF-Extract-Kit/raw/main/assets/whl/detectron2-0.6-cp310-cp310-linux_x86_64.whl ``` -After installation, you may encounter some version conflicts leading to version changes. If you encounter version-related errors, you can try the following commands to reinstall specific versions of the libraries. - -```bash -pip install pillow==8.4.0 -``` - -In addition to version conflicts, you may also encounter errors where torch cannot be invoked. First, uninstall the following library and then reinstall cuda12 and cudnn. +After installation, you may also encounter errors where torch cannot be invoked. First, uninstall the following library and then reinstall cuda12 and cudnn. ```bash pip uninstall nvidia-cusparse-cu12 @@ -255,7 +249,7 @@ If you intend to experience this project on Google Colab, please 3000 pixels, don't enlarge the image - if pix.width > 3000 or pix.height > 3000: - pix = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) - image = Image.frombytes('RGB', (pix.width, pix.height), pix.samples) + # If the width or height exceeds 9000 after scaling, do not scale further. + if pm.width > 9000 or pm.height > 9000: + pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) - # images.append(image) - images.append(np.array(image)[:,:,::-1]) + img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples) + images.append(np.array(img)) return images diff --git a/modules/self_modify.py b/modules/self_modify.py index 1f830b8..29bb99e 100644 --- a/modules/self_modify.py +++ b/modules/self_modify.py @@ -76,49 +76,100 @@ def sorted_boxes(dt_boxes): return _boxes -def formula_in_text(mf_bbox, text_bbox): - x1, y1, x2, y2 = mf_bbox - x3, y3 = text_bbox[0] - x4, y4 = text_bbox[2] - left_box, right_box = None, None - same_line = abs((y1+y2)/2 - (y3+y4)/2) / abs(y4-y3) < 0.2 - if not same_line: - return False, left_box, right_box - else: - drop_origin = False - left_x = x1 - 1 - right_x = x2 + 1 - if x3 < x1 and x2 < x4: - drop_origin = True - left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32') - right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32') - if x3 < x1 and x1 <= x4 <= x2: - drop_origin = True - left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32') - if x1 <= x3 <= x2 and x2 < x4: - drop_origin = True - right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32') - if x1 <= x3 < x4 <= x2: - drop_origin = True - return drop_origin, left_box, right_box - - -def update_det_boxes(dt_boxes, mfdetrec_res): - new_dt_boxes = dt_boxes - for mf_box in mfdetrec_res: - flag, left_box, right_box = False, None, None - for idx, text_box in enumerate(new_dt_boxes): - ret, left_box, right_box = formula_in_text(mf_box['bbox'], text_box) - if ret: - new_dt_boxes.pop(idx) - if left_box is not None: - new_dt_boxes.append(left_box) - if right_box is not None: - new_dt_boxes.append(right_box) - break - +def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8): + """Check if two bounding boxes overlap on the y-axis, and if the height of the overlapping region exceeds 80% of the height of the shorter bounding box.""" + _, y0_1, _, y1_1 = bbox1 + _, y0_2, _, y1_2 = bbox2 + + overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2)) + height1, height2 = y1_1 - y0_1, y1_2 - y0_2 + max_height = max(height1, height2) + min_height = min(height1, height2) + + return (overlap / min_height) > overlap_ratio_threshold + + +def bbox_to_points(bbox): + """ change bbox(shape: N * 4) to polygon(shape: N * 8) """ + x0, y0, x1, y1 = bbox + return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32') + + +def points_to_bbox(points): + """ change polygon(shape: N * 8) to bbox(shape: N * 4) """ + x0, y0 = points[0] + x1, _ = points[1] + _, y1 = points[2] + return [x0, y0, x1, y1] + + +def merge_intervals(intervals): + # Sort the intervals based on the start value + intervals.sort(key=lambda x: x[0]) + + merged = [] + for interval in intervals: + # If the list of merged intervals is empty or if the current + # interval does not overlap with the previous, simply append it. + if not merged or merged[-1][1] < interval[0]: + merged.append(interval) + else: + # Otherwise, there is overlap, so we merge the current and previous intervals. + merged[-1][1] = max(merged[-1][1], interval[1]) + + return merged + + +def remove_intervals(original, masks): + # Merge all mask intervals + merged_masks = merge_intervals(masks) + + result = [] + original_start, original_end = original + + for mask in merged_masks: + mask_start, mask_end = mask + + # If the mask starts after the original range, ignore it + if mask_start > original_end: + continue + + # If the mask ends before the original range starts, ignore it + if mask_end < original_start: + continue + + # Remove the masked part from the original range + if original_start < mask_start: + result.append([original_start, mask_start - 1]) + + original_start = max(mask_end + 1, original_start) + + # Add the remaining part of the original range, if any + if original_start <= original_end: + result.append([original_start, original_end]) + + return result + + +def update_det_boxes(dt_boxes, mfd_res): + new_dt_boxes = [] + for text_box in dt_boxes: + text_bbox = points_to_bbox(text_box) + masks_list = [] + for mf_box in mfd_res: + mf_bbox = mf_box['bbox'] + if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox): + masks_list.append([mf_bbox[0], mf_bbox[2]]) + text_x_range = [text_bbox[0], text_bbox[2]] + text_remove_mask_range = remove_intervals(text_x_range, masks_list) + temp_dt_box = [] + for text_remove_mask in text_remove_mask_range: + temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]])) + if len(temp_dt_box) > 0: + new_dt_boxes.extend(temp_dt_box) return new_dt_boxes + class ModifiedPaddleOCR(PaddleOCR): def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)): """ @@ -257,4 +308,4 @@ def __call__(self, img, cls=True, mfd_res=None): filter_rec_res.append(rec_result) end = time.time() time_dict['all'] = end - start - return filter_boxes, filter_rec_res, time_dict \ No newline at end of file + return filter_boxes, filter_rec_res, time_dict diff --git a/pdf_extract.py b/pdf_extract.py index 97610e1..5f68a30 100644 --- a/pdf_extract.py +++ b/pdf_extract.py @@ -111,7 +111,7 @@ def __getitem__(self, idx): else: all_pdfs = [args.pdf] print("total files:", len(all_pdfs)) - for idx, single_pdf in enumerate(all_pdfs): + for pdf_idx, single_pdf in enumerate(all_pdfs): try: img_list = load_pdf_fitz(single_pdf, dpi=dpi) except: @@ -119,7 +119,7 @@ def __getitem__(self, idx): print("unexpected pdf file:", single_pdf) if img_list is None: continue - print("pdf index:", idx, "pages:", len(img_list)) + print("pdf index:", pdf_idx, "pages:", len(img_list)) # layout detection and formula detection doc_layout_result = [] latex_filling_list = [] @@ -151,7 +151,7 @@ def __getitem__(self, idx): # Formula recognition, collect all formula images in whole pdf file, then batch infer them. a = time.time() dataset = MathDataset(mf_image_list, transform=mfr_transform) - dataloader = DataLoader(dataset, batch_size=128, num_workers=32) + dataloader = DataLoader(dataset, batch_size=64, num_workers=0) mfr_res = [] for imgs in dataloader: imgs = imgs.to(device) @@ -161,51 +161,103 @@ def __getitem__(self, idx): res['latex'] = latex_rm_whitespace(latex) b = time.time() print("formula nums:", len(mf_image_list), "mfr time:", round(b-a, 2)) + + def crop_img(input_res, input_pil_img, padding_x=0, padding_y=0): + crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1]) + crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5]) + # Create a white background with an additional width and height of 50 + crop_new_width = crop_xmax - crop_xmin + padding_x * 2 + crop_new_height = crop_ymax - crop_ymin + padding_y * 2 + return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white') + + # Crop image + crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax) + cropped_img = input_pil_img.crop(crop_box) + return_image.paste(cropped_img, (padding_x, padding_y)) + return_list = [padding_x, padding_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height] + return return_image, return_list # ocr and table recognition for idx, image in enumerate(img_list): - pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) - single_page_res = doc_layout_result[idx]['layout_dets'] + + layout_res = doc_layout_result[idx]['layout_dets'] + pil_img = Image.fromarray(image) + + ocr_res_list = [] + table_res_list = [] single_page_mfdetrec_res = [] - for res in single_page_res: + + for res in layout_res: if int(res['category_id']) in [13, 14]: - xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) - xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) single_page_mfdetrec_res.append({ - "bbox": [xmin, ymin, xmax, ymax], + "bbox": [int(res['poly'][0]), int(res['poly'][1]), + int(res['poly'][4]), int(res['poly'][5])], }) - for res in single_page_res: - if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: #categories that need to do ocr - xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) - xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) - crop_box = [xmin, ymin, xmax, ymax] - cropped_img = Image.new('RGB', pil_img.size, 'white') - cropped_img.paste(pil_img.crop(crop_box), crop_box) - cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR) - ocr_res = ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0] - if ocr_res: - for box_ocr_res in ocr_res: - p1, p2, p3, p4 = box_ocr_res[0] - text, score = box_ocr_res[1] - doc_layout_result[idx]['layout_dets'].append({ - 'category_id': 15, - 'poly': p1 + p2 + p3 + p4, - 'score': round(score, 2), - 'text': text, - }) - elif int(res['category_id']) == 5: # do table recognition - xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) - xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) - crop_box = [xmin, ymin, xmax, ymax] - cropped_img = pil_img.convert("RGB").crop(crop_box) - start = time.time() - with torch.no_grad(): - output = tr_model(cropped_img) - end = time.time() - if (end-start) > model_configs['model_args']['table_max_time']: - res["timeout"] = True - res["latex"] = output[0] + elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]: + ocr_res_list.append(res) + elif int(res['category_id']) in [5]: + table_res_list.append(res) + + ocr_start = time.time() + # Process each area that requires OCR processing + for res in ocr_res_list: + new_image, useful_list = crop_img(res, pil_img, padding_x=25, padding_y=25) + paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list + # Adjust the coordinates of the formula area + adjusted_mfdetrec_res = [] + for mf_res in single_page_mfdetrec_res: + mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"] + # Adjust the coordinates of the formula area to the coordinates relative to the cropping area + x0 = mf_xmin - xmin + paste_x + y0 = mf_ymin - ymin + paste_y + x1 = mf_xmax - xmin + paste_x + y1 = mf_ymax - ymin + paste_y + # Filter formula blocks outside the graph + if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]): + continue + else: + adjusted_mfdetrec_res.append({ + "bbox": [x0, y0, x1, y1], + }) + + # OCR recognition + new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) + ocr_res = ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] + + # Integration results + if ocr_res: + for box_ocr_res in ocr_res: + p1, p2, p3, p4 = box_ocr_res[0] + text, score = box_ocr_res[1] + + # Convert the coordinates back to the original coordinate system + p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin] + p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin] + p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] + p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin] + + layout_res.append({ + 'category_id': 15, + 'poly': p1 + p2 + p3 + p4, + 'score': round(score, 2), + 'text': text, + }) + + ocr_cost = round(time.time() - ocr_start, 2) + print(f"ocr cost: {ocr_cost}") + + table_start = time.time() + for res in table_res_list: + new_image, _ = crop_img(res, pil_img) + single_table_start = time.time() + with torch.no_grad(): + output = tr_model(new_image) + if (time.time() - single_table_start) > model_configs['model_args']['table_max_time']: + res["timeout"] = True + res["latex"] = output[0] + table_cost = round(time.time() - table_start, 2) + print(f"table cost: {table_cost}") output_dir = args.output os.makedirs(output_dir, exist_ok=True) @@ -223,7 +275,7 @@ def __getitem__(self, idx): vis_pdf_result = [] for idx, image in enumerate(img_list): single_page_res = doc_layout_result[idx]['layout_dets'] - vis_img = Image.new('RGB', Image.fromarray(image).size, 'white') if args.render else Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + vis_img = Image.new('RGB', Image.fromarray(image).size, 'white') if args.render else Image.fromarray(image) draw = ImageDraw.Draw(vis_img) for res in single_page_res: label = int(res['category_id']) @@ -265,4 +317,4 @@ def __getitem__(self, idx): now = datetime.datetime.now(tz) end = time.time() print(now.strftime('%Y-%m-%d %H:%M:%S')) - print('Finished! time cost:', int(end-start), 's') \ No newline at end of file + print('Finished! time cost:', int(end-start), 's') diff --git a/requirements+cpu.txt b/requirements+cpu.txt index 56b122e..02cf66e 100644 --- a/requirements+cpu.txt +++ b/requirements+cpu.txt @@ -1,7 +1,7 @@ -unimernet -matplotlib +unimernet>=0.1.6 +matplotlib<=3.9.0 PyMuPDF ultralytics -paddlepaddle +paddlepaddle==2.6.1 paddleocr==2.7.3 struct-eqtable==0.1.0 \ No newline at end of file diff --git a/requirements-without-unimernet+cpu.txt b/requirements-without-unimernet+cpu.txt deleted file mode 100644 index e3fbf19..0000000 --- a/requirements-without-unimernet+cpu.txt +++ /dev/null @@ -1,6 +0,0 @@ -matplotlib -PyMuPDF -ultralytics -paddlepaddle -paddleocr==2.7.3 -struct-eqtable==0.1.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b6a98a5..f1ec32f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -unimernet +unimernet>=0.1.6 matplotlib PyMuPDF ultralytics