diff --git a/models/README.md b/models/README.md index 7ae4982..eeeb4d9 100644 --- a/models/README.md +++ b/models/README.md @@ -45,7 +45,7 @@ Put [model files]() here: ./ ├── Layout │ ├── config.json -│ └── weights.pth +│ └── model_final.pth ├── MFD │ └── weights.pt ├── MFR diff --git a/modules/layoutlmv3/layoutlmv3_base_inference.yaml b/modules/layoutlmv3/layoutlmv3_base_inference.yaml index 6f47cb9..3e0fd30 100644 --- a/modules/layoutlmv3/layoutlmv3_base_inference.yaml +++ b/modules/layoutlmv3/layoutlmv3_base_inference.yaml @@ -1,6 +1,6 @@ AUG: DETR: true -CACHE_DIR: /mnt/localdata/users/yupanhuang/cache/huggingface +CACHE_DIR: ~/cache/huggingface CUDNN_BENCHMARK: false DATALOADER: ASPECT_RATIO_GROUPING: true @@ -294,7 +294,7 @@ MODEL: POS_TYPE: abs WEIGHTS: OUTPUT_DIR: -SCIHUB_DATA_DIR_TRAIN: /mnt/petrelfs/share_data/zhaozhiyuan/publaynet/layout_scihub/train +SCIHUB_DATA_DIR_TRAIN: ~/publaynet/layout_scihub/train SEED: 42 SOLVER: AMP: diff --git a/pdf_extract.py b/pdf_extract.py index 3dccf1f..603a577 100644 --- a/pdf_extract.py +++ b/pdf_extract.py @@ -113,14 +113,13 @@ def __getitem__(self, idx): if img_list is None: continue print("pdf index:", idx, "pages:", len(img_list)) - # layout检测 + 公式检测 + # layout detection and formula detection doc_layout_result = [] latex_filling_list = [] mf_image_list = [] for idx, image in enumerate(img_list): img_H, img_W = image.shape[0], image.shape[1] layout_res = layout_model(image, ignore_catids=[]) - # 公式检测 mfd_res = mfd_model.predict(image, imgsz=img_size, conf=conf_thres, iou=iou_thres, verbose=True)[0] for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] @@ -142,7 +141,7 @@ def __getitem__(self, idx): ) doc_layout_result.append(layout_res) - # 公式识别,因为识别速度较慢,为了提速,把单个pdf的所有公式裁剪完,一起批量做识别。 + # Formula recognition, collect all formula images in whole pdf file, then batch infer them. a = time.time() dataset = MathDataset(mf_image_list, transform=mfr_transform) dataloader = DataLoader(dataset, batch_size=128, num_workers=32) @@ -156,7 +155,7 @@ def __getitem__(self, idx): b = time.time() print("formula nums:", len(mf_image_list), "mfr time:", round(b-a, 2)) - # ocr识别 + # ocr for idx, image in enumerate(img_list): pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) single_page_res = doc_layout_result[idx]['layout_dets'] @@ -169,7 +168,7 @@ def __getitem__(self, idx): "bbox": [xmin, ymin, xmax, ymax], }) for res in single_page_res: - if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: #需要进行ocr的类别 + if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: #categories that need to do ocr xmin, ymin = int(res['poly'][0]), int(res['poly'][1]) xmax, ymax = int(res['poly'][4]), int(res['poly'][5]) crop_box = [xmin, ymin, xmax, ymax] @@ -208,19 +207,19 @@ def __getitem__(self, idx): draw = ImageDraw.Draw(vis_img) for res in single_page_res: label = int(res['category_id']) - if label > 15: # 筛选要可视化的类别 + if label > 15: # categories that do not need visualize continue label_name = id2names[label] x_min, y_min = int(res['poly'][0]), int(res['poly'][1]) x_max, y_max = int(res['poly'][4]), int(res['poly'][5]) if args.render and label in [13, 14, 15]: try: - if label in [13, 14]: # 渲染公式 + if label in [13, 14]: # render formula window_img = tex2pil(res['latex'])[0] else: - if True: # 渲染中文 + if True: # render chinese window_img = zhtext2pil(res['text']) - else: # 渲染英文 + else: # render english window_img = tex2pil([res['text']], tex_type="text")[0] ratio = min((x_max - x_min) / window_img.width, (y_max - y_min) / window_img.height) - 0.05 window_img = window_img.resize((int(window_img.width * ratio), int(window_img.height * ratio)))