Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing Subgraph Methods #6613

Merged
merged 13 commits into from
Feb 7, 2023
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 7, 2023
commit 245b8b4b5b53c3bd4b3f5a6f62c10e656e77281b
29 changes: 13 additions & 16 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,25 +203,18 @@ def test_hetero_data_subgraph():
assert out.node_types == ['paper', 'author', 'conf']

assert len(out['paper']) == 3
assert torch.allclose(
out['paper'].x,
data['paper'].x[subset_sorted['paper']]
)
assert torch.allclose(out['paper'].x,
data['paper'].x[subset_sorted['paper']])
assert out['paper'].name == 'paper'
assert out['paper'].num_nodes == 4

assert len(out['author']) == 2
assert torch.allclose(
out['author'].x,
data['author'].x[subset_sorted['author']]
)
assert torch.allclose(out['author'].x,
data['author'].x[subset_sorted['author']])
assert out['author'].num_nodes == 2

assert len(out['conf']) == 2
assert torch.allclose(
out['conf'].x,
data['conf'].x[subset_sorted['conf']]
)
assert torch.allclose(out['conf'].x, data['conf'].x[subset_sorted['conf']])
assert out['conf'].num_nodes == 2

# construct correct edge index manually
Expand All @@ -242,7 +235,8 @@ def test_hetero_data_subgraph():

edge_mask_paper_paper = paper_mask[edge_index_paper_paper[0]]
edge_mask_paper_paper &= paper_mask[edge_index_paper_paper[1]]
sub_edge_index_paper_paper = edge_index_paper_paper[:, edge_mask_paper_paper]
sub_edge_index_paper_paper = edge_index_paper_paper[:,
edge_mask_paper_paper]
sub_edge_index_paper_paper[0] = paper_node_map[
sub_edge_index_paper_paper[0]]
sub_edge_index_paper_paper[1] = paper_node_map[
Expand All @@ -251,23 +245,26 @@ def test_hetero_data_subgraph():

edge_mask_paper_author = paper_mask[edge_index_paper_author[0]]
edge_mask_paper_author &= author_mask[edge_index_paper_author[1]]
sub_edge_index_paper_author = edge_index_paper_author[:, edge_mask_paper_author]
sub_edge_index_paper_author = edge_index_paper_author[:,
edge_mask_paper_author]
sub_edge_index_paper_author[0] = paper_node_map[
sub_edge_index_paper_author[0]]
sub_edge_index_paper_author[1] = author_node_map[
sub_edge_index_paper_author[1]]

edge_mask_author_paper = author_mask[edge_index_author_paper[0]]
edge_mask_author_paper &= paper_mask[edge_index_author_paper[1]]
sub_edge_index_author_paper = edge_index_author_paper[:, edge_mask_author_paper]
sub_edge_index_author_paper = edge_index_author_paper[:,
edge_mask_author_paper]
sub_edge_index_author_paper[0] = author_node_map[
sub_edge_index_author_paper[0]]
sub_edge_index_author_paper[1] = paper_node_map[
sub_edge_index_author_paper[1]]

edge_mask_paper_conf = paper_mask[edge_index_paper_conference[0]]
edge_mask_paper_conf &= conf_mask[edge_index_paper_conference[1]]
sub_edge_index_paper_conf = edge_index_paper_conference[:, edge_mask_paper_conf]
sub_edge_index_paper_conf = edge_index_paper_conference[:,
edge_mask_paper_conf]
sub_edge_index_paper_conf[0] = paper_node_map[sub_edge_index_paper_conf[0]]
sub_edge_index_paper_conf[1] = conf_node_map[sub_edge_index_paper_conf[1]]

Expand Down