Skip to content

Commit

Permalink
Support for custom HeteroData mini-batch class in remote backends (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 9, 2023
1 parent 19d5cbc commit 3ea1d81
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.3.0] - 2023-MM-DD
### Added
- Added support for custom `HeteroData` mini-batch class in remote backends ([#6377](https://github.com/pyg-team/pytorch_geometric/pull/6377))
- Added the `GNNFF` model ([#5866](https://github.com/pyg-team/pytorch_geometric/pull/5866))
- Added `MLPAggregation`, `SetTransformerAggregation`, `GRUAggregation`, and `DeepSetsAggregation` as adaptive readout functions ([#6301](https://github.com/pyg-team/pytorch_geometric/pull/6301), [#6336](https://github.com/pyg-team/pytorch_geometric/pull/6336), [#6338](https://github.com/pyg-team/pytorch_geometric/pull/6338))
- Added `Dataset.to_datapipe` for converting PyG datasets into a torchdata `DataPipe`([#6141](https://github.com/pyg-team/pytorch_geometric/pull/6141))
Expand Down
7 changes: 6 additions & 1 deletion torch_geometric/loader/link_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ class LinkLoader(torch.utils.data.DataLoader):
(2) it may slown down data loading,
(3) it requires operating on CPU tensors.
(default: :obj:`False`)
custom_cls (HeteroData, optional): A custom
:class:`~torch_geometric.data.HeteroData` class to return for
mini-batches in case of remote backends. (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
Expand All @@ -120,6 +123,7 @@ def __init__(
transform: Optional[Callable] = None,
transform_sampler_output: Optional[Callable] = None,
filter_per_worker: bool = False,
custom_cls: Optional[HeteroData] = None,
**kwargs,
):
# Remove for PyTorch Lightning:
Expand All @@ -142,6 +146,7 @@ def __init__(
self.transform = transform
self.transform_sampler_output = transform_sampler_output
self.filter_per_worker = filter_per_worker
self.custom_cls = custom_cls

if (self.neg_sampling is not None and self.neg_sampling.is_binary()
and edge_label is not None and edge_label.min() == 0):
Expand Down Expand Up @@ -220,7 +225,7 @@ def filter_fn(
self.link_sampler.edge_permutation)
else: # Tuple[FeatureStore, GraphStore]
data = filter_custom_store(*self.data, out.node, out.row,
out.col, out.edge)
out.col, out.edge, self.custom_cls)

for key, batch in (out.batch or {}).items():
data[key].batch = batch
Expand Down
7 changes: 6 additions & 1 deletion torch_geometric/loader/node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class NodeLoader(torch.utils.data.DataLoader):
(2) it may slown down data loading,
(3) it requires operating on CPU tensors.
(default: :obj:`False`)
custom_cls (HeteroData, optional): A custom
:class:`~torch_geometric.data.HeteroData` class to return for
mini-batches in case of remote backends. (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,
:obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.
Expand All @@ -81,6 +84,7 @@ def __init__(
transform: Optional[Callable] = None,
transform_sampler_output: Optional[Callable] = None,
filter_per_worker: bool = False,
custom_cls: Optional[HeteroData] = None,
**kwargs,
):
# Remove for PyTorch Lightning:
Expand All @@ -97,6 +101,7 @@ def __init__(
self.transform = transform
self.transform_sampler_output = transform_sampler_output
self.filter_per_worker = filter_per_worker
self.custom_cls = custom_cls

self.input_data = NodeSamplerInput(
input_id=None,
Expand Down Expand Up @@ -153,7 +158,7 @@ def filter_fn(
self.node_sampler.edge_permutation)
else: # Tuple[FeatureStore, GraphStore]
data = filter_custom_store(*self.data, out.node, out.row,
out.col, out.edge)
out.col, out.edge, self.custom_cls)

for key, batch in (out.batch or {}).items():
data[key].batch = batch
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,14 @@ def filter_custom_store(
row_dict: Dict[str, Tensor],
col_dict: Dict[str, Tensor],
edge_dict: Dict[str, Tensor],
custom_cls: Optional[HeteroData] = None,
) -> HeteroData:
r"""Constructs a `HeteroData` object from a feature store that only holds
nodes in `node` end edges in `edge` for each node and edge type,
respectively."""

# Construct a new `HeteroData` object:
data = HeteroData()
data = custom_cls() if custom_cls is not None else HeteroData()

# Filter edge storage:
# TODO support edge attributes
Expand Down

0 comments on commit 3ea1d81

Please sign in to comment.