From 8049901326cb5b12c9ba26e43faca85c5d9472b9 Mon Sep 17 00:00:00 2001 From: Yao You Date: Wed, 11 Sep 2024 16:52:11 -0500 Subject: [PATCH] refactor page layout elements --- .../models/test_yolox.py | 13 +++--- unstructured_inference/inference/elements.py | 11 +++++ unstructured_inference/inference/layout.py | 10 ++++- .../inference/layoutelement.py | 45 +++++++++++++++---- .../models/unstructuredmodel.py | 21 ++++----- 5 files changed, 70 insertions(+), 30 deletions(-) diff --git a/test_unstructured_inference/models/test_yolox.py b/test_unstructured_inference/models/test_yolox.py index 9e6f19ac..f108d449 100644 --- a/test_unstructured_inference/models/test_yolox.py +++ b/test_unstructured_inference/models/test_yolox.py @@ -15,14 +15,15 @@ def test_layout_yolox_local_parsing_image(): assert len(document_layout.pages) == 1 # NOTE(benjamin) The example sent to the test contains 13 detections types_known = ["Text", "Section-header", "Page-header"] - known_regions = [e for e in document_layout.pages[0].elements if e.type in types_known] + elements = document_layout.pages[0].elements + known_regions = [ + e for e in elements.element_class_ids if elements.element_class_id_map[e] in types_known + ] assert len(known_regions) == 13 - assert hasattr( - document_layout.pages[0].elements[0], - "prob", - ) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities + # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities + assert hasattr(elements, "element_probs") assert isinstance( - document_layout.pages[0].elements[0].prob, + elements.element_probs[0], float, ) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float diff --git a/unstructured_inference/inference/elements.py b/unstructured_inference/inference/elements.py index 72373532..fbc4000d 100644 --- a/unstructured_inference/inference/elements.py +++ b/unstructured_inference/inference/elements.py @@ -212,6 +212,17 @@ class TextRegions: texts: np.array | None = None source: Source | None = None + def __post_init__(self): + if self.texts is None: + self.texts = np.array([None] * self.element_coords.shape[0]) + + def slice(self, indices) -> TextRegions: + return TextRegions( + element_coord=self.element_coords[indices], + texts=self.texts[indices], + source=self.source, + ) + def as_list(self): if self.texts is None: return [ diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index ad313826..2e1cb669 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -2,6 +2,7 @@ import os import tempfile +from functools import cached_property from pathlib import PurePath from typing import Any, BinaryIO, Collection, List, Optional, Union, cast @@ -147,10 +148,15 @@ def __init__( self.detection_model = detection_model self.element_extraction_model = element_extraction_model self.elements: Collection[LayoutElement] = [] + self.elements_array: LayoutElements | None = None # NOTE(alan): Dropped LocationlessLayoutElement that was created for chipper - chipper has # locations now and if we need to support LayoutElements without bounding boxes we can make # the bbox property optional + @cached_property + def elements(self) -> list[LayoutElement]: + return self.elements_array.as_list() + def __str__(self) -> str: return "\n\n".join([str(element) for element in self.elements]) @@ -173,7 +179,7 @@ def get_elements_using_image_extraction( def get_elements_with_detection_model( self, inplace: bool = True, - ) -> Optional[List[LayoutElement]]: + ) -> LayoutElements | None: """Uses specified model to detect the elements on the page.""" if self.detection_model is None: model = get_model() @@ -191,7 +197,7 @@ def get_elements_with_detection_model( ) if inplace: - self.elements = inferred_layout + self.elements_array = inferred_layout return None return inferred_layout diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py index 54f2e386..9589dbe0 100644 --- a/unstructured_inference/inference/layoutelement.py +++ b/unstructured_inference/inference/layoutelement.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Collection, List, Optional +from typing import Collection, Iterable, List, Optional import numpy as np from layoutparser.elements.layout import TextBlock @@ -34,6 +34,40 @@ class LayoutElements(TextRegions): element_class_ids: np.ndarray | None = None element_class_id_map: dict[int, str] | None = None + def __post_init__(self): + for attr in ("element_probs", "element_class_ids", "texts"): + if getattr(self, attr) is None: + setattr(self, attr, np.array([None] * self.element_coords.shape[0])) + + def slice(self, indices) -> LayoutElements: + return LayoutElements( + element_coords=self.element_coords[indices], + texts=self.texts[indices], + source=self.source, + element_probs=self.element_probs[indices], + element_class_ids=self.element_class_ids[indices], + element_class_id_map=self.element_class_id_map, + ) + + @classmethod + def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements: + coords, texts, probs, class_ids = [], [], [], [] + class_id_map = {} + for group in groups: + coords.append(group.element_coords) + texts.append(group.texts) + probs.append(group.element_probs) + class_ids.append(group.element_class_ids) + class_id_map.update(group.element_class_id_map) + return cls( + element_coords=np.concatenate(coords), + texts=np.concatenate(texts), + element_probs=np.concatenate(probs), + element_class_ids=np.concatenate(class_ids), + element_class_id_map=class_id_map, + source=group.source, + ) + def as_list(self) -> list[LayoutElement]: """for backward compatibility""" len_elements = self.element_coords.shape[0] @@ -290,14 +324,7 @@ def partition_groups_from_regions(regions: TextRegions) -> List[TextRegions]: group_count, group_nums = connected_components(intersection_mtx) groups: List[TextRegions] = [] for group in range(group_count): - indices = np.where(group_nums == group)[0] - groups.append( - TextRegions( - element_coords=regions.element_coords[indices], - texts=regions.texts[indices], - source=regions.source, - ), - ) + groups.append(regions.slice(np.where(group_nums == group)[0])) return groups diff --git a/unstructured_inference/models/unstructuredmodel.py b/unstructured_inference/models/unstructuredmodel.py index 5dafe086..15198fb2 100644 --- a/unstructured_inference/models/unstructuredmodel.py +++ b/unstructured_inference/models/unstructuredmodel.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, List, cast +from typing import Any, List import numpy as np from PIL.Image import Image @@ -12,17 +12,13 @@ intersections, ) from unstructured_inference.inference.layoutelement import ( + LayoutElement, + LayoutElements, clean_layoutelements, partition_groups_from_regions, separate, ) -if TYPE_CHECKING: - from unstructured_inference.inference.layoutelement import ( - LayoutElement, - LayoutElements, - ) - class UnstructuredModel(ABC): """Wrapper class for the various models used by unstructured.""" @@ -181,15 +177,14 @@ def deduplicate_detected_elements( if len(elements) <= 1: return elements - cleaned_elements: LayoutElements = [] + cleaned_elements = [] # TODO: Delete nested elements with low or None probability # TODO: Keep most confident # TODO: Better to grow horizontally than vertically? - groups_tmp = partition_groups_from_regions(elements) - groups = cast(List[LayoutElements], groups_tmp) - for elements in groups: - cleaned_elements.extend(clean_layoutelements(elements)) - return cleaned_elements + groups = partition_groups_from_regions(elements) + for group in groups: + cleaned_elements.append(clean_layoutelements(group)) + return LayoutElements.concatenate(cleaned_elements) class UnstructuredElementExtractionModel(UnstructuredModel):