Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add temporal sampling support to NeighborLoader #4025

Merged
merged 18 commits into from
Apr 25, 2022
Merged
14 changes: 12 additions & 2 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_homogeneous_neighbor_loader(directed):


@pytest.mark.parametrize('directed', [True, False])
def test_heterogeneous_neighbor_loader(directed):
@pytest.mark.parametrize('temporal', [True, False])
def test_heterogeneous_neighbor_loader(directed, temporal):
torch.manual_seed(12345)

data = HeteroData()
Expand All @@ -88,9 +89,18 @@ def test_heterogeneous_neighbor_loader(directed):
value=torch.arange(2500),
)

# timestamps of nodes
if temporal:
node_time = {}
node_time['paper'] = torch.tensor(np.arange(100))
node_time['author'] = torch.tensor(np.arange(300))
else:
node_time = None

batch_size = 20
loader = NeighborLoader(data, num_neighbors=[10] * 2, input_nodes='paper',
batch_size=batch_size, directed=directed)
batch_size=batch_size, directed=directed,
node_time=node_time)
assert str(loader) == 'NeighborLoader()'
assert len(loader) == (100 + batch_size - 1) // batch_size

Expand Down
47 changes: 33 additions & 14 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch_geometric.loader.utils import (edge_type_to_str, filter_data,
filter_hetero_data, to_csc,
to_hetero_csc)
from torch_geometric.typing import EdgeType, InputNodes
from torch_geometric.typing import EdgeType, InputNodes, NodeType

NumNeighbors = Union[List[int], Dict[EdgeType, List[int]]]

Expand All @@ -22,11 +22,13 @@ def __init__(
replace: bool = False,
directed: bool = True,
input_node_type: Optional[str] = None,
node_time: Optional[Dict[NodeType, Tensor]] = None,
):
self.data_cls = data.__class__
self.num_neighbors = num_neighbors
self.replace = replace
self.directed = directed
self.node_time = node_time

if isinstance(data, Data):
# Convert the graph data into a suitable format for sampling.
Expand Down Expand Up @@ -75,18 +77,34 @@ def __call__(self, indices: List[int]):
return node, row, col, edge, index.numel()

elif issubclass(self.data_cls, HeteroData):
sample_fn = torch.ops.torch_sparse.hetero_neighbor_sample
node_dict, row_dict, col_dict, edge_dict = sample_fn(
self.node_types,
self.edge_types,
self.colptr_dict,
self.row_dict,
{self.input_node_type: index},
self.num_neighbors,
self.num_hops,
self.replace,
self.directed,
)
if self.node_time is None:
sample_fn = torch.ops.torch_sparse.hetero_neighbor_sample
node_dict, row_dict, col_dict, edge_dict = sample_fn(
self.node_types,
self.edge_types,
self.colptr_dict,
self.row_dict,
{self.input_node_type: index},
self.num_neighbors,
self.num_hops,
self.replace,
self.directed,
)
else:
sample_fn = \
torch.ops.torch_sparse.hetero_neighbor_temporal_sample
node_dict, row_dict, col_dict, edge_dict = sample_fn(
self.node_types,
self.edge_types,
self.colptr_dict,
self.row_dict,
{self.input_node_type: index},
self.num_neighbors,
self.node_time,
self.num_hops,
self.replace,
self.directed,
)
return node_dict, row_dict, col_dict, edge_dict, index.numel()


Expand Down Expand Up @@ -207,6 +225,7 @@ def __init__(
directed: bool = True,
transform: Callable = None,
neighbor_sampler: Optional[NeighborSampler] = None,
node_time: Optional[Dict[NodeType, Tensor]] = None,
**kwargs,
):
if 'dataset' in kwargs:
Expand All @@ -227,7 +246,7 @@ def __init__(
input_node_type = get_input_node_type(input_nodes)
self.neighbor_sampler = NeighborSampler(data, num_neighbors,
replace, directed,
input_node_type)
input_node_type, node_time)

return super().__init__(get_input_node_indices(self.data, input_nodes),
collate_fn=self.neighbor_sampler, **kwargs)
Expand Down