Skip to content

Commit

Permalink
fix: fix bugs in layoutelements (#393)
Browse files Browse the repository at this point in the history
This PR fixes two bugs:

- fix a type casting issue when subtracting a int array with a float.
This popped up when testing with `unstructured`, and some sources of
element coordinates are of `int` type. This PR adds a new unit test case
for `int` coord type with the grouping function
- fix element class id 0 becomes None bug: this happens when dumping
`LayoutElements` as a list of `LayoutElement`. When an element class id
is 0 the logic on main would treat it as no existing and use `None` as
the type.
  • Loading branch information
badGarnet authored Oct 16, 2024
1 parent 4431fe5 commit c53866b
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 10 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.7.41

* fix: fix incorrect type casting with higher versions of `numpy` when substracting a `float` from an `int` array
* fix: fix a bug where class id 0 becomes class type `None` when calling `LayoutElements.as_list()`

## 0.7.40

* fix: store probabilities with `float` data type instead of `int`
Expand Down
15 changes: 14 additions & 1 deletion test_unstructured_inference/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ def test_minimal_containing_rect():
assert rect2.is_in(big_rect)


def test_partition_groups_from_regions(mock_embedded_text_regions):
@pytest.mark.parametrize("coord_type", [int, float])
def test_partition_groups_from_regions(mock_embedded_text_regions, coord_type):
words = TextRegions.from_list(mock_embedded_text_regions)
words.element_coords = words.element_coords.astype(coord_type)
groups = partition_groups_from_regions(words)
assert len(groups) == 1
text = "".join(groups[-1].texts)
Expand Down Expand Up @@ -421,3 +423,14 @@ def test_clean_layoutelements_for_class(
elements = clean_layoutelements_for_class(elements, element_class=class_to_filter)
np.testing.assert_array_equal(elements.element_coords, expected_coords)
np.testing.assert_array_equal(elements.element_class_ids, expected_ids)


def test_layoutelements_to_list_and_back(test_layoutelements):
back = LayoutElements.from_list(test_layoutelements.as_list())
np.testing.assert_array_equal(test_layoutelements.element_coords, back.element_coords)
np.testing.assert_array_equal(test_layoutelements.texts, back.texts)
assert all(np.isnan(back.element_probs))
assert [
test_layoutelements.element_class_id_map[idx]
for idx in test_layoutelements.element_class_ids
] == [back.element_class_id_map[idx] for idx in back.element_class_ids]
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.40" # pragma: no cover
__version__ = "0.7.41" # pragma: no cover
2 changes: 1 addition & 1 deletion unstructured_inference/inference/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def as_list(self):
]

@classmethod
def from_list(cls, regions: list[TextRegion]):
def from_list(cls, regions: list):
"""create TextRegions from a list of TextRegion objects; the objects must have the same
source"""
coords, texts = [], []
Expand Down
60 changes: 54 additions & 6 deletions unstructured_inference/inference/layoutelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,34 @@
class LayoutElements(TextRegions):
element_probs: np.ndarray = field(default_factory=lambda: np.array([]))
element_class_ids: np.ndarray = field(default_factory=lambda: np.array([]))
element_class_id_map: dict[int, str] | None = None
element_class_id_map: dict[int, str] = field(default_factory=dict)

def __post_init__(self):
if self.element_probs is not None:
self.element_probs = self.element_probs.astype(float)
element_size = self.element_coords.shape[0]
for attr in ("element_probs", "element_class_ids", "texts"):
if getattr(self, attr).size == 0 and element_size:
setattr(self, attr, np.array([None] * element_size))

self.element_probs = self.element_probs.astype(float)

def __eq__(self, other: object) -> bool:
if not isinstance(other, LayoutElements):
return NotImplemented

mask = ~np.isnan(self.element_probs)
other_mask = ~np.isnan(other.element_probs)
return (
np.array_equal(self.element_coords, other.element_coords)
and np.array_equal(self.texts, other.texts)
and np.array_equal(mask, other_mask)
and np.array_equal(self.element_probs[mask], other.element_probs[mask])
and (
[self.element_class_id_map[idx] for idx in self.element_class_ids]
== [other.element_class_id_map[idx] for idx in other.element_class_ids]
)
and self.source == other.source
)

def slice(self, indices) -> LayoutElements:
"""slice and return only selected indices"""
return LayoutElements(
Expand Down Expand Up @@ -85,10 +103,10 @@ def as_list(self):
text=text,
type=(
self.element_class_id_map[class_id]
if class_id and self.element_class_id_map
if class_id is not None and self.element_class_id_map
else None
),
prob=prob,
prob=None if np.isnan(prob) else prob,
source=self.source,
)
for (x1, y1, x2, y2), text, prob, class_id in zip(
Expand All @@ -99,6 +117,36 @@ def as_list(self):
)
]

@classmethod
def from_list(cls, elements: list):
"""create LayoutElements from a list of LayoutElement objects; the objects must have the
same source"""
len_ele = len(elements)
coords = np.empty((len_ele, 4), dtype=float)
# text and probs can be Nones so use lists first then convert into array to avoid them being
# filled as nan
texts = []
class_probs = []
class_types = np.empty((len_ele,), dtype="object")

for i, element in enumerate(elements):
coords[i] = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2]
texts.append(element.text)
class_probs.append(element.prob)
class_types[i] = element.type or "None"

unique_ids, class_ids = np.unique(class_types, return_inverse=True)
unique_ids[unique_ids == "None"] = None

return cls(
element_coords=coords,
texts=np.array(texts),
element_probs=np.array(class_probs),
element_class_ids=class_ids,
element_class_id_map=dict(zip(range(len(unique_ids)), unique_ids)),
source=elements[0].source,
)


@dataclass
class LayoutElement(TextRegion):
Expand Down Expand Up @@ -315,7 +363,7 @@ def partition_groups_from_regions(regions: TextRegions) -> List[TextRegions]:
regions, each list corresponding with a group"""
if len(regions) == 0:
return []
padded_coords = regions.element_coords.copy()
padded_coords = regions.element_coords.copy().astype(float)
v_pad = (regions.y2 - regions.y1) * inference_config.ELEMENTS_V_PADDING_COEF
h_pad = (regions.x2 - regions.x1) * inference_config.ELEMENTS_H_PADDING_COEF
padded_coords[:, 0] -= h_pad
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/models/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def image_processing(
sorted_dets = dets[order]

return LayoutElements(
element_coords=sorted_dets[:, :4],
element_coords=sorted_dets[:, :4].astype(float),
element_probs=sorted_dets[:, 4].astype(float),
element_class_ids=sorted_dets[:, 5].astype(int),
element_class_id_map=self.layout_classes,
Expand Down

0 comments on commit c53866b

Please sign in to comment.