Skip to content

Commit

Permalink
Link level NeighborLoader (#4396)
Browse files Browse the repository at this point in the history
* link level loader wip

* add target output

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update torch_geometric/utils/mask.py

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* Update torch_geometric/loader/link_neighbour_loader.py

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* link level loader wip

* add target output

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update torch_geometric/utils/mask.py

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* Update torch_geometric/loader/link_neighbour_loader.py

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* fix from review

* fix import

* undo change

* undo change

* handle different edge_input type

* simplify formatting

* fix pre-commit

* remove american spelling

* sort

* format

* format

* add support hetro data

* data test

* add test file

* update tests

* remove asserts

* fix doc

* update

* update

* typo

* fix doc

* fix test

* update

* update doc

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Apr 8, 2022
1 parent 2017d47 commit 9bcb946
Show file tree
Hide file tree
Showing 9 changed files with 414 additions and 48 deletions.
100 changes: 100 additions & 0 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest
import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import LinkNeighborLoader


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long)
return torch.stack([row, col], dim=0)


def unique_edge_pairs(edge_index):
return set(map(tuple, edge_index.t().tolist()))


@pytest.mark.parametrize('directed', [True, False])
def test_homogeneous_link_neighbor_loader(directed):
torch.manual_seed(12345)

pos_edge_index = get_edge_index(100, 50, 500)
neg_edge_index = get_edge_index(100, 50, 500)
neg_edge_index[1, :] += 50

edge_label_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
edge_label = torch.cat([torch.ones(500), torch.zeros(500)], dim=0)

data = Data()

data.edge_index = pos_edge_index
data.x = torch.arange(100)
data.edge_attr = torch.arange(500)

loader = LinkNeighborLoader(data, num_neighbors=[-1] * 2, batch_size=20,
edge_label_index=edge_label_index,
edge_label=edge_label, directed=directed,
shuffle=True)

assert str(loader) == 'LinkNeighborLoader()'
assert len(loader) == 1000 / 20

for batch in loader:
assert isinstance(batch, Data)

assert len(batch) == 5
assert batch.x.size(0) <= 100
assert batch.x.min() >= 0 and batch.x.max() < 100
assert batch.edge_index.min() >= 0
assert batch.edge_index.max() < batch.num_nodes
assert batch.edge_attr.min() >= 0
assert batch.edge_attr.max() < 500

# Assert positive samples were present in the original graph:
edge_index = unique_edge_pairs(batch.edge_index)
edge_label_index = batch.edge_label_index[:, batch.edge_label == 1]
edge_label_index = unique_edge_pairs(edge_label_index)
assert len(edge_index | edge_label_index) == len(edge_index)

# Assert negative samples were not present in the original graph:
edge_index = unique_edge_pairs(batch.edge_index)
edge_label_index = batch.edge_label_index[:, batch.edge_label == 0]
edge_label_index = unique_edge_pairs(edge_label_index)
assert len(edge_index & edge_label_index) == 0


@pytest.mark.parametrize('directed', [True, False])
def test_heterogeneous_link_neighbor_loader(directed):
torch.manual_seed(12345)

data = HeteroData()

data['paper'].x = torch.arange(100)
data['author'].x = torch.arange(100, 300)

data['paper', 'paper'].edge_index = get_edge_index(100, 100, 500)
data['paper', 'paper'].edge_attr = torch.arange(500)
data['paper', 'author'].edge_index = get_edge_index(100, 200, 1000)
data['paper', 'author'].edge_attr = torch.arange(500, 1500)
data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000)
data['author', 'paper'].edge_attr = torch.arange(1500, 2500)

loader = LinkNeighborLoader(data, num_neighbors=[-1] * 2,
edge_label_index=('paper', 'to', 'author'),
batch_size=20, directed=directed, shuffle=True)

assert str(loader) == 'LinkNeighborLoader()'
assert len(loader) == int(1000 / 20)

for batch in loader:
assert isinstance(batch, HeteroData)
print(batch)

assert len(batch) == 4

# Assert positive samples were present in the original graph:
edge_index = unique_edge_pairs(batch['paper', 'author'].edge_index)
edge_label_index = batch['paper', 'author'].edge_label_index
edge_label_index = unique_edge_pairs(edge_label_index)
assert len(edge_index | edge_label_index) == len(edge_index)
3 changes: 0 additions & 3 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ def test_homogeneous_neighbor_loader(directed):

assert is_subset(batch.edge_index, data.edge_index, batch.x, batch.x)

# Test for isolated nodes (there shouldn't exist any):
assert data.edge_index.view(-1).unique().numel() == data.num_nodes


@pytest.mark.parametrize('directed', [True, False])
def test_heterogeneous_neighbor_loader(directed):
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/data/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch_geometric.loader.neighbor_loader import (
NeighborLoader,
NeighborSampler,
get_input_node_type,
get_input_nodes,
)
from torch_geometric.typing import InputNodes

Expand Down Expand Up @@ -272,7 +272,7 @@ def __init__(
num_neighbors=kwargs.get('num_neighbors', None),
replace=kwargs.get('replace', False),
directed=kwargs.get('directed', True),
input_node_type=get_input_node_type(input_train_nodes),
input_type=get_input_nodes(data, input_train_nodes)[0],
)
self.input_train_nodes = input_train_nodes
self.input_val_nodes = input_val_nodes
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/loader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .dataloader import DataLoader
from .neighbor_loader import NeighborLoader
from .link_neighbor_loader import LinkNeighborLoader
from .hgt_loader import HGTLoader
from .cluster import ClusterData, ClusterLoader
from .graph_saint import (GraphSAINTSampler, GraphSAINTNodeSampler,
Expand All @@ -15,6 +16,7 @@
__all__ = classes = [
'DataLoader',
'NeighborLoader',
'LinkNeighborLoader',
'HGTLoader',
'ClusterData',
'ClusterLoader',
Expand Down
Loading

0 comments on commit 9bcb946

Please sign in to comment.