Skip to content

Commit

Permalink
refactor page layout elements
Browse files Browse the repository at this point in the history
  • Loading branch information
badGarnet committed Sep 11, 2024
1 parent 39caa99 commit 8049901
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 30 deletions.
13 changes: 7 additions & 6 deletions test_unstructured_inference/models/test_yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ def test_layout_yolox_local_parsing_image():
assert len(document_layout.pages) == 1
# NOTE(benjamin) The example sent to the test contains 13 detections
types_known = ["Text", "Section-header", "Page-header"]
known_regions = [e for e in document_layout.pages[0].elements if e.type in types_known]
elements = document_layout.pages[0].elements
known_regions = [
e for e in elements.element_class_ids if elements.element_class_id_map[e] in types_known
]
assert len(known_regions) == 13
assert hasattr(
document_layout.pages[0].elements[0],
"prob",
) # NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities
# NOTE(pravin) New Assertion to Make Sure LayoutElement has probabilities
assert hasattr(elements, "element_probs")
assert isinstance(
document_layout.pages[0].elements[0].prob,
elements.element_probs[0],
float,
) # NOTE(pravin) New Assertion to Make Sure Populated Probability is Float

Expand Down
11 changes: 11 additions & 0 deletions unstructured_inference/inference/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,17 @@ class TextRegions:
texts: np.array | None = None
source: Source | None = None

def __post_init__(self):
if self.texts is None:
self.texts = np.array([None] * self.element_coords.shape[0])

def slice(self, indices) -> TextRegions:
return TextRegions(
element_coord=self.element_coords[indices],
texts=self.texts[indices],
source=self.source,
)

def as_list(self):
if self.texts is None:
return [
Expand Down
10 changes: 8 additions & 2 deletions unstructured_inference/inference/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import tempfile
from functools import cached_property
from pathlib import PurePath
from typing import Any, BinaryIO, Collection, List, Optional, Union, cast

Expand Down Expand Up @@ -147,10 +148,15 @@ def __init__(
self.detection_model = detection_model
self.element_extraction_model = element_extraction_model
self.elements: Collection[LayoutElement] = []
self.elements_array: LayoutElements | None = None
# NOTE(alan): Dropped LocationlessLayoutElement that was created for chipper - chipper has
# locations now and if we need to support LayoutElements without bounding boxes we can make
# the bbox property optional

@cached_property
def elements(self) -> list[LayoutElement]:
return self.elements_array.as_list()

def __str__(self) -> str:
return "\n\n".join([str(element) for element in self.elements])

Expand All @@ -173,7 +179,7 @@ def get_elements_using_image_extraction(
def get_elements_with_detection_model(
self,
inplace: bool = True,
) -> Optional[List[LayoutElement]]:
) -> LayoutElements | None:
"""Uses specified model to detect the elements on the page."""
if self.detection_model is None:
model = get_model()
Expand All @@ -191,7 +197,7 @@ def get_elements_with_detection_model(
)

if inplace:
self.elements = inferred_layout
self.elements_array = inferred_layout
return None

return inferred_layout
Expand Down
45 changes: 36 additions & 9 deletions unstructured_inference/inference/layoutelement.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Collection, List, Optional
from typing import Collection, Iterable, List, Optional

import numpy as np
from layoutparser.elements.layout import TextBlock
Expand Down Expand Up @@ -34,6 +34,40 @@ class LayoutElements(TextRegions):
element_class_ids: np.ndarray | None = None
element_class_id_map: dict[int, str] | None = None

def __post_init__(self):
for attr in ("element_probs", "element_class_ids", "texts"):
if getattr(self, attr) is None:
setattr(self, attr, np.array([None] * self.element_coords.shape[0]))

def slice(self, indices) -> LayoutElements:
return LayoutElements(
element_coords=self.element_coords[indices],
texts=self.texts[indices],
source=self.source,
element_probs=self.element_probs[indices],
element_class_ids=self.element_class_ids[indices],
element_class_id_map=self.element_class_id_map,
)

@classmethod
def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
coords, texts, probs, class_ids = [], [], [], []
class_id_map = {}
for group in groups:
coords.append(group.element_coords)
texts.append(group.texts)
probs.append(group.element_probs)
class_ids.append(group.element_class_ids)
class_id_map.update(group.element_class_id_map)
return cls(
element_coords=np.concatenate(coords),
texts=np.concatenate(texts),
element_probs=np.concatenate(probs),
element_class_ids=np.concatenate(class_ids),
element_class_id_map=class_id_map,
source=group.source,
)

def as_list(self) -> list[LayoutElement]:
"""for backward compatibility"""
len_elements = self.element_coords.shape[0]
Expand Down Expand Up @@ -290,14 +324,7 @@ def partition_groups_from_regions(regions: TextRegions) -> List[TextRegions]:
group_count, group_nums = connected_components(intersection_mtx)
groups: List[TextRegions] = []
for group in range(group_count):
indices = np.where(group_nums == group)[0]
groups.append(
TextRegions(
element_coords=regions.element_coords[indices],
texts=regions.texts[indices],
source=regions.source,
),
)
groups.append(regions.slice(np.where(group_nums == group)[0]))

return groups

Expand Down
21 changes: 8 additions & 13 deletions unstructured_inference/models/unstructuredmodel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, cast
from typing import Any, List

import numpy as np
from PIL.Image import Image
Expand All @@ -12,17 +12,13 @@
intersections,
)
from unstructured_inference.inference.layoutelement import (
LayoutElement,
LayoutElements,
clean_layoutelements,
partition_groups_from_regions,
separate,
)

if TYPE_CHECKING:
from unstructured_inference.inference.layoutelement import (
LayoutElement,
LayoutElements,
)


class UnstructuredModel(ABC):
"""Wrapper class for the various models used by unstructured."""
Expand Down Expand Up @@ -181,15 +177,14 @@ def deduplicate_detected_elements(
if len(elements) <= 1:
return elements

cleaned_elements: LayoutElements = []
cleaned_elements = []
# TODO: Delete nested elements with low or None probability
# TODO: Keep most confident
# TODO: Better to grow horizontally than vertically?
groups_tmp = partition_groups_from_regions(elements)
groups = cast(List[LayoutElements], groups_tmp)
for elements in groups:
cleaned_elements.extend(clean_layoutelements(elements))
return cleaned_elements
groups = partition_groups_from_regions(elements)
for group in groups:
cleaned_elements.append(clean_layoutelements(group))
return LayoutElements.concatenate(cleaned_elements)


class UnstructuredElementExtractionModel(UnstructuredModel):
Expand Down

0 comments on commit 8049901

Please sign in to comment.