diff --git a/.github/workflows/ci-dev.yml b/.github/workflows/ci-dev.yml index c9e48f83..566bbb8d 100644 --- a/.github/workflows/ci-dev.yml +++ b/.github/workflows/ci-dev.yml @@ -57,7 +57,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: "3.10" - name: Install required osmium dependencies run: sudo apt-get install build-essential cmake libboost-dev libexpat1-dev zlib1g-dev libbz2-dev - uses: actions/cache@v3 @@ -72,8 +72,6 @@ jobs: run: pdm export --no-default -G docs -G visualization -f requirements -o requirements.txt - name: Install dependencies run: pip install --no-deps -r requirements.txt - - name: Install geovoronoi dependency - run: pip install geovoronoi==0.4.0 - name: Install nbconvert dependency run: pip install jupyter nbconvert nbformat - name: Install srai diff --git a/.github/workflows/ci-prod.yml b/.github/workflows/ci-prod.yml index e0640ed7..0b369672 100644 --- a/.github/workflows/ci-prod.yml +++ b/.github/workflows/ci-prod.yml @@ -43,7 +43,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: "3.10" - name: Install required osmium dependencies run: sudo apt-get install build-essential cmake libboost-dev libexpat1-dev zlib1g-dev libbz2-dev - uses: actions/cache@v3 @@ -58,8 +58,6 @@ jobs: run: pdm export --no-default -G docs -G visualization -f requirements -o requirements.txt - name: Install dependencies run: pip install --no-deps -r requirements.txt - - name: Install geovoronoi dependency - run: pip install geovoronoi==0.4.0 - name: Install nbconvert dependency run: pip install jupyter nbconvert nbformat - name: Install srai diff --git a/CHANGELOG.md b/CHANGELOG.md index 8243e99a..5086cfaa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - H3Neighbourhood - AdjacencyNeighbourhood - (CI) Changelog Enforcer +- Utility plotting module based on Folium and Plotly ### Changed diff --git a/codecov.yml b/codecov.yml index 3725b8cc..60b71346 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,10 +4,13 @@ coverage: default: target: auto informational: true - threshold: 1% # the leniency in hitting the target + threshold: 1% # the leniency in hitting the target patch: default: informational: true comment: require_changes: true + +ignore: + - "srai/plotting" # disable coverage for plotting module diff --git a/examples/embedders/count_embedder.ipynb b/examples/embedders/count_embedder.ipynb index 39f05420..8a5d00c7 100644 --- a/examples/embedders/count_embedder.ipynb +++ b/examples/embedders/count_embedder.ipynb @@ -8,11 +8,12 @@ "source": [ "from shapely import geometry\n", "import geopandas as gpd\n", - "from srai.constants import WGS84_CRS\n", + "from srai.constants import WGS84_CRS, REGIONS_INDEX\n", "from srai.loaders.osm_loaders import OSMOnlineLoader\n", "from srai.regionizers import H3Regionizer\n", "from srai.joiners import IntersectionJoiner\n", - "from srai.embedders import CountEmbedder" + "from srai.embedders import CountEmbedder\n", + "from srai.plotting.folium_wrapper import plot_regions, plot_numeric_data" ] }, { @@ -75,8 +76,8 @@ "source": [ "regionizer = H3Regionizer(resolution=8, buffer=True)\n", "regions_gdf = regionizer.transform(bbox_gdf)\n", - "ax = bbox_gdf.plot()\n", - "regions_gdf.plot(ax=ax, color=\"red\", alpha=0.5)" + "folium_map = bbox_gdf.explore(tiles=\"CartoDB positron\")\n", + "plot_regions(regions_gdf, map=folium_map)" ] }, { @@ -108,11 +109,8 @@ "metadata": {}, "outputs": [], "source": [ - "ax = regions_gdf.plot()\n", - "features_gdf.plot(\n", - " ax=ax,\n", - " color=\"red\",\n", - ")" + "folium_map = plot_regions(regions_gdf, tiles_style=\"CartoDB positron\", colormap=[\"lightgray\"])\n", + "features_gdf.explore(m=folium_map)" ] }, { @@ -140,7 +138,10 @@ "metadata": {}, "outputs": [], "source": [ - "joint_gdf.plot()" + "from plotly.express import colors\n", + "\n", + "folium_map = plot_regions(regions_gdf, tiles_style=\"CartoDB positron\", colormap=[\"rgba(0,0,0,0)\"])\n", + "joint_gdf.reset_index().explore(m=folium_map, column=REGIONS_INDEX, cmap=colors.qualitative.Bold)" ] }, { @@ -181,6 +182,17 @@ "embedding_expected_features" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_numeric_data(\n", + " regions_gdf, embedding_expected_features, \"amenity_pub\", tiles_style=\"CartoDB positron\"\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -208,7 +220,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/examples/embedders/gtfs2vec_embedder.ipynb b/examples/embedders/gtfs2vec_embedder.ipynb index eb9fb865..458932e8 100644 --- a/examples/embedders/gtfs2vec_embedder.ipynb +++ b/examples/embedders/gtfs2vec_embedder.ipynb @@ -6,11 +6,13 @@ "metadata": {}, "outputs": [], "source": [ - "from srai.embedders import GTFS2VecEmbedder\n", "import pandas as pd\n", - "from shapely.geometry import Polygon\n", "import geopandas as gpd\n", - "from pytorch_lightning import seed_everything" + "from shapely.geometry import Polygon\n", + "from pytorch_lightning import seed_everything\n", + "\n", + "from srai.embedders import GTFS2VecEmbedder\n", + "from srai.constants import REGIONS_INDEX" ] }, { @@ -47,9 +49,8 @@ " ],\n", " },\n", " geometry=gpd.points_from_xy([1, 2, 5], [1, 2, 2]),\n", - " index=[1, 2, 3],\n", + " index=pd.Index(name=\"stop_id\", data=[1, 2, 3]),\n", ")\n", - "features_gdf.index.name = \"stop_id\"\n", "features_gdf" ] }, @@ -60,18 +61,23 @@ "outputs": [], "source": [ "regions_gdf = gpd.GeoDataFrame(\n", - " {\n", - " \"region_id\": [\"ff1\", \"ff2\", \"ff3\"],\n", - " },\n", " geometry=[\n", " Polygon([(0, 0), (0, 3), (3, 3), (3, 0)]),\n", " Polygon([(4, 0), (4, 3), (7, 3), (7, 0)]),\n", " Polygon([(8, 0), (8, 3), (11, 3), (11, 0)]),\n", " ],\n", - ").set_index(\"region_id\")\n", + " index=pd.Index(name=REGIONS_INDEX, data=[\"ff1\", \"ff2\", \"ff3\"]),\n", + ")\n", "regions_gdf" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -91,7 +97,7 @@ "joint_gdf = gpd.GeoDataFrame()\n", "joint_gdf.index = pd.MultiIndex.from_tuples(\n", " [(\"ff1\", 1), (\"ff1\", 2), (\"ff2\", 3)],\n", - " names=[\"region_id\", \"stop_id\"],\n", + " names=[REGIONS_INDEX, \"stop_id\"],\n", ")\n", "joint_gdf" ] @@ -153,7 +159,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.14" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/examples/joiners/intersection_joiner.ipynb b/examples/joiners/intersection_joiner.ipynb index fa250a53..0cd9f341 100644 --- a/examples/joiners/intersection_joiner.ipynb +++ b/examples/joiners/intersection_joiner.ipynb @@ -8,7 +8,9 @@ "source": [ "import geopandas as gpd\n", "from shapely import geometry\n", - "import matplotlib.pyplot as plt" + "import matplotlib.pyplot as plt\n", + "from srai.constants import WGS84_CRS, REGIONS_INDEX, FEATURES_INDEX\n", + "from srai.plotting.folium_wrapper import plot_regions" ] }, { @@ -32,7 +34,8 @@ " geometry.Polygon([(-2, -1), (-2, -2), (-1, -2), (-1, -1)]),\n", " geometry.Polygon([(-2, 0.5), (-2, -0.5), (-1, -0.5), (-1, 0.5)]),\n", " ],\n", - " crs=\"epsg:4326\",\n", + " crs=WGS84_CRS,\n", + " index=gpd.pd.Index(name=REGIONS_INDEX, data=[1, 2, 3, 4]),\n", ")\n", "\n", "features = gpd.GeoDataFrame(\n", @@ -42,15 +45,41 @@ " geometry.Point((0, 0)),\n", " geometry.Point((-0.5, -0.5)),\n", " ],\n", - " crs=\"epsg:4326\",\n", - ")\n", - "\n", - "print(regions)\n", - "print(features)\n", - "\n", - "ax = regions.plot()\n", - "features.plot(ax=ax, color=\"red\", alpha=0.5)\n", - "plt.show()" + " crs=WGS84_CRS,\n", + " index=gpd.pd.Index(name=FEATURES_INDEX, data=[1, 2, 3, 4]),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "regions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "features" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "folium_map = plot_regions(regions, colormap=[\"royalblue\"])\n", + "features.explore(\n", + " m=folium_map,\n", + " style_kwds=dict(color=\"red\", opacity=0.8, fillColor=\"red\", fillOpacity=0.5),\n", + " marker_kwds=dict(radius=3),\n", + ")" ] }, { @@ -72,13 +101,26 @@ "joiner = IntersectionJoiner()\n", "joint = joiner.transform(regions, features)\n", "\n", - "print(joint)\n", - "\n", - "ax = regions.plot(alpha=0.3)\n", - "ax = features.plot(ax=ax, color=\"red\", alpha=0.3)\n", - "joint.plot(ax=ax, color=\"green\", alpha=0.5)\n", - "\n", - "plt.show()" + "joint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "folium_map = plot_regions(regions, colormap=[\"royalblue\"])\n", + "features.explore(\n", + " m=folium_map,\n", + " style_kwds=dict(color=\"red\", opacity=0.5, fillColor=\"red\", fillOpacity=0.5),\n", + " marker_kwds=dict(radius=3),\n", + ")\n", + "joint.explore(\n", + " m=folium_map,\n", + " style_kwds=dict(color=\"yellow\", opacity=1.0, fillColor=\"yellow\", fillOpacity=1.0),\n", + " marker_kwds=dict(radius=3),\n", + ")" ] } ], @@ -98,7 +140,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.14" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/examples/loaders/geoparquet_loader.ipynb b/examples/loaders/geoparquet_loader.ipynb index 415002e9..e0f0a3cd 100644 --- a/examples/loaders/geoparquet_loader.ipynb +++ b/examples/loaders/geoparquet_loader.ipynb @@ -6,12 +6,11 @@ "metadata": {}, "outputs": [], "source": [ - "from pathlib import Path\n", - "\n", "import geopandas as gpd\n", "from shapely.geometry import box\n", "\n", - "from srai.loaders import GeoparquetLoader" + "from srai.loaders import GeoparquetLoader\n", + "from srai.constants import WGS84_CRS" ] }, { @@ -46,7 +45,7 @@ "metadata": {}, "outputs": [], "source": [ - "base_gdf.plot()" + "base_gdf.explore()" ] }, { @@ -79,7 +78,7 @@ "source": [ "# Create Texas bounding box\n", "bbox = box(minx=-106.645646, maxx=-93.508292, miny=25.837377, maxy=36.500704)\n", - "bbox_gdf = gpd.GeoDataFrame({\"geometry\": [bbox]}, crs=\"EPSG:4326\")\n", + "bbox_gdf = gpd.GeoDataFrame({\"geometry\": [bbox]}, crs=WGS84_CRS)\n", "\n", "cut_gdf = gpql.load(file_path=\"example_files/example.parquet\", area=bbox_gdf)\n", "cut_gdf" @@ -91,7 +90,7 @@ "metadata": {}, "outputs": [], "source": [ - "cut_gdf.plot()" + "cut_gdf.explore()" ] } ], @@ -111,7 +110,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.14" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/examples/loaders/gtfs_loader.ipynb b/examples/loaders/gtfs_loader.ipynb index bb24235e..1d4cac08 100644 --- a/examples/loaders/gtfs_loader.ipynb +++ b/examples/loaders/gtfs_loader.ipynb @@ -18,7 +18,6 @@ "from srai.loaders import GTFSLoader\n", "import gtfs_kit as gk\n", "import geopandas as gpd\n", - "import numpy as np\n", "from shapely.geometry import Point\n", "from srai.constants import WGS84_CRS\n", "from srai.utils import download_file" @@ -74,7 +73,7 @@ " crs=WGS84_CRS,\n", ")\n", "\n", - "stops_gdf.plot(markersize=1)" + "stops_gdf.explore(tiles=\"CartoDB positron\")" ] }, { @@ -96,6 +95,15 @@ "\n", "print(trips_gdf.columns)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trips_gdf" + ] } ], "metadata": { @@ -114,7 +122,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.14" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/examples/loaders/osm_online_loader.ipynb b/examples/loaders/osm_online_loader.ipynb index 5eb9f178..6a608ee1 100644 --- a/examples/loaders/osm_online_loader.ipynb +++ b/examples/loaders/osm_online_loader.ipynb @@ -18,6 +18,7 @@ "from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER\n", "from srai.loaders.osm_loaders import OSMOnlineLoader\n", "from srai.utils import geocode_to_region_gdf\n", + "from srai.plotting.folium_wrapper import plot_regions\n", "from functional import seq" ] }, @@ -125,8 +126,8 @@ "metadata": {}, "outputs": [], "source": [ - "ax = wroclaw_gdf.plot(color=\"blue\", alpha=0.3, figsize=(8, 8))\n", - "parks_gdf.plot(ax=ax, color=\"green\")" + "folium_map = plot_regions(wroclaw_gdf, colormap=[\"lightgray\"], tiles_style=\"CartoDB positron\")\n", + "parks_gdf.explore(m=folium_map, color=\"forestgreen\")" ] }, { @@ -163,8 +164,12 @@ "metadata": {}, "outputs": [], "source": [ - "ax = barcelona_gdf.plot(color=\"green\", alpha=0.3, figsize=(7, 7))\n", - "barcelona_objects_gdf.query(\"amenity.notna()\").plot(ax=ax, color=\"red\", markersize=1)" + "folium_map = plot_regions(barcelona_gdf, colormap=[\"lightgray\"], tiles_style=\"CartoDB positron\")\n", + "barcelona_objects_gdf.query(\"amenity.notna()\").explore(\n", + " m=folium_map,\n", + " color=\"orangered\",\n", + " marker_kwds=dict(radius=1),\n", + ")" ] } ], diff --git a/examples/loaders/osm_way_loader.ipynb b/examples/loaders/osm_way_loader.ipynb index 573e831d..309ed209 100644 --- a/examples/loaders/osm_way_loader.ipynb +++ b/examples/loaders/osm_way_loader.ipynb @@ -15,13 +15,12 @@ "outputs": [], "source": [ "import geopandas as gpd\n", - "import osmnx as ox\n", - "import pandas as pd\n", - "from keplergl import KeplerGl\n", "import shapely.geometry as shpg\n", "\n", "from srai.loaders.osm_way_loader import NetworkType, OSMWayLoader\n", - "from srai.constants import WGS84_CRS" + "from srai.constants import WGS84_CRS, REGIONS_INDEX\n", + "from srai.plotting.folium_wrapper import plot_regions\n", + "from srai.utils import geocode_to_region_gdf" ] }, { @@ -55,8 +54,12 @@ " (17.0994473, 51.1083722),\n", " ]\n", ")\n", - "gdf_place = gpd.GeoDataFrame({\"geometry\": [polygon1, polygon2]}, crs=WGS84_CRS)\n", - "gdf_place.plot()" + "gdf_place = gpd.GeoDataFrame(\n", + " {\"geometry\": [polygon1, polygon2]},\n", + " crs=WGS84_CRS,\n", + " index=gpd.pd.Index(name=REGIONS_INDEX, data=[1, 2]),\n", + ")\n", + "plot_regions(gdf_place)" ] }, { @@ -67,8 +70,10 @@ "source": [ "osmwl = OSMWayLoader(NetworkType.BIKE, metadata=True)\n", "gdf_nodes, gdf_edges = osmwl.load(gdf_place)\n", - "ax = gdf_edges.plot(linewidth=1, figsize=(12, 7))\n", - "gdf_nodes.plot(ax=ax, markersize=3, color=\"red\")" + "\n", + "folium_map = plot_regions(gdf_place, colormap=[\"lightgray\"], tiles_style=\"CartoDB positron\")\n", + "gdf_edges.explore(m=folium_map)\n", + "gdf_nodes.explore(m=folium_map, color=\"orangered\")" ] }, { @@ -103,8 +108,8 @@ "metadata": {}, "outputs": [], "source": [ - "gdf_place = ox.geocode_to_gdf(\"Wroclaw, Poland\")\n", - "gdf_place.plot()" + "gdf_place = geocode_to_region_gdf(\"Wroclaw, Poland\")\n", + "plot_regions(gdf_place)" ] }, { @@ -114,9 +119,18 @@ "outputs": [], "source": [ "osmwl = OSMWayLoader(NetworkType.DRIVE)\n", - "gdf_nodes, gdf_edges = osmwl.load(gdf_place)\n", - "ax = gdf_edges.plot(linewidth=1, figsize=(12, 7))\n", - "gdf_nodes.plot(ax=ax, markersize=3, color=\"red\")" + "gdf_nodes, gdf_edges = osmwl.load(gdf_place)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "folium_map = plot_regions(gdf_place, colormap=[\"rgba(0,0,0,0)\"], tiles_style=\"CartoDB positron\")\n", + "gdf_edges.explore(m=folium_map)\n", + "gdf_nodes.explore(m=folium_map, color=\"orangered\")" ] }, { @@ -154,7 +168,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.1" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/examples/neighbourhoods/README.md b/examples/neighbourhoods/README.md new file mode 100644 index 00000000..00253fee --- /dev/null +++ b/examples/neighbourhoods/README.md @@ -0,0 +1,6 @@ +# Neighbourhoods + +Examples illustrating the usage of every Neighbourhood. + +- [AdjacencyNeighbourhood](adjacency_neighbourhood.ipynb) +- [H3Neighbourhood](h3_neighbourhood.ipynb) diff --git a/examples/neighbourhoods/adjacency_neighbourhood.ipynb b/examples/neighbourhoods/adjacency_neighbourhood.ipynb new file mode 100644 index 00000000..5051f616 --- /dev/null +++ b/examples/neighbourhoods/adjacency_neighbourhood.ipynb @@ -0,0 +1,366 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from shapely.geometry import Point\n", + "import geopandas as gpd\n", + "\n", + "from srai.constants import WGS84_CRS\n", + "from srai.neighbourhoods import AdjacencyNeighbourhood\n", + "from srai.regionizers import AdministrativeBoundaryRegionizer, VoronoiRegionizer\n", + "from srai.utils.geocode import geocode_to_region_gdf\n", + "from srai.plotting.folium_wrapper import plot_regions, plot_neighbours, plot_all_neighbourhood" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adjacency Neighbourhood\n", + "It can generate neighbourhoods for all geodataframes with touching geometries." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Real boundaries example - Italy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "it_gdf = geocode_to_region_gdf(query=[\"R365331\"], by_osmid=True)\n", + "plot_regions(it_gdf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "regionizer = AdministrativeBoundaryRegionizer(admin_level=4)\n", + "it_regions_gdf = regionizer.transform(it_gdf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_regions(it_regions_gdf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "neighbourhood = AdjacencyNeighbourhood(it_regions_gdf)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Nearest neighbours" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "region_id = \"Lazio\"\n", + "neighbours = neighbourhood.get_neighbours(region_id)\n", + "neighbours" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_neighbours(it_regions_gdf, region_id, neighbours)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Neighbours at a distance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "region_id = \"Basilicata\"\n", + "neighbours = neighbourhood.get_neighbours_at_distance(region_id, 2)\n", + "neighbours" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_neighbours(it_regions_gdf, region_id, neighbours)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Regions without neighbours" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "region_id = \"Sardinia\"\n", + "neighbours = neighbourhood.get_neighbours(region_id)\n", + "neighbours" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_neighbours(it_regions_gdf, region_id, neighbours)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Plotting all neighbourhood" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "region_id = \"Campania\"\n", + "plot_all_neighbourhood(it_regions_gdf, region_id, neighbourhood)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Voronoi example - Australia" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "au_gdf = geocode_to_region_gdf(query=[\"R80500\"], by_osmid=True)\n", + "plot_regions(au_gdf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from shapely.geometry import Point\n", + "import geopandas as gpd\n", + "from srai.constants import WGS84_CRS\n", + "\n", + "\n", + "def generate_random_points(shape, n_points=500):\n", + " minx, miny, maxx, maxy = shape.bounds\n", + " pts = []\n", + "\n", + " while len(pts) < 4:\n", + " randx = np.random.uniform(minx, maxx, n_points)\n", + " randy = np.random.uniform(miny, maxy, n_points)\n", + " coords = np.vstack((randx, randy)).T\n", + "\n", + " # use only the points inside the geographic area\n", + " pts = [p for p in list(map(Point, coords)) if p.within(shape)]\n", + "\n", + " del coords # not used any more\n", + "\n", + " return pts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pts = generate_random_points(au_gdf.geometry[0])\n", + "\n", + "au_seeds_gdf = gpd.GeoDataFrame(\n", + " {\"geometry\": pts},\n", + " index=list(range(len(pts))),\n", + " crs=WGS84_CRS,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vr = VoronoiRegionizer(seeds=au_seeds_gdf)\n", + "au_result_gdf = vr.transform(gdf=au_gdf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "folium_map = plot_regions(au_result_gdf, tiles_style=\"CartoDB positron\")\n", + "au_seeds_gdf.explore(\n", + " m=folium_map,\n", + " style_kwds=dict(color=\"#444\", opacity=1, fillColor=\"#f2f2f2\", fillOpacity=1),\n", + " marker_kwds=dict(radius=3),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "neighbourhood = AdjacencyNeighbourhood(regions_gdf=au_result_gdf)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Nearest neighbours" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "region_id = 0\n", + "neighbours = neighbourhood.get_neighbours(region_id)\n", + "neighbours" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_neighbours(au_result_gdf, region_id, neighbours)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Neighbours at a distance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "region_id = 0\n", + "neighbours = neighbourhood.get_neighbours_at_distance(region_id, 3)\n", + "neighbours" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_neighbours(au_result_gdf, region_id, neighbours)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Plotting all neighbourhood" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "region_id = 0\n", + "plot_all_neighbourhood(au_result_gdf, region_id, neighbourhood)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/neighbourhoods/h3_neighbourhood.ipynb b/examples/neighbourhoods/h3_neighbourhood.ipynb index 274ca4ca..41e5b65b 100644 --- a/examples/neighbourhoods/h3_neighbourhood.ipynb +++ b/examples/neighbourhoods/h3_neighbourhood.ipynb @@ -8,28 +8,8 @@ "source": [ "from srai.neighbourhoods import H3Neighbourhood\n", "from srai.regionizers import H3Regionizer\n", - "from srai.utils import geocode_to_region_gdf" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import geopandas as gpd\n", - "from typing import Set, Optional\n", - "\n", - "\n", - "def show_neighbours(\n", - " regions_gdf: gpd.GeoDataFrame, region_id: str, neighbours_ids: Optional[Set[str]] = None\n", - "):\n", - " regions = regions_gdf.copy()\n", - " regions[\"type\"] = \"none\"\n", - " regions.loc[region_id, \"type\"] = \"region\"\n", - " if neighbours_ids is not None:\n", - " regions.loc[list(neighbours_ids), \"type\"] = \"neighbour\"\n", - " return regions.explore(tiles=\"CartoDB positron\", column=\"type\")" + "from srai.utils.geocode import geocode_to_region_gdf\n", + "from srai.plotting.folium_wrapper import plot_neighbours, plot_all_neighbourhood" ] }, { @@ -67,7 +47,7 @@ "neighbourhood_with_regions = H3Neighbourhood(regions_gdf)\n", "region_id = \"881e204089fffff\"\n", "neighbours_ids = neighbourhood_with_regions.get_neighbours(region_id)\n", - "show_neighbours(regions_gdf, region_id, neighbours_ids)" + "plot_neighbours(regions_gdf, region_id, neighbours_ids)" ] }, { @@ -85,7 +65,7 @@ "outputs": [], "source": [ "neighbours_ids = neighbourhood_with_regions.get_neighbours_at_distance(region_id, 3)\n", - "show_neighbours(regions_gdf, region_id, neighbours_ids)" + "plot_neighbours(regions_gdf, region_id, neighbours_ids)" ] }, { @@ -103,7 +83,24 @@ "outputs": [], "source": [ "neighbours_ids = neighbourhood_with_regions.get_neighbours_up_to_distance(region_id, 3)\n", - "show_neighbours(regions_gdf, region_id, neighbours_ids)" + "plot_neighbours(regions_gdf, region_id, neighbours_ids)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Full neighbourhood" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_all_neighbourhood(regions_gdf, region_id, neighbourhood_with_regions)" ] }, { @@ -139,7 +136,7 @@ "metadata": {}, "outputs": [], "source": [ - "show_neighbours(regions_gdf, edge_region_id, neighbours_ids)" + "plot_neighbours(regions_gdf, edge_region_id, neighbours_ids)" ] }, { @@ -193,7 +190,16 @@ "metadata": {}, "outputs": [], "source": [ - "show_neighbours(regions_gdf, edge_region_id, neighbours_ids)" + "plot_neighbours(regions_gdf, edge_region_id, neighbours_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plot_all_neighbourhood(regions_gdf, edge_region_id, neighbourhood_with_regions)" ] } ], @@ -213,7 +219,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/examples/regionizers/administrative_boundary_regionizer.ipynb b/examples/regionizers/administrative_boundary_regionizer.ipynb index dc920304..03f8fffb 100644 --- a/examples/regionizers/administrative_boundary_regionizer.ipynb +++ b/examples/regionizers/administrative_boundary_regionizer.ipynb @@ -7,11 +7,12 @@ "outputs": [], "source": [ "import geopandas as gpd\n", - "from osmnx import geocode_to_gdf\n", "import plotly.express as px\n", "from shapely.geometry import Point, box\n", "\n", - "from srai.regionizers import AdministrativeBoundaryRegionizer" + "from srai.regionizers import AdministrativeBoundaryRegionizer\n", + "from srai.plotting.folium_wrapper import plot_regions\n", + "from srai.utils import geocode_to_region_gdf" ] }, { @@ -31,8 +32,8 @@ "metadata": {}, "outputs": [], "source": [ - "wroclaw_gdf = geocode_to_gdf(query=[\"R451516\"], by_osmid=True)\n", - "wroclaw_gdf.plot()" + "wroclaw_gdf = geocode_to_region_gdf(query=[\"R451516\"], by_osmid=True)\n", + "plot_regions(wroclaw_gdf)" ] }, { @@ -60,30 +61,7 @@ "metadata": {}, "outputs": [], "source": [ - "fig = px.choropleth_mapbox(\n", - " wro_result_gdf,\n", - " geojson=wro_result_gdf,\n", - " color=wro_result_gdf.index,\n", - " locations=wro_result_gdf.index,\n", - " center={\"lat\": 51.125, \"lon\": 16.99},\n", - " mapbox_style=\"carto-positron\",\n", - " zoom=10.5,\n", - ")\n", - "fig.update_layout(margin={\"r\": 0, \"t\": 0, \"l\": 0, \"b\": 0})\n", - "fig.update_traces(marker={\"opacity\": 0.6}, selector=dict(type=\"choroplethmapbox\"))\n", - "fig.update_traces(showlegend=False)\n", - "minx, miny, maxx, maxy = wroclaw_gdf.geometry[0].bounds\n", - "fig.update_geos(\n", - " projection_type=\"equirectangular\",\n", - " lataxis_range=[miny - 0.1, maxy + 0.1],\n", - " lonaxis_range=[minx - 0.1, maxx + 0.1],\n", - " showlakes=False,\n", - " showcountries=False,\n", - " showframe=False,\n", - " resolution=50,\n", - ")\n", - "fig.update_layout(height=600, width=800, margin={\"r\": 0, \"t\": 0, \"l\": 0, \"b\": 0})\n", - "fig.show(renderer=\"png\") # replace with fig.show() to allow interactivity" + "plot_regions(wro_result_gdf)" ] }, { @@ -103,9 +81,7 @@ "metadata": {}, "outputs": [], "source": [ - "madagascar_bbox = box(\n", - " minx=43.2541870461, miny=-25.6014344215, maxx=50.4765368996, maxy=-12.0405567359\n", - ")\n", + "madagascar_bbox = box(minx=43.21418, miny=-25.61143, maxx=50.48704, maxy=-11.951126)\n", "madagascar_bbox_gdf = gpd.GeoDataFrame({\"geometry\": [madagascar_bbox]}, crs=\"EPSG:4326\")" ] }, @@ -134,30 +110,7 @@ "metadata": {}, "outputs": [], "source": [ - "fig = px.choropleth(\n", - " madagascar_result_gdf,\n", - " geojson=madagascar_result_gdf.geometry,\n", - " locations=madagascar_result_gdf.index,\n", - " color=madagascar_result_gdf.index,\n", - " color_continuous_scale=px.colors.sequential.Viridis,\n", - ")\n", - "\n", - "fig.update_traces(marker={\"opacity\": 0.6}, selector=dict(type=\"choropleth\"))\n", - "fig.update_layout(coloraxis_showscale=False)\n", - "fig.update_traces(showlegend=False)\n", - "minx, miny, maxx, maxy = madagascar_bbox.bounds\n", - "fig.update_geos(\n", - " projection_type=\"equirectangular\",\n", - " lataxis_range=[miny - 0.1, maxy + 0.1],\n", - " lonaxis_range=[minx - 0.1, maxx + 0.1],\n", - " visible=False,\n", - " showlakes=False,\n", - " showcountries=False,\n", - " showframe=False,\n", - " resolution=50,\n", - ")\n", - "fig.update_layout(height=800, width=420, margin={\"r\": 0, \"t\": 0, \"l\": 0, \"b\": 0})\n", - "fig.show(renderer=\"png\") # replace with fig.show() to allow interactivity" + "plot_regions(madagascar_result_gdf)" ] }, { @@ -214,29 +167,7 @@ "metadata": {}, "outputs": [], "source": [ - "fig = px.choropleth(\n", - " eu_result_gdf,\n", - " geojson=eu_result_gdf.geometry,\n", - " locations=eu_result_gdf.index,\n", - " color=eu_result_gdf.index,\n", - " color_continuous_scale=px.colors.sequential.Viridis,\n", - ")\n", - "\n", - "fig.update_traces(marker={\"opacity\": 0.6}, selector=dict(type=\"choropleth\"))\n", - "fig.update_layout(coloraxis_showscale=False)\n", - "fig.update_traces(showlegend=False)\n", - "minx, miny, maxx, maxy = eu_bbox.bounds\n", - "fig.update_geos(\n", - " projection_type=\"equirectangular\",\n", - " lataxis_range=[miny - 1, maxy + 1],\n", - " lonaxis_range=[minx - 1, maxx + 1],\n", - " showlakes=False,\n", - " showcountries=False,\n", - " showframe=False,\n", - " resolution=50,\n", - ")\n", - "fig.update_layout(height=800, width=1000, margin={\"r\": 0, \"t\": 0, \"l\": 0, \"b\": 0})\n", - "fig.show(renderer=\"png\") # replace with fig.show() to allow interactivity" + "plot_regions(eu_result_gdf)" ] }, { @@ -368,24 +299,12 @@ "metadata": {}, "outputs": [], "source": [ - "# Paris\n", - "fig = px.choropleth_mapbox(\n", - " paris_districts_result,\n", - " geojson=paris_districts_result,\n", - " color=paris_districts_result.index,\n", - " locations=paris_districts_result.index,\n", - " center={\"lat\": 48.858, \"lon\": 2.353},\n", - " mapbox_style=\"carto-positron\",\n", - " zoom=10.8,\n", - ")\n", - "fig2 = px.scatter_mapbox(stations_gdf, lat=stations_gdf.geometry.y, lon=stations_gdf.geometry.x)\n", - "fig.add_trace(fig2.data[0])\n", - "fig.update_layout(margin={\"r\": 0, \"t\": 0, \"l\": 0, \"b\": 0})\n", - "fig.update_traces(showlegend=False)\n", - "fig.update_traces(marker={\"opacity\": 0.6}, selector=dict(type=\"choroplethmapbox\"))\n", - "fig.update_traces(marker_color=\"black\", marker_size=5, selector=dict(type=\"scattermapbox\"))\n", - "fig.update_layout(height=800, width=800, margin={\"r\": 0, \"t\": 0, \"l\": 0, \"b\": 0})\n", - "fig.show(renderer=\"png\") # replace with fig.show() to allow interactivity" + "folium_map = plot_regions(paris_districts_result, tiles_style=\"CartoDB positron\")\n", + "stations_gdf.explore(\n", + " m=folium_map,\n", + " style_kwds=dict(color=\"#444\", opacity=1, fillColor=\"#f2f2f2\", fillOpacity=1),\n", + " marker_kwds=dict(radius=1),\n", + ")" ] } ], diff --git a/examples/regionizers/h3_regionizer.ipynb b/examples/regionizers/h3_regionizer.ipynb index d099b6f1..07116641 100644 --- a/examples/regionizers/h3_regionizer.ipynb +++ b/examples/regionizers/h3_regionizer.ipynb @@ -14,10 +14,11 @@ "outputs": [], "source": [ "import geopandas as gpd\n", - "import matplotlib.pyplot as plt\n", "from shapely import geometry\n", "\n", - "from srai.regionizers import H3Regionizer" + "from srai.regionizers import H3Regionizer\n", + "from srai.plotting.folium_wrapper import plot_regions\n", + "from srai.constants import WGS84_CRS" ] }, { @@ -54,9 +55,9 @@ " ),\n", " geometry.Polygon(shell=[(-0.25, 0), (0.25, 0), (0, 0.2)]),\n", " ],\n", - " crs=\"epsg:4326\",\n", + " crs=WGS84_CRS,\n", ")\n", - "gdf.plot()" + "gdf.explore()" ] }, { @@ -73,7 +74,7 @@ "metadata": {}, "outputs": [], "source": [ - "resolution = 4" + "resolution = 5" ] }, { @@ -99,9 +100,8 @@ "regionizer = H3Regionizer(resolution, buffer=False)\n", "gdf_h3 = regionizer.transform(gdf)\n", "\n", - "ax = gdf.plot()\n", - "gdf_h3.plot(ax=ax, color=\"red\", alpha=0.5)\n", - "plt.show()" + "folium_map = gdf.explore()\n", + "plot_regions(gdf_h3, colormap=[\"red\"], map=folium_map)" ] }, { @@ -127,9 +127,8 @@ "regionizer_buffered = H3Regionizer(resolution)\n", "gdf_h3_buffered = regionizer_buffered.transform(gdf)\n", "\n", - "ax = gdf.plot()\n", - "gdf_h3_buffered.plot(ax=ax, color=\"red\", alpha=0.5)\n", - "plt.show()" + "folium_map = gdf.explore()\n", + "plot_regions(gdf_h3_buffered, colormap=[\"red\"], map=folium_map)" ] } ], @@ -149,7 +148,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.1" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/examples/regionizers/s2_regionizer.ipynb b/examples/regionizers/s2_regionizer.ipynb index 4150e682..9a5cb0af 100644 --- a/examples/regionizers/s2_regionizer.ipynb +++ b/examples/regionizers/s2_regionizer.ipynb @@ -14,10 +14,11 @@ "outputs": [], "source": [ "import geopandas as gpd\n", - "import matplotlib.pyplot as plt\n", "from shapely import geometry\n", "\n", - "from srai.regionizers import S2Regionizer" + "from srai.regionizers import S2Regionizer\n", + "from srai.plotting.folium_wrapper import plot_regions\n", + "from srai.constants import WGS84_CRS" ] }, { @@ -54,9 +55,9 @@ " ),\n", " geometry.Polygon(shell=[(-0.25, 0), (0.25, 0), (0, 0.2)]),\n", " ],\n", - " crs=\"epsg:4326\",\n", + " crs=WGS84_CRS,\n", ")\n", - "gdf.plot()" + "gdf.explore()" ] }, { @@ -73,7 +74,7 @@ "metadata": {}, "outputs": [], "source": [ - "resolution = 11" + "resolution = 10" ] }, { @@ -94,9 +95,8 @@ "regionizer = S2Regionizer(resolution, buffer=False)\n", "gdf_s2 = regionizer.transform(gdf)\n", "\n", - "ax = gdf.plot()\n", - "gdf_s2.plot(ax=ax, color=\"red\", alpha=0.5)\n", - "plt.show()" + "folium_map = gdf.explore()\n", + "plot_regions(gdf_s2, colormap=[\"red\"], map=folium_map)" ] }, { @@ -115,11 +115,10 @@ "outputs": [], "source": [ "regionizer = S2Regionizer(resolution, buffer=True)\n", - "gdf_s2 = regionizer.transform(gdf)\n", + "gdf_s2_buffered = regionizer.transform(gdf)\n", "\n", - "ax = gdf.plot()\n", - "gdf_s2.plot(ax=ax, color=\"red\", alpha=0.5)\n", - "plt.show()" + "folium_map = gdf.explore()\n", + "plot_regions(gdf_s2_buffered, colormap=[\"red\"], map=folium_map)" ] } ], @@ -139,7 +138,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.14" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/examples/regionizers/voronoi_regionizer.ipynb b/examples/regionizers/voronoi_regionizer.ipynb index ffee83f8..18d288f6 100644 --- a/examples/regionizers/voronoi_regionizer.ipynb +++ b/examples/regionizers/voronoi_regionizer.ipynb @@ -8,12 +8,13 @@ "source": [ "import geopandas as gpd\n", "import numpy as np\n", - "from osmnx import geocode_to_gdf\n", "import plotly.express as px\n", - "from shapely.geometry import MultiPolygon, Point, Polygon\n", + "from shapely.geometry import Point\n", "\n", "from srai.regionizers import VoronoiRegionizer\n", - "from srai.constants import WGS84_CRS" + "from srai.constants import WGS84_CRS\n", + "from srai.plotting.folium_wrapper import plot_regions\n", + "from srai.utils import geocode_to_region_gdf" ] }, { @@ -84,6 +85,14 @@ "result_gdf" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Globe view" + ] + }, { "cell_type": "code", "execution_count": null, @@ -112,6 +121,28 @@ "fig.show(renderer=\"png\") # replace with fig.show() to allow interactivity" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2D OSM View" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "folium_map = plot_regions(result_gdf)\n", + "seeds_gdf.explore(\n", + " m=folium_map,\n", + " style_kwds=dict(color=\"#444\", opacity=1, fillColor=\"#f2f2f2\", fillOpacity=1),\n", + " marker_kwds=dict(radius=3),\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -127,10 +158,9 @@ "metadata": {}, "outputs": [], "source": [ - "uk_gdf = geocode_to_gdf(query=[\"R62149\"], by_osmid=True)\n", + "uk_gdf = geocode_to_region_gdf(query=[\"R62149\"], by_osmid=True)\n", "\n", - "uk_gdf = uk_gdf.to_crs(epsg=4326) # convert to wgs84\n", - "uk_gdf_shape = uk_gdf.iloc[0].geometry # get the Polygon" + "uk_shape = uk_gdf.iloc[0].geometry # get the Polygon" ] }, { @@ -171,7 +201,7 @@ "metadata": {}, "outputs": [], "source": [ - "pts = generate_random_points(uk_gdf_shape)\n", + "pts = generate_random_points(uk_shape)\n", "\n", "uk_seeds_gdf = gpd.GeoDataFrame(\n", " {\"geometry\": pts},\n", @@ -180,13 +210,26 @@ ")" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Random points on a map" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "uk_seeds_gdf.plot()" + "folium_map = plot_regions(uk_gdf, tiles_style=\"CartoDB positron\")\n", + "uk_seeds_gdf.explore(\n", + " m=folium_map,\n", + " style_kwds=dict(color=\"#444\", opacity=1, fillColor=\"#f2f2f2\", fillOpacity=1),\n", + " marker_kwds=dict(radius=3),\n", + ")" ] }, { @@ -216,35 +259,26 @@ "uk_result_gdf.head()" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Generated regions on a map" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "fig = px.choropleth(\n", - " uk_result_gdf,\n", - " geojson=uk_result_gdf.geometry,\n", - " locations=uk_result_gdf.index,\n", - " color=uk_result_gdf.index,\n", - " color_continuous_scale=px.colors.qualitative.Plotly,\n", - ")\n", - "fig2 = px.scatter_geo(uk_seeds_gdf, lat=uk_seeds_gdf.geometry.y, lon=uk_seeds_gdf.geometry.x)\n", - "fig.update_traces(marker={\"opacity\": 0.6}, selector=dict(type=\"choropleth\"))\n", - "fig.add_trace(fig2.data[0])\n", - "fig.update_traces(marker_color=\"black\", marker_size=6, selector=dict(type=\"scattergeo\"))\n", - "fig.update_layout(coloraxis_showscale=False)\n", - "minx, miny, maxx, maxy = uk_gdf_shape.bounds\n", - "fig.update_geos(\n", - " projection_type=\"natural earth\",\n", - " lataxis_range=[miny - 1, maxy + 1],\n", - " lonaxis_range=[minx - 1, maxx + 1],\n", - " resolution=50,\n", - " showframe=False,\n", - " showlakes=False,\n", - ")\n", - "fig.update_layout(height=800, width=675, margin={\"r\": 0, \"t\": 0, \"l\": 0, \"b\": 0})\n", - "fig.show(renderer=\"png\") # replace with fig.show() to allow interactivity" + "folium_map = plot_regions(uk_result_gdf, tiles_style=\"CartoDB positron\")\n", + "uk_seeds_gdf.explore(\n", + " m=folium_map,\n", + " style_kwds=dict(color=\"#444\", opacity=1, fillColor=\"#f2f2f2\", fillOpacity=1),\n", + " marker_kwds=dict(radius=3),\n", + ")" ] }, { @@ -314,37 +348,36 @@ "rail_result_gdf = vr_rail.transform()" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Germany view" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# Germany\n", - "fig = px.choropleth(\n", - " rail_result_gdf,\n", - " geojson=rail_result_gdf.geometry,\n", - " locations=rail_result_gdf.index,\n", - " color=rail_result_gdf.index,\n", - " color_continuous_scale=px.colors.sequential.Viridis,\n", + "folium_map = plot_regions(rail_result_gdf, tiles_style=\"CartoDB positron\")\n", + "stations_gdf.explore(\n", + " m=folium_map,\n", + " style_kwds=dict(color=\"#444\", opacity=1, fillColor=\"#f2f2f2\", fillOpacity=1),\n", + " marker_kwds=dict(radius=1),\n", ")\n", - "fig2 = px.scatter_geo(stations_gdf, lat=stations_gdf.geometry.y, lon=stations_gdf.geometry.x)\n", - "fig.update_traces(marker={\"opacity\": 0.6}, selector=dict(type=\"choropleth\"))\n", - "fig.add_trace(fig2.data[0])\n", - "fig.update_traces(marker_color=\"white\", marker_size=2, selector=dict(type=\"scattergeo\"))\n", - "fig.update_layout(coloraxis_showscale=False)\n", - "fig.update_geos(\n", - " projection_type=\"orthographic\",\n", - " projection_rotation_lon=10,\n", - " projection_rotation_lat=51,\n", - " projection_scale=11,\n", - " resolution=50,\n", - " showlakes=False,\n", - " showcountries=True,\n", - " showframe=False,\n", - ")\n", - "fig.update_layout(height=800, width=800, margin={\"r\": 0, \"t\": 0, \"l\": 0, \"b\": 0})\n", - "fig.show(renderer=\"png\") # replace with fig.show() to allow interactivity" + "folium_map.fit_bounds([(54.98310, 5.98865), (47.30248, 15.01699)])\n", + "folium_map" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Berlin view" ] }, { @@ -354,26 +387,12 @@ "outputs": [], "source": [ "# Berlin\n", - "fig = px.choropleth_mapbox(\n", - " rail_result_gdf,\n", - " geojson=rail_result_gdf,\n", - " color=rail_result_gdf.index,\n", - " locations=rail_result_gdf.index,\n", - " center={\"lat\": 52.51637, \"lon\": 13.40665},\n", - " mapbox_style=\"open-street-map\",\n", - " zoom=11,\n", - ")\n", - "fig2 = px.scatter_mapbox(stations_gdf, lat=stations_gdf.geometry.y, lon=stations_gdf.geometry.x)\n", - "fig.add_trace(fig2.data[0])\n", - "fig.update_layout(margin={\"r\": 0, \"t\": 0, \"l\": 0, \"b\": 0})\n", - "fig.update_layout(coloraxis_showscale=False)\n", - "fig.update_traces(marker={\"opacity\": 0.6}, selector=dict(type=\"choroplethmapbox\"))\n", - "fig.update_traces(marker_color=\"white\", marker_size=5, selector=dict(type=\"scattermapbox\"))\n", - "fig.update_layout(height=800, width=800, margin={\"r\": 0, \"t\": 0, \"l\": 0, \"b\": 0})\n", - "fig.show(renderer=\"png\") # replace with fig.show() to allow interactivity" + "folium_map.fit_bounds([(52.51637 + 0.1, 13.40665 + 0.1), (52.51637 - 0.1, 13.40665 - 0.1)])\n", + "folium_map" ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -381,7 +400,11 @@ "\n", "Showing the difference between working on the sphere and projected 2D plane.\n", "\n", - "Uses `geovoronoi` package as an example." + "Uses `shapely.voronoi_polygons` function as an example.\n", + "\n", + "`shapely.voronoi_diagram` function allows for a quick division of the Earth using list of seeds on a projected 2d plane.\n", + "This results in straight lines with angles distorted and polygons differences\n", + "might be substantial during comparisons or any area calculations." ] }, { @@ -390,16 +413,7 @@ "metadata": {}, "outputs": [], "source": [ - "\"\"\"\n", - "Geovoronoi package allows for a quick division of the Earth using list of seeds on a projected 2d plane.\n", - "This results in straight lines with angles distorted and polygons differences\n", - "might be substantial during comparisons or any area calculations.\n", - "\"\"\"\n", - "# geovoronoi package isn't used in this library, but is optional and can be installed using\n", - "# pip install geovoronoi\n", - "from geovoronoi import voronoi_regions_from_coords\n", - "\n", - "from shapely.geometry.polygon import orient\n", + "from shapely.ops import voronoi_diagram\n", "from plotly.subplots import make_subplots" ] }, @@ -409,9 +423,8 @@ "metadata": {}, "outputs": [], "source": [ - "pl_gdf = geocode_to_gdf(query=[\"R49715\"], by_osmid=True)\n", + "pl_gdf = geocode_to_region_gdf(query=[\"R49715\"], by_osmid=True)\n", "\n", - "pl_gdf = pl_gdf.to_crs(epsg=4326) # convert to wgs84\n", "pl_gdf_shape = pl_gdf.iloc[0].geometry # get the Polygon" ] }, @@ -436,8 +449,8 @@ "metadata": {}, "outputs": [], "source": [ - "region_polys, region_pts, unassigned_pts = voronoi_regions_from_coords(\n", - " pts, pl_gdf_shape, return_unassigned_points=True, per_geom=False\n", + "region_polygons = list(\n", + " voronoi_diagram(pl_seeds_gdf.geometry.unary_union, envelope=pl_gdf_shape).normalize().geoms\n", ")" ] }, @@ -447,18 +460,11 @@ "metadata": {}, "outputs": [], "source": [ - "def orient_geom(geom):\n", - " if type(geom) == Polygon:\n", - " return orient(geom, sign=-1)\n", - " elif type(geom) == MultiPolygon:\n", - " return MultiPolygon([orient(g, sign=-1) for g in geom.geoms])\n", - "\n", - "\n", "pl_regions_2d_gdf = gpd.GeoDataFrame(\n", - " {\"geometry\": [orient_geom(geom) for geom in region_polys.values()]},\n", - " index=list(range(len(region_polys))),\n", + " {\"geometry\": [polygon for polygon in region_polygons]},\n", + " index=list(range(len(region_polygons))),\n", " crs=WGS84_CRS,\n", - ")" + ").clip(pl_gdf_shape)" ] }, { diff --git a/pdm.lock b/pdm.lock index d611dda4..2c9d9b37 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2220,7 +2220,7 @@ summary = "Backport of pathlib-compatible object wrapper for zip files" [metadata] lock_version = "4.1" -content_hash = "sha256:7b94cdee7c97e4eacbb931cdb1449e10c0b47589dccd92045ad81350a607378f" +content_hash = "sha256:962e57d64c978b93586dde39aafde37e2c322e5d5764b2af1be7b5d5fa101ef4" [metadata.files] "aiohttp 3.8.4" = [ diff --git a/pyproject.toml b/pyproject.toml index ecb7c54e..a4c969a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,9 @@ osm = ["osmium>=3.4.1", "osmnx>=1.2.2", "overpass>=0.7"] voronoi = ["pymap3d>=2.9.1", "haversine>=2.7.0", "spherical-geometry>=1.2.23"] # pdm add -G gtfs gtfs = ["gtfs-kit>=5.0.0"] -all = ["srai[osm,voronoi,gtfs]"] +# pdm add -G plotting +plotting = ["folium>=0.14.0", "mapclassify>=2.5.0", "plotly>=5.10.0", "kaleido==0.2.1"] +all = ["srai[osm,voronoi,gtfs,plotting]"] [build-system] requires = ["pdm-pep517>=1.0.0"] @@ -84,10 +86,6 @@ test = [ visualization = [ "keplergl>=0.3.2", "matplotlib>=3.6.1", - "plotly>=5.10.0", - "kaleido==0.2.1", - "folium>=0.14.0", - "mapclassify>=2.5.0", ] # pdm add -dG docs docs = [ diff --git a/srai/embedders/_base.py b/srai/embedders/_base.py index 6d149b76..1a2c596f 100644 --- a/srai/embedders/_base.py +++ b/srai/embedders/_base.py @@ -5,6 +5,8 @@ import geopandas as gpd import pandas as pd +from srai.constants import GEOMETRY_COLUMN + class Embedder(abc.ABC): """Abstract class for embedders.""" @@ -69,6 +71,6 @@ def _validate_indexes( ) def _remove_geometry_if_present(self, data: gpd.GeoDataFrame) -> pd.DataFrame: - if "geometry" in data.columns: - data = data.drop(columns="geometry") + if GEOMETRY_COLUMN in data.columns: + data = data.drop(columns=GEOMETRY_COLUMN) return pd.DataFrame(data) diff --git a/srai/loaders/geoparquet_loader.py b/srai/loaders/geoparquet_loader.py index 297c6117..0661a696 100644 --- a/srai/loaders/geoparquet_loader.py +++ b/srai/loaders/geoparquet_loader.py @@ -9,7 +9,7 @@ import geopandas as gpd -from srai.constants import WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, WGS84_CRS class GeoparquetLoader: @@ -48,8 +48,8 @@ def load( Returns: gpd.GeoDataFrame: Loaded geoparquet file as a GeoDataFrame. """ - if columns and "geometry" not in columns: - columns.append("geometry") + if columns and GEOMETRY_COLUMN not in columns: + columns.append(GEOMETRY_COLUMN) gdf = gpd.read_parquet(path=file_path, columns=columns) diff --git a/srai/loaders/gtfs_loader.py b/srai/loaders/gtfs_loader.py index e2c60bc2..1b669def 100644 --- a/srai/loaders/gtfs_loader.py +++ b/srai/loaders/gtfs_loader.py @@ -17,7 +17,7 @@ import pandas as pd from shapely.geometry import Point -from srai.constants import WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, WGS84_CRS from srai.utils._optional import import_optional_dependencies if TYPE_CHECKING: # pragma: no cover @@ -69,13 +69,13 @@ def load( directions_df = self._load_directions(feed) stops_df = feed.stops[["stop_id", "stop_lat", "stop_lon"]].set_index("stop_id") - stops_df["geometry"] = stops_df.apply( + stops_df[GEOMETRY_COLUMN] = stops_df.apply( lambda row: Point([row["stop_lon"], row["stop_lat"]]), axis=1 ) result_gdf = gpd.GeoDataFrame( - trips_df.merge(stops_df["geometry"], how="inner", on="stop_id"), - geometry="geometry", + trips_df.merge(stops_df[GEOMETRY_COLUMN], how="inner", on="stop_id"), + geometry=GEOMETRY_COLUMN, crs=WGS84_CRS, ) diff --git a/srai/loaders/osm_loaders/osm_online_loader.py b/srai/loaders/osm_loaders/osm_online_loader.py index c961ac77..24460fde 100644 --- a/srai/loaders/osm_loaders/osm_online_loader.py +++ b/srai/loaders/osm_loaders/osm_online_loader.py @@ -11,7 +11,7 @@ from functional import seq from tqdm import tqdm -from srai.constants import FEATURES_INDEX, WGS84_CRS +from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS from srai.loaders.osm_loaders.filters.osm_tags_type import osm_tags_type from srai.utils._optional import import_optional_dependencies @@ -84,12 +84,12 @@ def load( results = [] - pbar = tqdm(product(area_wgs84["geometry"], _tags), total=total_queries) + pbar = tqdm(product(area_wgs84[GEOMETRY_COLUMN], _tags), total=total_queries) for polygon, (key, value) in pbar: pbar.set_description(self._get_pbar_desc(key, value, desc_max_len)) geometries = ox.geometries_from_polygon(polygon, {key: value}) if not geometries.empty: - results.append(geometries[["geometry", key]]) + results.append(geometries[[GEOMETRY_COLUMN, key]]) result_gdf = self._group_gdfs(results).set_crs(WGS84_CRS) diff --git a/srai/neighbourhoods/_base.py b/srai/neighbourhoods/_base.py index 66b4011a..b7cfa6ae 100644 --- a/srai/neighbourhoods/_base.py +++ b/srai/neighbourhoods/_base.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from queue import Queue -from typing import Generic, Set, Tuple, TypeVar +from typing import Dict, Generic, Set, Tuple, TypeVar from functional import seq @@ -80,15 +80,16 @@ def get_neighbours_at_distance(self, index: IndexType, distance: int) -> Set[Ind def _get_neighbours_with_distances( self, index: IndexType, distance: int ) -> Set[Tuple[IndexType, int]]: - visited_indexes = {} + visited_indexes: Dict[IndexType, int] = {} to_visit: Queue[Tuple[IndexType, int]] = Queue() to_visit.put((index, 0)) - while not to_visit.empty(): current_index, current_distance = to_visit.get() - visited_indexes[current_index] = current_distance + visited_indexes[current_index] = min( + current_distance, visited_indexes.get(current_index, distance) + ) if current_distance < distance: current_neighbours = self.get_neighbours(current_index) for neighbour in current_neighbours: diff --git a/srai/neighbourhoods/adjacency_neighbourhood.py b/srai/neighbourhoods/adjacency_neighbourhood.py index 778e7a37..529c67e8 100644 --- a/srai/neighbourhoods/adjacency_neighbourhood.py +++ b/srai/neighbourhoods/adjacency_neighbourhood.py @@ -8,6 +8,7 @@ import geopandas as gpd +from srai.constants import GEOMETRY_COLUMN from srai.neighbourhoods import Neighbourhood @@ -32,7 +33,7 @@ def __init__(self, regions_gdf: gpd.GeoDataFrame) -> None: Raises: AttributeError: If regions_gdf doesn't have geometry column. """ - if "geometry" not in regions_gdf.columns: + if GEOMETRY_COLUMN not in regions_gdf.columns: raise AttributeError("Regions must have a geometry column.") self.regions_gdf = regions_gdf self.lookup: Dict[Hashable, Set[Hashable]] = {} @@ -75,7 +76,9 @@ def _get_adjacent_neighbours(self, index: Hashable) -> Set[Hashable]: 1. https://shapely.readthedocs.io/en/stable/reference/shapely.touches.html """ current_region = self.regions_gdf.loc[index] - neighbours = self.regions_gdf[self.regions_gdf.geometry.touches(current_region["geometry"])] + neighbours = self.regions_gdf[ + self.regions_gdf.geometry.touches(current_region[GEOMETRY_COLUMN]) + ] return set(neighbours.index) def _index_incorrect(self, index: Hashable) -> bool: diff --git a/srai/plotting/__init__.py b/srai/plotting/__init__.py new file mode 100644 index 00000000..0917005f --- /dev/null +++ b/srai/plotting/__init__.py @@ -0,0 +1 @@ +"""Plotting module.""" diff --git a/srai/plotting/folium_wrapper.py b/srai/plotting/folium_wrapper.py new file mode 100644 index 00000000..717549e3 --- /dev/null +++ b/srai/plotting/folium_wrapper.py @@ -0,0 +1,268 @@ +""" +Folium wrapper. + +This module contains functions for quick plotting of analysed gdfs using Geopandas `explore()` +function. +""" +from itertools import cycle, islice +from typing import List, Optional, Set, Union + +from srai.utils._optional import import_optional_dependencies + +import_optional_dependencies(dependency_group="plotting", modules=["folium", "plotly"]) + +# flake8: noqa E402 + +import branca.colormap as cm +import folium +import geopandas as gpd +import numpy as np +import pandas as pd +import plotly.express as px + +from srai.constants import REGIONS_INDEX +from srai.neighbourhoods import Neighbourhood +from srai.neighbourhoods._base import IndexType + + +def plot_regions( + regions_gdf: gpd.GeoDataFrame, + tiles_style: str = "OpenStreetMap", + height: Union[str, float] = "100%", + width: Union[str, float] = "100%", + colormap: Union[str, List[str]] = px.colors.qualitative.Bold, + map: Optional[folium.Map] = None, +) -> folium.Map: + """ + Plot regions shapes using Folium library. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries to plot. + tiles_style (str, optional): Map style background. For more styles, look at tiles param at + https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html. + Defaults to "OpenStreetMap". + height (Union[str, float], optional): Height of the plot. Defaults to "100%". + width (Union[str, float], optional): Width of the plot. Defaults to "100%". + colormap (Union[str, List[str]], optional): Colormap to apply to the regions. + Defaults to `px.colors.qualitative.Bold` from plotly library. + map (folium.Map, optional): Existing map instance on which to draw the plot. + Defaults to None. + + Returns: + folium.Map: Generated map. + """ + return regions_gdf.reset_index().explore( + column=REGIONS_INDEX, + tooltip=REGIONS_INDEX, + tiles=tiles_style, + height=height, + width=width, + legend=False, + cmap=colormap, + categorical=True, + style_kwds=dict(color="#444", opacity=0.5, fillOpacity=0.5), + m=map, + ) + + +def plot_numeric_data( + regions_gdf: gpd.GeoDataFrame, + embedding_df: Union[pd.DataFrame, gpd.GeoDataFrame], + data_column: str, + tiles_style: str = "OpenStreetMap", + height: Union[str, float] = "100%", + width: Union[str, float] = "100%", + colormap: Union[str, List[str]] = px.colors.sequential.Sunsetdark, + map: Optional[folium.Map] = None, +) -> folium.Map: + """ + Plot numerical data within regions shapes using Folium library. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries to plot. + embedding_df (Union[pd.DataFrame, gpd.GeoDataFrame]): Region indexes and numerical data + to plot. + data_column (str): Name of the column used to colour the regions. + tiles_style (str, optional): Map style background. For more styles, look at tiles param at + https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html. + Defaults to "OpenStreetMap". + height (Union[str, float], optional): Height of the plot. Defaults to "100%". + width (Union[str, float], optional): Width of the plot. Defaults to "100%". + colormap (Union[str, List[str]], optional): Colormap to apply to the regions. + Defaults to px.colors.sequential.Sunsetdark. + map (folium.Map, optional): Existing map instance on which to draw the plot. + Defaults to None. + + Returns: + folium.Map: Generated map. + """ + regions_gdf_copy = regions_gdf.copy() + regions_gdf_copy = regions_gdf_copy.merge(embedding_df, on=REGIONS_INDEX) + + if not isinstance(colormap, str): + colormap = _generate_linear_colormap( + colormap, + min_value=regions_gdf_copy[data_column].min(), + max_value=regions_gdf_copy[data_column].max(), + ) + + return regions_gdf_copy.reset_index().explore( + column=data_column, + tooltip=[REGIONS_INDEX, data_column], + tiles=tiles_style, + height=height, + width=width, + legend=True, + cmap=colormap, + categorical=False, + style_kwds=dict(color="#444", opacity=0.5, fillOpacity=0.8), + m=map, + ) + + +def plot_neighbours( + regions_gdf: gpd.GeoDataFrame, + region_id: IndexType, + neighbours_ids: Set[IndexType], + tiles_style: str = "OpenStreetMap", + height: Union[str, float] = "100%", + width: Union[str, float] = "100%", + map: Optional[folium.Map] = None, +) -> folium.Map: + """ + Plot neighbours on a map using Folium library. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries to plot. + region_id (IndexType): Center `region_id` around which the neighbourhood should be plotted. + neighbours_ids (Set[IndexType]): List of neighbours to highlight. + tiles_style (str, optional): Map style background. For more styles, look at tiles param at + https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html. + Defaults to "OpenStreetMap". + height (Union[str, float], optional): Height of the plot. Defaults to "100%". + width (Union[str, float], optional): Width of the plot. Defaults to "100%". + map (folium.Map, optional): Existing map instance on which to draw the plot. + Defaults to None. + + Returns: + folium.Map: Generated map. + """ + if region_id not in regions_gdf.index: + raise AttributeError(f"{region_id!r} doesn't exist in provided regions_gdf.") + + regions_gdf_copy = regions_gdf.copy() + regions_gdf_copy["region"] = "other" + regions_gdf_copy.loc[region_id, "region"] = "selected" + regions_gdf_copy.loc[neighbours_ids, "region"] = "neighbour" + return regions_gdf_copy.reset_index().explore( + column="region", + tooltip=REGIONS_INDEX, + tiles=tiles_style, + height=height, + width=width, + cmap=[ + "rgb(242, 242, 242)", + px.colors.sequential.Sunsetdark[-1], + px.colors.sequential.Sunsetdark[2], + ], + categorical=True, + categories=["selected", "neighbour", "other"], + style_kwds=dict(color="#444", opacity=0.5, fillOpacity=0.8), + m=map, + ) + + +def plot_all_neighbourhood( + regions_gdf: gpd.GeoDataFrame, + region_id: IndexType, + neighbourhood: Neighbourhood[IndexType], + neighbourhood_max_distance: int = 100, + tiles_style: str = "OpenStreetMap", + height: Union[str, float] = "100%", + width: Union[str, float] = "100%", + colormap: Union[str, List[str]] = px.colors.sequential.Agsunset_r, + map: Optional[folium.Map] = None, +) -> folium.Map: + """ + Plot full neighbourhood on a map using Folium library. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries to plot. + region_id (IndexType): Center `region_id` around which the neighbourhood should be plotted. + neighbourhood (Neighbourhood[IndexType]): `Neighbourhood` class required for finding + neighbours. + neighbourhood_max_distance (int, optional): Max distance for rendering neighbourhoods. + Neighbours farther away won't be coloured, and will be left as "other" regions. + Defaults to 100. + tiles_style (str, optional): Map style background. For more styles, look at tiles param at + https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html. + Defaults to "OpenStreetMap". + height (Union[str, float], optional): Height of the plot. Defaults to "100%". + width (Union[str, float], optional): Width of the plot. Defaults to "100%". + colormap (Union[str, List[str]], optional): Colormap to apply to the neighbourhoods. + Defaults to `px.colors.sequential.Agsunset_r` from plotly library. + map (folium.Map, optional): Existing map instance on which to draw the plot. + Defaults to None. + + Returns: + folium.Map: Generated map. + """ + if region_id not in regions_gdf.index: + raise AttributeError(f"{region_id!r} doesn't exist in provided regions_gdf.") + + regions_gdf_copy = regions_gdf.copy() + regions_gdf_copy["region"] = "other" + regions_gdf_copy.loc[region_id, "region"] = "selected" + + distance = 1 + neighbours_ids = neighbourhood.get_neighbours_at_distance(region_id, distance).intersection( + regions_gdf.index + ) + while neighbours_ids and distance <= neighbourhood_max_distance: + regions_gdf_copy.loc[list(neighbours_ids), "region"] = distance + distance += 1 + neighbours_ids = neighbourhood.get_neighbours_at_distance(region_id, distance).intersection( + regions_gdf.index + ) + + if not isinstance(colormap, str): + colormap = _generate_colormap( + distance, colormap=_resample_plotly_colormap(colormap, min(distance, 10)) + ) + + return regions_gdf_copy.reset_index().explore( + column="region", + tooltip=[REGIONS_INDEX, "region"], + tiles=tiles_style, + height=height, + width=width, + cmap=colormap, + categorical=True, + categories=["selected", *list(range(distance))[1:], "other"], + style_kwds=dict(color="#444", opacity=0.5, fillOpacity=0.8), + legend=distance <= 11, + m=map, + ) + + +def _resample_plotly_colormap(colormap: List[str], steps: int) -> List[str]: + resampled_colormap: List[str] = px.colors.sample_colorscale( + colormap, np.linspace(0, 1, num=steps) + ) + return resampled_colormap + + +def _generate_colormap( + distance: int, + colormap: List[str], + selected_color: str = "rgb(242, 242, 242)", + other_color: str = "rgb(153, 153, 153)", +) -> List[str]: + return [selected_color, *islice(cycle(colormap), None, distance - 1), other_color] + + +def _generate_linear_colormap( + colormap: List[str], min_value: float, max_value: float +) -> cm.LinearColormap: + values, _ = px.colors.convert_colors_to_same_type(colormap, colortype="tuple") + return cm.LinearColormap(values, vmin=min_value, vmax=max_value) diff --git a/srai/plotting/plotly_wrapper.py b/srai/plotting/plotly_wrapper.py new file mode 100644 index 00000000..7cc73c79 --- /dev/null +++ b/srai/plotting/plotly_wrapper.py @@ -0,0 +1,357 @@ +""" +Plotly wrapper. + +This module contains functions for quick plotting of analysed gdfs using Plotly library. +""" +from typing import Any, Dict, List, Optional, Set + +from srai.utils._optional import import_optional_dependencies + +import_optional_dependencies(dependency_group="plotting", modules=["plotly"]) + +# flake8: noqa E402 + +import geopandas as gpd +import numpy as np +import plotly.express as px +import plotly.graph_objs as go +from shapely.geometry import Point + +from srai.constants import REGIONS_INDEX, WGS84_CRS +from srai.neighbourhoods import Neighbourhood +from srai.neighbourhoods._base import IndexType + + +def plot_regions( + regions_gdf: gpd.GeoDataFrame, + return_plot: bool = False, + mapbox_style: str = "open-street-map", + mapbox_accesstoken: Optional[str] = None, + renderer: Optional[str] = None, + zoom: Optional[float] = None, + height: Optional[float] = None, + width: Optional[float] = None, +) -> Optional[go.Figure]: + """ + Plot regions shapes using Plotly library. + + For more info about parameters, check https://plotly.com/python/mapbox-layers/. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries to plot. + return_plot (bool, optional): Flag whether to return the Figure object or not. + If `True`, the plot won't be displayed automatically. Defaults to False. + mapbox_style (str, optional): Map style background. Defaults to "open-street-map". + mapbox_accesstoken (str, optional): Access token required for mapbox based map backgrounds. + Defaults to None. + renderer (str, optional): Name of renderer used for displaying the figure. + For all descriptions, look here: https://plotly.com/python/renderers/. + Defaults to None. + zoom (float, optional): Map zoom. If not filled, will be approximated based on + the bounding box of regions. Defaults to None. + height (float, optional): Height of the plot. Defaults to None. + width (float, optional): Width of the plot. Defaults to None. + + Returns: + Optional[go.Figure]: Figure of the plot. Will be returned if `return_plot` is set to `True`. + """ + regions_gdf_copy = regions_gdf.copy() + regions_gdf_copy[REGIONS_INDEX] = regions_gdf_copy.index + return _plot_regions( + regions_gdf=regions_gdf_copy, + hover_column_name=REGIONS_INDEX, + color_feature_column=None, + hover_data=[], + show_legend=False, + return_plot=return_plot, + mapbox_style=mapbox_style, + mapbox_accesstoken=mapbox_accesstoken, + renderer=renderer, + zoom=zoom, + height=height, + width=width, + color_discrete_sequence=px.colors.qualitative.Safe, + opacity=0.4, + traces_kwargs=dict(marker_line_width=2), + ) + + +def plot_neighbours( + regions_gdf: gpd.GeoDataFrame, + region_id: IndexType, + neighbours_ids: Set[IndexType], + return_plot: bool = False, + mapbox_style: str = "open-street-map", + mapbox_accesstoken: Optional[str] = None, + renderer: Optional[str] = None, + zoom: Optional[float] = None, + height: Optional[float] = None, + width: Optional[float] = None, +) -> Optional[go.Figure]: + """ + Plot neighbours on a map using Plotly library. + + For more info about parameters, check https://plotly.com/python/mapbox-layers/. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries to plot. + region_id (IndexType): Center `region_id` around which the neighbourhood should be plotted. + neighbours_ids (Set[IndexType]): List of neighbours to highlight. + return_plot (bool, optional): Flag whether to return the Figure object or not. + If `True`, the plot won't be displayed automatically. Defaults to False. + mapbox_style (str, optional): Map style background. Defaults to "open-street-map". + mapbox_accesstoken (str, optional): Access token required for mapbox based map backgrounds. + Defaults to None. + renderer (str, optional): Name of renderer used for displaying the figure. + For all descriptions, look here: https://plotly.com/python/renderers/. + Defaults to None. + zoom (float, optional): Map zoom. If not filled, will be approximated based on + the bounding box of regions. Defaults to None. + height (float, optional): Height of the plot. Defaults to None. + width (float, optional): Width of the plot. Defaults to None. + + Returns: + Optional[go.Figure]: Figure of the plot. Will be returned if `return_plot` is set to `True`. + """ + regions_gdf_copy = regions_gdf.copy() + regions_gdf_copy[REGIONS_INDEX] = regions_gdf_copy.index + regions_gdf_copy["region"] = "other" + regions_gdf_copy.loc[region_id, "region"] = "selected" + regions_gdf_copy.loc[neighbours_ids, "region"] = "neighbour" + return _plot_regions( + regions_gdf=regions_gdf_copy, + hover_column_name=REGIONS_INDEX, + color_feature_column="region", + hover_data=[], + show_legend=True, + return_plot=return_plot, + mapbox_style=mapbox_style, + mapbox_accesstoken=mapbox_accesstoken, + renderer=renderer, + zoom=zoom, + height=height, + width=width, + layout_kwargs=dict( + legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01, traceorder="normal"), + ), + category_orders={"region": ["selected", "neighbour", "other"]}, + color_discrete_sequence=[ + px.colors.qualitative.Plotly[1], + px.colors.qualitative.Plotly[2], + px.colors.qualitative.Plotly[0], + ], + ) + + +def plot_all_neighbourhood( + regions_gdf: gpd.GeoDataFrame, + region_id: IndexType, + neighbourhood: Neighbourhood[IndexType], + neighbourhood_max_distance: int = 100, + return_plot: bool = False, + mapbox_style: str = "open-street-map", + mapbox_accesstoken: Optional[str] = None, + renderer: Optional[str] = None, + zoom: Optional[float] = None, + height: Optional[float] = None, + width: Optional[float] = None, +) -> Optional[go.Figure]: + """ + Plot full neighbourhood on a map using Plotly library. + + For more info about parameters, check https://plotly.com/python/mapbox-layers/. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries to plot. + region_id (IndexType): Center `region_id` around which the neighbourhood should be plotted. + neighbourhood (Neighbourhood[IndexType]): `Neighbourhood` class required for finding + neighbours. + neighbourhood_max_distance (int, optional): Max distance for rendering neighbourhoods. + Neighbours farther away won't be coloured, and will be left as "other" regions. + Defaults to 100. + return_plot (bool, optional): Flag whether to return the Figure object or not. + If `True`, the plot won't be displayed automatically. Defaults to False. + mapbox_style (str, optional): Map style background. Defaults to "open-street-map". + mapbox_accesstoken (str, optional): Access token required for mapbox based map backgrounds. + Defaults to None. + renderer (str, optional): Name of renderer used for displaying the figure. + For all descriptions, look here: https://plotly.com/python/renderers/. + Defaults to None. + zoom (float, optional): Map zoom. If not filled, will be approximated based on + the bounding box of regions. Defaults to None. + height (float, optional): Height of the plot. Defaults to None. + width (float, optional): Width of the plot. Defaults to None. + + Returns: + Optional[go.Figure]: Figure of the plot. Will be returned if `return_plot` is set to `True`. + """ + regions_gdf_copy = regions_gdf.copy() + regions_gdf_copy[REGIONS_INDEX] = regions_gdf_copy.index + regions_gdf_copy["region"] = "other" + regions_gdf_copy.loc[region_id, "region"] = "selected" + + distance = 1 + neighbours_ids = neighbourhood.get_neighbours_at_distance(region_id, distance).intersection( + regions_gdf.index + ) + while neighbours_ids and distance <= neighbourhood_max_distance: + regions_gdf_copy.loc[list(neighbours_ids), "region"] = distance + distance += 1 + neighbours_ids = neighbourhood.get_neighbours_at_distance(region_id, distance).intersection( + regions_gdf.index + ) + + return _plot_regions( + regions_gdf=regions_gdf_copy, + hover_column_name=REGIONS_INDEX, + color_feature_column="region", + hover_data=[], + show_legend=True, + return_plot=return_plot, + mapbox_style=mapbox_style, + mapbox_accesstoken=mapbox_accesstoken, + renderer=renderer, + zoom=zoom, + height=height, + width=width, + layout_kwargs=dict( + legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01, traceorder="normal"), + ), + category_orders={"region": ["selected", *range(distance), "other"]}, + color_discrete_sequence=px.colors.cyclical.Edge, + ) + + +def _plot_regions( + regions_gdf: gpd.GeoDataFrame, + hover_column_name: str, + hover_data: List[str], + color_feature_column: Optional[str] = None, + show_legend: bool = False, + return_plot: bool = False, + mapbox_style: str = "open-street-map", + mapbox_accesstoken: Optional[str] = None, + opacity: float = 0.6, + renderer: Optional[str] = None, + zoom: Optional[float] = None, + height: Optional[float] = None, + width: Optional[float] = None, + layout_kwargs: Optional[Dict[str, Any]] = None, + traces_kwargs: Optional[Dict[str, Any]] = None, + **choropleth_mapbox_kwargs: Any, +) -> Optional[go.Figure]: + """ + Plot regions shapes using Plotly library. + + Uses `choroplethmapbox` function. + For more info about parameters, check https://plotly.com/python/mapbox-layers/. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries to plot. + hover_column_name (str): Column name used for hover popup title. + hover_data (List[str]): List of column names displayed additionally on hover. + color_feature_column (str, optional): Column name used for colouring the plot. + Can be `None` to disable colouring. + show_legend (bool, optional): Flag whether to show the legend or not. Defaults to False. + return_plot (bool, optional): Flag whether to return the Figure object or not. + If `True`, the plot won't be displayed automatically. Defaults to False. + mapbox_style (str, optional): Map style background. Defaults to "open-street-map". + mapbox_accesstoken (str, optional): Access token required for mapbox based map backgrounds. + Defaults to None. + opacity (float, optional): Markers opacity. Defaults to 0.6. + renderer (str, optional): Name of renderer used for displaying the figure. + For all descriptions, look here: https://plotly.com/python/renderers/. + Defaults to None. + zoom (float, optional): Map zoom. If not filled, will be approximated based on + the bounding box of regions. Defaults to None. + height (float, optional): Height of the plot. Defaults to None. + width (float, optional): Width of the plot. Defaults to None. + layout_kwargs (Dict[str, Any], optional): Additional parameters passed to + the `update_layout` function. Defaults to None. + traces_kwargs (Dict[str, Any], optional): Additional parameters passed to + the `update_traces` function. Defaults to None. + **choropleth_mapbox_kwargs: Additional parameters that can be passed to + the `choropleth_mapbox` constructor. + + Returns: + Optional[go.Figure]: Figure of the plot. Will be returned if `return_plot` is set to `True`. + """ + center_point = _calculate_map_centroid(regions_gdf) + if not zoom: + zoom = _calculate_mapbox_zoom(regions_gdf) + + fig = px.choropleth_mapbox( + regions_gdf, + geojson=regions_gdf, + color=color_feature_column, + hover_name=hover_column_name, + hover_data=hover_data, + locations=REGIONS_INDEX, + center={"lon": center_point.x, "lat": center_point.y}, + zoom=zoom, + **choropleth_mapbox_kwargs, + ) + + update_layout_dict = dict( + height=height, + width=width, + margin=dict(r=0, t=0, l=0, b=0), + mapbox_style=mapbox_style, + mapbox_accesstoken=mapbox_accesstoken, + ) + fig.update_layout(**update_layout_dict) + if layout_kwargs: + fig.update_layout(**layout_kwargs) + + update_traces_dict = dict(marker_opacity=opacity, showlegend=show_legend) + fig.update_traces(**update_traces_dict) + if traces_kwargs: + fig.update_traces(**traces_kwargs) + + fig.update_coloraxes(showscale=show_legend) + + if return_plot: + return fig + else: + fig.show(renderer=renderer) + return None + + +def _calculate_map_centroid(regions_gdf: gpd.GeoDataFrame) -> Point: + """ + Calculate regions centroid using Equal Area Cylindrical projection [1]. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries. + + Returns: + Point: Center point in WGS84 units. + + References: + 1. https://proj.org/operations/projections/cea.html + """ + center_point = regions_gdf.to_crs("+proj=cea").dissolve().centroid.to_crs(WGS84_CRS)[0] + return center_point + + +# Inspired by: +# https://stackoverflow.com/a/65043576/7766101 +def _calculate_mapbox_zoom( + regions_gdf: gpd.GeoDataFrame, +) -> float: + """ + Calculate approximate zoom for a plotly figure. + + Currently Plotly doesn't implement auto-fit feature for mapbox plots. + + Args: + regions_gdf (gpd.GeoDataFrame): Region indexes and geometries. + + Returns: + float: zoom level for a mapbox plot. + """ + + minx, miny, maxx, maxy = regions_gdf.geometry.total_bounds + max_bound = max(abs(maxx - minx), abs(maxy - miny)) * 111 + zoom = float(12.5 - np.log(max_bound)) + return zoom diff --git a/srai/regionizers/administrative_boundary_regionizer.py b/srai/regionizers/administrative_boundary_regionizer.py index f5685f6b..d1bd667c 100644 --- a/srai/regionizers/administrative_boundary_regionizer.py +++ b/srai/regionizers/administrative_boundary_regionizer.py @@ -4,7 +4,7 @@ This module contains administrative boundary regionizer implementation. """ import time -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import geopandas as gpd import topojson as tp @@ -14,7 +14,7 @@ from shapely.validation import make_valid from tqdm import tqdm -from srai.constants import REGIONS_INDEX, WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, REGIONS_INDEX, WGS84_CRS from srai.regionizers import Regionizer from srai.utils import flatten_geometry_series from srai.utils._optional import import_optional_dependencies @@ -36,6 +36,8 @@ class AdministrativeBoundaryRegionizer(Regionizer): 1. https://wiki.openstreetmap.org/wiki/Key:admin_level """ + EMPTY_REGION_NAME = "EMPTY" + def __init__( self, admin_level: int, @@ -43,6 +45,7 @@ def __init__( return_empty_region: bool = False, prioritize_english_name: bool = True, toposimplify: Union[bool, float] = True, + remove_artefact_regions: bool = True, ) -> None: """ Init AdministrativeBoundaryRegionizer. @@ -50,7 +53,7 @@ def __init__( Args: admin_level (int): OpenStreetMap admin_level value. See [1] for detailed description of available values. - clip_regions (bool, optional): Whether to to clip regions using a provided mask. + clip_regions (bool, optional): Whether to clip regions using a provided mask. Turning it off can an be useful when trying to load regions using list a of points. Defaults to True. return_empty_region (bool, optional): Whether to return an empty region to fill @@ -61,6 +64,12 @@ def __init__( geometries size or not. Value is passed to `topojson` library for topology-aware simplification. Since provided values are treated like degrees, values between 1e-4 and 1.0 are recommended. Defaults to True (which results in value equal 1e-4). + remove_artefact_regions (bool, optional): Whether to remove small regions barely + intersecting queried area. Turning it off can sometimes load unnecessary boundaries + that touch on the edge. It removes regions that intersect with an area smaller + than 1% of total self. Takes into consideration if provided query GeoDataFrame + contains points and then skips calculating area when intersects any point. + Defaults to True. Raises: ValueError: If admin_level is outside available range (1-11). See [2] for list of @@ -83,6 +92,7 @@ def __init__( self.prioritize_english_name = prioritize_english_name self.clip_regions = clip_regions self.return_empty_region = return_empty_region + self.remove_artefact_regions = remove_artefact_regions if isinstance(toposimplify, (int, float)) and toposimplify > 0: self.toposimplify = toposimplify @@ -133,16 +143,52 @@ def transform(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: regions_dicts = self._generate_regions_from_all_geometries(gdf_wgs84) + if not regions_dicts: + import warnings + + warnings.warn( + ( + "Couldn't find any administrative boundaries with" + f" `admin_level`={self.admin_level}." + ), + RuntimeWarning, + stacklevel=2, + ) + + return self._get_empty_geodataframe(gdf_wgs84) + regions_gdf = gpd.GeoDataFrame(data=regions_dicts, crs=WGS84_CRS).set_index(REGIONS_INDEX) regions_gdf = self._toposimplify_gdf(regions_gdf) + if self.remove_artefact_regions: + points_collection: Optional[BaseGeometry] = gdf_wgs84[ + gdf_wgs84.geom_type == "Point" + ].geometry.unary_union + clipping_polygon_area: Optional[BaseGeometry] = gdf_wgs84[ + gdf_wgs84.geom_type != "Point" + ].geometry.unary_union + + regions_to_keep = [ + region_id + for region_id, row in regions_gdf.iterrows() + if self._check_intersects_with_points(row["geometry"], points_collection) + or self._calculate_intersection_area_fraction( + row["geometry"], clipping_polygon_area + ) + > 0.01 + ] + regions_gdf = regions_gdf.loc[regions_to_keep] + if self.clip_regions: regions_gdf = regions_gdf.clip(mask=gdf_wgs84, keep_geom_type=False) if self.return_empty_region: empty_region = self._generate_empty_region(mask=gdf_wgs84, regions_gdf=regions_gdf) if not empty_region.is_empty: - regions_gdf.loc["EMPTY", "geometry"] = empty_region + regions_gdf.loc[ + AdministrativeBoundaryRegionizer.EMPTY_REGION_NAME, GEOMETRY_COLUMN + ] = empty_region + return regions_gdf def _generate_regions_from_all_geometries( @@ -156,8 +202,8 @@ def _generate_regions_from_all_geometries( with tqdm(desc="Loading boundaries") as pbar: for geometry in all_geometries: - unary_geometry = unary_union([r["geometry"] for r in generated_regions]) - if not geometry.within(unary_geometry): + unary_geometry = unary_union([r[GEOMETRY_COLUMN] for r in generated_regions]) + if not geometry.covered_by(unary_geometry): query = self._generate_query_for_single_geometry(geometry) boundaries_list = self._query_overpass(query) for boundary in boundaries_list: @@ -243,8 +289,8 @@ def _parse_overpass_element(self, element: Dict[str, Any]) -> Dict[str, Any]: region_id = str(element["id"]) return { - "geometry": self._get_boundary_geometry(element["id"]), - "region_id": region_id, + GEOMETRY_COLUMN: self._get_boundary_geometry(element["id"]), + REGIONS_INDEX: region_id, } def _get_boundary_geometry(self, relation_id: str) -> BaseGeometry: @@ -278,6 +324,46 @@ def _generate_empty_region( self, mask: gpd.GeoDataFrame, regions_gdf: gpd.GeoDataFrame ) -> BaseGeometry: """Generate a region filling the space between regions and full clipping mask.""" - joined_mask = unary_union(mask.geometry) - joined_geometry = unary_union(regions_gdf.geometry) + joined_mask = mask.geometry.unary_union + joined_geometry = regions_gdf.geometry.unary_union return joined_mask.difference(joined_geometry) + + def _get_empty_geodataframe(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + """Get empty GeoDataFrame when zero boundaries have been found.""" + if self.return_empty_region: + regions_gdf = gpd.GeoDataFrame( + data={ + GEOMETRY_COLUMN: [gdf.geometry.unary_union], + REGIONS_INDEX: [AdministrativeBoundaryRegionizer.EMPTY_REGION_NAME], + }, + crs=WGS84_CRS, + ).set_index(REGIONS_INDEX) + else: + regions_gdf = gpd.GeoDataFrame( + data={ + GEOMETRY_COLUMN: [], + REGIONS_INDEX: [], + }, + crs=WGS84_CRS, + ) + return regions_gdf + + def _check_intersects_with_points( + self, region_geometry: BaseGeometry, points_collection: Optional[BaseGeometry] + ) -> bool: + """Check if region intersects with any point in query regions.""" + return ( + points_collection is not None + and not points_collection.is_empty + and region_geometry.intersects(points_collection) + ) + + def _calculate_intersection_area_fraction( + self, region_geometry: BaseGeometry, clipping_polygon_area: Optional[BaseGeometry] + ) -> float: + """Calculate intersection area fraction to check if it's big enough.""" + if clipping_polygon_area is None or clipping_polygon_area.is_empty: + return 0 + full_area = float(region_geometry.area) + clip_area = float(region_geometry.intersection(clipping_polygon_area).area) + return clip_area / full_area diff --git a/srai/regionizers/h3_regionizer.py b/srai/regionizers/h3_regionizer.py index ca6f3f5f..2d6de954 100644 --- a/srai/regionizers/h3_regionizer.py +++ b/srai/regionizers/h3_regionizer.py @@ -19,7 +19,7 @@ from functional import seq from shapely import geometry -from srai.constants import REGIONS_INDEX, WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, REGIONS_INDEX, WGS84_CRS from srai.regionizers import Regionizer from srai.utils import buffer_geometry @@ -75,7 +75,7 @@ def transform(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: gdf_buffered = self._buffer(gdf_exploded) if self.buffer else gdf_exploded h3_indexes = ( - seq(gdf_buffered["geometry"]) + seq(gdf_buffered[GEOMETRY_COLUMN]) .map(self._polygon_shapely_to_h3) .flat_map(lambda polygon: h3.polygon_to_cells(polygon, self.resolution)) .distinct() @@ -86,7 +86,9 @@ def transform(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: # there may be too many cells because of too big buffer if self.buffer: - gdf_h3_clipped = gdf_h3.sjoin(gdf_exploded[["geometry"]]).drop(columns="index_right") + gdf_h3_clipped = gdf_h3.sjoin(gdf_exploded[[GEOMETRY_COLUMN]]).drop( + columns="index_right" + ) gdf_h3_clipped = gdf_h3_clipped[~gdf_h3_clipped.index.duplicated(keep="first")] else: gdf_h3_clipped = gdf_h3 diff --git a/srai/regionizers/s2_regionizer.py b/srai/regionizers/s2_regionizer.py index 8ce55458..67158601 100644 --- a/srai/regionizers/s2_regionizer.py +++ b/srai/regionizers/s2_regionizer.py @@ -18,7 +18,7 @@ from s2 import s2 from shapely.geometry import Polygon -from srai.constants import REGIONS_INDEX, WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, REGIONS_INDEX, WGS84_CRS from srai.regionizers import Regionizer @@ -83,7 +83,7 @@ def _fill_with_s2_cells(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: geo_json = json.loads(gdf.to_json()) cells = ( seq(geo_json["features"]) - .flat_map(lambda f: self._geojson_to_cells(f["geometry"], self.resolution)) + .flat_map(lambda f: self._geojson_to_cells(f[GEOMETRY_COLUMN], self.resolution)) .to_dict() ) cells_gdf = gpd.GeoDataFrame( @@ -97,6 +97,6 @@ def _fill_with_s2_cells(self, gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame: def _geojson_to_cells(self, geo_json: Dict[str, Any], res: int) -> Sequence: raw_cells = s2.polyfill(geo_json, res, with_id=True, geo_json_conformant=True) - cells: Sequence = seq(raw_cells).map(lambda c: (c["id"], Polygon(c["geometry"]))) + cells: Sequence = seq(raw_cells).map(lambda c: (c["id"], Polygon(c[GEOMETRY_COLUMN]))) return cells diff --git a/srai/regionizers/voronoi_regionizer.py b/srai/regionizers/voronoi_regionizer.py index c54917ef..f2d6248a 100644 --- a/srai/regionizers/voronoi_regionizer.py +++ b/srai/regionizers/voronoi_regionizer.py @@ -9,7 +9,7 @@ import geopandas as gpd from shapely.geometry import Point, box -from srai.constants import REGIONS_INDEX, WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, REGIONS_INDEX, WGS84_CRS from srai.regionizers import Regionizer from srai.utils._optional import import_optional_dependencies @@ -107,7 +107,7 @@ def transform(self, gdf: Optional[gpd.GeoDataFrame] = None) -> gpd.GeoDataFrame: if gdf is None: gdf = gpd.GeoDataFrame( - {"geometry": [box(minx=-180, maxx=180, miny=-90, maxy=90)]}, crs=WGS84_CRS + {GEOMETRY_COLUMN: [box(minx=-180, maxx=180, miny=-90, maxy=90)]}, crs=WGS84_CRS ) gdf_wgs84 = gdf.to_crs(crs=WGS84_CRS) @@ -118,7 +118,7 @@ def transform(self, gdf: Optional[gpd.GeoDataFrame] = None) -> gpd.GeoDataFrame: multiprocessing_activation_threshold=self.multiprocessing_activation_threshold, ) regions_gdf = gpd.GeoDataFrame( - data={"geometry": generated_regions}, index=self.region_ids, crs=WGS84_CRS + data={GEOMETRY_COLUMN: generated_regions}, index=self.region_ids, crs=WGS84_CRS ) regions_gdf.index.rename(REGIONS_INDEX, inplace=True) clipped_regions_gdf = regions_gdf.clip(mask=gdf_wgs84, keep_geom_type=False) @@ -126,7 +126,9 @@ def transform(self, gdf: Optional[gpd.GeoDataFrame] = None) -> gpd.GeoDataFrame: def _get_duplicated_seeds_ids(self) -> List[Hashable]: """Return all seeds ids that overlap with another using quick sjoin operation.""" - gdf = gpd.GeoDataFrame(data={"geometry": self.seeds}, index=self.region_ids, crs=WGS84_CRS) + gdf = gpd.GeoDataFrame( + data={GEOMETRY_COLUMN: self.seeds}, index=self.region_ids, crs=WGS84_CRS + ) duplicated_seeds = gdf.sjoin(gdf).index.value_counts().loc[lambda x: x > 1] duplicated_seeds_ids: List[Hashable] = duplicated_seeds.index.to_list() return duplicated_seeds_ids diff --git a/tests/embedders/conftest.py b/tests/embedders/conftest.py index 850b5307..4945a1bf 100644 --- a/tests/embedders/conftest.py +++ b/tests/embedders/conftest.py @@ -5,7 +5,7 @@ import pytest from shapely import geometry -from srai.constants import WGS84_CRS +from srai.constants import FEATURES_INDEX, REGIONS_INDEX, WGS84_CRS @pytest.fixture # type: ignore @@ -31,20 +31,21 @@ def gdf_empty() -> gpd.GeoDataFrame: @pytest.fixture # type: ignore def gdf_regions_empty() -> gpd.GeoDataFrame: """Get empty GeoDataFrame with region_id as index name.""" - return gpd.GeoDataFrame(index=pd.Index([], name="region_id"), geometry=[]) + return gpd.GeoDataFrame(index=pd.Index([], name=REGIONS_INDEX), geometry=[]) @pytest.fixture # type: ignore def gdf_features_empty() -> gpd.GeoDataFrame: """Get empty GeoDataFrame with feature_id as index name.""" - return gpd.GeoDataFrame(index=pd.Index([], name="feature_id"), geometry=[]) + return gpd.GeoDataFrame(index=pd.Index([], name=FEATURES_INDEX), geometry=[]) @pytest.fixture # type: ignore def gdf_joint_empty() -> gpd.GeoDataFrame: """Get empty GeoDataFrame with MultiIndex and [region_id, feature_id] as index names.""" return gpd.GeoDataFrame( - index=pd.MultiIndex.from_arrays([[], []], names=["region_id", "feature_id"]), geometry=[] + index=pd.MultiIndex.from_arrays([[], []], names=[REGIONS_INDEX, FEATURES_INDEX]), + geometry=[], ) @@ -106,7 +107,7 @@ def gdf_regions() -> gpd.GeoDataFrame: ), ], index=pd.Index( - data=["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], name="region_id" + data=["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], name=REGIONS_INDEX ), crs=WGS84_CRS, ) @@ -159,7 +160,7 @@ def gdf_features() -> gpd.GeoDataFrame: ], index=pd.Index( data=["way/312457804", "way/1533817161", "way/312457812", "way/312457834"], - name="feature_id", + name=FEATURES_INDEX, ), crs=WGS84_CRS, ) @@ -215,7 +216,7 @@ def gdf_joint() -> gpd.GeoDataFrame: ["891e2040d4bffff", "891e2040897ffff", "891e2040897ffff", "891e2040d5bffff"], ["way/312457804", "way/1533817161", "way/312457834", "way/312457812"], ], - names=["region_id", "feature_id"], + names=[REGIONS_INDEX, FEATURES_INDEX], ), crs=WGS84_CRS, ) diff --git a/tests/embedders/test_count_embedder.py b/tests/embedders/test_count_embedder.py index 1d6e0c48..2ebd6ca7 100644 --- a/tests/embedders/test_count_embedder.py +++ b/tests/embedders/test_count_embedder.py @@ -7,6 +7,7 @@ import pytest from pandas.testing import assert_frame_equal, assert_index_equal +from srai.constants import REGIONS_INDEX from srai.embedders import CountEmbedder @@ -15,13 +16,13 @@ def expected_embedding_df() -> pd.DataFrame: """Get expected CountEmbedder output for the default case.""" expected_df = pd.DataFrame( { - "region_id": ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], "leisure_adult_gaming_centre": [0, 0, 1], "leisure_playground": [0, 1, 0], "amenity_pub": [1, 0, 1], }, ) - expected_df.set_index("region_id", inplace=True) + expected_df.set_index(REGIONS_INDEX, inplace=True) return expected_df @@ -38,13 +39,13 @@ def specified_features_expected_embedding_df() -> pd.DataFrame: """Get expected CountEmbedder output for the case with specified features.""" expected_df = pd.DataFrame( { - "region_id": ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], + REGIONS_INDEX: ["891e2040897ffff", "891e2040d4bffff", "891e2040d5bffff"], "amenity_parking": [0, 0, 0], "leisure_park": [0, 0, 0], "amenity_pub": [1, 0, 1], }, ) - expected_df.set_index("region_id", inplace=True) + expected_df.set_index(REGIONS_INDEX, inplace=True) return expected_df diff --git a/tests/embedders/test_gtfs2vec_embedder.py b/tests/embedders/test_gtfs2vec_embedder.py index 0fe9b275..c55b6901 100644 --- a/tests/embedders/test_gtfs2vec_embedder.py +++ b/tests/embedders/test_gtfs2vec_embedder.py @@ -10,6 +10,7 @@ from pytorch_lightning import seed_everything from shapely.geometry import Polygon +from srai.constants import REGIONS_INDEX from srai.embedders import GTFS2VecEmbedder from srai.exceptions import ModelNotFitException @@ -40,14 +41,14 @@ def gtfs2vec_regions() -> gpd.GeoDataFrame: """Get GTFS2Vec regions GeoDataFrame.""" regions_gdf = gpd.GeoDataFrame( { - "region_id": ["ff1", "ff2", "ff3"], + REGIONS_INDEX: ["ff1", "ff2", "ff3"], }, geometry=[ Polygon([(0, 0), (0, 3), (3, 3), (3, 0)]), Polygon([(4, 0), (4, 3), (7, 3), (7, 0)]), Polygon([(8, 0), (8, 3), (11, 3), (11, 0)]), ], - ).set_index("region_id") + ).set_index(REGIONS_INDEX) return regions_gdf @@ -57,7 +58,7 @@ def gtfs2vec_joint() -> gpd.GeoDataFrame: joint_gdf = gpd.GeoDataFrame() joint_gdf.index = pd.MultiIndex.from_tuples( [("ff1", 1), ("ff1", 2), ("ff2", 3)], - names=["region_id", "stop_id"], + names=[REGIONS_INDEX, "stop_id"], ) return joint_gdf @@ -71,9 +72,9 @@ def features_not_embedded() -> pd.DataFrame: "trips_at_7": [1.0, 0.0, 0.0], "trips_at_8": [0.0, 0.5, 0.0], "directions_at_6": [1.0, 0.25, 0.0], - "region_id": ["ff1", "ff2", "ff3"], + REGIONS_INDEX: ["ff1", "ff2", "ff3"], }, - ).set_index("region_id") + ).set_index(REGIONS_INDEX) @pytest.fixture # type: ignore @@ -89,7 +90,7 @@ def features_embedded() -> pd.DataFrame: dtype=np.float32, ) features = pd.DataFrame(embeddings.T) - features.index = pd.Index(["ff1", "ff2", "ff3"], name="region_id") + features.index = pd.Index(["ff1", "ff2", "ff3"], name=REGIONS_INDEX) features.columns = pd.RangeIndex(0, 4, 1) return features diff --git a/tests/joiners/conftest.py b/tests/joiners/conftest.py index f59d8cdd..845e6e1c 100644 --- a/tests/joiners/conftest.py +++ b/tests/joiners/conftest.py @@ -5,7 +5,7 @@ import pytest from shapely import geometry -from srai.constants import WGS84_CRS +from srai.constants import FEATURES_INDEX, REGIONS_INDEX, WGS84_CRS @pytest.fixture # type: ignore @@ -55,5 +55,5 @@ def features_gdf() -> gpd.GeoDataFrame: def joint_multiindex() -> pd.MultiIndex: """Get MultiIndex for joint GeoDataFrame.""" return pd.MultiIndex.from_tuples( - [(0, 2), (0, 3), (1, 2), (0, 0), (3, 0), (2, 1)], names=["region_id", "feature_id"] + [(0, 2), (0, 3), (1, 2), (0, 0), (3, 0), (2, 1)], names=[REGIONS_INDEX, FEATURES_INDEX] ) diff --git a/tests/joiners/test_intersection_joiner.py b/tests/joiners/test_intersection_joiner.py index ff490903..85aef481 100644 --- a/tests/joiners/test_intersection_joiner.py +++ b/tests/joiners/test_intersection_joiner.py @@ -5,6 +5,7 @@ import pandas as pd import pytest +from srai.constants import GEOMETRY_COLUMN from srai.joiners import IntersectionJoiner ut = TestCase() @@ -62,5 +63,5 @@ def test_correct_multiindex_intersection_joiner_without_geom( ut.assertEqual(joint.index.names, joint_multiindex.names) ut.assertCountEqual(joint.index, joint_multiindex) - ut.assertNotIn("geometry", joint.columns) + ut.assertNotIn(GEOMETRY_COLUMN, joint.columns) ut.assertIs(len(joint.columns), 0) diff --git a/tests/loaders/osm_loaders/conftest.py b/tests/loaders/osm_loaders/conftest.py index a82baa57..5c90b6e1 100644 --- a/tests/loaders/osm_loaders/conftest.py +++ b/tests/loaders/osm_loaders/conftest.py @@ -4,7 +4,7 @@ import pytest from shapely.geometry import Point, Polygon -from srai.constants import WGS84_CRS +from srai.constants import FEATURES_INDEX, GEOMETRY_COLUMN, WGS84_CRS @pytest.fixture # type: ignore @@ -22,7 +22,7 @@ def area_with_no_objects_gdf() -> gpd.GeoDataFrame: @pytest.fixture # type: ignore def empty_result_gdf() -> gpd.GeoDataFrame: """Get empty OSMOnlineLoader result gdf.""" - result_index = pd.Index(data=[], name="feature_id", dtype="object") + result_index = pd.Index(data=[], name=FEATURES_INDEX, dtype="object") return gpd.GeoDataFrame(index=result_index, crs=WGS84_CRS, geometry=[]) @@ -30,7 +30,7 @@ def empty_result_gdf() -> gpd.GeoDataFrame: def single_polygon_area_gdf() -> gpd.GeoDataFrame: """Get an example area gdf with with one polygon.""" polygon_1 = Polygon([(0, 0), (0, 1), (1, 1), (1, 0)]) - gdf = gpd.GeoDataFrame({"geometry": [polygon_1]}, crs=WGS84_CRS) + gdf = gpd.GeoDataFrame({GEOMETRY_COLUMN: [polygon_1]}, crs=WGS84_CRS) return gdf @@ -39,7 +39,7 @@ def two_polygons_area_gdf() -> gpd.GeoDataFrame: """Get an example area gdf with with two polygons.""" polygon_1 = Polygon([(0, 0), (0, 1), (1, 1), (1, 0)]) polygon_2 = Polygon([(1, 1), (2, 2), (2, 1), (1, 0)]) - gdf = gpd.GeoDataFrame({"geometry": [polygon_1, polygon_2]}, crs=WGS84_CRS) + gdf = gpd.GeoDataFrame({GEOMETRY_COLUMN: [polygon_1, polygon_2]}, crs=WGS84_CRS) return gdf @@ -89,7 +89,7 @@ def expected_result_single_polygon() -> gpd.GeoDataFrame: data=[ "node/1", ], - name="feature_id", + name=FEATURES_INDEX, dtype="object", ), crs=WGS84_CRS, @@ -109,7 +109,7 @@ def expected_result_gdf_simple() -> gpd.GeoDataFrame: "node/1", "node/2", ], - name="feature_id", + name=FEATURES_INDEX, dtype="object", ), crs=WGS84_CRS, @@ -131,7 +131,7 @@ def expected_result_gdf_complex() -> gpd.GeoDataFrame: "node/2", "way/3", ], - name="feature_id", + name=FEATURES_INDEX, dtype="object", ), crs=WGS84_CRS, diff --git a/tests/loaders/test_geoparquet_loader.py b/tests/loaders/test_geoparquet_loader.py index ca421b18..90b198f9 100644 --- a/tests/loaders/test_geoparquet_loader.py +++ b/tests/loaders/test_geoparquet_loader.py @@ -5,11 +5,11 @@ import pytest from shapely.geometry import box -from srai.constants import WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, WGS84_CRS from srai.loaders import GeoparquetLoader bbox = box(minx=-180, maxx=180, miny=-90, maxy=90) -bbox_gdf = gpd.GeoDataFrame({"geometry": [bbox]}) +bbox_gdf = gpd.GeoDataFrame({GEOMETRY_COLUMN: [bbox]}) def test_wrong_path_error() -> None: @@ -59,7 +59,7 @@ def test_setting_index() -> None: def test_clipping() -> None: """Test if properly clips the data.""" bbox = box(minx=-106.645646, maxx=-93.508292, miny=25.837377, maxy=36.500704) - bbox_gdf = gpd.GeoDataFrame({"geometry": [bbox]}, crs=WGS84_CRS) + bbox_gdf = gpd.GeoDataFrame({GEOMETRY_COLUMN: [bbox]}, crs=WGS84_CRS) gdf = GeoparquetLoader().load( file_path=Path(__file__).parent / "test_files" / "example.parquet", area=bbox_gdf ) diff --git a/tests/loaders/test_gtfs_loader.py b/tests/loaders/test_gtfs_loader.py index f9a64c6a..45432a3b 100644 --- a/tests/loaders/test_gtfs_loader.py +++ b/tests/loaders/test_gtfs_loader.py @@ -7,6 +7,7 @@ import pytest from pytest_mock import MockerFixture +from srai.constants import GEOMETRY_COLUMN from srai.loaders import GTFSLoader from srai.loaders.gtfs_loader import GTFS2VEC_DIRECTIONS_PREFIX, GTFS2VEC_TRIPS_PREFIX @@ -63,7 +64,7 @@ def test_gtfs_loader(feed: Any, mocker: MockerFixture, gtfs_validation_ok: pd.Da f"{GTFS2VEC_TRIPS_PREFIX}13", f"{GTFS2VEC_DIRECTIONS_PREFIX}12", f"{GTFS2VEC_DIRECTIONS_PREFIX}13", - "geometry", + GEOMETRY_COLUMN, ], ) diff --git a/tests/miscellaneous/test_optional_dependencies.py b/tests/miscellaneous/test_optional_dependencies.py index a57bce3e..188bc4e2 100644 --- a/tests/miscellaneous/test_optional_dependencies.py +++ b/tests/miscellaneous/test_optional_dependencies.py @@ -3,9 +3,11 @@ from contextlib import nullcontext as does_not_raise from typing import Any, List +import geopandas as gpd import pytest +from shapely.geometry import box -from srai.constants import WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, REGIONS_INDEX, WGS84_CRS from srai.utils._optional import ImportErrorHandle, import_optional_dependency @@ -42,6 +44,10 @@ def no_optional_dependencies(monkeypatch): "haversine", "spherical_geometry", "gtfs_kit", + "folium", + "mapclassify", + "plotly", + "kaleido", ] for package in optional_packages: sys.modules.pop(package, None) @@ -59,7 +65,7 @@ def _test_voronoi_regionizer() -> None: seeds_gdf = gpd.GeoDataFrame( { - "geometry": [ + GEOMETRY_COLUMN: [ Point(17.014997869227177, 51.09919872601259), Point(16.935542631959215, 51.09380600286582), Point(16.900425, 51.1162552343), @@ -72,15 +78,12 @@ def _test_voronoi_regionizer() -> None: vr = VoronoiRegionizer(seeds=seeds_gdf) vr.transform( gdf=gpd.GeoDataFrame( - {"geometry": [box(minx=-180, maxx=180, miny=-90, maxy=90)]}, crs=WGS84_CRS + {GEOMETRY_COLUMN: [box(minx=-180, maxx=180, miny=-90, maxy=90)]}, crs=WGS84_CRS ) ) def _test_administrative_boundary_regionizer() -> None: - import geopandas as gpd - from shapely.geometry import box - from srai.regionizers.administrative_boundary_regionizer import ( AdministrativeBoundaryRegionizer, ) @@ -91,18 +94,49 @@ def _test_administrative_boundary_regionizer() -> None: maxx=88.50230949587835, maxy=34.846427760404225, ) - asia_bbox_gdf = gpd.GeoDataFrame({"geometry": [asia_bbox]}, crs=WGS84_CRS) + asia_bbox_gdf = gpd.GeoDataFrame({GEOMETRY_COLUMN: [asia_bbox]}, crs=WGS84_CRS) abr = AdministrativeBoundaryRegionizer( admin_level=2, return_empty_region=True, toposimplify=0.001 ) abr.transform(gdf=asia_bbox_gdf) +def _test_plotting_folium_module() -> None: + from srai.plotting import folium_wrapper + + folium_wrapper.plot_regions(_get_regions_gdf()) + + +def _test_plotting_plotly_module() -> None: + from srai.plotting import plotly_wrapper + + plotly_wrapper.plot_regions(_get_regions_gdf()) + + +def _get_regions_gdf() -> gpd.GeoDataFrame: + return gpd.GeoDataFrame( + data={ + GEOMETRY_COLUMN: [ + box( + minx=0, + miny=0, + maxx=1, + maxy=1, + ) + ], + REGIONS_INDEX: [1], + }, + crs=WGS84_CRS, + ) + + @pytest.mark.parametrize( # type: ignore "test_fn", [ (_test_voronoi_regionizer), (_test_administrative_boundary_regionizer), + (_test_plotting_folium_module), + (_test_plotting_plotly_module), ], ) def test_optional_available(test_fn): @@ -116,6 +150,8 @@ def test_optional_available(test_fn): [ (_test_voronoi_regionizer), (_test_administrative_boundary_regionizer), + (_test_plotting_folium_module), + (_test_plotting_plotly_module), ], ) def test_optional_missing(test_fn): diff --git a/tests/neighbourhoods/test_neighbourhood.py b/tests/neighbourhoods/test_neighbourhood.py index 600bf703..bc77a191 100644 --- a/tests/neighbourhoods/test_neighbourhood.py +++ b/tests/neighbourhoods/test_neighbourhood.py @@ -59,6 +59,31 @@ def grid_3_by_3_neighbourhood() -> Dict[int, Set[int]]: } +@pytest.fixture # type: ignore +def grid_3_by_3_irrregular_neighbourhood() -> Dict[int, Set[int]]: + """ + Get irregular grid neighbourhood. + + This dict represents a simple 3 by 3 grid-like neighbourhood. The tiles are numbered from 1 to + 4, from left to right, top to bottom. The tiles are considered neighbours if they are adjacent + by edge, not by vertex. Tiles are irregular, not single squares 1 by 1. Visually it looks like + this: + + [[1, 1, 2], + [1, 1, 2], + [3, 4, 4]] + + Returns: + Dict[int, Set[int]]: A dict representing 3 by 3 grid neighbourhood. + """ + return { + 1: {2, 3, 4}, + 2: {1, 4}, + 3: {1, 4}, + 4: {1, 2, 3}, + } + + @pytest.mark.parametrize( # type: ignore "index, distance, expected", [ @@ -115,3 +140,63 @@ def test_get_neighbours_up_to_distance( neighbourhood = LookupNeighbourhood(grid_3_by_3_neighbourhood) neighbours = neighbourhood.get_neighbours_up_to_distance(index, distance) assert neighbours == expected + + +@pytest.mark.parametrize( # type: ignore + "index, distance, expected", + [ + (1, -2, set()), + (1, -1, set()), + (1, 0, set()), + (1, 1, {2, 3, 4}), + (1, 2, set()), + (2, 1, {1, 4}), + (2, 2, {3}), + (2, 3, set()), + (3, 1, {1, 4}), + (3, 2, {2}), + (3, 3, set()), + (4, 1, {1, 2, 3}), + (4, 2, set()), + ], +) +def test_get_neighbours_at_distance_irregular( + index: str, + distance: int, + expected: Set[str], + grid_3_by_3_irrregular_neighbourhood: Dict[str, Set[str]], +) -> None: + """Test neighbours at distance.""" + neighbourhood = LookupNeighbourhood(grid_3_by_3_irrregular_neighbourhood) + neighbours = neighbourhood.get_neighbours_at_distance(index, distance) + assert neighbours == expected + + +@pytest.mark.parametrize( # type: ignore + "index, distance, expected", + [ + (1, -2, set()), + (1, -1, set()), + (1, 0, set()), + (1, 1, {2, 3, 4}), + (1, 2, {2, 3, 4}), + (2, 1, {1, 4}), + (2, 2, {1, 3, 4}), + (2, 3, {1, 3, 4}), + (3, 1, {1, 4}), + (3, 2, {1, 2, 4}), + (3, 3, {1, 2, 4}), + (4, 1, {1, 2, 3}), + (4, 2, {1, 2, 3}), + ], +) +def test_get_neighbours_up_to_distance_irregular( + index: str, + distance: int, + expected: Set[str], + grid_3_by_3_irrregular_neighbourhood: Dict[str, Set[str]], +) -> None: + """Test neighbours up to a distance.""" + neighbourhood = LookupNeighbourhood(grid_3_by_3_irrregular_neighbourhood) + neighbours = neighbourhood.get_neighbours_up_to_distance(index, distance) + assert neighbours == expected diff --git a/tests/regionizers/conftest.py b/tests/regionizers/conftest.py index 874fa7b2..0fba78e5 100644 --- a/tests/regionizers/conftest.py +++ b/tests/regionizers/conftest.py @@ -4,7 +4,7 @@ import pytest from shapely import geometry -from srai.constants import WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, WGS84_CRS @pytest.fixture # type: ignore @@ -101,7 +101,7 @@ def gdf_earth_poles() -> gpd.GeoDataFrame: """Get GeoDataFrame with 6 Earth poles.""" return gpd.GeoDataFrame( { - "geometry": [ + GEOMETRY_COLUMN: [ geometry.Point(0, 0), geometry.Point(90, 0), geometry.Point(180, 0), @@ -132,4 +132,4 @@ def earth_bbox() -> geometry.Polygon: @pytest.fixture # type: ignore def gdf_earth_bbox(earth_bbox) -> gpd.GeoDataFrame: """Get full bounding box GeoDataFrame.""" - return gpd.GeoDataFrame({"geometry": [earth_bbox]}, crs=WGS84_CRS) + return gpd.GeoDataFrame({GEOMETRY_COLUMN: [earth_bbox]}, crs=WGS84_CRS) diff --git a/tests/regionizers/test_administrative_boundary_regionizer.py b/tests/regionizers/test_administrative_boundary_regionizer.py index 2fc0c655..6f0480ba 100644 --- a/tests/regionizers/test_administrative_boundary_regionizer.py +++ b/tests/regionizers/test_administrative_boundary_regionizer.py @@ -8,12 +8,12 @@ from pytest_mock import MockerFixture from shapely.geometry import Point, box -from srai.constants import WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, WGS84_CRS from srai.regionizers import AdministrativeBoundaryRegionizer from srai.utils import merge_disjointed_gdf_geometries bbox = box(minx=-180, maxx=180, miny=-90, maxy=90) -bbox_gdf = gpd.GeoDataFrame({"geometry": [bbox]}, crs=WGS84_CRS) +bbox_gdf = gpd.GeoDataFrame({GEOMETRY_COLUMN: [bbox]}, crs=WGS84_CRS) @pytest.mark.parametrize( # type: ignore @@ -64,7 +64,7 @@ def mock_overpass_api(mocker: MockerFixture) -> None: mocker.patch.object(API, "get", return_value={"elements": [{"type": "relation", "id": 2137}]}) geocoded_gdf = gpd.GeoDataFrame( - {"geometry": [box(minx=0, miny=0, maxx=1, maxy=1)]}, crs=WGS84_CRS + {GEOMETRY_COLUMN: [box(minx=0, miny=0, maxx=1, maxy=1)]}, crs=WGS84_CRS ) mocker.patch("osmnx.geocode_to_gdf", return_value=geocoded_gdf) @@ -84,7 +84,7 @@ def test_empty_region_full_bounding_box(toposimplify: Union[bool, float], reques """Test checks if empty region fills required bounding box.""" request.getfixturevalue("mock_overpass_api") request_bbox = box(minx=0, miny=0, maxx=2, maxy=2) - request_bbox_gdf = gpd.GeoDataFrame({"geometry": [request_bbox]}, crs=WGS84_CRS) + request_bbox_gdf = gpd.GeoDataFrame({GEOMETRY_COLUMN: [request_bbox]}, crs=WGS84_CRS) abr = AdministrativeBoundaryRegionizer( admin_level=4, return_empty_region=True, toposimplify=toposimplify ) @@ -108,7 +108,7 @@ def test_no_empty_region_full_bounding_box(toposimplify: Union[bool, float], req """Test checks if no empty region is generated when not needed.""" request.getfixturevalue("mock_overpass_api") request_bbox = box(minx=0, miny=0, maxx=1, maxy=1) - request_bbox_gdf = gpd.GeoDataFrame({"geometry": [request_bbox]}, crs=WGS84_CRS) + request_bbox_gdf = gpd.GeoDataFrame({GEOMETRY_COLUMN: [request_bbox]}, crs=WGS84_CRS) abr = AdministrativeBoundaryRegionizer( admin_level=2, return_empty_region=True, toposimplify=toposimplify ) @@ -131,14 +131,14 @@ def test_no_empty_region_full_bounding_box(toposimplify: Union[bool, float], req def test_points_in_result(toposimplify: Union[bool, float], request: Any) -> None: """Test checks case when points are in a requested region.""" request.getfixturevalue("mock_overpass_api") - request_gdf = gpd.GeoDataFrame({"geometry": [Point(0.5, 0.5)]}, crs=WGS84_CRS) + request_gdf = gpd.GeoDataFrame({GEOMETRY_COLUMN: [Point(0.5, 0.5)]}, crs=WGS84_CRS) abr = AdministrativeBoundaryRegionizer( admin_level=2, return_empty_region=False, clip_regions=False, toposimplify=toposimplify ) result_gdf = abr.transform(gdf=request_gdf) - assert request_gdf.geometry[0].within(result_gdf.geometry[0]) + assert request_gdf.geometry[0].covered_by(result_gdf.geometry[0]) @pytest.mark.parametrize( # type: ignore @@ -157,7 +157,7 @@ def test_toposimplify_on_real_data(toposimplify: Union[float, bool]) -> None: madagascar_bbox = box( minx=43.2541870461, miny=-25.6014344215, maxx=50.4765368996, maxy=-12.0405567359 ) - madagascar_bbox_gdf = gpd.GeoDataFrame({"geometry": [madagascar_bbox]}, crs=WGS84_CRS) + madagascar_bbox_gdf = gpd.GeoDataFrame({GEOMETRY_COLUMN: [madagascar_bbox]}, crs=WGS84_CRS) abr = AdministrativeBoundaryRegionizer( admin_level=4, return_empty_region=True, toposimplify=toposimplify @@ -166,3 +166,32 @@ def test_toposimplify_on_real_data(toposimplify: Union[float, bool]) -> None: assert ( merge_disjointed_gdf_geometries(madagascar_result_gdf).difference(madagascar_bbox).is_empty ) + + +@pytest.mark.parametrize( # type: ignore + "return_empty_region", + [ + (True), + (False), + ], +) +def test_regions_not_found_on_real_data(return_empty_region: bool) -> None: + """Test if warns when can't find any regions.""" + null_island_region = box(minx=0, miny=0, maxx=0.1, maxy=0.1) + null_island_region_gdf = gpd.GeoDataFrame( + {GEOMETRY_COLUMN: [null_island_region]}, crs=WGS84_CRS + ) + + abr = AdministrativeBoundaryRegionizer(admin_level=10, return_empty_region=return_empty_region) + + with pytest.warns(RuntimeWarning): + madagascar_result_gdf = abr.transform(gdf=null_island_region_gdf) + + if return_empty_region: + assert ( + merge_disjointed_gdf_geometries(madagascar_result_gdf) + .difference(null_island_region) + .is_empty + ) + else: + assert merge_disjointed_gdf_geometries(madagascar_result_gdf).is_empty diff --git a/tests/regionizers/test_h3_regionizer.py b/tests/regionizers/test_h3_regionizer.py index 24a838fe..6ea63667 100644 --- a/tests/regionizers/test_h3_regionizer.py +++ b/tests/regionizers/test_h3_regionizer.py @@ -5,6 +5,7 @@ import pytest +from srai.constants import GEOMETRY_COLUMN from srai.regionizers import H3Regionizer if TYPE_CHECKING: @@ -64,4 +65,4 @@ def test_transform( gdf_h3 = H3Regionizer(resolution, buffer=buffer).transform(gdf) ut.assertCountEqual(first=gdf_h3.index.to_list(), second=h3_indexes) - assert "geometry" in gdf_h3 + assert GEOMETRY_COLUMN in gdf_h3 diff --git a/tests/regionizers/test_s2_regionizer.py b/tests/regionizers/test_s2_regionizer.py index 72131928..90bb6891 100644 --- a/tests/regionizers/test_s2_regionizer.py +++ b/tests/regionizers/test_s2_regionizer.py @@ -5,6 +5,7 @@ import pytest +from srai.constants import GEOMETRY_COLUMN from srai.regionizers import S2Regionizer if TYPE_CHECKING: @@ -57,4 +58,4 @@ def test_transform( gdf_s2 = S2Regionizer(resolution).transform(gdf) ut.assertCountEqual(first=gdf_s2.index.to_list(), second=s2_indexes) - assert "geometry" in gdf_s2 + assert GEOMETRY_COLUMN in gdf_s2 diff --git a/tests/regionizers/test_voronoi_regionizer.py b/tests/regionizers/test_voronoi_regionizer.py index 1b5ad591..0d24622d 100644 --- a/tests/regionizers/test_voronoi_regionizer.py +++ b/tests/regionizers/test_voronoi_regionizer.py @@ -6,7 +6,7 @@ import pytest from shapely.geometry import Point, Polygon -from srai.constants import WGS84_CRS +from srai.constants import GEOMETRY_COLUMN, WGS84_CRS from srai.regionizers import VoronoiRegionizer from srai.regionizers._spherical_voronoi import generate_voronoi_regions from srai.utils import merge_disjointed_gdf_geometries @@ -31,7 +31,7 @@ def test_duplicate_seeds_value_error() -> None: """Test checks if duplicate points are disallowed.""" with pytest.raises(ValueError): seeds_gdf = gpd.GeoDataFrame( - {"geometry": [Point(0, 0), Point(0, 0), Point(-1, -1), Point(2, 2)]}, + {GEOMETRY_COLUMN: [Point(0, 0), Point(0, 0), Point(-1, -1), Point(2, 2)]}, index=[1, 2, 3, 4], crs=WGS84_CRS, ) @@ -42,7 +42,7 @@ def test_single_seed_region() -> None: """Test checks if single seed is disallowed.""" with pytest.raises(ValueError): seeds_gdf = gpd.GeoDataFrame( - {"geometry": [Point(0, 0)]}, + {GEOMETRY_COLUMN: [Point(0, 0)]}, index=[1], crs=WGS84_CRS, ) @@ -74,10 +74,10 @@ def test_big_number_of_seeds_regions(gdf_earth_bbox: gpd.GeoDataFrame, earth_bbo randy = rng.uniform(miny, maxy, number_of_points) coords = np.vstack((randx, randy)).T - pts = [p for p in list(map(Point, coords)) if p.within(earth_bbox)] + pts = [p for p in list(map(Point, coords)) if p.covered_by(earth_bbox)] random_points_gdf = gpd.GeoDataFrame( - {"geometry": pts}, + {GEOMETRY_COLUMN: pts}, index=list(range(len(pts))), crs=WGS84_CRS, ) @@ -92,7 +92,7 @@ def test_four_close_seed_region(gdf_earth_bbox: gpd.GeoDataFrame, earth_bbox: Po """Test checks if four close seeds are properly evaluated.""" seeds_gdf = gpd.GeoDataFrame( { - "geometry": [ + GEOMETRY_COLUMN: [ Point(17.014997869227177, 51.09919872601259), Point(16.935542631959215, 51.09380600286582), Point(16.900425, 51.1162552343),