Skip to content

Commit

Permalink
DMoNPooling: Doc and clean-up (#4242)
Browse files Browse the repository at this point in the history
* update

* update

* update

* type hints

* update

* typo
  • Loading branch information
rusty1s authored Mar 12, 2022
1 parent 592ca50 commit f5d7ca6
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 62 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ It is commonly applied to graph-level tasks, which require combining node featur
* **[GlobalAttention](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.GlobalAttention)** from Li *et al.*: [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/global_attention.py)]
* **[Set2Set](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.Set2Set)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)]
* **[Sort Pool](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.global_sort_pool)** from Zhang *et al.*: [An End-to-End Deep Learning Architecture for Graph Classification](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf) (AAAI 2018) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sort_pool.py)]
* **[Dense MinCUT Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.mincut_pool.dense_mincut_pool)** from Bianchi *et al.*: [MinCUT Pooling in Graph Neural Networks](https://arxiv.org/abs/1907.00481) (CoRR 2019) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_mincut_pool.py)]
* **[MinCUT Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.mincut_pool.dense_mincut_pool)** from Bianchi *et al.*: [MinCUT Pooling in Graph Neural Networks](https://arxiv.org/abs/1907.00481) (CoRR 2019) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_mincut_pool.py)]
* **[DMoN Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.dmon_pool.DMoNPooling)** from Tsitsulin *et al.*: [Graph Clustering with Graph Neural Networks](https://arxiv.org/abs/2006.16904) (CoRR 2020) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_dmon_pool.py)]
* **[Graclus Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.graclus)** from Dhillon *et al.*: [Weighted Graph Cuts without Eigenvectors: A Multilevel Approach](http://www.cs.utexas.edu/users/inderjit/public_papers/multilevel_pami.pdf) (PAMI 2007) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mnist_graclus.py)]
* **[Voxel Grid Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.voxel_grid)** from, *e.g.*, Simonovsky and Komodakis: [Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) (CVPR 2017) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mnist_voxel_grid.py)]
* **[SAG Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.SAGPooling)** from Lee *et al.*: [Self-Attention Graph Pooling](https://arxiv.org/abs/1904.08082) (ICML 2019) and Knyazev *et al.*: [Understanding Attention and Generalization in Graph Neural Networks](https://arxiv.org/abs/1905.02850) (ICLR-W 2019) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sag_pool.py)]
Expand Down
14 changes: 7 additions & 7 deletions examples/proteins_diff_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def forward(self, x, adj, mask=None):
batch_size, num_nodes, in_channels = x.size()

x0 = x
x1 = self.bn(1, F.relu(self.conv1(x0, adj, mask)))
x2 = self.bn(2, F.relu(self.conv2(x1, adj, mask)))
x3 = self.bn(3, F.relu(self.conv3(x2, adj, mask)))
x1 = self.bn(1, self.conv1(x0, adj, mask).relu())
x2 = self.bn(2, self.conv2(x1, adj, mask).relu())
x3 = self.bn(3, self.conv3(x2, adj, mask).relu())

x = torch.cat([x1, x2, x3], dim=-1)

if self.lin is not None:
x = F.relu(self.lin(x))
x = self.lin(x).relu()

return x

Expand Down Expand Up @@ -104,7 +104,7 @@ def forward(self, x, adj, mask=None):
x = self.gnn3_embed(x, adj)

x = x.mean(dim=1)
x = F.relu(self.lin1(x))
x = self.lin1(x).relu()
x = self.lin2(x)
return F.log_softmax(x, dim=-1), l1 + l2, e1 + e2

Expand All @@ -124,7 +124,7 @@ def train(epoch):
output, _, _ = model(data.x, data.adj, data.mask)
loss = F.nll_loss(output, data.y.view(-1))
loss.backward()
loss_all += data.y.size(0) * loss.item()
loss_all += data.y.size(0) * float(loss)
optimizer.step()
return loss_all / len(train_dataset)

Expand All @@ -137,7 +137,7 @@ def test(loader):
for data in loader:
data = data.to(device)
pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1]
correct += pred.eq(data.y.view(-1)).sum().item()
correct += int(pred.eq(data.y.view(-1)).sum())
return correct / len(loader.dataset)


Expand Down
24 changes: 14 additions & 10 deletions examples/proteins_dmon_pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os.path as osp
from math import ceil

import torch
Expand All @@ -6,10 +7,11 @@

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import DenseGraphConv, DMonPooling, GCNConv
from torch_geometric.nn import DenseGraphConv, DMoNPooling, GCNConv
from torch_geometric.utils import to_dense_adj, to_dense_batch

dataset = TUDataset(root='/tmp/PROTEINS', name='PROTEINS').shuffle()
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS')
dataset = TUDataset(path, name='PROTEINS').shuffle()
avg_num_nodes = int(dataset.data.x.size(0) / len(dataset))
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
Expand All @@ -26,25 +28,27 @@ def __init__(self, in_channels, out_channels, hidden_channels=32):

self.conv1 = GCNConv(in_channels, hidden_channels)
num_nodes = ceil(0.5 * avg_num_nodes)
self.pool1 = DMonPooling([hidden_channels, hidden_channels], num_nodes)
self.pool1 = DMoNPooling([hidden_channels, hidden_channels], num_nodes)

self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)
num_nodes = ceil(0.5 * num_nodes)
self.pool2 = DMonPooling([hidden_channels, hidden_channels], num_nodes)
self.pool2 = DMoNPooling([hidden_channels, hidden_channels], num_nodes)

self.conv3 = DenseGraphConv(hidden_channels, hidden_channels)

self.lin1 = Linear(hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, out_channels)

def forward(self, x, edge_index, batch=None):
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index).relu()

x, mask = to_dense_batch(x, batch)
adj = to_dense_adj(edge_index, batch)

_, x, adj, sp1, o1, c1 = self.pool1(x, adj, mask)

x = self.conv2(x, adj).relu()

_, x, adj, sp2, o2, c2 = self.pool2(x, adj)

x = self.conv3(x, adj)
Expand All @@ -57,7 +61,7 @@ def forward(self, x, edge_index, batch=None):

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(train_loader):
Expand All @@ -70,7 +74,7 @@ def train(train_loader):
out, tot_loss = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(out, data.y.view(-1)) + tot_loss
loss.backward()
loss_all += data.y.size(0) * loss.item()
loss_all += data.y.size(0) * float(loss)
optimizer.step()
return loss_all / len(train_dataset)

Expand All @@ -85,13 +89,13 @@ def test(loader):
data = data.to(device)
pred, tot_loss = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(pred, data.y.view(-1)) + tot_loss
loss_all += data.y.size(0) * loss.item()
correct += pred.max(dim=1)[1].eq(data.y.view(-1)).sum().item()
loss_all += data.y.size(0) * float(loss)
correct += int(pred.max(dim=1)[1].eq(data.y.view(-1)).sum())

return loss_all / len(loader.dataset), correct / len(loader.dataset)


for epoch in range(100):
for epoch in range(1, 101):
train_loss = train(train_loader)
_, train_acc = test(train_loader)
val_loss, val_acc = test(val_loader)
Expand Down
12 changes: 6 additions & 6 deletions examples/proteins_mincut_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,23 @@ def __init__(self, in_channels, out_channels, hidden_channels=32):
self.lin2 = Linear(hidden_channels, out_channels)

def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = self.conv1(x, edge_index).relu()

x, mask = to_dense_batch(x, batch)
adj = to_dense_adj(edge_index, batch)

s = self.pool1(x)
x, adj, mc1, o1 = dense_mincut_pool(x, adj, s, mask)

x = F.relu(self.conv2(x, adj))
x = self.conv2(x, adj).relu()
s = self.pool2(x)

x, adj, mc2, o2 = dense_mincut_pool(x, adj, s)

x = self.conv3(x, adj)

x = x.mean(dim=1)
x = F.relu(self.lin1(x))
x = self.lin1(x).relu()
x = self.lin2(x)
return F.log_softmax(x, dim=-1), mc1 + mc2, o1 + o2

Expand All @@ -76,7 +76,7 @@ def train(epoch):
out, mc_loss, o_loss = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(out, data.y.view(-1)) + mc_loss + o_loss
loss.backward()
loss_all += data.y.size(0) * loss.item()
loss_all += data.y.size(0) * float(loss)
optimizer.step()
return loss_all / len(train_dataset)

Expand All @@ -91,8 +91,8 @@ def test(loader):
data = data.to(device)
pred, mc_loss, o_loss = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(pred, data.y.view(-1)) + mc_loss + o_loss
loss_all += data.y.size(0) * loss.item()
correct += pred.max(dim=1)[1].eq(data.y.view(-1)).sum().item()
loss_all += data.y.size(0) * float(loss)
correct += int(pred.max(dim=1)[1].eq(data.y.view(-1)).sum())

return loss_all / len(loader.dataset), correct / len(loader.dataset)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import torch

from torch_geometric.nn import DMonPooling
from torch_geometric.nn import DMoNPooling


def test_dense_dmon_pool():
def test_dmon_pooling():
batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10)
x = torch.randn((batch_size, num_nodes, channels))
adj = torch.ones((batch_size, num_nodes, num_nodes))
mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool)

pool = DMonPooling([channels, channels], num_clusters)
pool = DMoNPooling([channels, channels], num_clusters)
assert str(pool) == 'DMoNPooling(16, num_clusters=10)'

s, x, adj, spectral_loss, ortho_loss, cluster_loss = pool(x, adj, mask)
assert s.size() == (2, 20, 10)
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .dense_gin_conv import DenseGINConv
from .diff_pool import dense_diff_pool
from .mincut_pool import dense_mincut_pool
from .dense_dmon_pool import DMonPooling
from .dmon_pool import DMoNPooling

__all__ = [
'Linear',
Expand All @@ -16,7 +16,7 @@
'DenseSAGEConv',
'dense_diff_pool',
'dense_mincut_pool',
'DMonPooling',
'DMoNPooling',
]

lin_classes = __all__[:2]
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/nn/dense/diff_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def dense_diff_pool(x, adj, s, mask=None):
r"""Differentiable pooling operator from the `"Hierarchical Graph
r"""The differentiable pooling operator from the `"Hierarchical Graph
Representation Learning with Differentiable Pooling"
<https://arxiv.org/abs/1806.08804>`_ paper
Expand All @@ -17,15 +17,15 @@ def dense_diff_pool(x, adj, s, mask=None):
based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B
\times N \times C}`.
Returns pooled node feature matrix, coarsened adjacency matrix and two
auxiliary objectives: (1) The link prediction loss
Returns the pooled node feature matrix, the coarsened adjacency matrix and
two auxiliary objectives: (1) The link prediction loss
.. math::
\mathcal{L}_{LP} = {\| \mathbf{A} -
\mathrm{softmax}(\mathbf{S}) {\mathrm{softmax}(\mathbf{S})}^{\top}
\|}_F,
and the entropy regularization
and (2) the entropy regularization
.. math::
\mathcal{L}_E = \frac{1}{N} \sum_{n=1}^N H(\mathbf{S}_n).
Expand Down
Loading

0 comments on commit f5d7ca6

Please sign in to comment.