Skip to content

Commit

Permalink
feat: add gtfs2vec (#169)
Browse files Browse the repository at this point in the history
* chore: move index validation to BaseEmbedder, add empty GTFS2VecEmbedder class

* chore: add autoencoder for gtfs2vec

* test: switch macOS tests to python 3.10

* fix: remove python 3.11 from tests for PyTorch compatibility

* chore: add logic for features to regions aggregation in GTFS2VecEmbedder

* chore: add model training and embedding

* docs: add example for the embedder on dummy data

* fix: update NVIDIA libs to skip in licensecheck

* docs: update examples readme and CHANGELOG.md

* chore: allow to skip embedding features in gtfs2vec

* chore: extract column prefixes in gtfs2vec to library constants

* test: add tests to gtfs2vec embedder

* docs: add example with skip_embedding=True

* chore: remove unused param from GTFS2VecModel
  • Loading branch information
piotrgramacki authored Feb 22, 2023
1 parent 050bda4 commit 77ee8a8
Show file tree
Hide file tree
Showing 18 changed files with 1,552 additions and 111 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10"]
include:
- os: macos-latest
python-version: "3.11"
python-version: "3.10"
env:
OS: ${{ matrix.os }}
PYTHON: ${{ matrix.python-version }}
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,6 @@ requirements.txt

# osmnx
cache/

# pytorch lightning
lightning_logs/
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- GTFS Loader from gtfs2vec paper
- GTFS2Vec Model from gtfs2vec paper
- GTFS2Vec Embedder using gtfs2vec model
- OSMTagLoader

### Changed
Expand Down
6 changes: 6 additions & 0 deletions examples/embedders/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Embedders

Examples illustrating the usage of every Joiner.

- [CountEmbedder](count_embedder.ipynb)
- [GTFS2VecEmbedder](gtfs2vec_embedder.ipynb)
166 changes: 166 additions & 0 deletions examples/embedders/gtfs2vec_embedder.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from srai.embedders import GTFS2VecEmbedder\n",
"import pandas as pd\n",
"from shapely.geometry import Polygon\n",
"import geopandas as gpd\n",
"from pytorch_lightning import seed_everything"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example on artificial data"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define features and regions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"features_gdf = gpd.GeoDataFrame(\n",
" {\n",
" \"trip_count_at_6\": [1, 0, 0],\n",
" \"trip_count_at_7\": [1, 1, 0],\n",
" \"trip_count_at_8\": [0, 0, 1],\n",
" \"directions_at_6\": [\n",
" {\"A\", \"A1\"},\n",
" {\"B\", \"B1\"},\n",
" {\"C\"},\n",
" ],\n",
" },\n",
" geometry=gpd.points_from_xy([1, 2, 5], [1, 2, 2]),\n",
" index=[1, 2, 3],\n",
")\n",
"features_gdf.index.name = \"stop_id\"\n",
"features_gdf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"regions_gdf = gpd.GeoDataFrame(\n",
" {\n",
" \"region_id\": [\"ff1\", \"ff2\", \"ff3\"],\n",
" },\n",
" geometry=[\n",
" Polygon([(0, 0), (0, 3), (3, 3), (3, 0)]),\n",
" Polygon([(4, 0), (4, 3), (7, 3), (7, 0)]),\n",
" Polygon([(8, 0), (8, 3), (11, 3), (11, 0)]),\n",
" ],\n",
").set_index(\"region_id\")\n",
"regions_gdf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ax = regions_gdf.plot()\n",
"features_gdf.plot(ax=ax, color=\"red\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"joint_gdf = gpd.GeoDataFrame()\n",
"joint_gdf.index = pd.MultiIndex.from_tuples(\n",
" [(\"ff1\", 1), (\"ff1\", 2), (\"ff2\", 3)],\n",
" names=[\"region_id\", \"stop_id\"],\n",
")\n",
"joint_gdf"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get features without embedding them"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embedder = GTFS2VecEmbedder(skip_autoencoder=True)\n",
"res = embedder.transform(regions_gdf, features_gdf, joint_gdf)\n",
"res"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fit and train the embedder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"seed_everything(42)\n",
"embedder = GTFS2VecEmbedder(hidden_size=2, embedding_size=4)\n",
"embedder.fit(regions_gdf, features_gdf, joint_gdf)\n",
"res = embedder.transform(regions_gdf, features_gdf, joint_gdf)\n",
"res"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.14"
},
"vscode": {
"interpreter": {
"hash": "f39c7279c85c8be5d827e53eddb5011e966102d239fe8b81ca4bd9f0123eda8f"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 77ee8a8

Please sign in to comment.