From c53866bbd8cc3f0bb2879a6ad910f8d5ef79728e Mon Sep 17 00:00:00 2001 From: Yao You Date: Wed, 16 Oct 2024 13:16:06 -0500 Subject: [PATCH] fix: fix bugs in layoutelements (#393) This PR fixes two bugs: - fix a type casting issue when subtracting a int array with a float. This popped up when testing with `unstructured`, and some sources of element coordinates are of `int` type. This PR adds a new unit test case for `int` coord type with the grouping function - fix element class id 0 becomes None bug: this happens when dumping `LayoutElements` as a list of `LayoutElement`. When an element class id is 0 the logic on main would treat it as no existing and use `None` as the type. --- CHANGELOG.md | 5 ++ test_unstructured_inference/test_elements.py | 15 ++++- unstructured_inference/__version__.py | 2 +- unstructured_inference/inference/elements.py | 2 +- .../inference/layoutelement.py | 60 +++++++++++++++++-- unstructured_inference/models/yolox.py | 2 +- 6 files changed, 76 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2526031e..38ea33db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.7.41 + +* fix: fix incorrect type casting with higher versions of `numpy` when substracting a `float` from an `int` array +* fix: fix a bug where class id 0 becomes class type `None` when calling `LayoutElements.as_list()` + ## 0.7.40 * fix: store probabilities with `float` data type instead of `int` diff --git a/test_unstructured_inference/test_elements.py b/test_unstructured_inference/test_elements.py index b99a55b1..6627e205 100644 --- a/test_unstructured_inference/test_elements.py +++ b/test_unstructured_inference/test_elements.py @@ -143,8 +143,10 @@ def test_minimal_containing_rect(): assert rect2.is_in(big_rect) -def test_partition_groups_from_regions(mock_embedded_text_regions): +@pytest.mark.parametrize("coord_type", [int, float]) +def test_partition_groups_from_regions(mock_embedded_text_regions, coord_type): words = TextRegions.from_list(mock_embedded_text_regions) + words.element_coords = words.element_coords.astype(coord_type) groups = partition_groups_from_regions(words) assert len(groups) == 1 text = "".join(groups[-1].texts) @@ -421,3 +423,14 @@ def test_clean_layoutelements_for_class( elements = clean_layoutelements_for_class(elements, element_class=class_to_filter) np.testing.assert_array_equal(elements.element_coords, expected_coords) np.testing.assert_array_equal(elements.element_class_ids, expected_ids) + + +def test_layoutelements_to_list_and_back(test_layoutelements): + back = LayoutElements.from_list(test_layoutelements.as_list()) + np.testing.assert_array_equal(test_layoutelements.element_coords, back.element_coords) + np.testing.assert_array_equal(test_layoutelements.texts, back.texts) + assert all(np.isnan(back.element_probs)) + assert [ + test_layoutelements.element_class_id_map[idx] + for idx in test_layoutelements.element_class_ids + ] == [back.element_class_id_map[idx] for idx in back.element_class_ids] diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 53ea3558..bb85a995 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.40" # pragma: no cover +__version__ = "0.7.41" # pragma: no cover diff --git a/unstructured_inference/inference/elements.py b/unstructured_inference/inference/elements.py index 8e6596af..939ea0cc 100644 --- a/unstructured_inference/inference/elements.py +++ b/unstructured_inference/inference/elements.py @@ -237,7 +237,7 @@ def as_list(self): ] @classmethod - def from_list(cls, regions: list[TextRegion]): + def from_list(cls, regions: list): """create TextRegions from a list of TextRegion objects; the objects must have the same source""" coords, texts = [], [] diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py index 9341ab2d..1d20d498 100644 --- a/unstructured_inference/inference/layoutelement.py +++ b/unstructured_inference/inference/layoutelement.py @@ -32,16 +32,34 @@ class LayoutElements(TextRegions): element_probs: np.ndarray = field(default_factory=lambda: np.array([])) element_class_ids: np.ndarray = field(default_factory=lambda: np.array([])) - element_class_id_map: dict[int, str] | None = None + element_class_id_map: dict[int, str] = field(default_factory=dict) def __post_init__(self): - if self.element_probs is not None: - self.element_probs = self.element_probs.astype(float) element_size = self.element_coords.shape[0] for attr in ("element_probs", "element_class_ids", "texts"): if getattr(self, attr).size == 0 and element_size: setattr(self, attr, np.array([None] * element_size)) + self.element_probs = self.element_probs.astype(float) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LayoutElements): + return NotImplemented + + mask = ~np.isnan(self.element_probs) + other_mask = ~np.isnan(other.element_probs) + return ( + np.array_equal(self.element_coords, other.element_coords) + and np.array_equal(self.texts, other.texts) + and np.array_equal(mask, other_mask) + and np.array_equal(self.element_probs[mask], other.element_probs[mask]) + and ( + [self.element_class_id_map[idx] for idx in self.element_class_ids] + == [other.element_class_id_map[idx] for idx in other.element_class_ids] + ) + and self.source == other.source + ) + def slice(self, indices) -> LayoutElements: """slice and return only selected indices""" return LayoutElements( @@ -85,10 +103,10 @@ def as_list(self): text=text, type=( self.element_class_id_map[class_id] - if class_id and self.element_class_id_map + if class_id is not None and self.element_class_id_map else None ), - prob=prob, + prob=None if np.isnan(prob) else prob, source=self.source, ) for (x1, y1, x2, y2), text, prob, class_id in zip( @@ -99,6 +117,36 @@ def as_list(self): ) ] + @classmethod + def from_list(cls, elements: list): + """create LayoutElements from a list of LayoutElement objects; the objects must have the + same source""" + len_ele = len(elements) + coords = np.empty((len_ele, 4), dtype=float) + # text and probs can be Nones so use lists first then convert into array to avoid them being + # filled as nan + texts = [] + class_probs = [] + class_types = np.empty((len_ele,), dtype="object") + + for i, element in enumerate(elements): + coords[i] = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2] + texts.append(element.text) + class_probs.append(element.prob) + class_types[i] = element.type or "None" + + unique_ids, class_ids = np.unique(class_types, return_inverse=True) + unique_ids[unique_ids == "None"] = None + + return cls( + element_coords=coords, + texts=np.array(texts), + element_probs=np.array(class_probs), + element_class_ids=class_ids, + element_class_id_map=dict(zip(range(len(unique_ids)), unique_ids)), + source=elements[0].source, + ) + @dataclass class LayoutElement(TextRegion): @@ -315,7 +363,7 @@ def partition_groups_from_regions(regions: TextRegions) -> List[TextRegions]: regions, each list corresponding with a group""" if len(regions) == 0: return [] - padded_coords = regions.element_coords.copy() + padded_coords = regions.element_coords.copy().astype(float) v_pad = (regions.y2 - regions.y1) * inference_config.ELEMENTS_V_PADDING_COEF h_pad = (regions.x2 - regions.x1) * inference_config.ELEMENTS_H_PADDING_COEF padded_coords[:, 0] -= h_pad diff --git a/unstructured_inference/models/yolox.py b/unstructured_inference/models/yolox.py index 031ac2b2..8e57843d 100644 --- a/unstructured_inference/models/yolox.py +++ b/unstructured_inference/models/yolox.py @@ -136,7 +136,7 @@ def image_processing( sorted_dets = dets[order] return LayoutElements( - element_coords=sorted_dets[:, :4], + element_coords=sorted_dets[:, :4].astype(float), element_probs=sorted_dets[:, 4].astype(float), element_class_ids=sorted_dets[:, 5].astype(int), element_class_id_map=self.layout_classes,