Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jan 21, 2023
1 parent 37e1372 commit a18c401
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 233 deletions.
107 changes: 47 additions & 60 deletions test/nn/test_to_hetero_module.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,59 @@
import pytest # noqa
import torch

from torch_geometric.data import HeteroData
from torch_geometric.nn.conv import GCNConv
from torch_geometric.nn.conv import SAGEConv
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.to_hetero_module import ToHeteroModule
from torch_geometric.nn.to_hetero_module import (
ToHeteroLinear,
ToHeteroMessagePassing,
)

heterodata = HeteroData()
heterodata['v0'].x = torch.randn(20, 10)
heterodata['v1'].x = torch.randn(20, 10)
heterodata[('v0', 'r0',
'v0')].edge_index = torch.randint(high=20,
size=(2, 30)).to(torch.long)
heterodata[('v0', 'r2',
'v1')].edge_index = torch.randint(high=20,
size=(2, 30)).to(torch.long)
heterodata[('v1', 'r3',
'v0')].edge_index = torch.randint(high=20,
size=(2, 30)).to(torch.long)
heterodata[('v1', 'r4', 'v1')] = torch.randn(2, 50)

@pytest.mark.parametrize('LinearCls', [torch.nn.Linear, Linear])
def test_to_hetero_linear(LinearCls):
x_dict = {'1': torch.randn(5, 16), '2': torch.randn(4, 16)}
x = torch.cat([x_dict['1'], x_dict['2']], dim=0)
type_vec = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1])

def test_to_hetero_linear():
lin = Linear(10, 5)
heterolin = ToHeteroModule(lin, heterodata.metadata())
# test dict input
x_dict = heterodata.collect('x')
out = heterolin(x_dict)
assert out['v0'].shape == (20, 5)
assert out['v1'].shape == (20, 5)
# test fused input
x = torch.cat([x_j for x_j in x_dict.values()])
node_type = torch.cat([(j * torch.ones(x_j.shape[0])).long()
for j, x_j in enumerate(x_dict.values())])
out = heterolin(x=x, node_type=node_type)
assert out.shape == (40, 5)
module = ToHeteroLinear(LinearCls(16, 32), list(x_dict.keys()))

out_dict = module(x_dict)
assert len(out_dict) == 2
assert out_dict['1'].size() == (5, 32)
assert out_dict['2'].size() == (4, 32)

def test_to_hetero_gcn():
gcnconv = GCNConv(10, 5)
rgcnconv = ToHeteroModule(gcnconv, heterodata.metadata())
# test dict input
x_dict = heterodata.collect('x')
e_idx_dict = heterodata.collect('edge_index')
out = rgcnconv(x_dict, edge_index=e_idx_dict)
assert out['v0'].shape == (20, 5)
assert out['v1'].shape == (20, 5)
out = module(x, type_vec)
assert out.size() == (9, 32)

x = torch.cat(list(x_dict.values()), dim=0)
assert torch.allclose(out_dict['1'], out[0:5])
assert torch.allclose(out_dict['2'], out[5:9])

num_node_dict = heterodata.collect('num_nodes')
increment_dict = {}
ctr = 0
for node_type in num_node_dict:
increment_dict[node_type] = ctr
ctr += num_node_dict[node_type]

etypes_list = []
for i, e_type in enumerate(e_idx_dict.keys()):
src_type, dst_type = e_type[0], e_type[-1]
if torch.numel(e_idx_dict[e_type]) != 0:
e_idx_dict[e_type][
0, :] = e_idx_dict[e_type][0, :] + increment_dict[src_type]
e_idx_dict[e_type][
1, :] = e_idx_dict[e_type][1, :] + increment_dict[dst_type]
etypes_list.append(torch.ones(e_idx_dict[e_type].shape[-1]) * i)
edge_type = torch.cat(etypes_list).to(torch.long)
edge_index = torch.cat(list(e_idx_dict.values()), dim=1)
# test fused input
out = rgcnconv(x, edge_index=edge_index, edge_type=edge_type)
assert out.shape == (40, 5)
def test_to_hetero_message_passing():
x_dict = {'1': torch.randn(5, 16), '2': torch.randn(4, 16)}
x = torch.cat([x_dict['1'], x_dict['2']], dim=0)
node_type = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1])

edge_index_dict = {
('1', 'to', '2'): torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 2, 3]]),
('2', 'to', '1'): torch.tensor([[0, 0, 1, 2, 3], [0, 1, 2, 3, 4]]),
}
edge_index = torch.tensor([
[0, 1, 2, 3, 4, 5, 5, 6, 7, 8],
[5, 5, 6, 7, 8, 0, 1, 2, 3, 4],
])
edge_type = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

module = ToHeteroMessagePassing(SAGEConv(16, 32), list(x_dict.keys()),
list(edge_index_dict.keys()))

out_dict = module(x_dict, edge_index_dict)
assert len(out_dict) == 2
assert out_dict['1'].size() == (5, 32)
assert out_dict['2'].size() == (4, 32)

out = module(x, edge_index, node_type, edge_type)
assert out.size() == (9, 32)

assert torch.allclose(out_dict['1'], out[0:5])
assert torch.allclose(out_dict['2'], out[5:9])
Loading

0 comments on commit a18c401

Please sign in to comment.