Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Sparse tensor support: ClusterGCN, SAGEConv, AGNNConv, APPNP, and FeaStConv #6874

Merged
merged 10 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193))
- Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187))
- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874))
- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
Expand Down
8 changes: 5 additions & 3 deletions test/nn/conv/test_agnn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ def test_agnn_conv(requires_grad):
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

conv = AGNNConv(requires_grad=requires_grad)
assert str(conv) == 'AGNNConv()'
out = conv(x, edge_index)
assert out.size() == (4, 16)
assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor) -> Tensor'
Expand All @@ -26,4 +28,4 @@ def test_agnn_conv(requires_grad):

t = '(Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)
assert torch.allclose(jit(x, adj1.t()), out, atol=1e-6)
17 changes: 11 additions & 6 deletions test/nn/conv/test_appnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ def test_appnp():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

conv = APPNP(K=3, alpha=0.1, cached=True)
assert str(conv) == 'APPNP(K=3, alpha=0.1)'
out = conv(x, edge_index)
assert out.size() == (4, 16)
assert torch.allclose(conv(x, adj.t()), out)
assert torch.allclose(conv(x, adj1.t()), out)
assert torch.allclose(conv(x, adj2.t()), out)

# Run again to test the cached functionality:
assert conv._cached_edge_index is not None
assert conv._cached_adj_t is not None
assert torch.allclose(conv(x, edge_index), conv(x, adj.t()))
assert torch.allclose(conv(x, edge_index), conv(x, adj1.t()))
assert torch.allclose(conv(x, edge_index), conv(x, adj2.t()))

conv.reset_parameters()
assert conv._cached_edge_index is None
Expand All @@ -33,16 +36,18 @@ def test_appnp():

t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj.t()), out)
assert torch.allclose(jit(x, adj1.t()), out)


def test_appnp_dropout():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

# With dropout probability of 1.0, the final output equals to alpha * x:
conv = APPNP(K=2, alpha=0.1, dropout=1.0)
assert torch.allclose(0.1 * x, conv(x, edge_index))
assert torch.allclose(0.1 * x, conv(x, adj.t()))
assert torch.allclose(0.1 * x, conv(x, adj1.t()))
assert torch.allclose(0.1 * x, conv(x, adj2.t()))
8 changes: 5 additions & 3 deletions test/nn/conv/test_cluster_gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ def test_cluster_gcn_conv():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

conv = ClusterGCNConv(16, 32, diag_lambda=1.)
assert str(conv) == 'ClusterGCNConv(16, 32, diag_lambda=1.0)'
out = conv(x, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x, adj.t()), out, atol=1e-5)
assert torch.allclose(conv(x, adj1.t()), out, atol=1e-5)
assert torch.allclose(conv(x, adj2.t()), out, atol=1e-5)

if is_full_test():
t = '(Tensor, Tensor) -> Tensor'
Expand All @@ -24,4 +26,4 @@ def test_cluster_gcn_conv():

t = '(Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj.t()), out, atol=1e-5)
assert torch.allclose(jit(x, adj1.t()), out, atol=1e-5)
16 changes: 10 additions & 6 deletions test/nn/conv/test_feast_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ def test_feast_conv():
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

conv = FeaStConv(16, 32, heads=2)
assert str(conv) == 'FeaStConv(16, 32, heads=2)'

out = conv(x1, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor) -> Tensor'
Expand All @@ -26,12 +28,14 @@ def test_feast_conv():

t = '(Tensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)
assert torch.allclose(jit(x1, adj1.t()), out, atol=1e-6)

adj = adj.sparse_resize((4, 2))
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_coo_tensor()
out = conv((x1, x2), edge_index)
assert out.size() == (2, 32)
assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(PairTensor, Tensor) -> Tensor'
Expand All @@ -40,4 +44,4 @@ def test_feast_conv():

t = '(PairTensor, SparseTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6)
assert torch.allclose(jit((x1, x2), adj1.t()), out, atol=1e-6)
28 changes: 13 additions & 15 deletions test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@ def test_sage_conv(project, aggr):
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj.to_torch_sparse_coo_tensor()
adj1 = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
adj2 = adj1.to_torch_sparse_coo_tensor()

conv = SAGEConv(8, 32, project=project, aggr=aggr)
assert str(conv) == f'SAGEConv(8, 32, aggr={aggr})'
out = conv(x1, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out, atol=1e-6)
assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)
if aggr == 'sum':
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, Size) -> Tensor'
Expand All @@ -33,22 +32,21 @@ def test_sage_conv(project, aggr):

t = '(Tensor, SparseTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)
assert torch.allclose(jit(x1, adj1.t()), out, atol=1e-6)

adj = adj.sparse_resize((4, 2))
adj2 = adj.to_torch_sparse_coo_tensor()
adj1 = adj1.sparse_resize((4, 2))
adj2 = adj1.to_torch_sparse_coo_tensor()
conv = SAGEConv((8, 16), 32, project=project, aggr=aggr)
assert str(conv) == f'SAGEConv((8, 16), 32, aggr={aggr})'
out1 = conv((x1, x2), edge_index)
out2 = conv((x1, None), edge_index, (4, 2))
assert out1.size() == (2, 32)
assert out2.size() == (2, 32)
assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2, atol=1e-6)
if aggr == 'sum':
assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6)
assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)

if is_full_test():
t = '(OptPairTensor, Tensor, Size) -> Tensor'
Expand All @@ -61,8 +59,8 @@ def test_sage_conv(project, aggr):

t = '(OptPairTensor, SparseTensor, Size) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)
assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6)
assert torch.allclose(jit((x1, x2), adj1.t()), out1, atol=1e-6)
assert torch.allclose(jit((x1, None), adj1.t()), out2, atol=1e-6)


def test_lstm_aggr_sage_conv():
Expand Down
18 changes: 15 additions & 3 deletions torch_geometric/nn/conv/appnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor
from torch_geometric.utils import spmm
from torch_geometric.utils import (
is_torch_sparse_tensor,
spmm,
to_edge_index,
to_torch_coo_tensor,
)


class APPNP(MessagePassing):
Expand Down Expand Up @@ -107,8 +112,15 @@ def forward(self, x: Tensor, edge_index: Adj,
for k in range(self.K):
if self.dropout > 0 and self.training:
if isinstance(edge_index, Tensor):
assert edge_weight is not None
edge_weight = F.dropout(edge_weight, p=self.dropout)
if is_torch_sparse_tensor(edge_index):
edge_index, edge_weight = to_edge_index(edge_index)
edge_weight = F.dropout(edge_weight, p=self.dropout)
edge_index = to_torch_coo_tensor(
edge_index, edge_weight,
size=x.size(self.node_dim))
else:
assert edge_weight is not None
edge_weight = F.dropout(edge_weight, p=self.dropout)
else:
value = edge_index.storage.value()
assert value is not None
Expand Down
11 changes: 11 additions & 0 deletions torch_geometric/nn/conv/cluster_gcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from torch_geometric.utils import (
add_self_loops,
degree,
is_torch_sparse_tensor,
remove_self_loops,
spmm,
to_edge_index,
to_torch_coo_tensor,
)


Expand Down Expand Up @@ -69,6 +72,10 @@ def reset_parameters(self):
def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
edge_weight: OptTensor = None
if isinstance(edge_index, Tensor):
is_sparse_tensor = is_torch_sparse_tensor(edge_index)
if is_sparse_tensor:
edge_index, _ = to_edge_index(edge_index.t())

num_nodes = x.size(self.node_dim)
if self.add_self_loops:
edge_index, _ = remove_self_loops(edge_index)
Expand All @@ -80,6 +87,10 @@ def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
edge_weight = deg_inv[col]
edge_weight[row == col] += self.diag_lambda * deg_inv

if is_sparse_tensor:
edge_index = to_torch_coo_tensor(edge_index.flip(0),
edge_weight, size=num_nodes)

elif isinstance(edge_index, SparseTensor):
if self.add_self_loops:
edge_index = torch_sparse.set_diag(edge_index)
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/utils/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def remove_self_loops(
is_sparse = is_torch_sparse_tensor(edge_index)
if is_sparse:
assert edge_attr is None
size = edge_index.size()
size = (edge_index.size(0), edge_index.size(1))
edge_index, edge_attr = to_edge_index(edge_index)

mask = edge_index[0] != edge_index[1]
Expand Down Expand Up @@ -201,6 +201,7 @@ def add_self_loops(
is_sparse = is_torch_sparse_tensor(edge_index)
if is_sparse:
assert edge_attr is None
size = (edge_index.size(0), edge_index.size(1))
edge_index, edge_attr = to_edge_index(edge_index)

loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device)
Expand Down Expand Up @@ -230,7 +231,7 @@ def add_self_loops(

edge_index = torch.cat([edge_index, loop_index], dim=1)
if is_sparse:
return to_torch_coo_tensor(edge_index, edge_attr), None
return to_torch_coo_tensor(edge_index, edge_attr, size=size), None
return edge_index, edge_attr


Expand Down