Skip to content

Commit

Permalink
fix: fix missing source after clean layoutelements
Browse files Browse the repository at this point in the history
  • Loading branch information
badGarnet committed Oct 18, 2024
1 parent c53866b commit cbc9ba0
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.7.42

* fix: fix missing source after cleaning layout elements

## 0.7.41

* fix: fix incorrect type casting with higher versions of `numpy` when substracting a `float` from an `int` array
Expand Down
2 changes: 2 additions & 0 deletions test_unstructured_inference/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_layoutelements():
element_coords=coords,
element_class_ids=element_class_ids,
element_class_id_map=class_map,
source="yolox",
)


Expand Down Expand Up @@ -345,6 +346,7 @@ def test_clean_layoutelements(test_layoutelements):
elements[1].bbox.x2,
elements[1].bbox.x2,
) == (2, 2, 3, 3)
assert elements[0].source == elements[1].source == "yolox"


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.41" # pragma: no cover
__version__ = "0.7.42" # pragma: no cover
11 changes: 8 additions & 3 deletions unstructured_inference/inference/layoutelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,15 @@ def slice(self, indices) -> LayoutElements:
@classmethod
def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
"""concatenate a sequence of LayoutElements in order as one LayoutElements"""
coords, texts, probs, class_ids = [], [], [], []
coords, texts, probs, class_ids, sources = [], [], [], [], []
class_id_map = {}
for group in groups:
coords.append(group.element_coords)
texts.append(group.texts)
probs.append(group.element_probs)
class_ids.append(group.element_class_ids)
if group.source:
sources.append(group.source)
if group.element_class_id_map:
class_id_map.update(group.element_class_id_map)
return cls(
Expand All @@ -89,7 +91,7 @@ def concatenate(cls, groups: Iterable[LayoutElements]) -> LayoutElements:
element_probs=np.concatenate(probs),
element_class_ids=np.concatenate(class_ids),
element_class_id_map=class_id_map,
source=group.source,
source=sources[0],
)

def as_list(self):
Expand Down Expand Up @@ -442,7 +444,10 @@ def clean_layoutelements(elements: LayoutElements, subregion_threshold: float =
final_coords = sorted_coords[mask]
sorted_by_y1 = np.argsort(final_coords[:, 1])

final_attrs: dict[str, Any] = {"element_class_id_map": elements.element_class_id_map}
final_attrs: dict[str, Any] = {
"element_class_id_map": elements.element_class_id_map,
"source": elements.source,
}
for attr in ("element_class_ids", "element_probs", "texts"):
if (original_attr := getattr(elements, attr)) is None:
continue
Expand Down

0 comments on commit cbc9ba0

Please sign in to comment.