Skip to content

Commit

Permalink
chore: merge branch '200-add-basic-plotting-functions-to-the-library'…
Browse files Browse the repository at this point in the history
… of https://github.com/srai-lab/srai into 200-add-basic-plotting-functions-to-the-library
  • Loading branch information
RaczeQ committed Mar 29, 2023
2 parents 1d32346 + be312ab commit a2a1eea
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- H3Neighbourhood
- AdjacencyNeighbourhood
- (CI) Changelog Enforcer
- Utility plotting module base on Folium and Plotly
- Utility plotting module based on Folium and Plotly

### Changed

Expand Down
1 change: 1 addition & 0 deletions examples/embedders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ Examples illustrating the usage of every Joiner.

- [CountEmbedder](count_embedder.ipynb)
- [GTFS2VecEmbedder](gtfs2vec_embedder.ipynb)
- [Highway2VecEmbedder](highway2vec_embedder.ipynb)
25 changes: 18 additions & 7 deletions srai/plotting/folium_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def plot_regions(
Args:
regions_gdf (gpd.GeoDataFrame): Region indexes and geometries to plot.
tiles_style (str, optional): Map style background. Defaults to "OpenStreetMap".
tiles_style (str, optional): Map style background. For more styles, look at tiles param at
https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html.
Defaults to "OpenStreetMap".
height (Union[str, float], optional): Height of the plot. Defaults to "100%".
width (Union[str, float], optional): Width of the plot. Defaults to "100%".
colormap (Union[str, List[str]], optional): Colormap to apply to the regions.
Expand All @@ -49,8 +51,7 @@ def plot_regions(
Returns:
folium.Map: Generated map.
"""
regions_gdf_copy = regions_gdf.copy()
return regions_gdf_copy.reset_index().explore(
return regions_gdf.reset_index().explore(
column=REGIONS_INDEX,
tooltip=REGIONS_INDEX,
tiles=tiles_style,
Expand Down Expand Up @@ -82,7 +83,9 @@ def plot_numeric_data(
embedding_df (Union[pd.DataFrame, gpd.GeoDataFrame]): Region indexes and numerical data
to plot.
data_column (str): Name of the column used to colour the regions.
tiles_style (str, optional): Map style background. Defaults to "OpenStreetMap".
tiles_style (str, optional): Map style background. For more styles, look at tiles param at
https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html.
Defaults to "OpenStreetMap".
height (Union[str, float], optional): Height of the plot. Defaults to "100%".
width (Union[str, float], optional): Width of the plot. Defaults to "100%".
colormap (Union[str, List[str]], optional): Colormap to apply to the regions.
Expand Down Expand Up @@ -133,7 +136,9 @@ def plot_neighbours(
regions_gdf (gpd.GeoDataFrame): Region indexes and geometries to plot.
region_id (IndexType): Center `region_id` around which the neighbourhood should be plotted.
neighbours_ids (Set[IndexType]): List of neighbours to highlight.
tiles_style (str, optional): Map style background. Defaults to "OpenStreetMap".
tiles_style (str, optional): Map style background. For more styles, look at tiles param at
https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html.
Defaults to "OpenStreetMap".
height (Union[str, float], optional): Height of the plot. Defaults to "100%".
width (Union[str, float], optional): Width of the plot. Defaults to "100%".
map (folium.Map, optional): Existing map instance on which to draw the plot.
Expand Down Expand Up @@ -171,6 +176,7 @@ def plot_all_neighbourhood(
regions_gdf: gpd.GeoDataFrame,
region_id: IndexType,
neighbourhood: Neighbourhood[IndexType],
neighbourhood_max_distance: int = 100,
tiles_style: str = "OpenStreetMap",
height: Union[str, float] = "100%",
width: Union[str, float] = "100%",
Expand All @@ -185,7 +191,12 @@ def plot_all_neighbourhood(
region_id (IndexType): Center `region_id` around which the neighbourhood should be plotted.
neighbourhood (Neighbourhood[IndexType]): `Neighbourhood` class required for finding
neighbours.
tiles_style (str, optional): Map style background. Defaults to "OpenStreetMap".
neighbourhood_max_distance (int, optional): Max distance for rendering neighbourhoods.
Neighbours farther away won't be coloured, and will be left as "other" regions.
Defaults to 100.
tiles_style (str, optional): Map style background. For more styles, look at tiles param at
https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html.
Defaults to "OpenStreetMap".
height (Union[str, float], optional): Height of the plot. Defaults to "100%".
width (Union[str, float], optional): Width of the plot. Defaults to "100%".
colormap (Union[str, List[str]], optional): Colormap to apply to the neighbourhoods.
Expand All @@ -207,7 +218,7 @@ def plot_all_neighbourhood(
neighbours_ids = neighbourhood.get_neighbours_at_distance(region_id, distance).intersection(
regions_gdf.index
)
while neighbours_ids:
while neighbours_ids and distance <= neighbourhood_max_distance:
regions_gdf_copy.loc[list(neighbours_ids), "region"] = distance
distance += 1
neighbours_ids = neighbourhood.get_neighbours_at_distance(region_id, distance).intersection(
Expand Down
6 changes: 5 additions & 1 deletion srai/plotting/plotly_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def plot_all_neighbourhood(
regions_gdf: gpd.GeoDataFrame,
region_id: IndexType,
neighbourhood: Neighbourhood[IndexType],
neighbourhood_max_distance: int = 100,
return_plot: bool = False,
mapbox_style: str = "open-street-map",
mapbox_accesstoken: Optional[str] = None,
Expand All @@ -165,6 +166,9 @@ def plot_all_neighbourhood(
region_id (IndexType): Center `region_id` around which the neighbourhood should be plotted.
neighbourhood (Neighbourhood[IndexType]): `Neighbourhood` class required for finding
neighbours.
neighbourhood_max_distance (int, optional): Max distance for rendering neighbourhoods.
Neighbours farther away won't be coloured, and will be left as "other" regions.
Defaults to 100.
return_plot (bool, optional): Flag whether to return the Figure object or not.
If `True`, the plot won't be displayed automatically. Defaults to False.
mapbox_style (str, optional): Map style background. Defaults to "open-street-map".
Expand All @@ -190,7 +194,7 @@ def plot_all_neighbourhood(
neighbours_ids = neighbourhood.get_neighbours_at_distance(region_id, distance).intersection(
regions_gdf.index
)
while neighbours_ids:
while neighbours_ids and distance <= neighbourhood_max_distance:
regions_gdf_copy.loc[list(neighbours_ids), "region"] = distance
distance += 1
neighbours_ids = neighbourhood.get_neighbours_at_distance(region_id, distance).intersection(
Expand Down

0 comments on commit a2a1eea

Please sign in to comment.