Skip to content

Commit

Permalink
Fix LinkNeighborLoader in case src_node_type = dst_node_type (#4439)
Browse files Browse the repository at this point in the history
* fix hetro sampler when start = end node type

* merge

* remove edge_attr

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
Padarn and rusty1s authored Apr 9, 2022
1 parent 3f0019f commit 478db99
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 17 deletions.
36 changes: 31 additions & 5 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ def test_homogeneous_link_neighbor_loader(directed):
assert batch.edge_attr.min() >= 0
assert batch.edge_attr.max() < 500

# Assert positive samples were present in the original graph:
# Assert positive samples are present in the original graph:
edge_index = unique_edge_pairs(batch.edge_index)
edge_label_index = batch.edge_label_index[:, batch.edge_label == 1]
edge_label_index = unique_edge_pairs(edge_label_index)
assert len(edge_index | edge_label_index) == len(edge_index)

# Assert negative samples were not present in the original graph:
# Assert negative samples are not present in the original graph:
edge_index = unique_edge_pairs(batch.edge_index)
edge_label_index = batch.edge_label_index[:, batch.edge_label == 0]
edge_label_index = unique_edge_pairs(edge_label_index)
Expand Down Expand Up @@ -89,12 +89,38 @@ def test_heterogeneous_link_neighbor_loader(directed):

for batch in loader:
assert isinstance(batch, HeteroData)
print(batch)

assert len(batch) == 4

# Assert positive samples were present in the original graph:
# Assert positive samples are present in the original graph:
edge_index = unique_edge_pairs(batch['paper', 'author'].edge_index)
edge_label_index = batch['paper', 'author'].edge_label_index
edge_label_index = unique_edge_pairs(edge_label_index)
assert len(edge_index | edge_label_index) == len(edge_index)


@pytest.mark.parametrize('directed', [True, False])
def test_heterogeneous_link_neighbor_loader_loop(directed):
torch.manual_seed(12345)

data = HeteroData()

data['paper'].x = torch.arange(100)
data['author'].x = torch.arange(100, 300)

data['paper', 'paper'].edge_index = get_edge_index(100, 100, 500)
data['paper', 'author'].edge_index = get_edge_index(100, 200, 1000)
data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000)

loader = LinkNeighborLoader(data, num_neighbors=[-1] * 2,
edge_label_index=('paper', 'to', 'paper'),
batch_size=20, directed=directed)

for batch in loader:
assert batch['paper'].x.size(0) <= 100
assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100

# Assert positive samples are present in the original graph:
edge_index = unique_edge_pairs(batch['paper', 'paper'].edge_index)
edge_label_index = batch['paper', 'paper'].edge_label_index
edge_label_index = unique_edge_pairs(edge_label_index)
assert len(edge_index | edge_label_index) == len(edge_index)
33 changes: 21 additions & 12 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __call__(self, query: List[Tuple[Tensor]]):

query_nodes = edge_label_index.view(-1)
query_nodes, reverse = query_nodes.unique(return_inverse=True)
edge_label_index = reverse.view(2, -1)

node, row, col, edge = sample_fn(
self.colptr,
Expand All @@ -35,33 +36,41 @@ def __call__(self, query: List[Tuple[Tensor]]):
self.directed,
)

return node, row, col, edge, reverse.view(2, -1), edge_label
return node, row, col, edge, edge_label_index, edge_label

elif issubclass(self.data_cls, HeteroData):
sample_fn = torch.ops.torch_sparse.hetero_neighbor_sample

query_src = edge_label_index[0]
query_src, reverse_src = query_src.unique(return_inverse=True)

query_dst = edge_label_index[1]
query_dst, reverse_dst = query_dst.unique(return_inverse=True)
if self.input_type[0] != self.input_type[-1]:
query_src = edge_label_index[0]
query_src, reverse_src = query_src.unique(return_inverse=True)
query_dst = edge_label_index[1]
query_dst, reverse_dst = query_dst.unique(return_inverse=True)
edge_label_index = torch.stack([reverse_src, reverse_dst], 0)
query_node_dict = {
self.input_type[0]: query_src,
self.input_type[-1]: query_dst,
}
else: # Merge both source and destination node indices:
query_nodes = edge_label_index.view(-1)
query_nodes, reverse = query_nodes.unique(return_inverse=True)
edge_label_index = reverse.view(2, -1)
query_node_dict = {self.input_type[0]: query_nodes}

node_dict, row_dict, col_dict, edge_dict = sample_fn(
self.node_types,
self.edge_types,
self.colptr_dict,
self.row_dict,
{
self.input_type[0]: query_src,
self.input_type[-1]: query_dst,
},
query_node_dict,
self.num_neighbors,
self.num_hops,
self.replace,
self.directed,
)
return (node_dict, row_dict, col_dict, edge_dict,
torch.stack([reverse_src, reverse_dst], dim=0), edge_label)

return (node_dict, row_dict, col_dict, edge_dict, edge_label_index,
edge_label)


class LinkNeighborLoader(torch.utils.data.DataLoader):
Expand Down

0 comments on commit 478db99

Please sign in to comment.