Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite tests to not depend on currently broken dblp dataset #5250

Merged
merged 8 commits into from
Aug 20, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Rewrite tests to not depend on currently broken dblp dataset
  • Loading branch information
Padarn committed Aug 20, 2022
commit 2034f85c77397dc3978803e2b357ebe18ce07422
21 changes: 18 additions & 3 deletions test/data/test_lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F

from torch_geometric.data import (
HeteroData,
LightningDataset,
LightningLinkData,
LightningNodeData,
Expand All @@ -18,6 +19,12 @@
LightningModule = torch.nn.Module


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long)
return torch.stack([row, col], dim=0)


class LinearGraphModule(LightningModule):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
Expand Down Expand Up @@ -274,9 +281,17 @@ def test_lightning_hetero_node_data(get_dataset):
@onlyFullTest
@withPackage('pytorch_lightning')
def test_lightning_hetero_link_data(get_dataset):
# TODO: Add more datasets.
dataset = get_dataset(name='DBLP')
data = dataset[0]
torch.manual_seed(12345)

data = HeteroData()

data['paper'].x = torch.arange(10)
data['author'].x = torch.arange(10)
data['term'].x = torch.arange(10)

data['paper', 'author'].edge_index = get_edge_index(10, 10, 10)
data['author', 'paper'].edge_index = get_edge_index(10, 10, 10)
data['paper', 'term'].edge_index = get_edge_index(10, 10, 10)

datamodule = LightningLinkData(
data,
Expand Down
14 changes: 2 additions & 12 deletions test/loader/test_hgt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch_geometric.data import HeteroData
from torch_geometric.loader import HGTLoader
from torch_geometric.nn import GraphConv, to_hetero
from torch_geometric.testing import withPackage
from torch_geometric.utils import k_hop_subgraph


Expand Down Expand Up @@ -56,8 +55,9 @@ def test_hgt_loader():
for batch in loader:
assert isinstance(batch, HeteroData)

# Test node type selection:
# Test node and types:
assert set(batch.node_types) == {'paper', 'author'}
assert set(batch.edge_types) == set(data.edge_types)

assert len(batch['paper']) == 2
assert batch['paper'].x.size() == (40, ) # 20 + 4 * 5
Expand Down Expand Up @@ -175,13 +175,3 @@ def forward(self, x, edge_index, edge_weight):
out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict,
hetero_batch.edge_weight_dict)['paper'][:batch_size]
assert torch.allclose(out1, out2, atol=1e-6)


@withPackage('torch_sparse>=0.6.15')
def test_hgt_loader_on_dblp(get_dataset):
data = get_dataset(name='dblp')[0]
loader = HGTLoader(data, num_samples=[10, 10],
input_nodes=('author', data['author'].train_mask))

for batch in loader:
assert set(batch.edge_types) == set(data.edge_types)