Skip to content

Commit

Permalink
fix: flatten index resulting from OSMTagLoader.load() (#183)
Browse files Browse the repository at this point in the history
* fix: flatten index resulting from OSMTagLoader.load()

* chore: replace literal feature_id with constant
  • Loading branch information
simonusher authored Feb 22, 2023
1 parent 37e9e96 commit 39b75e3
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
13 changes: 11 additions & 2 deletions srai/loaders/osm_tag_loader/osm_tag_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from functional import seq
from tqdm import tqdm

from srai.utils.constants import WGS84_CRS
from srai.utils.constants import FEATURES_INDEX, WGS84_CRS


class OSMTagLoader:
Expand Down Expand Up @@ -85,7 +85,7 @@ def load(

result_gdf = self._group_gdfs(results).set_crs(WGS84_CRS)

return result_gdf
return self._flatten_index(result_gdf)

def _flatten_tags(
self, tags: Dict[str, Union[List[str], str, bool]]
Expand Down Expand Up @@ -118,3 +118,12 @@ def _group_gdfs(self, gdfs: List[gpd.GeoDataFrame]) -> gpd.GeoDataFrame:
def _get_empty_result(self) -> gpd.GeoDataFrame:
result_index = pd.MultiIndex.from_arrays(arrays=[[], []], names=self._RESULT_INDEX_NAMES)
return gpd.GeoDataFrame(index=result_index, crs=WGS84_CRS, geometry=[])

def _flatten_index(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
gdf = gdf.reset_index()
gdf[FEATURES_INDEX] = (
gdf[self._RESULT_INDEX_NAMES]
.apply(lambda idx: "/".join(map(str, idx)), axis=1)
.astype(str)
)
return gdf.set_index(FEATURES_INDEX).drop(columns=self._RESULT_INDEX_NAMES)
35 changes: 19 additions & 16 deletions tests/loaders/osm_tag_loader/test_osm_tag_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def area_with_no_objects_gdf() -> gpd.GeoDataFrame:
@pytest.fixture # type: ignore
def empty_result_gdf() -> gpd.GeoDataFrame:
"""Get empty OSMTagLoader result gdf."""
result_index = pd.MultiIndex.from_arrays(arrays=[[], []], names=["element_type", "osmid"])
result_index = pd.Index(data=[], name="feature_id", dtype="object")
return gpd.GeoDataFrame(index=result_index, crs=WGS84_CRS, geometry=[])


Expand Down Expand Up @@ -120,12 +120,12 @@ def expected_result_single_polygon() -> gpd.GeoDataFrame:
"amenity": ["restaurant"],
},
geometry=[Point(0, 0)],
index=pd.MultiIndex.from_arrays(
arrays=[
["node"],
[1],
index=pd.Index(
data=[
"node/1",
],
names=["element_type", "osmid"],
name="feature_id",
dtype="object",
),
crs=WGS84_CRS,
)
Expand All @@ -139,12 +139,13 @@ def expected_result_gdf_simple() -> gpd.GeoDataFrame:
"amenity": ["restaurant", "restaurant"],
},
geometry=[Point(0, 0), Point(1, 1)],
index=pd.MultiIndex.from_arrays(
arrays=[
["node", "node"],
[1, 2],
index=pd.Index(
data=[
"node/1",
"node/2",
],
names=["element_type", "osmid"],
name="feature_id",
dtype="object",
),
crs=WGS84_CRS,
)
Expand All @@ -159,12 +160,14 @@ def expected_result_gdf_complex() -> gpd.GeoDataFrame:
"building": ["commercial", None, "retail"],
},
geometry=[Point(0, 0), Point(1, 1), Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])],
index=pd.MultiIndex.from_arrays(
arrays=[
["node", "node", "way"],
[1, 2, 3],
index=pd.Index(
data=[
"node/1",
"node/2",
"way/3",
],
names=["element_type", "osmid"],
name="feature_id",
dtype="object",
),
crs=WGS84_CRS,
)
Expand Down

0 comments on commit 39b75e3

Please sign in to comment.