Skip to content

Commit

Permalink
fix: h3 regionizer very slow with buffer=True and large number of reg…
Browse files Browse the repository at this point in the history
…ions (#194)

* fix: replace drop_duplicates() with dropping duplicated indices in H3Regionizer

* chore: add h3 regionizer test case to cover unbuffered case

* fix(pre-commit.ci): auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Szymon Woźniak <szymon.wozniak@brand24.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 9, 2023
1 parent ef3672c commit 4da1270
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
10 changes: 5 additions & 5 deletions srai/regionizers/h3_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def transform(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
gdf_h3 = self._gdf_from_h3_indexes(h3_indexes)

# there may be too many cells because of too big buffer
gdf_h3_clipped = (
gdf_h3.sjoin(gdf_exploded[["geometry"]]).drop(columns="index_right").drop_duplicates()
if self.buffer
else gdf_h3
)
if self.buffer:
gdf_h3_clipped = gdf_h3.sjoin(gdf_exploded[["geometry"]]).drop(columns="index_right")
gdf_h3_clipped = gdf_h3_clipped[~gdf_h3_clipped.index.duplicated(keep="first")]
else:
gdf_h3_clipped = gdf_h3

gdf_h3_clipped.index.name = REGIONS_INDEX

Expand Down
26 changes: 18 additions & 8 deletions tests/regionizers/test_h3_regionizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,39 @@ def expected_h3_indexes() -> List[str]:
]


@pytest.fixture # type: ignore
def expected_unbuffered_h3_indexes() -> List[str]:
"""Get expected h3 index for the unbuffered case."""
return [
"83754efffffffff",
]


@pytest.mark.parametrize( # type: ignore
"gdf_fixture,expected_h3_indexes_fixture,resolution,expectation",
"gdf_fixture,expected_h3_indexes_fixture,resolution,buffer,expectation",
[
("gdf_polygons", "expected_h3_indexes", H3_RESOLUTION, does_not_raise()),
("gdf_multipolygon", "expected_h3_indexes", H3_RESOLUTION, does_not_raise()),
("gdf_empty", "expected_h3_indexes", H3_RESOLUTION, pytest.raises(AttributeError)),
("gdf_polygons", "expected_h3_indexes", -1, pytest.raises(ValueError)),
("gdf_polygons", "expected_h3_indexes", 16, pytest.raises(ValueError)),
("gdf_no_crs", "expected_h3_indexes", H3_RESOLUTION, pytest.raises(ValueError)),
("gdf_polygons", "expected_h3_indexes", H3_RESOLUTION, True, does_not_raise()),
("gdf_polygons", "expected_unbuffered_h3_indexes", H3_RESOLUTION, False, does_not_raise()),
("gdf_multipolygon", "expected_h3_indexes", H3_RESOLUTION, True, does_not_raise()),
("gdf_empty", "expected_h3_indexes", H3_RESOLUTION, True, pytest.raises(AttributeError)),
("gdf_polygons", "expected_h3_indexes", -1, True, pytest.raises(ValueError)),
("gdf_polygons", "expected_h3_indexes", 16, True, pytest.raises(ValueError)),
("gdf_no_crs", "expected_h3_indexes", H3_RESOLUTION, True, pytest.raises(ValueError)),
],
)
def test_transform(
gdf_fixture: str,
expected_h3_indexes_fixture: str,
resolution: int,
buffer: bool,
expectation: Any,
request: Any,
) -> None:
"""Test transform of H3Regionizer."""
gdf: gpd.GeoDataFrame = request.getfixturevalue(gdf_fixture)
h3_indexes: List[str] = request.getfixturevalue(expected_h3_indexes_fixture)
with expectation:
gdf_h3 = H3Regionizer(resolution).transform(gdf)
gdf_h3 = H3Regionizer(resolution, buffer=buffer).transform(gdf)

ut.assertCountEqual(first=gdf_h3.index.to_list(), second=h3_indexes)
assert "geometry" in gdf_h3

0 comments on commit 4da1270

Please sign in to comment.