Skip to content

Commit

Permalink
Rewrite tests to not depend on currently broken dblp dataset (#5250)
Browse files Browse the repository at this point in the history
* Rewrite tests to not depend on currently broken dblp dataset

* update changelog

* Update test/data/test_lightning_datamodule.py

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* Update CHANGELOG.md

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* add skip for broken test

* add mark in skip

* add mark in skip

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
Padarn and rusty1s authored Aug 20, 2022
1 parent 39c7e88 commit 3d5f855
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))
- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))
### Changed
- Changed tests relying on `dblp` datasets to instead use synthetic data. ([#5250](https://github.com/pyg-team/pytorch_geometric/pull/5250))
### Removed

## [2.1.0] - 2022-08-17
Expand Down
23 changes: 19 additions & 4 deletions test/data/test_lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F

from torch_geometric.data import (
HeteroData,
LightningDataset,
LightningLinkData,
LightningNodeData,
Expand All @@ -18,6 +19,12 @@
LightningModule = torch.nn.Module


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.long)
return torch.stack([row, col], dim=0)


class LinearGraphModule(LightningModule):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
Expand Down Expand Up @@ -273,10 +280,18 @@ def test_lightning_hetero_node_data(get_dataset):
@withCUDA
@onlyFullTest
@withPackage('pytorch_lightning')
def test_lightning_hetero_link_data(get_dataset):
# TODO: Add more datasets.
dataset = get_dataset(name='DBLP')
data = dataset[0]
def test_lightning_hetero_link_data():
torch.manual_seed(12345)

data = HeteroData()

data['paper'].x = torch.arange(10)
data['author'].x = torch.arange(10)
data['term'].x = torch.arange(10)

data['paper', 'author'].edge_index = get_edge_index(10, 10, 10)
data['author', 'paper'].edge_index = get_edge_index(10, 10, 10)
data['paper', 'term'].edge_index = get_edge_index(10, 10, 10)

datamodule = LightningLinkData(
data,
Expand Down
5 changes: 4 additions & 1 deletion test/loader/test_hgt_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
import torch
from torch_sparse import SparseTensor

Expand Down Expand Up @@ -56,8 +57,9 @@ def test_hgt_loader():
for batch in loader:
assert isinstance(batch, HeteroData)

# Test node type selection:
# Test node and types:
assert set(batch.node_types) == {'paper', 'author'}
assert set(batch.edge_types) == set(data.edge_types)

assert len(batch['paper']) == 2
assert batch['paper'].x.size() == (40, ) # 20 + 4 * 5
Expand Down Expand Up @@ -177,6 +179,7 @@ def forward(self, x, edge_index, edge_weight):
assert torch.allclose(out1, out2, atol=1e-6)


@pytest.mark.skip("'dblp' dataset is broken")
@withPackage('torch_sparse>=0.6.15')
def test_hgt_loader_on_dblp(get_dataset):
data = get_dataset(name='dblp')[0]
Expand Down

0 comments on commit 3d5f855

Please sign in to comment.