diff --git a/README-zh_CN.md b/README-zh_CN.md index 0e948c0..b719d9e 100644 --- a/README-zh_CN.md +++ b/README-zh_CN.md @@ -229,16 +229,10 @@ conda create -n pipeline python=3.10 pip install -r requirements.txt -pip install --extra-index-url https://miropsota.github.io/torch_packages_builder detectron2==0.6+pt2.3.1cu121 +pip install https://github.com/opendatalab/PDF-Extract-Kit/raw/main/assets/whl/detectron2-0.6-cp310-cp310-linux_x86_64.whl ``` -安装完环境后,可能会遇到一些版本冲突导致版本变更,如果遇到了版本相关的报错,可以尝试下面的命令重新安装指定版本的库。 - -```bash -pip install pillow==8.4.0 -``` - -除了版本冲突外,可能还会遇到torch无法调用的错误,可以先把下面的库卸载,然后重新安装cuda12和cudnn。 +安装完环境后,可能还会遇到torch无法调用的错误,可以先把下面的库卸载,然后重新安装cuda12和cudnn。 ```bash pip uninstall nvidia-cusparse-cu12 @@ -260,7 +254,7 @@ pip uninstall nvidia-cusparse-cu12 ## 运行提取脚本 ```bash -python pdf_extract.py --pdf data/pdfs/ocr_1.pdf +python pdf_extract.py --pdf assets/examples/example.pdf ``` 相关参数解释: diff --git a/README.md b/README.md index 3a4e508..4e4a289 100644 --- a/README.md +++ b/README.md @@ -214,23 +214,17 @@ The formula recognition we used is based on the weights downloaded from [UniMERN The table recognition we used is based on the weights downloaded from [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy), a solution that converts images of Table into LaTeX. Compared to the table recognition capability of PP-StructureV2, StructEqTable demonstrates stronger recognition performance, delivering good results even with complex tables, which may currently be best suited for data within research papers. There is also significant room for improvement in terms of speed, and we are continuously iterating and optimizing. Within a week, we will update the table recognition capability to [MinerU](https://github.com/opendatalab/MinerU). -## Installation Guide +## Installation Guide(Linux) ```bash conda create -n pipeline python=3.10 pip install -r requirements.txt -pip install --extra-index-url https://miropsota.github.io/torch_packages_builder detectron2==0.6+pt2.3.1cu121 +pip install https://github.com/opendatalab/PDF-Extract-Kit/raw/main/assets/whl/detectron2-0.6-cp310-cp310-linux_x86_64.whl ``` -After installation, you may encounter some version conflicts leading to version changes. If you encounter version-related errors, you can try the following commands to reinstall specific versions of the libraries. - -```bash -pip install pillow==8.4.0 -``` - -In addition to version conflicts, you may also encounter errors where torch cannot be invoked. First, uninstall the following library and then reinstall cuda12 and cudnn. +After installation, you may also encounter errors where torch cannot be invoked. First, uninstall the following library and then reinstall cuda12 and cudnn. ```bash pip uninstall nvidia-cusparse-cu12 @@ -255,7 +249,7 @@ If you intend to experience this project on Google Colab, please 3000 pixels, don't enlarge the image - if pix.width > 3000 or pix.height > 3000: - pix = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) - image = Image.frombytes('RGB', (pix.width, pix.height), pix.samples) + # If the width or height exceeds 9000 after scaling, do not scale further. + if pm.width > 9000 or pm.height > 9000: + pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) - # images.append(image) - images.append(np.array(image)[:,:,::-1]) + img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples) + images.append(np.array(img)) return images diff --git a/modules/self_modify.py b/modules/self_modify.py index 1f830b8..29bb99e 100644 --- a/modules/self_modify.py +++ b/modules/self_modify.py @@ -76,49 +76,100 @@ def sorted_boxes(dt_boxes): return _boxes -def formula_in_text(mf_bbox, text_bbox): - x1, y1, x2, y2 = mf_bbox - x3, y3 = text_bbox[0] - x4, y4 = text_bbox[2] - left_box, right_box = None, None - same_line = abs((y1+y2)/2 - (y3+y4)/2) / abs(y4-y3) < 0.2 - if not same_line: - return False, left_box, right_box - else: - drop_origin = False - left_x = x1 - 1 - right_x = x2 + 1 - if x3 < x1 and x2 < x4: - drop_origin = True - left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32') - right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32') - if x3 < x1 and x1 <= x4 <= x2: - drop_origin = True - left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32') - if x1 <= x3 <= x2 and x2 < x4: - drop_origin = True - right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32') - if x1 <= x3 < x4 <= x2: - drop_origin = True - return drop_origin, left_box, right_box - - -def update_det_boxes(dt_boxes, mfdetrec_res): - new_dt_boxes = dt_boxes - for mf_box in mfdetrec_res: - flag, left_box, right_box = False, None, None - for idx, text_box in enumerate(new_dt_boxes): - ret, left_box, right_box = formula_in_text(mf_box['bbox'], text_box) - if ret: - new_dt_boxes.pop(idx) - if left_box is not None: - new_dt_boxes.append(left_box) - if right_box is not None: - new_dt_boxes.append(right_box) - break - +def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8): + """Check if two bounding boxes overlap on the y-axis, and if the height of the overlapping region exceeds 80% of the height of the shorter bounding box.""" + _, y0_1, _, y1_1 = bbox1 + _, y0_2, _, y1_2 = bbox2 + + overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2)) + height1, height2 = y1_1 - y0_1, y1_2 - y0_2 + max_height = max(height1, height2) + min_height = min(height1, height2) + + return (overlap / min_height) > overlap_ratio_threshold + + +def bbox_to_points(bbox): + """ change bbox(shape: N * 4) to polygon(shape: N * 8) """ + x0, y0, x1, y1 = bbox + return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32') + + +def points_to_bbox(points): + """ change polygon(shape: N * 8) to bbox(shape: N * 4) """ + x0, y0 = points[0] + x1, _ = points[1] + _, y1 = points[2] + return [x0, y0, x1, y1] + + +def merge_intervals(intervals): + # Sort the intervals based on the start value + intervals.sort(key=lambda x: x[0]) + + merged = [] + for interval in intervals: + # If the list of merged intervals is empty or if the current + # interval does not overlap with the previous, simply append it. + if not merged or merged[-1][1] < interval[0]: + merged.append(interval) + else: + # Otherwise, there is overlap, so we merge the current and previous intervals. + merged[-1][1] = max(merged[-1][1], interval[1]) + + return merged + + +def remove_intervals(original, masks): + # Merge all mask intervals + merged_masks = merge_intervals(masks) + + result = [] + original_start, original_end = original + + for mask in merged_masks: + mask_start, mask_end = mask + + # If the mask starts after the original range, ignore it + if mask_start > original_end: + continue + + # If the mask ends before the original range starts, ignore it + if mask_end < original_start: + continue + + # Remove the masked part from the original range + if original_start < mask_start: + result.append([original_start, mask_start - 1]) + + original_start = max(mask_end + 1, original_start) + + # Add the remaining part of the original range, if any + if original_start <= original_end: + result.append([original_start, original_end]) + + return result + + +def update_det_boxes(dt_boxes, mfd_res): + new_dt_boxes = [] + for text_box in dt_boxes: + text_bbox = points_to_bbox(text_box) + masks_list = [] + for mf_box in mfd_res: + mf_bbox = mf_box['bbox'] + if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox): + masks_list.append([mf_bbox[0], mf_bbox[2]]) + text_x_range = [text_bbox[0], text_bbox[2]] + text_remove_mask_range = remove_intervals(text_x_range, masks_list) + temp_dt_box = [] + for text_remove_mask in text_remove_mask_range: + temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]])) + if len(temp_dt_box) > 0: + new_dt_boxes.extend(temp_dt_box) return new_dt_boxes + class ModifiedPaddleOCR(PaddleOCR): def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)): """ @@ -257,4 +308,4 @@ def __call__(self, img, cls=True, mfd_res=None): filter_rec_res.append(rec_result) end = time.time() time_dict['all'] = end - start - return filter_boxes, filter_rec_res, time_dict \ No newline at end of file + return filter_boxes, filter_rec_res, time_dict diff --git a/pdf_extract.py b/pdf_extract.py index 97610e1..5f68a30 100644 --- a/pdf_extract.py +++ b/pdf_extract.py @@ -111,7 +111,7 @@ def __getitem__(self, idx): else: all_pdfs = [args.pdf] print("total files:", len(all_pdfs)) - for idx, single_pdf in enumerate(all_pdfs): + for pdf_idx, single_pdf in enumerate(all_pdfs): try: img_list = load_pdf_fitz(single_pdf, dpi=dpi) except: @@ -119,7 +119,7 @@ def __getitem__(self, idx): print("unexpected pdf file:", single_pdf) if img_list is None: continue - print("pdf index:", idx, "pages:", len(img_list)) + print("pdf index:", pdf_idx, "pages:", len(img_list)) # layout detection and formula detection doc_layout_result = [] latex_filling_list = [] @@ -151,7 +151,7 @@ def __getitem__(self, idx): # Formula recognition, collect all formula images in whole pdf file, then batch infer them. a = time.time() dataset = MathDataset(mf_image_list, transform=mfr_transform) - dataloader = DataLoader(dataset, batch_size=128, num_workers=32) + dataloader = DataLoader(dataset, batch_size=64, num_workers=0) mfr_res = [] for imgs in dataloader: imgs = imgs.to(device) @@ -161,51 +161,103 @@ def __getitem__(self, idx): res['latex'] = latex_rm_whitespace(latex) b = time.time() print("formula nums:", len(mf_image_list), "mfr time:", round(b-a, 2)) + + def crop_img(input_res, input_pil_img, padding_x=0, padding_y=0): + crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1]) + crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5]) + # Create a white background with an additional width and height of 50 + crop_new_width = crop_xmax - crop_xmin + padding_x * 2 + crop_new_height = crop_ymax - crop_ymin + padding_y * 2 + return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white') + + # Crop image + crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax) + cropped_img = input_pil_img.crop(crop_box) + return_image.paste(cropped_img, (padding_x, padding_y)) + return_list = [padding_x, padding_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height] + return return_image, return_list # ocr and table recognition for idx, image in enumerate(img_list): - pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) - single_page_res = doc_layout_result[idx]['layout_dets'] + + layout_res = doc_layout_result[idx]['layout_dets'] + pil_img = Image.fromarray(image) + + ocr_res_list = [] + table_res_list = [] single_page_mfdetrec_res = [] - for res in single_page_res: + + for res in layout_res: if int(res['category_id']) in [13, 14]: - xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) - xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) single_page_mfdetrec_res.append({ - "bbox": [xmin, ymin, xmax, ymax], + "bbox": [int(res['poly'][0]), int(res['poly'][1]), + int(res['poly'][4]), int(res['poly'][5])], }) - for res in single_page_res: - if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: #categories that need to do ocr - xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) - xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) - crop_box = [xmin, ymin, xmax, ymax] - cropped_img = Image.new('RGB', pil_img.size, 'white') - cropped_img.paste(pil_img.crop(crop_box), crop_box) - cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR) - ocr_res = ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0] - if ocr_res: - for box_ocr_res in ocr_res: - p1, p2, p3, p4 = box_ocr_res[0] - text, score = box_ocr_res[1] - doc_layout_result[idx]['layout_dets'].append({ - 'category_id': 15, - 'poly': p1 + p2 + p3 + p4, - 'score': round(score, 2), - 'text': text, - }) - elif int(res['category_id']) == 5: # do table recognition - xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) - xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) - crop_box = [xmin, ymin, xmax, ymax] - cropped_img = pil_img.convert("RGB").crop(crop_box) - start = time.time() - with torch.no_grad(): - output = tr_model(cropped_img) - end = time.time() - if (end-start) > model_configs['model_args']['table_max_time']: - res["timeout"] = True - res["latex"] = output[0] + elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]: + ocr_res_list.append(res) + elif int(res['category_id']) in [5]: + table_res_list.append(res) + + ocr_start = time.time() + # Process each area that requires OCR processing + for res in ocr_res_list: + new_image, useful_list = crop_img(res, pil_img, padding_x=25, padding_y=25) + paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list + # Adjust the coordinates of the formula area + adjusted_mfdetrec_res = [] + for mf_res in single_page_mfdetrec_res: + mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"] + # Adjust the coordinates of the formula area to the coordinates relative to the cropping area + x0 = mf_xmin - xmin + paste_x + y0 = mf_ymin - ymin + paste_y + x1 = mf_xmax - xmin + paste_x + y1 = mf_ymax - ymin + paste_y + # Filter formula blocks outside the graph + if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]): + continue + else: + adjusted_mfdetrec_res.append({ + "bbox": [x0, y0, x1, y1], + }) + + # OCR recognition + new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) + ocr_res = ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] + + # Integration results + if ocr_res: + for box_ocr_res in ocr_res: + p1, p2, p3, p4 = box_ocr_res[0] + text, score = box_ocr_res[1] + + # Convert the coordinates back to the original coordinate system + p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin] + p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin] + p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin] + p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin] + + layout_res.append({ + 'category_id': 15, + 'poly': p1 + p2 + p3 + p4, + 'score': round(score, 2), + 'text': text, + }) + + ocr_cost = round(time.time() - ocr_start, 2) + print(f"ocr cost: {ocr_cost}") + + table_start = time.time() + for res in table_res_list: + new_image, _ = crop_img(res, pil_img) + single_table_start = time.time() + with torch.no_grad(): + output = tr_model(new_image) + if (time.time() - single_table_start) > model_configs['model_args']['table_max_time']: + res["timeout"] = True + res["latex"] = output[0] + table_cost = round(time.time() - table_start, 2) + print(f"table cost: {table_cost}") output_dir = args.output os.makedirs(output_dir, exist_ok=True) @@ -223,7 +275,7 @@ def __getitem__(self, idx): vis_pdf_result = [] for idx, image in enumerate(img_list): single_page_res = doc_layout_result[idx]['layout_dets'] - vis_img = Image.new('RGB', Image.fromarray(image).size, 'white') if args.render else Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + vis_img = Image.new('RGB', Image.fromarray(image).size, 'white') if args.render else Image.fromarray(image) draw = ImageDraw.Draw(vis_img) for res in single_page_res: label = int(res['category_id']) @@ -265,4 +317,4 @@ def __getitem__(self, idx): now = datetime.datetime.now(tz) end = time.time() print(now.strftime('%Y-%m-%d %H:%M:%S')) - print('Finished! time cost:', int(end-start), 's') \ No newline at end of file + print('Finished! time cost:', int(end-start), 's') diff --git a/requirements+cpu.txt b/requirements+cpu.txt index 56b122e..02cf66e 100644 --- a/requirements+cpu.txt +++ b/requirements+cpu.txt @@ -1,7 +1,7 @@ -unimernet -matplotlib +unimernet>=0.1.6 +matplotlib<=3.9.0 PyMuPDF ultralytics -paddlepaddle +paddlepaddle==2.6.1 paddleocr==2.7.3 struct-eqtable==0.1.0 \ No newline at end of file diff --git a/requirements-without-unimernet+cpu.txt b/requirements-without-unimernet+cpu.txt deleted file mode 100644 index e3fbf19..0000000 --- a/requirements-without-unimernet+cpu.txt +++ /dev/null @@ -1,6 +0,0 @@ -matplotlib -PyMuPDF -ultralytics -paddlepaddle -paddleocr==2.7.3 -struct-eqtable==0.1.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b6a98a5..f1ec32f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -unimernet +unimernet>=0.1.6 matplotlib PyMuPDF ultralytics