Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
EdisonLeeeee committed Jan 27, 2023
1 parent cb7002f commit 216fa11
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 37 deletions.
16 changes: 8 additions & 8 deletions test/nn/conv/test_gin_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def test_gin_conv():
'))')
out = conv(x1, edge_index)
assert out.size() == (4, 32)
assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
assert conv(x1, adj.t()).tolist() == out.tolist()
assert conv(x1, adj2.t()).tolist() == out.tolist()
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out, atol=1e-6)
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, Size) -> Tensor'
Expand All @@ -46,11 +46,11 @@ def test_gin_conv():
out2 = conv((x1, None), edge_index, (4, 2))
assert out1.size() == (2, 32)
assert out2.size() == (2, 32)
assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist()
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
assert conv((x1, x2), adj2.t()).tolist() == out1.tolist()
assert conv((x1, None), adj2.t()).tolist() == out2.tolist()
assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)

if is_full_test():
t = '(OptPairTensor, Tensor, Size) -> Tensor'
Expand Down
27 changes: 15 additions & 12 deletions test/nn/conv/test_graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ def test_graph_conv():
assert conv.__repr__() == 'GraphConv(8, 32)'
out11 = conv(x1, edge_index)
assert out11.size() == (4, 32)
assert conv(x1, edge_index, size=(4, 4)).tolist() == out11.tolist()
assert conv(x1, adj1.t()).tolist() == out11.tolist()
assert conv(x1, adj3.t()).tolist() == out11.tolist()
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out11, atol=1e-6)
assert torch.allclose(conv(x1, adj1.t()), out11, atol=1e-6)
assert torch.allclose(conv(x1, adj3.t()), out11, atol=1e-6)

out12 = conv(x1, edge_index, value)
assert out12.size() == (4, 32)
assert conv(x1, edge_index, value, size=(4, 4)).tolist() == out12.tolist()
assert conv(x1, adj2.t()).tolist() == out12.tolist()
assert conv(x1, adj4.t()).tolist() == out12.tolist()
assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out12,
atol=1e-6)
assert torch.allclose(conv(x1, adj2.t()), out12, atol=1e-6)
assert torch.allclose(conv(x1, adj4.t()), out12, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, OptTensor, Size) -> Tensor'
Expand Down Expand Up @@ -58,12 +59,14 @@ def test_graph_conv():
assert out22.size() == (2, 32)
assert out23.size() == (2, 32)
assert out24.size() == (2, 32)
assert conv((x1, x2), edge_index, size=(4, 2)).tolist() == out21.tolist()
assert conv((x1, x2), edge_index, value, (4, 2)).tolist() == out22.tolist()
assert conv((x1, x2), adj1.t()).tolist() == out21.tolist()
assert conv((x1, x2), adj2.t()).tolist() == out22.tolist()
assert conv((x1, x2), adj3.t()).tolist() == out21.tolist()
assert conv((x1, x2), adj4.t()).tolist() == out22.tolist()
assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out21,
atol=1e-6)
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out22,
atol=1e-6)
assert torch.allclose(conv((x1, x2), adj1.t()), out21, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj2.t()), out22, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj3.t()), out21, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj4.t()), out22, atol=1e-6)

if is_full_test():
t = '(OptPairTensor, Tensor, OptTensor, Size) -> Tensor'
Expand Down
35 changes: 18 additions & 17 deletions test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@ def test_sage_conv(project, aggr):
assert str(conv) == f'SAGEConv(8, 32, aggr={aggr})'
out = conv(x1, edge_index)
assert out.size() == (4, 32)
assert conv(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
assert conv(x1, adj.t()).tolist() == out.tolist()
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out, atol=1e-6)
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
if aggr == 'sum':
assert conv(x1, adj2.t()).tolist() == out.tolist()
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, edge_index).tolist() == out.tolist()
assert jit(x1, edge_index, size=(4, 4)).tolist() == out.tolist()
assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out, atol=1e-6)

t = '(Tensor, SparseTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit(x1, adj.t()).tolist() == out.tolist()
assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)

adj = adj.sparse_resize((4, 2))
adj2 = adj.to_torch_sparse_coo_tensor()
Expand All @@ -43,25 +43,26 @@ def test_sage_conv(project, aggr):
out2 = conv((x1, None), edge_index, (4, 2))
assert out1.size() == (2, 32)
assert out2.size() == (2, 32)
assert conv((x1, x2), edge_index, (4, 2)).tolist() == out1.tolist()
assert conv((x1, x2), adj.t()).tolist() == out1.tolist()
assert conv((x1, None), adj.t()).tolist() == out2.tolist()
assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2, atol=1e-6)
if aggr == 'sum':
assert conv((x1, x2), adj2.t()).tolist() == out1.tolist()
assert conv((x1, None), adj2.t()).tolist() == out2.tolist()
assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)

if is_full_test():
t = '(OptPairTensor, Tensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), edge_index).tolist() == out1.tolist()
assert jit((x1, x2), edge_index, size=(4, 2)).tolist() == out1.tolist()
assert jit((x1, None), edge_index,
size=(4, 2)).tolist() == out2.tolist()
assert torch.allclose(jit((x1, x2), edge_index), out1, atol=1e-6)
assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1,
atol=1e-6)
assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2,
atol=1e-6)

t = '(OptPairTensor, SparseTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert jit((x1, x2), adj.t()).tolist() == out1.tolist()
assert jit((x1, None), adj.t()).tolist() == out2.tolist()
assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)
assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6)


def test_lstm_aggr_sage_conv():
Expand Down

0 comments on commit 216fa11

Please sign in to comment.