Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/layoutparser/elements/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from typing import List, Union, Dict, Dict, Any, Optional
from collections.abc import MutableSequence
from collections.abc import MutableSequence, Iterable
from copy import copy

import pandas as pd
Expand Down Expand Up @@ -47,6 +47,21 @@ class Layout(MutableSequence):
"""

def __init__(self, blocks: Optional[List] = None, *, page_data: Dict = None):

if not (
(blocks is None)
or (isinstance(blocks, Iterable) and blocks.__class__.__name__ != "Layout")
):

if blocks.__class__.__name__ == "Layout":
error_msg = f"Please check the input: it should be lp.Layout([layout]) instead of lp.Layout(layout)"
else:
error_msg = f"Blocks should be a list of layout elements or empty (None), instead got {blocks}.\n"
raise ValueError(error_msg)

if isinstance(blocks, tuple):
blocks = list(blocks) # <- more robust handling for tuple-like inputs

self._blocks = blocks if blocks is not None else []
self.page_data = page_data or {}

Expand Down Expand Up @@ -76,10 +91,7 @@ def __repr__(self):

def __eq__(self, other):
if isinstance(other, Layout):
return (
all((a, b) for a, b in zip(self, other))
and self.page_data == other.page_data
)
return self._blocks == other._blocks and self.page_data == other.page_data
else:
return False

Expand Down
11 changes: 7 additions & 4 deletions src/layoutparser/io/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def load_dict(data: Union[Dict, List[Dict]]) -> Union[BaseLayoutElement, Layout]
if isinstance(data, dict):
if "page_data" in data:
# It is a layout instance
return Layout(load_dict(data["blocks"]), page_data=data["page_data"])
return Layout(load_dict(data["blocks"])._blocks, page_data=data["page_data"])
else:

if data["block_type"] not in BASECOORD_ELEMENT_NAMEMAP:
Expand Down Expand Up @@ -140,7 +140,10 @@ def load_dataframe(df: pd.DataFrame, block_type: str = None) -> Layout:
else:
df["block_type"] = block_type

if "id" not in df.columns:
df["id"] = df.index

print((df.columns), TextBlock._features, any(col in TextBlock._features for col in df.columns))
if any(col in TextBlock._features for col in df.columns):
# Automatically setting index for textblock
if "id" not in df.columns:
df["id"] = df.index

return load_dict(df.apply(lambda x: x.dropna().to_dict(), axis=1).to_list())
3 changes: 2 additions & 1 deletion src/layoutparser/io/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ def extract_words_for_page(
)

page_tokens = load_dataframe(
df.rename(
df.reset_index().rename(
columns={
"x0": "x_1",
"x1": "x_2",
"top": "y_1",
"bottom": "y_2",
"index": "id",
"fontname": "type", # also loading fontname as "type"
}
),
Expand Down
57 changes: 43 additions & 14 deletions tests/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@
import numpy as np
import pandas as pd

from layoutparser.elements import (
Interval,
Rectangle,
Quadrilateral,
TextBlock,
Layout
)
from layoutparser.elements import Interval, Rectangle, Quadrilateral, TextBlock, Layout
from layoutparser.elements.errors import InvalidShapeError, NotSupportedShapeError


Expand Down Expand Up @@ -242,12 +236,25 @@ def test_layout():
r = Rectangle(3, 3, 5, 6)
t = TextBlock(i, id=1, type=2, text="12")

# Test Initializations
l = Layout([i, q, r])
l = Layout((i,q))
Layout([l])
with pytest.raises(ValueError):
Layout(l)

# Test tuple-like inputs
l = Layout((i, q, r))
assert l._blocks == [i, q, r]
l.append(i)

# Test apply functions
l = Layout([i, q, r])
l.get_texts()
l.condition_on(i)
l.relative_to(q)
l.filter_by(t)
l.is_in(r)
assert l.filter_by(t) == Layout([i])
assert l.condition_on(i) == Layout([block.condition_on(i) for block in [i, q, r]])
assert l.relative_to(q) == Layout([block.relative_to(q) for block in [i, q, r]])
assert l.is_in(r) == Layout([block.is_in(r) for block in [i, q, r]])
assert l.get_homogeneous_blocks() == [i.to_quadrilateral(), q, r.to_quadrilateral()]

i2 = TextBlock(i, id=1, type=2, text="12")
Expand Down Expand Up @@ -286,17 +293,39 @@ def test_layout():
l + l2

# Test sort
## When sorting inplace, it should return None
l = Layout([i])
assert l.sort(key=lambda x: x.coordinates[1], inplace=True) is None

## Make sure only sorting inplace works
l = Layout([i, i.shift(2)])
l.sort(key=lambda x: x.coordinates[1], reverse=True)
assert l != Layout([i.shift(2), i])
l.sort(key=lambda x: x.coordinates[1], reverse=True, inplace=True)
assert l == Layout([i.shift(2), i])

l = Layout([q, r, i], page_data={"width": 200, "height": 400})
assert l.sort(key=lambda x: x.coordinates[0], inplace=False) == Layout(
assert l.sort(key=lambda x: x.coordinates[0]) == Layout(
[i, q, r], page_data={"width": 200, "height": 400}
)

l = Layout([q, t])
assert l.sort(key=lambda x: x.coordinates[0], inplace=False) == Layout([q, t])
assert l.sort(key=lambda x: x.coordinates[0]) == Layout([t, q])


def test_layout_comp():
a = Layout([Rectangle(1, 2, 3, 4)])
b = Layout([Rectangle(1, 2, 3, 4)])

assert a == b

a.append(Rectangle(1, 2, 3, 5))
assert a != b
b.append(Rectangle(1, 2, 3, 5))
assert a == b

a = Layout([TextBlock(Rectangle(1, 2, 3, 4))])
assert a != b


def test_shape_operations():
Expand Down Expand Up @@ -428,4 +457,4 @@ def test_dict():

l2 = Layout([i2, r2, q2])
l2_dict = {"page_data": {}, "blocks": [i_dict, r_dict, q_dict]}
assert l2.to_dict() == l2_dict
assert l2.to_dict() == l2_dict
2 changes: 1 addition & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_csv():
_l.page_data = {"width": 200, "height": 200}
assert _l == l

i2 = TextBlock(i, "")
i2 = i # <- Allow mixmode loading
r2 = TextBlock(r, id=24)
q2 = TextBlock(q, text="test", parent=45)
l2 = Layout([i2, r2, q2])
Expand Down