Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add **dataloader_kwargs for embedders #464

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions srai/embedders/geovex/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def transform(
regions_gdf: gpd.GeoDataFrame,
features_gdf: gpd.GeoDataFrame,
joint_gdf: gpd.GeoDataFrame,
dataloader_kwargs: Optional[dict[str, Any]] = None,
) -> pd.DataFrame:
"""
Create region embeddings.
Expand All @@ -116,6 +117,7 @@ def transform(
neighbourhood,
self._batch_size,
shuffle=False,
dataloader_kwargs=dataloader_kwargs,
)

return self._transform(dataset=self._dataset, dataloader=dataloader)
Expand All @@ -124,9 +126,15 @@ def _transform(
self,
dataset: HexagonalDataset[T],
dataloader: Optional[DataLoader] = None,
dataloader_kwargs: Optional[dict[str, Any]] = None,
) -> pd.DataFrame:
dataloader_kwargs = dataloader_kwargs or {}
if "batch_size" not in dataloader_kwargs:
dataloader_kwargs["batch_size"] = self._batch_size
if "shuffle" not in dataloader_kwargs:
dataloader_kwargs["shuffle"] = False
if dataloader is None:
dataloader = DataLoader(dataset, batch_size=self._batch_size, shuffle=False)
dataloader = DataLoader(dataset, **dataloader_kwargs)

embeddings = [
self._model.encoder(batch).detach().numpy() # type: ignore
Expand All @@ -149,6 +157,7 @@ def fit(
neighbourhood: H3Neighbourhood,
learning_rate: float = 0.001,
trainer_kwargs: Optional[dict[str, Any]] = None,
dataloader_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Fit the model to the data.
Expand All @@ -167,7 +176,13 @@ def fit(

trainer_kwargs = self._prepare_trainer_kwargs(trainer_kwargs)
counts_df, dataloader, dataset = self._prepare_dataset( # type: ignore
regions_gdf, features_gdf, joint_gdf, neighbourhood, self._batch_size, shuffle=True
regions_gdf,
features_gdf,
joint_gdf,
neighbourhood,
self._batch_size,
shuffle=True,
dataloader_kwargs=dataloader_kwargs,
)

self._prepare_model(counts_df, learning_rate)
Expand Down Expand Up @@ -196,14 +211,20 @@ def _prepare_dataset(
neighbourhood: H3Neighbourhood,
batch_size: Optional[int],
shuffle: bool = True,
dataloader_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[pd.DataFrame, DataLoader, HexagonalDataset[T]]:
counts_df = self._get_raw_counts(regions_gdf, features_gdf, joint_gdf)
dataset: HexagonalDataset[T] = HexagonalDataset(
counts_df,
neighbourhood,
neighbor_k_ring=self._r,
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
dataloader_kwargs = dataloader_kwargs or {}
if "batch_size" not in dataloader_kwargs:
dataloader_kwargs["batch_size"] = batch_size
if "shuffle" not in dataloader_kwargs:
dataloader_kwargs["shuffle"] = shuffle
dataloader = DataLoader(dataset, **dataloader_kwargs)
return counts_df, dataloader, dataset

def fit_transform(
Expand All @@ -214,6 +235,7 @@ def fit_transform(
neighbourhood: H3Neighbourhood,
learning_rate: float = 0.001,
trainer_kwargs: Optional[dict[str, Any]] = None,
dataloader_kwargs: Optional[dict[str, Any]] = None,
) -> pd.DataFrame:
"""
Fit the model to the data and create region embeddings.
Expand All @@ -236,6 +258,7 @@ def fit_transform(
neighbourhood=neighbourhood,
learning_rate=learning_rate,
trainer_kwargs=trainer_kwargs,
dataloader_kwargs=dataloader_kwargs,
)
assert self._dataset is not None # for mypy
return self._transform(dataset=self._dataset)
Expand Down
19 changes: 15 additions & 4 deletions srai/embedders/gtfs2vec/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def fit(
regions_gdf: gpd.GeoDataFrame,
features_gdf: gpd.GeoDataFrame,
joint_gdf: gpd.GeoDataFrame,
dataloader_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Fit model to a given data.
Expand All @@ -101,13 +102,14 @@ def fit(
features = self._prepare_features(regions_gdf, features_gdf, joint_gdf)

if not self._skip_autoencoder:
self._model = self._train_model_unsupervised(features)
self._model = self._train_model_unsupervised(features, dataloader_kwargs)

def fit_transform(
self,
regions_gdf: gpd.GeoDataFrame,
features_gdf: gpd.GeoDataFrame,
joint_gdf: gpd.GeoDataFrame,
dataloader_kwargs: Optional[dict[str, Any]] = None,
) -> pd.DataFrame:
"""
Fit model and transform a given data.
Expand All @@ -131,7 +133,7 @@ def fit_transform(
if self._skip_autoencoder:
return features
else:
self._model = self._train_model_unsupervised(features)
self._model = self._train_model_unsupervised(features, dataloader_kwargs)
return self._embed(features)

def _maybe_get_model(self) -> GTFS2VecModel:
Expand Down Expand Up @@ -228,7 +230,9 @@ def _normalize_features(self, features: pd.DataFrame) -> pd.DataFrame:

return features

def _train_model_unsupervised(self, features: pd.DataFrame) -> GTFS2VecModel:
def _train_model_unsupervised(
self, features: pd.DataFrame, dataloader_kwargs: Optional[dict[str, Any]] = None
) -> GTFS2VecModel:
"""
Train model unsupervised.

Expand All @@ -244,7 +248,14 @@ def _train_model_unsupervised(self, features: pd.DataFrame) -> GTFS2VecModel:
n_embed=self._embedding_size,
)
X = features.to_numpy().astype(np.float32)
x_dataloader = DataLoader(X, batch_size=24, shuffle=True, num_workers=4)
dataloader_kwargs = dataloader_kwargs or {}
if "num_workers" not in dataloader_kwargs:
dataloader_kwargs["num_workers"] = 4
if "batch_size" not in dataloader_kwargs:
dataloader_kwargs["batch_size"] = 24
if "shuffle" not in dataloader_kwargs:
dataloader_kwargs["shuffle"] = True
x_dataloader = DataLoader(X, **dataloader_kwargs)
trainer = pl.Trainer(max_epochs=10)

trainer.fit(model, x_dataloader)
Expand Down
11 changes: 10 additions & 1 deletion srai/embedders/hex2vec/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def fit(
batch_size: int = 32,
learning_rate: float = 0.001,
trainer_kwargs: Optional[dict[str, Any]] = None,
dataloader_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""
Fit the model to the data.
Expand Down Expand Up @@ -141,7 +142,13 @@ def fit(
layer_sizes=[num_features, *self._encoder_sizes], learning_rate=learning_rate
)
dataset = NeighbourDataset(counts_df, neighbourhood, negative_sample_k_distance)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataloader_kwargs = dataloader_kwargs or {}
if "batch_size" not in dataloader_kwargs:
dataloader_kwargs["batch_size"] = batch_size
if "shuffle" not in dataloader_kwargs:
dataloader_kwargs["shuffle"] = True

dataloader = DataLoader(dataset, **dataloader_kwargs)

trainer = pl.Trainer(**trainer_kwargs)
trainer.fit(self._model, dataloader)
Expand All @@ -157,6 +164,7 @@ def fit_transform(
batch_size: int = 32,
learning_rate: float = 0.001,
trainer_kwargs: Optional[dict[str, Any]] = None,
dataloader_kwargs: Optional[dict[str, Any]] = None,
) -> pd.DataFrame:
"""
Fit the model to the data and return the embeddings.
Expand Down Expand Up @@ -192,6 +200,7 @@ def fit_transform(
batch_size,
learning_rate,
trainer_kwargs,
dataloader_kwargs,
)
return self.transform(regions_gdf, features_gdf, joint_gdf)

Expand Down
Loading