-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
571 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Highway2Vec Embedder" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import geopandas as gpd\n", | ||
"from IPython.display import display" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Get an area to embed" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from srai.utils import geocode_to_region_gdf\n", | ||
"\n", | ||
"area_gdf = geocode_to_region_gdf(\"Wrocław, PL\")\n", | ||
"area_gdf.plot()" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Regionize the area with a regionizer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from srai.regionizers import H3Regionizer\n", | ||
"\n", | ||
"regionizer = H3Regionizer(9)\n", | ||
"regions_gdf = regionizer.transform(area_gdf)\n", | ||
"print(len(regions_gdf))\n", | ||
"display(regions_gdf.head(3))\n", | ||
"regions_gdf.plot()" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Download a road infrastructure for the area" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from srai.loaders import OSMWayLoader\n", | ||
"from srai.loaders.osm_way_loader import NetworkType\n", | ||
"\n", | ||
"loader = OSMWayLoader(NetworkType.DRIVE)\n", | ||
"nodes_gdf, edges_gdf = loader.load(area_gdf)\n", | ||
"\n", | ||
"display(nodes_gdf.head(3))\n", | ||
"display(edges_gdf.head(3))\n", | ||
"\n", | ||
"ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))\n", | ||
"nodes_gdf.plot(ax=ax, markersize=3, color=\"red\")" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Find out which edges correspond to which regions " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from srai.joiners import IntersectionJoiner\n", | ||
"\n", | ||
"joiner = IntersectionJoiner()\n", | ||
"joint_gdf = joiner.transform(regions_gdf, edges_gdf, return_geom=False)\n", | ||
"joint_gdf" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Get the embeddings for regions based on the road infrastructure" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from srai.embedders import Highway2VecEmbedder\n", | ||
"from pytorch_lightning import seed_everything\n", | ||
"\n", | ||
"seed_everything(42)\n", | ||
"\n", | ||
"embedder = Highway2VecEmbedder()\n", | ||
"embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)\n", | ||
"embeddings" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.1" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,5 @@ | |
|
||
REGIONS_INDEX = "region_id" | ||
FEATURES_INDEX = "feature_id" | ||
|
||
GEOMETRY_COLUMN = "geometry" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""Highway2Vec.""" | ||
|
||
from .embedder import Highway2VecEmbedder | ||
from .model import Highway2VecModel | ||
|
||
__all__ = ["Highway2VecEmbedder", "Highway2VecModel"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
""" | ||
Highway2Vec embedder. | ||
This module contains the embedder from the `highway2vec` paper [1]. | ||
References: | ||
[1] https://doi.org/10.1145/3557918.3565865 | ||
""" | ||
from typing import Any, Dict, Optional | ||
|
||
import geopandas as gpd | ||
import pandas as pd | ||
import pytorch_lightning as pl | ||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from srai.embedders import Embedder | ||
from srai.exceptions import ModelNotFitException | ||
|
||
from .model import Highway2VecModel | ||
|
||
|
||
class Highway2VecEmbedder(Embedder): | ||
"""Highway2Vec Embedder.""" | ||
|
||
def __init__(self, hidden_size: int = 64, embedding_size: int = 30) -> None: | ||
""" | ||
Init Highway2Vec Embedder. | ||
Args: | ||
hidden_size (int, optional): Hidden size in encoder and decoder. Defaults to 64. | ||
embedding_size (int, optional): Embedding size. Defaults to 30. | ||
""" | ||
self._model: Optional[Highway2VecModel] = None | ||
self._hidden_size = hidden_size | ||
self._embedding_size = embedding_size | ||
self._is_fitted = False | ||
|
||
def transform( | ||
self, | ||
regions_gdf: gpd.GeoDataFrame, | ||
features_gdf: gpd.GeoDataFrame, | ||
joint_gdf: gpd.GeoDataFrame, | ||
) -> pd.DataFrame: # pragma: no cover | ||
""" | ||
Embed regions using features. | ||
Args: | ||
regions_gdf (gpd.GeoDataFrame): Region indexes and geometries. | ||
features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values. | ||
joint_gdf (gpd.GeoDataFrame): Joiner result with region-feature multi-index. | ||
Returns: | ||
pd.DataFrame: Embedding and geometry index for each region in regions_gdf. | ||
Raises: | ||
ValueError: If any of the gdfs index names is None. | ||
ValueError: If joint_gdf.index is not of type pd.MultiIndex or doesn't have 2 levels. | ||
ValueError: If index levels in gdfs don't overlap correctly. | ||
""" | ||
self._validate_indexes(regions_gdf, features_gdf, joint_gdf) | ||
self._check_is_fitted() | ||
features_df = self._remove_geometry_if_present(features_gdf) | ||
|
||
self._model.eval() # type: ignore | ||
embeddings = self._model(torch.Tensor(features_df.values)).detach().numpy() # type: ignore | ||
embeddings_df = pd.DataFrame(embeddings, index=features_df.index) | ||
embeddings_joint = joint_gdf.join(embeddings_df) | ||
embeddings_aggregated = embeddings_joint.groupby(level=[0]).mean() | ||
|
||
return embeddings_aggregated | ||
|
||
def fit( | ||
self, | ||
regions_gdf: gpd.GeoDataFrame, | ||
features_gdf: gpd.GeoDataFrame, | ||
joint_gdf: gpd.GeoDataFrame, | ||
trainer_kwargs: Optional[Dict[str, Any]] = None, | ||
dataloader_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> None: | ||
""" | ||
Fit the model to the data. | ||
Args: | ||
regions_gdf (gpd.GeoDataFrame): Region indexes and geometries. | ||
features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values. | ||
joint_gdf (gpd.GeoDataFrame): Joiner result with region-feature multi-index. | ||
trainer_kwargs (Optional[Dict[str, Any]], optional): Trainer kwargs. Defaults to None. | ||
dataloader_kwargs (Optional[Dict[str, Any]], optional): Dataloader kwargs. | ||
Defaults to None. | ||
Raises: | ||
ValueError: If any of the gdfs index names is None. | ||
ValueError: If joint_gdf.index is not of type pd.MultiIndex or doesn't have 2 levels. | ||
ValueError: If index levels in gdfs don't overlap correctly. | ||
""" | ||
self._validate_indexes(regions_gdf, features_gdf, joint_gdf) | ||
features_df = self._remove_geometry_if_present(features_gdf) | ||
|
||
num_features = len(features_df.columns) | ||
self._model = Highway2VecModel( | ||
n_features=num_features, n_hidden=self._hidden_size, n_embed=self._embedding_size | ||
) | ||
|
||
dataloader_kwargs = dataloader_kwargs or {} | ||
if "batch_size" not in dataloader_kwargs: | ||
dataloader_kwargs["batch_size"] = 128 | ||
|
||
dataloader = DataLoader(torch.Tensor(features_df.values), **dataloader_kwargs) | ||
|
||
trainer_kwargs = trainer_kwargs or {} | ||
if "max_epochs" not in trainer_kwargs: | ||
trainer_kwargs["max_epochs"] = 10 | ||
|
||
trainer = pl.Trainer(**trainer_kwargs) | ||
trainer.fit(self._model, dataloader) | ||
self._is_fitted = True | ||
|
||
def fit_transform( | ||
self, | ||
regions_gdf: gpd.GeoDataFrame, | ||
features_gdf: gpd.GeoDataFrame, | ||
joint_gdf: gpd.GeoDataFrame, | ||
trainer_kwargs: Optional[Dict[str, Any]] = None, | ||
dataloader_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> pd.DataFrame: | ||
""" | ||
Fit the model to the data and return the embeddings. | ||
Args: | ||
regions_gdf (gpd.GeoDataFrame): Region indexes and geometries. | ||
features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values. | ||
joint_gdf (gpd.GeoDataFrame): Joiner result with region-feature multi-index. | ||
trainer_kwargs (Optional[Dict[str, Any]], optional): Trainer kwargs. Defaults to None. | ||
dataloader_kwargs (Optional[Dict[str, Any]], optional): Dataloader kwargs. | ||
Defaults to None. | ||
Returns: | ||
pd.DataFrame: Region embeddings. | ||
Raises: | ||
ValueError: If any of the gdfs index names is None. | ||
ValueError: If joint_gdf.index is not of type pd.MultiIndex or doesn't have 2 levels. | ||
ValueError: If index levels in gdfs don't overlap correctly. | ||
""" | ||
self.fit(regions_gdf, features_gdf, joint_gdf, trainer_kwargs, dataloader_kwargs) | ||
return self.transform(regions_gdf, features_gdf, joint_gdf) | ||
|
||
def _check_is_fitted(self) -> None: | ||
if not self._is_fitted or self._model is None: | ||
raise ModelNotFitException("Model not fitted. Call fit() or fit_transform() first.") |
Oops, something went wrong.