Skip to content

Commit

Permalink
feat: add Highway2Vec (#209)
Browse files Browse the repository at this point in the history
  • Loading branch information
Calychas authored Mar 25, 2023
1 parent 8ebda07 commit bc80bec
Show file tree
Hide file tree
Showing 16 changed files with 571 additions and 21 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- GTFS Loader from gtfs2vec paper
- GTFS2Vec Model from gtfs2vec paper
- GTFS2Vec Embedder using gtfs2vec model
- Highway2Vec Model from highway2vec paper
- Highway2Vec Embedder using highway2vec model
- OSMOnlineLoader
- OSMPbfLoader
- OSMWayLoader
Expand All @@ -29,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- IntersectionJoiner incorrectly returned feature columns when `return_geom=False` ([#208](https://github.com/srai-lab/srai/issues/208))

### Security

## [0.0.1] - 2022-11-23
Expand Down
158 changes: 158 additions & 0 deletions examples/embedders/highway2vec_embedder.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Highway2Vec Embedder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import geopandas as gpd\n",
"from IPython.display import display"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get an area to embed"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from srai.utils import geocode_to_region_gdf\n",
"\n",
"area_gdf = geocode_to_region_gdf(\"Wrocław, PL\")\n",
"area_gdf.plot()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regionize the area with a regionizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from srai.regionizers import H3Regionizer\n",
"\n",
"regionizer = H3Regionizer(9)\n",
"regions_gdf = regionizer.transform(area_gdf)\n",
"print(len(regions_gdf))\n",
"display(regions_gdf.head(3))\n",
"regions_gdf.plot()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download a road infrastructure for the area"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from srai.loaders import OSMWayLoader\n",
"from srai.loaders.osm_way_loader import NetworkType\n",
"\n",
"loader = OSMWayLoader(NetworkType.DRIVE)\n",
"nodes_gdf, edges_gdf = loader.load(area_gdf)\n",
"\n",
"display(nodes_gdf.head(3))\n",
"display(edges_gdf.head(3))\n",
"\n",
"ax = edges_gdf.plot(linewidth=1, figsize=(12, 7))\n",
"nodes_gdf.plot(ax=ax, markersize=3, color=\"red\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Find out which edges correspond to which regions "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from srai.joiners import IntersectionJoiner\n",
"\n",
"joiner = IntersectionJoiner()\n",
"joint_gdf = joiner.transform(regions_gdf, edges_gdf, return_geom=False)\n",
"joint_gdf"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get the embeddings for regions based on the road infrastructure"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from srai.embedders import Highway2VecEmbedder\n",
"from pytorch_lightning import seed_everything\n",
"\n",
"seed_everything(42)\n",
"\n",
"embedder = Highway2VecEmbedder()\n",
"embeddings = embedder.fit_transform(regions_gdf, edges_gdf, joint_gdf)\n",
"embeddings"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion examples/loaders/osm_way_loader.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"metadata": {},
"outputs": [],
"source": [
"osmwl = OSMWayLoader(NetworkType.BIKE)\n",
"osmwl = OSMWayLoader(NetworkType.BIKE, metadata=True)\n",
"gdf_nodes, gdf_edges = osmwl.load(gdf_place)\n",
"ax = gdf_edges.plot(linewidth=1, figsize=(12, 7))\n",
"gdf_nodes.plot(ax=ax, markersize=3, color=\"red\")"
Expand Down
2 changes: 2 additions & 0 deletions srai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@

REGIONS_INDEX = "region_id"
FEATURES_INDEX = "feature_id"

GEOMETRY_COLUMN = "geometry"
3 changes: 2 additions & 1 deletion srai/embedders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from ._base import Embedder
from .count_embedder import CountEmbedder
from .gtfs2vec_embedder import GTFS2VecEmbedder
from .highway2vec import Highway2VecEmbedder

__all__ = ["Embedder", "CountEmbedder", "GTFS2VecEmbedder"]
__all__ = ["Embedder", "CountEmbedder", "GTFS2VecEmbedder", "Highway2VecEmbedder"]
1 change: 0 additions & 1 deletion srai/embedders/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def transform(
pd.DataFrame: Embedding and geometry index for each region in regions_gdf.
Raises:
ValueError: If features_gdf is empty and self.expected_output_features is not set.
ValueError: If any of the gdfs index names is None.
ValueError: If joint_gdf.index is not of type pd.MultiIndex or doesn't have 2 levels.
ValueError: If index levels in gdfs don't overlap correctly.
Expand Down
6 changes: 6 additions & 0 deletions srai/embedders/highway2vec/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Highway2Vec."""

from .embedder import Highway2VecEmbedder
from .model import Highway2VecModel

__all__ = ["Highway2VecEmbedder", "Highway2VecModel"]
151 changes: 151 additions & 0 deletions srai/embedders/highway2vec/embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
Highway2Vec embedder.
This module contains the embedder from the `highway2vec` paper [1].
References:
[1] https://doi.org/10.1145/3557918.3565865
"""
from typing import Any, Dict, Optional

import geopandas as gpd
import pandas as pd
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from srai.embedders import Embedder
from srai.exceptions import ModelNotFitException

from .model import Highway2VecModel


class Highway2VecEmbedder(Embedder):
"""Highway2Vec Embedder."""

def __init__(self, hidden_size: int = 64, embedding_size: int = 30) -> None:
"""
Init Highway2Vec Embedder.
Args:
hidden_size (int, optional): Hidden size in encoder and decoder. Defaults to 64.
embedding_size (int, optional): Embedding size. Defaults to 30.
"""
self._model: Optional[Highway2VecModel] = None
self._hidden_size = hidden_size
self._embedding_size = embedding_size
self._is_fitted = False

def transform(
self,
regions_gdf: gpd.GeoDataFrame,
features_gdf: gpd.GeoDataFrame,
joint_gdf: gpd.GeoDataFrame,
) -> pd.DataFrame: # pragma: no cover
"""
Embed regions using features.
Args:
regions_gdf (gpd.GeoDataFrame): Region indexes and geometries.
features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values.
joint_gdf (gpd.GeoDataFrame): Joiner result with region-feature multi-index.
Returns:
pd.DataFrame: Embedding and geometry index for each region in regions_gdf.
Raises:
ValueError: If any of the gdfs index names is None.
ValueError: If joint_gdf.index is not of type pd.MultiIndex or doesn't have 2 levels.
ValueError: If index levels in gdfs don't overlap correctly.
"""
self._validate_indexes(regions_gdf, features_gdf, joint_gdf)
self._check_is_fitted()
features_df = self._remove_geometry_if_present(features_gdf)

self._model.eval() # type: ignore
embeddings = self._model(torch.Tensor(features_df.values)).detach().numpy() # type: ignore
embeddings_df = pd.DataFrame(embeddings, index=features_df.index)
embeddings_joint = joint_gdf.join(embeddings_df)
embeddings_aggregated = embeddings_joint.groupby(level=[0]).mean()

return embeddings_aggregated

def fit(
self,
regions_gdf: gpd.GeoDataFrame,
features_gdf: gpd.GeoDataFrame,
joint_gdf: gpd.GeoDataFrame,
trainer_kwargs: Optional[Dict[str, Any]] = None,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Fit the model to the data.
Args:
regions_gdf (gpd.GeoDataFrame): Region indexes and geometries.
features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values.
joint_gdf (gpd.GeoDataFrame): Joiner result with region-feature multi-index.
trainer_kwargs (Optional[Dict[str, Any]], optional): Trainer kwargs. Defaults to None.
dataloader_kwargs (Optional[Dict[str, Any]], optional): Dataloader kwargs.
Defaults to None.
Raises:
ValueError: If any of the gdfs index names is None.
ValueError: If joint_gdf.index is not of type pd.MultiIndex or doesn't have 2 levels.
ValueError: If index levels in gdfs don't overlap correctly.
"""
self._validate_indexes(regions_gdf, features_gdf, joint_gdf)
features_df = self._remove_geometry_if_present(features_gdf)

num_features = len(features_df.columns)
self._model = Highway2VecModel(
n_features=num_features, n_hidden=self._hidden_size, n_embed=self._embedding_size
)

dataloader_kwargs = dataloader_kwargs or {}
if "batch_size" not in dataloader_kwargs:
dataloader_kwargs["batch_size"] = 128

dataloader = DataLoader(torch.Tensor(features_df.values), **dataloader_kwargs)

trainer_kwargs = trainer_kwargs or {}
if "max_epochs" not in trainer_kwargs:
trainer_kwargs["max_epochs"] = 10

trainer = pl.Trainer(**trainer_kwargs)
trainer.fit(self._model, dataloader)
self._is_fitted = True

def fit_transform(
self,
regions_gdf: gpd.GeoDataFrame,
features_gdf: gpd.GeoDataFrame,
joint_gdf: gpd.GeoDataFrame,
trainer_kwargs: Optional[Dict[str, Any]] = None,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> pd.DataFrame:
"""
Fit the model to the data and return the embeddings.
Args:
regions_gdf (gpd.GeoDataFrame): Region indexes and geometries.
features_gdf (gpd.GeoDataFrame): Feature indexes, geometries and feature values.
joint_gdf (gpd.GeoDataFrame): Joiner result with region-feature multi-index.
trainer_kwargs (Optional[Dict[str, Any]], optional): Trainer kwargs. Defaults to None.
dataloader_kwargs (Optional[Dict[str, Any]], optional): Dataloader kwargs.
Defaults to None.
Returns:
pd.DataFrame: Region embeddings.
Raises:
ValueError: If any of the gdfs index names is None.
ValueError: If joint_gdf.index is not of type pd.MultiIndex or doesn't have 2 levels.
ValueError: If index levels in gdfs don't overlap correctly.
"""
self.fit(regions_gdf, features_gdf, joint_gdf, trainer_kwargs, dataloader_kwargs)
return self.transform(regions_gdf, features_gdf, joint_gdf)

def _check_is_fitted(self) -> None:
if not self._is_fitted or self._model is None:
raise ModelNotFitException("Model not fitted. Call fit() or fit_transform() first.")
Loading

0 comments on commit bc80bec

Please sign in to comment.