From 42eebd3ce684c05031dc36ca757ed439c0dd06d8 Mon Sep 17 00:00:00 2001 From: Nathan Van Gheem Date: Thu, 17 Oct 2024 13:40:17 -0400 Subject: [PATCH] Remove chipper model (#395) --- CHANGELOG.md | 4 + .../models/test_chippermodel.py | 448 ------- unstructured_inference/__version__.py | 2 +- unstructured_inference/constants.py | 12 - .../inference/layoutelement.py | 3 - unstructured_inference/models/base.py | 9 +- unstructured_inference/models/chipper.py | 1152 ----------------- 7 files changed, 6 insertions(+), 1624 deletions(-) delete mode 100644 test_unstructured_inference/models/test_chippermodel.py delete mode 100644 unstructured_inference/models/chipper.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 38ea33db..f9b3f0c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.42 + +* Remove chipper model + ## 0.7.41 * fix: fix incorrect type casting with higher versions of `numpy` when substracting a `float` from an `int` array diff --git a/test_unstructured_inference/models/test_chippermodel.py b/test_unstructured_inference/models/test_chippermodel.py deleted file mode 100644 index c68aa6bc..00000000 --- a/test_unstructured_inference/models/test_chippermodel.py +++ /dev/null @@ -1,448 +0,0 @@ -from unittest import mock - -import pytest -import torch -from PIL import Image -from unstructured_inference.inference.layoutelement import LayoutElement - -from unstructured_inference.models import chipper -from unstructured_inference.models.base import get_model - - -def test_initialize(): - with mock.patch.object( - chipper.DonutProcessor, - "from_pretrained", - ) as mock_donut_processor, mock.patch.object( - chipper, - "NoRepeatNGramLogitsProcessor", - ) as mock_logits_processor, mock.patch.object( - chipper.VisionEncoderDecoderModel, - "from_pretrained", - ) as mock_vision_encoder_decoder_model: - model = chipper.UnstructuredChipperModel() - model.initialize("", "", "", "", "", "", "", "", "") - mock_donut_processor.assert_called_once() - mock_logits_processor.assert_called_once() - mock_vision_encoder_decoder_model.assert_called_once() - - -class MockToList: - def tolist(self): - return [[5, 4, 3, 2, 1]] - - -class MockModel: - def generate(*args, **kwargs): - return {"cross_attentions": mock.MagicMock(), "sequences": [[5, 4, 3, 2, 1]]} - return MockToList() - - -def mock_initialize(self, *arg, **kwargs): - self.model = MockModel() - self.model.encoder = mock.MagicMock() - self.stopping_criteria = mock.MagicMock() - self.processor = mock.MagicMock() - self.logits_processor = mock.MagicMock() - self.input_ids = mock.MagicMock() - self.device = "cpu" - self.max_length = 1200 - - -def test_predict_tokens(): - with mock.patch.object(chipper.UnstructuredChipperModel, "initialize", mock_initialize): - model = chipper.UnstructuredChipperModel() - model.initialize() - with open("sample-docs/loremipsum.png", "rb") as fp: - im = Image.open(fp) - tokens, _ = model.predict_tokens(im) - assert tokens == [5, 4, 3, 2, 1] - - -@pytest.mark.parametrize( - ("decoded_str", "expected_classes", "expected_parent_ids"), - [ - ( - "Hi buddy!There is some text here.", - ["Title", "Text"], - [None, None], - ), - ( - "Hi buddy!There is some text here.", - ["Title", "Text"], - [None, None], - ), - ( - "Hi buddy!", - ["List", "List-item"], - [None, 0], - ), - ], -) -def test_postprocess(decoded_str, expected_classes, expected_parent_ids): - model = chipper.UnstructuredChipperModel() - pre_trained_model = "unstructuredio/ved-fine-tuning" - model.initialize( - pre_trained_model, - prompt="", - swap_head=False, - swap_head_hidden_layer_size=128, - start_token_prefix=" 1There is some text here.", - ["Misc", "Text"], - ), - ( - "Text here.Final one", - ["List", "List-item", "List", "List-item"], - ), - ], -) -def test_postprocess_bbox(decoded_str, expected_classes): - model = chipper.UnstructuredChipperModel() - pre_trained_model = "unstructuredio/ved-fine-tuning" - model.initialize( - pre_trained_model, - prompt="", - swap_head=False, - swap_head_hidden_layer_size=128, - start_token_prefix=" 0 - - -def test_largest_margin_edge(): - model = get_model("chipper") - img = Image.open("sample-docs/easy_table.jpg") - output = model.largest_margin(image=img, input_bbox=[0, 1, 0, 0], transpose=False) - - assert output is None - - output = model.largest_margin(img, [1, 1, 1, 1], False) - - assert output is None - - output = model.largest_margin(img, [2, 1, 3, 10], True) - - assert output == (0, 0, 0) - - -def test_deduplicate_detected_elements(): - model = get_model("chipper") - img = Image.open("sample-docs/easy_table.jpg") - elements = model(img) - - output = model.deduplicate_detected_elements(elements) - - assert len(output) == 2 - - -def test_norepeatnGramlogitsprocessor_exception(): - with pytest.raises(ValueError): - chipper.NoRepeatNGramLogitsProcessor(ngram_size="") - - -def test_run_chipper_v3(): - model = get_model("chipper") - img = Image.open("sample-docs/easy_table.jpg") - elements = model(img) - tables = [el for el in elements if el.type == "Table"] - assert all(table.text_as_html.startswith("") for table in tables) - assert all("
" not in table.text for table in tables) - - -@pytest.mark.parametrize( - ("bbox", "output"), - [ - ( - [0, 0, 0, 0], - None, - ), - ( - [0, 1, 1, -1], - None, - ), - ], -) -def test_largest_margin(bbox, output): - model = get_model("chipper") - img = Image.open("sample-docs/easy_table.jpg") - assert model.largest_margin(img, bbox) is output - - -@pytest.mark.parametrize( - ("bbox", "output"), - [ - ( - [0, 1, 0, -1], - [0, 1, 0, -1], - ), - ( - [0, 1, 1, -1], - [0, 1, 1, -1], - ), - ( - [20, 10, 30, 40], - [20, 10, 30, 40], - ), - ], -) -def test_reduce_bbox_overlap(bbox, output): - model = get_model("chipper") - img = Image.open("sample-docs/easy_table.jpg") - assert model.reduce_bbox_overlap(img, bbox) == output - - -@pytest.mark.parametrize( - ("bbox", "output"), - [ - ( - [20, 10, 30, 40], - [20, 10, 30, 40], - ), - ], -) -def test_reduce_bbox_no_overlap(bbox, output): - model = get_model("chipper") - img = Image.open("sample-docs/easy_table.jpg") - assert model.reduce_bbox_no_overlap(img, bbox) == output - - -@pytest.mark.parametrize( - ("bbox1", "bbox2", "output"), - [ - ( - [0, 50, 20, 80], - [10, 10, 30, 30], - ( - "horizontal", - [10, 10, 30, 30], - [0, 50, 20, 80], - [0, 50, 20, 80], - [10, 10, 30, 30], - None, - ), - ), - ( - [10, 10, 30, 30], - [40, 10, 60, 30], - ( - "vertical", - [40, 10, 60, 30], - [10, 10, 30, 30], - [10, 10, 30, 30], - [40, 10, 60, 30], - None, - ), - ), - ( - [10, 80, 30, 100], - [40, 10, 60, 30], - ( - "none", - [40, 10, 60, 30], - [10, 80, 30, 100], - [10, 80, 30, 100], - [40, 10, 60, 30], - None, - ), - ), - ( - [40, 10, 60, 30], - [10, 10, 30, 30], - ( - "vertical", - [10, 10, 30, 30], - [40, 10, 60, 30], - [10, 10, 30, 30], - [40, 10, 60, 30], - None, - ), - ), - ], -) -def test_check_overlap(bbox1, bbox2, output): - model = get_model("chipper") - - assert model.check_overlap(bbox1, bbox2) == output - - -def test_format_table_elements(): - table_html = "
Cell 1Cell 2
Cell 3
" - texts = [ - "Text", - " - List element", - table_html, - None, - ] - elements = [LayoutElement(bbox=mock.MagicMock(), text=text) for text in texts] - formatted_elements = chipper.UnstructuredChipperModel.format_table_elements(elements) - text_attributes = [fe.text for fe in formatted_elements] - text_as_html_attributes = [ - fe.text_as_html if hasattr(fe, "text_as_html") else None for fe in formatted_elements - ] - assert text_attributes == [ - "Text", - " - List element", - "Cell 1Cell 2Cell 3", - None, - ] - assert text_as_html_attributes == [None, None, table_html, None] diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index bb85a995..6c64f4d1 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.41" # pragma: no cover +__version__ = "0.7.42" # pragma: no cover diff --git a/unstructured_inference/constants.py b/unstructured_inference/constants.py index 173e37b4..de20fd3f 100644 --- a/unstructured_inference/constants.py +++ b/unstructured_inference/constants.py @@ -5,21 +5,9 @@ class Source(Enum): YOLOX = "yolox" DETECTRON2_ONNX = "detectron2_onnx" DETECTRON2_LP = "detectron2_lp" - CHIPPER = "chipper" - CHIPPERV1 = "chipperv1" - CHIPPERV2 = "chipperv2" - CHIPPERV3 = "chipperv3" MERGED = "merged" -CHIPPER_VERSIONS = ( - Source.CHIPPER, - Source.CHIPPERV1, - Source.CHIPPERV2, - Source.CHIPPERV3, -) - - class ElementType: PARAGRAPH = "Paragraph" IMAGE = "Image" diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py index 1d20d498..49482171 100644 --- a/unstructured_inference/inference/layoutelement.py +++ b/unstructured_inference/inference/layoutelement.py @@ -10,7 +10,6 @@ from unstructured_inference.config import inference_config from unstructured_inference.constants import ( - CHIPPER_VERSIONS, FULL_PAGE_REGION_THRESHOLD, ElementType, Source, @@ -222,8 +221,6 @@ def merge_inferred_layout_with_extracted_layout( continue region_matched = False for inferred_region in inferred_layout: - if inferred_region.source in CHIPPER_VERSIONS: - continue if inferred_region.bbox.intersects(extracted_region.bbox): same_bbox = region_bounding_boxes_are_almost_the_same( diff --git a/unstructured_inference/models/base.py b/unstructured_inference/models/base.py index 92785ae3..8a5b13de 100644 --- a/unstructured_inference/models/base.py +++ b/unstructured_inference/models/base.py @@ -4,8 +4,6 @@ import os from typing import Dict, Optional, Tuple, Type -from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES -from unstructured_inference.models.chipper import UnstructuredChipperModel from unstructured_inference.models.detectron2onnx import ( MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES, ) @@ -28,12 +26,7 @@ def get_default_model_mappings() -> Tuple[ return { **{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES}, **{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES}, - **{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES}, - }, { - **DETECTRON2_ONNX_MODEL_TYPES, - **YOLOX_MODEL_TYPES, - **CHIPPER_MODEL_TYPES, - } + }, {**DETECTRON2_ONNX_MODEL_TYPES, **YOLOX_MODEL_TYPES} model_class_map, model_config_map = get_default_model_mappings() diff --git a/unstructured_inference/models/chipper.py b/unstructured_inference/models/chipper.py deleted file mode 100644 index 7fcf0f5b..00000000 --- a/unstructured_inference/models/chipper.py +++ /dev/null @@ -1,1152 +0,0 @@ -import copy -import os -import platform -from contextlib import nullcontext -from typing import ContextManager, Dict, List, Optional, Sequence, TextIO, Tuple, Union - -import cv2 -import numpy as np -import torch -import transformers -from cv2.typing import MatLike -from PIL.Image import Image -from transformers import DonutProcessor, VisionEncoderDecoderModel -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.stopping_criteria import StoppingCriteria - -from unstructured_inference.constants import CHIPPER_VERSIONS, Source -from unstructured_inference.inference.elements import Rectangle -from unstructured_inference.inference.layoutelement import LayoutElement -from unstructured_inference.logger import logger -from unstructured_inference.models.unstructuredmodel import ( - UnstructuredElementExtractionModel, -) -from unstructured_inference.utils import LazyDict, download_if_needed_and_get_local_path, strip_tags - -MODEL_TYPES: Dict[str, Union[LazyDict, dict]] = { - "chipperv1": { - "pre_trained_model_repo": "unstructuredio/ved-fine-tuning", - "swap_head": False, - "start_token_prefix": "", - "max_length": 1200, - "heatmap_h": 52, - "heatmap_w": 39, - "source": Source.CHIPPERV1, - }, - "chipperv2": { - "pre_trained_model_repo": "unstructuredio/chipper-fast-fine-tuning", - "swap_head": True, - "swap_head_hidden_layer_size": 128, - "start_token_prefix": "", - "max_length": 1536, - "heatmap_h": 40, - "heatmap_w": 30, - "source": Source.CHIPPERV2, - }, - "chipperv3": { - "pre_trained_model_repo": "unstructuredio/chipper-v3", - "swap_head": True, - "swap_head_hidden_layer_size": 128, - "start_token_prefix": "", - "max_length": 1536, - "heatmap_h": 40, - "heatmap_w": 30, - "source": Source.CHIPPER, - }, -} - -MODEL_TYPES["chipper"] = MODEL_TYPES["chipperv3"] - - -class UnstructuredChipperModel(UnstructuredElementExtractionModel): - def initialize( - self, - pre_trained_model_repo: str, - start_token_prefix: str, - prompt: str, - max_length: int, - heatmap_h: int, - heatmap_w: int, - source: Source, - swap_head: bool = False, - swap_head_hidden_layer_size: int = 0, - no_repeat_ngram_size: int = 10, - auth_token: Optional[str] = os.environ.get("UNSTRUCTURED_HF_TOKEN"), - device: Optional[str] = None, - ): - """Load the model for inference.""" - if device is None: - if torch.cuda.is_available(): - self.device = "cuda" - else: - self.device = "cpu" - else: - self.device = device - self.max_length = max_length - self.heatmap_h = heatmap_h - self.heatmap_w = heatmap_w - self.source = source - self.processor = DonutProcessor.from_pretrained( - pre_trained_model_repo, - token=auth_token, - ) - self.tokenizer = self.processor.tokenizer - self.logits_processor = [ - NoRepeatNGramLogitsProcessor( - no_repeat_ngram_size, - get_table_token_ids(self.processor), - ), - ] - - self.stopping_criteria = [ - NGramRepetitonStoppingCriteria( - repetition_window=30, - skip_tokens=get_table_token_ids(self.processor), - ), - ] - - self.model = VisionEncoderDecoderModel.from_pretrained( - pre_trained_model_repo, - ignore_mismatched_sizes=True, - token=auth_token, - ) - if swap_head: - lm_head_file = download_if_needed_and_get_local_path( - path_or_repo=pre_trained_model_repo, - filename="lm_head.pth", - token=auth_token, - ) - rank = swap_head_hidden_layer_size - self.model.decoder.lm_head = torch.nn.Sequential( - torch.nn.Linear( - self.model.decoder.lm_head.weight.shape[1], - rank, - bias=False, - ), - torch.nn.Linear(rank, rank, bias=False), - torch.nn.Linear( - rank, - self.model.decoder.lm_head.weight.shape[0], - bias=True, - ), - ) - self.model.decoder.lm_head.load_state_dict(torch.load(lm_head_file)) - else: - if swap_head_hidden_layer_size is not None: - logger.warning( - f"swap_head is False but recieved value {swap_head_hidden_layer_size} for " - "swap_head_hidden_layer_size, which will be ignored.", - ) - - self.input_ids = self.processor.tokenizer( - prompt, - add_special_tokens=False, - return_tensors="pt", - ).input_ids.to(self.device) - - self.start_tokens = [ - v - for k, v in self.processor.tokenizer.get_added_vocab().items() - if k.startswith(start_token_prefix) and v not in self.input_ids - ] - self.end_tokens = [ - v for k, v in self.processor.tokenizer.get_added_vocab().items() if k.startswith(" List[LayoutElement]: - """Do inference using the wrapped model.""" - tokens, decoder_cross_attentions = self.predict_tokens(image) - elements = self.format_table_elements( - self.postprocess(image, tokens, decoder_cross_attentions), - ) - return elements - - @staticmethod - def format_table_elements(elements: List[LayoutElement]) -> List[LayoutElement]: - """Makes chipper table element return the same as other layout models. - - 1. If `text` attribute is an html (has html tags in it), copies the `text` - attribute to `text_as_html` attribute. - 2. Strips html tags from the `text` attribute. - """ - for element in elements: - text = strip_tags(element.text) if element.text is not None else element.text - if text != element.text: - element.text_as_html = element.text # type: ignore[attr-defined] - element.text = text - return elements - - def predict_tokens( - self, - image: Image, - ) -> Tuple[List[int], Sequence[Sequence[torch.Tensor]]]: - """Predict tokens from image.""" - transformers.set_seed(42) - - with torch.no_grad(): - amp: Union[TextIO, ContextManager[None]] = ( - torch.cuda.amp.autocast() - if self.device == "cuda" - else (torch.cpu.amp.autocast() if platform.machine() == "x86_64" else nullcontext()) - ) - with amp: - encoder_outputs = self.model.encoder( - self.processor( - image, - return_tensors="pt", - ).pixel_values.to(self.device), - ) - - outputs = self.model.generate( - encoder_outputs=encoder_outputs, - input_ids=self.input_ids, - max_length=self.max_length, - no_repeat_ngram_size=0, - num_beams=1, - return_dict_in_generate=True, - output_attentions=True, - output_scores=True, - output_hidden_states=False, - stopping_criteria=self.stopping_criteria, - ) - - if ( - len(outputs["sequences"][0]) < self.max_length - and outputs["sequences"][0][-1] != self.processor.tokenizer.eos_token_id - ): - outputs = self.model.generate( - encoder_outputs=encoder_outputs, - input_ids=self.input_ids, - max_length=self.max_length, - logits_processor=self.logits_processor, - do_sample=True, - no_repeat_ngram_size=0, - num_beams=3, - return_dict_in_generate=True, - output_attentions=True, - output_scores=True, - output_hidden_states=False, - ) - - offset = len(self.input_ids) - - if "beam_indices" in outputs: - decoder_cross_attentions = [[torch.Tensor(0)]] * offset - - for token_id in range(0, len(outputs["beam_indices"][0])): - if outputs["beam_indices"][0][token_id] == -1: - break - - token_attentions = [] - - for decoder_layer_id in range( - len(outputs["cross_attentions"][token_id]), - ): - token_attentions.append( - outputs["cross_attentions"][token_id][decoder_layer_id][ - outputs["beam_indices"][0][token_id] - ].unsqueeze(0), - ) - - decoder_cross_attentions.append(token_attentions) - else: - decoder_cross_attentions = [[torch.Tensor(0)]] * offset + list( - outputs["cross_attentions"], - ) - - tokens = outputs["sequences"][0] - - return tokens, decoder_cross_attentions - - def update_parent_bbox(self, element): - """Update parents bboxes in a recursive way""" - # Check if children, if so update parent bbox - if element.parent is not None: - parent = element.parent - # parent has no box, create one - if parent.bbox is None: - parent.bbox = copy.copy(element.bbox) - else: - # adjust parent box - parent.bbox.x1 = min(parent.bbox.x1, element.bbox.x1) - parent.bbox.y1 = min(parent.bbox.y1, element.bbox.y1) - parent.bbox.x2 = max(parent.bbox.x2, element.bbox.x2) - parent.bbox.y2 = max(parent.bbox.y2, element.bbox.y2) - - self.update_parent_bbox(element.parent) - - def postprocess( - self, - image: Image, - output_ids: List[int], - decoder_cross_attentions: Sequence[Sequence[torch.Tensor]], - ) -> List[LayoutElement]: - """Process tokens into layout elements.""" - elements: List[LayoutElement] = [] - parents: List[LayoutElement] = [] - start = end = -1 - - x_offset, y_offset, ratio = self.image_padding( - image.size, - ( - self.processor.image_processor.size["width"], - self.processor.image_processor.size["height"], - ), - ) - - # Get bboxes - skip first token - bos - for i in range(1, len(output_ids)): - # Finish bounding box generation - if output_ids[i] in self.tokens_stop: - break - if output_ids[i] in self.start_tokens: - # Create the element - stype = self.tokenizer.decode(output_ids[i]) - # Identify parent - parent = parents[-1] if len(parents) > 0 else None - # Add to parent list - element = LayoutElement( - parent=parent, - type=stype[3:-1], - text="", - bbox=None, # type: ignore - source=self.source, - ) - - parents.append(element) - elements.append(element) - start = -1 - elif output_ids[i] in self.end_tokens: - if start != -1 and start <= end and len(parents) > 0: - slicing_end = end + 1 - string = self.tokenizer.decode(output_ids[start:slicing_end]) - - element = parents.pop(-1) - - element.text = string - bbox_coords = self.get_bounding_box( - decoder_cross_attentions=decoder_cross_attentions, - tkn_indexes=list(range(start, end + 1)), - final_w=self.processor.image_processor.size["width"], - final_h=self.processor.image_processor.size["height"], - heatmap_w=self.heatmap_w, - heatmap_h=self.heatmap_h, - ) - - bbox_coords = self.adjust_bbox( - bbox_coords, - x_offset, - y_offset, - ratio, - image.size, - ) - - element.bbox = Rectangle(*bbox_coords) - - self.update_parent_bbox(element) - - start = -1 - else: - if start == -1: - start = i - - end = i - - # If exited before eos is achieved - if start != -1 and start <= end and len(parents) > 0: - slicing_end = end + 1 - string = self.tokenizer.decode(output_ids[start:slicing_end]) - - element = parents.pop(-1) - element.text = string - - bbox_coords = self.get_bounding_box( - decoder_cross_attentions=decoder_cross_attentions, - tkn_indexes=list(range(start, end + 1)), - final_w=self.processor.image_processor.size["width"], - final_h=self.processor.image_processor.size["height"], - heatmap_w=self.heatmap_w, - heatmap_h=self.heatmap_h, - ) - - bbox_coords = self.adjust_bbox( - bbox_coords, - x_offset, - y_offset, - ratio, - image.size, - ) - - element.bbox = Rectangle(*bbox_coords) - - self.update_parent_bbox(element) - - # Reduce bounding boxes - for element in elements: - self.reduce_element_bbox(image, elements, element) - - # Solve overlaps - self.resolve_bbox_overlaps(image, elements) - - return elements - - def deduplicate_detected_elements( - self, - elements: List[LayoutElement], - min_text_size: int = 15, - ) -> List[LayoutElement]: - """For chipper, remove elements from other sources.""" - return [el for el in elements if el.source in CHIPPER_VERSIONS] - - def adjust_bbox(self, bbox, x_offset, y_offset, ratio, target_size): - """Translate bbox by (x_offset, y_offset) and shrink by ratio.""" - return [ - max((bbox[0] - x_offset) / ratio, 0), - max((bbox[1] - y_offset) / ratio, 0), - min((bbox[2] - x_offset) / ratio, target_size[0]), - min((bbox[3] - y_offset) / ratio, target_size[1]), - ] - - def get_bounding_box( - self, - decoder_cross_attentions: Sequence[Sequence[torch.Tensor]], - tkn_indexes: List[int], - final_h: int = 1280, - final_w: int = 960, - heatmap_h: int = 40, - heatmap_w: int = 30, - discard_ratio: float = 0.99, - ) -> List[int]: - """ - decoder_cross_attention: tuple(tuple(torch.FloatTensor)) - Tuple (one element for each generated token) of tuples (one element for - each layer of the decoder) of `torch.FloatTensor` of shape - `(batch_size, num_heads, generated_length, sequence_length)` - """ - agg_heatmap = np.zeros([final_h, final_w], dtype=np.uint8) - - for tidx in tkn_indexes: - hmaps = torch.stack(list(decoder_cross_attentions[tidx]), dim=0) - # shape [4, 1, 16, 1, 1200] -> [1, 4, 16, 1200] - hmaps = hmaps.permute(1, 3, 0, 2, 4).squeeze(0) - # shape [4, 1, 16, 1, 1200]->[4, 16, 1200] - hmaps = hmaps[-1] - # change shape [4, 16, 1200]->[4, 16, 40, 30] assuming (heatmap_h, heatmap_w) = (40, 30) - hmaps = hmaps.view(4, 16, heatmap_h, heatmap_w) - hmaps = torch.where(hmaps > 0.65, hmaps, 0.0) - # fusing 16 decoder attention heads i.e. [4, 16, 40, 30]-> [4, 40, 30] - hmaps = torch.max(hmaps, dim=1)[0] - # hmaps = torch.mean(hmaps, dim=1) - # fusing 4 decoder layers from BART i.e. [4, 40, 30]-> [4, 40, 30] - hmap = torch.max(hmaps, dim=0)[0] - # hmap = torch.mean(hmaps, dim=0) - - # dropping discard ratio activations - flat = hmap.view(heatmap_h * heatmap_w) - _, indices = flat.topk(int(flat.size(-1) * discard_ratio), largest=False) - flat[indices] = 0 - - hmap = flat.view(heatmap_h, heatmap_w) - - hmap = hmap.unsqueeze(dim=-1).to(torch.float32).cpu().numpy() # type: ignore - hmap = (hmap * 255.0).astype(np.uint8) # type:ignore - # (40, 30, 1) uint8 - # fuse heatmaps for different tokens by taking the max - agg_heatmap = np.max( - np.asarray( - [ - agg_heatmap, - cv2.resize( # type: ignore - hmap, - (final_w, final_h), - interpolation=cv2.INTER_LINEAR_EXACT, # cv2.INTER_CUBIC - ), - ], - ), - axis=0, - ).astype(np.uint8) - - # threshold to remove small attention pockets - thres_heatmap = cv2.threshold( - agg_heatmap, - 200, - 255, - cv2.THRESH_BINARY + cv2.THRESH_OTSU, - )[1] - - # kernel = np.ones((1, 20), np.uint8) - # thres_heatmap = cv2.dilate(thres_heatmap, kernel, iterations=2) - - """ - kernel = np.ones((5, 1), np.uint8) - thres_heatmap = cv2.erode(thres_heatmap, kernel, iterations=1) - """ - # Find contours - contours = cv2.findContours( - thres_heatmap, - cv2.RETR_EXTERNAL, - cv2.CHAIN_APPROX_SIMPLE, - ) - contours_selection = contours[0] if len(contours) == 2 else contours[1] - - bboxes = [cv2.boundingRect(ctr) for ctr in contours_selection] - - if len(bboxes) > 1: - kernel = np.ones((1, 50), np.uint8) - thres_heatmap = cv2.dilate(thres_heatmap, kernel, iterations=1) - - contours = cv2.findContours( - thres_heatmap, - cv2.RETR_EXTERNAL, - cv2.CHAIN_APPROX_SIMPLE, - ) - contours_selection = contours[0] if len(contours) == 2 else contours[1] - - bboxes = [cv2.boundingRect(ctr) for ctr in contours_selection] - - try: - # return box with max area - x, y, w, h = max(bboxes, key=lambda box: box[2] * box[3]) - - return [x, y, x + w, y + h] - except ValueError: - return [0, 0, 1, 1] - - def reduce_element_bbox( - self, - image: Image, - elements: List[LayoutElement], - element: LayoutElement, - ): - """ - Given a LayoutElement element, reduce the size of the bounding box, - depending on existing elements - """ - if element.bbox: - bbox = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2] - - if not self.element_overlap(elements, element): - element.bbox = Rectangle(*self.reduce_bbox_no_overlap(image, bbox)) - else: - element.bbox = Rectangle(*self.reduce_bbox_overlap(image, bbox)) - - def bbox_overlap( - self, - bboxa: List[float], - bboxb: List[float], - ) -> bool: - """ - Check if bounding boxes bboxa and bboxb overlap - """ - x1_a, y1_a, x2_a, y2_a = bboxa - x1_b, y1_b, x2_b, y2_b = bboxb - - return bool(x1_a <= x2_b and x1_b <= x2_a and y1_a <= y2_b and y1_b <= y2_a) - - def element_overlap( - self, - elements: List[LayoutElement], - element: LayoutElement, - ) -> bool: - """ - Check if an element overlaps with existing elements - """ - bboxb = [ - element.bbox.x1, - element.bbox.y1, - element.bbox.x2, - max(element.bbox.y1, element.bbox.y2), - ] - - for check_element in elements: - if check_element == element: - continue - - if self.bbox_overlap( - [ - check_element.bbox.x1, - check_element.bbox.y1, - check_element.bbox.x2, - max(check_element.bbox.y1, check_element.bbox.y2), - ], - bboxb, - ): - return True - - return False - - def remove_horizontal_lines( - self, - img_dst: MatLike, - ) -> MatLike: - """ - Remove horizontal lines in an image - """ - gray_dst = cv2.cvtColor(img_dst, cv2.COLOR_BGR2GRAY) - edges = cv2.Canny(gray_dst, 50, 150, apertureSize=5) - horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 1)) - - # Using morph open to get lines inside the drawing - opening = cv2.morphologyEx(edges, cv2.MORPH_OPEN, horizontal_kernel) - cnts = cv2.findContours(opening, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - cnts_selection = cnts[0] if len(cnts) == 2 else cnts[1] - mask = np.zeros(gray_dst.shape, np.uint8) - for c in cnts_selection: - cv2.drawContours(mask, [c], -1, (255, 255, 255), 6) - - return cv2.inpaint(img_dst, mask, 3, cv2.INPAINT_TELEA) - - def reduce_bbox_no_overlap( - self, - image: Image, - input_bbox: List[float], - ) -> List[float]: - """ - If an element does not overlap with other elements, remove any empty space around it - """ - input_bbox = [int(b) for b in input_bbox] - - if ( - (input_bbox[2] * input_bbox[3] <= 0) - or (input_bbox[2] < input_bbox[0]) - or (input_bbox[3] < input_bbox[1]) - ): - return input_bbox - - nimage = np.array(image.crop(input_bbox)) # type: ignore - - nimage = self.remove_horizontal_lines(nimage) - - # Convert the image to grayscale - gray = cv2.bitwise_not(cv2.cvtColor(nimage, cv2.COLOR_BGR2GRAY)) - - # Apply Canny edge detection - edges = cv2.Canny(gray, 50, 150) - - # Find the contours in the edge image - contours, _ = cv2.findContours( - edges, - cv2.RETR_EXTERNAL, - cv2.CHAIN_APPROX_SIMPLE, - ) - - try: - largest_contour = np.vstack(contours) - x, y, w, h = cv2.boundingRect(largest_contour) - - except ValueError: - return input_bbox - - return [ - input_bbox[0] + x, - input_bbox[1] + y, - input_bbox[0] + x + w, - input_bbox[1] + y + h, - ] - - def reduce_bbox_overlap( - self, - image: Image, - input_bbox: List[float], - ) -> List[float]: - """ - If an element does overlap with other elements, reduce bouding box by selecting the largest - bbox after blurring existing text - """ - input_bbox = [int(b) for b in input_bbox] - - if ( - (input_bbox[2] * input_bbox[3] <= 0) - or (input_bbox[2] < input_bbox[0]) - or (input_bbox[3] < input_bbox[1]) - ): - return input_bbox - - nimage = np.array(image.crop(input_bbox)) # type: ignore - - nimage = self.remove_horizontal_lines(nimage) - - center_h = nimage.shape[0] / 2 - center_w = nimage.shape[1] / 2 - - gray = cv2.bitwise_not(cv2.cvtColor(nimage, cv2.COLOR_BGR2GRAY)) - edges = cv2.Canny(gray, 50, 150, apertureSize=3) - - nim = cv2.threshold(edges, 1, 255, cv2.THRESH_BINARY)[1] - binary_mask = cv2.threshold(nim, 127, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] - - if center_h > center_w: - kernel = np.ones((1, 80), np.uint8) - binary_mask = cv2.dilate(binary_mask, kernel, iterations=1) - - kernel = np.ones((40, 1), np.uint8) - binary_mask = cv2.dilate(binary_mask, kernel, iterations=1) - else: - # kernel = np.ones((8, 1), np.uint8) - # binary_mask = cv2.erode(binary_mask, kernel, iterations=1) - - kernel = np.ones((1, 80), np.uint8) - binary_mask = cv2.dilate(binary_mask, kernel, iterations=1) - - kernel = np.ones((30, 1), np.uint8) - binary_mask = cv2.dilate(binary_mask, kernel, iterations=1) - - contours = cv2.findContours( - binary_mask, - cv2.RETR_EXTERNAL, - cv2.CHAIN_APPROX_NONE, - ) - - contours_selection = contours[0] if len(contours) == 2 else contours[1] - - bboxes = [cv2.boundingRect(ctr) for ctr in contours_selection] - - nbboxes = [ - bbox - for bbox in bboxes - if bbox[2] * bbox[3] > 1 - and (bbox[0] < center_w and bbox[0] + bbox[2] > center_w) - and (bbox[1] < center_h and bbox[1] + bbox[3] > center_h) - ] - - if len(nbboxes) == 0: - nbboxes = [bbox for bbox in bboxes if bbox[2] * bbox[3] > 1] - - if len(nbboxes) == 0: - return input_bbox - - x, y, w, h = max(nbboxes, key=lambda box: box[2] * box[3]) - - return [ - input_bbox[0] + x, - input_bbox[1] + y, - input_bbox[0] + x + w, - input_bbox[1] + y + h, - ] - - def separation_margins( - self, - mapping: MatLike, - ) -> Optional[List[Tuple[int, int, int]]]: - """ - Find intervals with no content - """ - current_start = None - - margins = [] - - for i, value in enumerate(mapping): - if value[0] == 0: - if current_start is None: - current_start = i - else: - if current_start is not None: - if current_start > 0: - margins.append((current_start, i, i - current_start)) - current_start = None - - if current_start is not None: - margins.append((current_start, i, i - current_start)) - - return margins - - def largest_margin( - self, - image: Image, - input_bbox: Tuple[float, float, float, float], - transpose: bool = False, - ) -> Optional[Tuple[int, int, int]]: - """ - Find the largest region with no text - """ - if ( - (input_bbox[2] * input_bbox[3] <= 0) - or (input_bbox[2] < input_bbox[0]) - or (input_bbox[3] < input_bbox[1]) - ): - return None - - nimage = np.array(image.crop(input_bbox)) # type: ignore - - if nimage.shape[0] * nimage.shape[1] == 0: - return None - - if transpose: - nimage = np.swapaxes(nimage, 0, 1) - - nimage = self.remove_horizontal_lines(nimage) - - gray = cv2.bitwise_not(cv2.cvtColor(nimage, cv2.COLOR_BGR2GRAY)) - - edges = cv2.Canny(gray, 50, 150, apertureSize=3) - - mapping = cv2.reduce(edges, 1, cv2.REDUCE_MAX) - - margins = self.separation_margins(mapping) - - if margins is None or len(margins) == 0: - return None - - largest_margin = sorted(margins, key=lambda x: x[2], reverse=True)[0] - - return largest_margin - - def check_overlap( - self, - box1: List[float], - box2: List[float], - ) -> Tuple[ - str, - List[float], - List[float], - List[float], - List[float], - Optional[Tuple[float, float, float, float]], - ]: - """ - Check the overlap between two bounding boxes, return properties of the overlap and - the bbox of the overlapped region - """ - # Get the coordinates of the two boxes - x1, y1, x11, y11 = box1 - x2, y2, x21, y21 = box2 - - w1 = x11 - x1 - h1 = y11 - y1 - w2 = x21 - x2 - h2 = y21 - y2 - - # Check for horizontal overlap - horizontal_overlap = bool(x1 <= x2 + w2 and x1 + w1 >= x2) - - # Check for vertical overlap - vertical_overlap = bool(y1 <= y2 + h2 and y1 + h1 >= y2) - - # Check for both horizontal and vertical overlap - if horizontal_overlap and vertical_overlap: - overlap_type = "both" - intersection_x1 = max(x1, x2) - intersection_y1 = max(y1, y2) - intersection_x2 = min(x1 + w1, x2 + w2) - intersection_y2 = min(y1 + h1, y2 + h2) - overlapping_bbox = ( - intersection_x1, - intersection_y1, - intersection_x2, - intersection_y2, - ) - elif horizontal_overlap and not vertical_overlap: - overlap_type = "horizontal" - overlapping_bbox = None - elif not horizontal_overlap and vertical_overlap: - overlap_type = "vertical" - overlapping_bbox = None - else: - overlap_type = "none" - overlapping_bbox = None - - # Check which box is on top and/or left - if y1 < y2: - top_box = box1 - bottom_box = box2 - else: - top_box = box2 - bottom_box = box1 - - if x1 < x2: - left_box = box1 - right_box = box2 - else: - left_box = box2 - right_box = box1 - - return overlap_type, top_box, bottom_box, left_box, right_box, overlapping_bbox - - def resolve_bbox_overlaps( - self, - image: Image, - elements: List[LayoutElement], - ): - """ - Resolve overlapping bounding boxes - """ - for element in elements: - if element.parent is not None: - continue - - ebbox1 = element.bbox - if ebbox1 is None: - continue - bbox1 = [ebbox1.x1, ebbox1.y1, ebbox1.x2, max(ebbox1.y1, ebbox1.y2)] - - for celement in elements: - if element == celement or celement.parent is not None: - continue - - ebbox2 = celement.bbox - bbox2 = [ebbox2.x1, ebbox2.y1, ebbox2.x2, max(ebbox2.y1, ebbox2.y2)] - - if self.bbox_overlap(bbox1, bbox2): - check = self.check_overlap(bbox1, bbox2) - - # For resolution, we should be sure that the overlap in the other dimension - # is large - if ( - check[-1] - and check[-1][0] > 0 - # and (check[0] == "vertical" or check[0] == "both") - and (bbox1[2] - bbox1[0]) / check[-1][0] > 0.9 - and (bbox2[2] - bbox2[0]) / check[-1][0] > 0.9 - ): - margin = self.largest_margin(image, check[-1]) - - if margin: - # Check with box is on top - if bbox1 == check[1]: - bbox1[3] -= margin[0] - bbox2[1] += margin[1] - else: - bbox2[3] -= margin[0] - bbox1[1] += margin[1] - - element.bbox = Rectangle(*bbox1) - celement.bbox = Rectangle(*bbox2) - - check = self.check_overlap(bbox1, bbox2) - - # We need to identify cases with horizontal alignment after - # vertical resolution. This is commented out for now. - """ - # For resolution, we should be sure that the overlap in the other dimension - # is large - if ( - check[-1] - and check[-1][0] > 0 - and (bbox1[3] - bbox1[1]) / check[-1][1] > 0.9 - and (bbox2[3] - bbox2[1]) / check[-1][1] > 0.9 - ): - margin = self.largest_margin( - image, - check[-1], - transpose=True, - ) - if margin: - # Check with box is on top - if bbox1 == check[3]: - bbox1[2] -= margin[0] - bbox2[0] += margin[1] - else: - bbox2[2] -= margin[0] - bbox1[0] += margin[1] - - element.bbox = Rectangle(*bbox1) - celement.bbox = Rectangle(*bbox2) - """ - - def image_padding(self, input_size, target_size): - """ - Resize an image to a defined size, preserving the aspect ratio, - and pad with a background color to maintain aspect ratio. - - Args: - input_size (tuple): Size of the input in the format (width, height). - target_size (tuple): Desired target size in the format (width, height). - """ - # Calculate the aspect ratio of the input and target sizes - input_width, input_height = input_size # 612, 792 - target_width, target_height = target_size # 960, 1280 - aspect_ratio_input = input_width / input_height # 0.773 - aspect_ratio_target = target_width / target_height # 0.75 - - # Determine the size of the image when resized to fit within the - # target size while preserving the aspect ratio - if aspect_ratio_input > aspect_ratio_target: - # Resize the image to fit the target width and calculate the new height - new_width = target_width - new_height = int(new_width / aspect_ratio_input) - else: - # Resize the image to fit the target height and calculate the new width - new_height = target_height - new_width = int(new_height * aspect_ratio_input) - - # Calculate the position to paste the resized image to center it within the target size - x_offset = (target_width - new_width) // 2 - y_offset = (target_height - new_height) // 2 - - return x_offset, y_offset, new_width / input_width - - -# Inspired on -# https://github.com/huggingface/transformers/blob/8e3980a290acc6d2f8ea76dba111b9ef0ef00309/src/transformers/generation/logits_process.py#L706 -class NoRepeatNGramLogitsProcessor(LogitsProcessor): - def __init__(self, ngram_size: int, skip_tokens: Optional[Sequence[int]] = None): - if not isinstance(ngram_size, int) or ngram_size <= 0: - raise ValueError( - f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}", - ) - self.ngram_size = ngram_size - self.skip_tokens = skip_tokens - - def __call__( - self, - input_ids: torch.LongTensor, - scores: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - [What are input IDs?](../glossary#input-ids) - scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): - Prediction scores of a language modeling head. These can be logits - for each vocabulary when not using beam search or log softmax for - each vocabulary token when using beam search - - Return: - `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The - processed prediction scores. - - """ - num_batch_hypotheses = scores.shape[0] - cur_len = input_ids.shape[-1] - return _no_repeat_ngram_logits( - input_ids, - cur_len, - scores, - batch_size=num_batch_hypotheses, - no_repeat_ngram_size=self.ngram_size, - skip_tokens=self.skip_tokens, - ) - - -class NGramRepetitonStoppingCriteria(StoppingCriteria): - def __init__(self, repetition_window: int, skip_tokens: set = set()): - self.repetition_window = repetition_window - self.skip_tokens = skip_tokens - - def __call__( - self, - input_ids: torch.LongTensor, - scores: torch.FloatTensor, - **kwargs, - ) -> bool: - """ - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] - and [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): - Prediction scores of a language modeling head. These can be scores for each - vocabulary token before SoftMax or scores for each vocabulary token after SoftMax. - kwargs (`Dict[str, Any]`, *optional*): - Additional stopping criteria specific kwargs. - - Return: - `bool`. `False` indicates we should continue, `True` indicates we should stop. - - """ - num_batch_hypotheses = input_ids.shape[0] - cur_len = input_ids.shape[-1] - - for banned_tokens in _calc_banned_tokens( - input_ids, - num_batch_hypotheses, - self.repetition_window, - cur_len, - ): - for token in banned_tokens: - if token not in self.skip_tokens: - return True - - return False - - -def _no_repeat_ngram_logits( - input_ids: torch.LongTensor, - cur_len: int, - logits: torch.FloatTensor, - batch_size: int = 1, - no_repeat_ngram_size: int = 0, - skip_tokens: Optional[Sequence[int]] = None, -) -> torch.FloatTensor: - if no_repeat_ngram_size > 0: - # calculate a list of banned tokens to prevent repetitively generating the same ngrams - # from fairseq: - # https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 - banned_tokens = _calc_banned_tokens( - input_ids, - batch_size, - no_repeat_ngram_size, - cur_len, - ) - for batch_idx in range(batch_size): - if skip_tokens is not None: - logits[ - batch_idx, - [token for token in banned_tokens[batch_idx] if int(token) not in skip_tokens], - ] = -float("inf") - else: - logits[batch_idx, banned_tokens[batch_idx]] = -float("inf") - - return logits - - -def _calc_banned_tokens( - prev_input_ids: torch.LongTensor, - num_hypos: int, - no_repeat_ngram_size: int, - cur_len: int, -) -> List[Tuple[int, ...]]: - # Copied from fairseq for no_repeat_ngram in beam_search""" - if cur_len + 1 < no_repeat_ngram_size: - # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet - return [() for _ in range(num_hypos)] - generated_ngrams: List[Dict[Tuple[int, ...], List[int]]] = [{} for _ in range(num_hypos)] - for idx in range(num_hypos): - gen_tokens = prev_input_ids[idx].tolist() - generated_ngram = generated_ngrams[idx] - for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]): - prev_ngram_tuple = tuple(ngram[:-1]) - generated_ngram[prev_ngram_tuple] = generated_ngram.get( - prev_ngram_tuple, - [], - ) + [ - ngram[-1], - ] - - def _get_generated_ngrams(hypo_idx): - # Before decoding the next token, prevent decoding of ngrams that have already appeared - start_idx = cur_len + 1 - no_repeat_ngram_size - ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist()) - - return generated_ngrams[hypo_idx].get(ngram_idx, []) - - banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)] - return banned_tokens - - -def get_table_token_ids(processor): - """ - Extracts the table tokens from the tokenizer of the processor - - Args: - processor (DonutProcessor): processor used to pre-process the images and text. - """ - skip_tokens = { - token_id - for token, token_id in processor.tokenizer.get_added_vocab().items() - if token.startswith("