Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ClusterPooling layer #9627

Merged
merged 13 commits into from
Sep 10, 2024
Prev Previous commit
Next Next commit
update
  • Loading branch information
rusty1s committed Sep 10, 2024
commit e0a87c345c8bf9161e070820c38499427c39ad6a
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `torch_geometric.nn.pool.cluster_pool` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))
- Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554))
Expand Down Expand Up @@ -40,7 +41,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983))
- Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952))
- Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))
- Added the `torch_geometric.nn.pool.cluster_pool` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))

### Changed

Expand Down
30 changes: 30 additions & 0 deletions test/nn/pool/test_cluster_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
import torch

from torch_geometric.nn import ClusterPooling


@pytest.mark.parametrize('edge_score_method', [
'tanh',
'sigmoid',
'log_softmax',
])
def test_cluster_pooling(edge_score_method):
x = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [-1.0]])
edge_index = torch.tensor([
[0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6],
[1, 2, 3, 6, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4, 0],
])
batch = torch.tensor([0, 0, 0, 0, 1, 1, 0])

op = ClusterPooling(in_channels=1, edge_score_method=edge_score_method)
assert str(op) == 'ClusterPooling(1)'
op.reset_parameters()

x, edge_index, batch, unpool_info = op(x, edge_index, batch)
assert x.size(0) <= 7
assert edge_index.size(0) == 2
if edge_index.numel() > 0:
assert edge_index.min() >= 0
assert edge_index.max() < x.size(0)
assert batch.size() == (x.size(0), )
12 changes: 7 additions & 5 deletions torch_geometric/nn/pool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@
import torch_geometric.typing
from torch_geometric.typing import OptTensor, torch_cluster

from .asap import ASAPooling
from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x
from .edge_pool import EdgePooling
from .glob import global_add_pool, global_max_pool, global_mean_pool
from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,
ApproxMIPSKNNIndex)
from .graclus import graclus
from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
from .mem_pool import MemPooling
from .pan_pool import PANPooling
from .sag_pool import SAGPooling
from .topk_pool import TopKPooling
from .sag_pool import SAGPooling
from .edge_pool import EdgePooling
from .cluster_pool import ClusterPooling
from .asap import ASAPooling
from .pan_pool import PANPooling
from .mem_pool import MemPooling
from .voxel_grid import voxel_grid
from .approx_knn import approx_knn, approx_knn_graph

Expand Down Expand Up @@ -344,6 +345,7 @@ def nearest(
'TopKPooling',
'SAGPooling',
'EdgePooling',
'ClusterPooling',
'ASAPooling',
'PANPooling',
'MemPooling',
Expand Down
219 changes: 83 additions & 136 deletions torch_geometric/nn/pool/cluster_pool.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,74 @@
from typing import Callable, NamedTuple, Optional, Tuple
from typing import NamedTuple, Optional, Tuple

import torch
import torch.nn.functional as F
from scipy.sparse.csgraph import connected_components
from torch import Tensor

from torch_geometric.utils import (
coalesce,
dense_to_sparse,
one_hot,
to_dense_adj,
to_scipy_sparse_matrix,
)


class UnpoolInfo(NamedTuple):
edge_index: Tensor
cluster: Tensor
batch: Tensor


class ClusterPooling(torch.nn.Module):
r"""The cluster pooling operator from the `"Edge-Based Graph Component
Pooling" <paper url>` paper.
Pooling" <paper url>`_ paper.

In short, a score is computed for each edge.
:class:`ClusterPooling` computes a score for each edge.
Based on the selected edges, graph clusters are calculated and compressed
to one node using an injective aggregation function (sum). Edges are
remapped based on the node created by each cluster and the original edges.
to one node using the injective :obj:`"sum" aggregation function.
Edges are remapped based on the nodes created by each cluster and the
original edges.

Args:
in_channels (int): Size of each input sample.
edge_score_method (function, optional): The function to apply
to compute the edge score from raw edge scores. By default,
this is the tanh over all incoming edges for each node.
This function takes in a :obj:`raw_edge_score` tensor of shape
:obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of
nodes :obj:`num_nodes`, and produces a new tensor of the same size
as :obj:`raw_edge_score` describing normalized edge scores.
Included functions are
:func:`ClusterPooling.compute_edge_score_tanh`,
:func:`ClusterPooling.compute_edge_score_sigmoid` and
:func:`ClusterPooling.compute_edge_score_logsoftmax`.
(default: :func:`ClusterPooling.compute_edge_score_tanh`)
edge_score_method (str, optional): The function to apply
to compute the edge score from raw edge scores (:obj:`"tanh"`,
"sigmoid", :obj:`"log_softmax"`). (default: :obj:`"tanh"`)
dropout (float, optional): The probability with
which to drop edge scores during training. (default: :obj:`0`)
which to drop edge scores during training. (default: :obj:`0.0`)
threshold (float, optional): The threshold of edge scores. If set to
:obj:`None`, will be automatically inferred depending on
:obj:`edge_score_method`. (default: :obj:`None`)
"""
unpool_description = NamedTuple("UnpoolDescription",
["edge_index", "batch", "cluster_map"])

def __init__(self, in_channels: int,
edge_score_method: Optional[Callable] = None,
dropout: Optional[float] = 0.0,
threshold: Optional[float] = None, directed: bool = False):
def __init__(
self,
in_channels: int,
edge_score_method: str = 'tanh',
dropout: float = 0.0,
threshold: Optional[float] = None,
):
super().__init__()
self.in_channels = in_channels
if edge_score_method is None:
edge_score_method = self.compute_edge_score_tanh
assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax']

if threshold is None:
if edge_score_method is self.compute_edge_score_sigmoid:
threshold = 0.5
else:
threshold = 0.0
self.compute_edge_score = edge_score_method
self.threshhold = threshold
threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0

self.in_channels = in_channels
self.edge_score_method = edge_score_method
self.dropout = dropout
self.directed = directed
self.lin = torch.nn.Linear(2 * in_channels, 1)
self.threshhold = threshold

self.reset_parameters()
self.lin = torch.nn.Linear(2 * in_channels, 1)

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
self.lin.reset_parameters()

@staticmethod
def compute_edge_score_tanh(raw_edge_score: Tensor):
r"""Normalizes edge scores via hyperbolic tangent application."""
return torch.tanh(raw_edge_score)

@staticmethod
def compute_edge_score_sigmoid(raw_edge_score: Tensor):
r"""Normalizes edge scores via sigmoid application."""
return torch.sigmoid(raw_edge_score)

@staticmethod
def compute_edge_score_logsoftmax(raw_edge_score: Tensor):
r"""Normalizes edge scores via logsoftmax application."""
return torch.nn.functional.log_softmax(raw_edge_score, dim=0)

def forward(
self,
x: Tensor,
edge_index: Tensor,
batch: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, NamedTuple]:
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
r"""Forward pass.

Args:
Expand All @@ -104,95 +85,61 @@ def forward(
* **unpool_info** *(unpool_description)* - Information that is
consumed by :func:`ClusterPooling.unpool` for unpooling.
"""
# First we drop the self edges as those cannot be clustered
msk = edge_index[0] != edge_index[1]
edge_index = edge_index[:, msk]
if not self.directed:
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1)
# We only evaluate each edge once, remove double edges from the list
edge_index = coalesce(edge_index)

e = torch.cat(
[x[edge_index[0]], x[edge_index[1]]],
dim=-1) # Concatenates source feature with target features
e = self.lin(e).view(
-1) # Apply linear NN on the node pairs (edges) and reshape
e = F.dropout(e, p=self.dropout, training=self.training)

e = self.compute_edge_score(e) # Non linear activation function
x, edge_index, batch, unpool_info = self.__merge_edges__(
x, edge_index, batch, e)

return x, edge_index, batch, unpool_info

def __merge_edges__(
self, X: Tensor, edge_index: Tensor, batch: Tensor,
edge_score: Tensor) -> Tuple[Tensor, Tensor, Tensor, NamedTuple]:
edges_contract = edge_index[..., edge_score > self.threshhold]

adj = to_scipy_sparse_matrix(edges_contract, num_nodes=X.size(0))
_, cluster_index = connected_components(adj, directed=True,
connection="weak")

cluster_index = torch.tensor(cluster_index, dtype=torch.int64,
device=X.device)
C = F.one_hot(cluster_index).type(torch.float)
A = to_dense_adj(edge_index, max_num_nodes=X.size(0)).squeeze(0)
S = to_dense_adj(edge_index, edge_attr=edge_score,
max_num_nodes=X.size(0)).squeeze(0)
mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]

A_contract = to_dense_adj(
edges_contract, max_num_nodes=X.size(0)).type(torch.int).squeeze(0)
nodes_single = ((A_contract.sum(-1) +
A_contract.sum(-2)) == 0).nonzero()
S[nodes_single, nodes_single] = 1

X_new = (S @ C).T @ X
edge_index_new, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))
edge_attr = torch.cat(
[x[edge_index[0]], x[edge_index[1]]],
dim=-1,
)
edge_score = self.lin(edge_attr).view(-1)
edge_score = F.dropout(edge_score, p=self.dropout,
training=self.training)

if self.edge_score_method == 'tanh':
edge_score = edge_score.tanh()
elif self.edge_score_method == 'sigmoid':
edge_score = edge_score.sigmoid()
else:
assert self.edge_score_method == 'log_softmax'
edge_score = F.log_softmax(edge_score, dim=0)

return self._merge_edges(x, edge_index, batch, edge_score)

def _merge_edges(
self,
x: Tensor,
edge_index: Tensor,
batch: Tensor,
edge_score: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:

new_batch = X.new_empty(X_new.size(0), dtype=torch.long)
new_batch = new_batch.scatter_(0, cluster_index, batch)
from scipy.sparse.csgraph import connected_components

unpool_info = self.unpool_description(edge_index=edge_index,
batch=batch,
cluster_map=cluster_index)
edge_contract = edge_index[:, edge_score > self.threshhold]

return X_new.to(X.device), edge_index_new.to(
X.device), new_batch, unpool_info
adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
_, cluster_np = connected_components(adj, directed=True,
connection="weak")

def unpool(
self,
x: Tensor,
unpool_info: NamedTuple,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Unpools a previous cluster pooling step.
cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device)
C = one_hot(cluster)
A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)
S = to_dense_adj(edge_index, edge_attr=edge_score,
max_num_nodes=x.size(0)).squeeze(0)

For unpooling, :obj:`x` should be of same shape as those produced by
this layer's :func:`forward` function. Then, it will produce an
unpooled :obj:`x` in addition to :obj:`edge_index` and :obj:`batch`.
A_contract = to_dense_adj(edge_contract,
max_num_nodes=x.size(0)).squeeze(0)
nodes_single = ((A_contract.sum(dim=-1) +
A_contract.sum(dim=-2)) == 0).nonzero()
S[nodes_single, nodes_single] = 1.0

Args:
x (Tensor): The node features.
unpool_info (unpool_description): Information that has
been produced by :func:`ClusterPooling.forward`.
x_out = (S @ C).t() @ x
edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))
batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch)
unpool_info = UnpoolInfo(edge_index, cluster, batch)

Return types:
* **x** *(Tensor)* - The unpooled node features.
* **edge_index** *(LongTensor)* - The new edge indices.
* **batch** *(LongTensor)* - The new batch vector.
"""
# We copy the cluster features into every node
node_maps = unpool_info.cluster_map
n_nodes = 0
for c in node_maps:
node_maps += len(c)
import numpy as np
repack = np.array([-1 for _ in range(n_nodes)])
for i, c in enumerate(node_maps):
repack[c] = i
new_x = x[repack]

return new_x, unpool_info.edge_index, unpool_info.batch
return x_out, edge_index_out, batch_out, unpool_info

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.in_channels})'
2 changes: 1 addition & 1 deletion torch_geometric/nn/pool/edge_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self,
in_channels: int,
edge_score_method: Optional[Callable] = None,
dropout: Optional[float] = 0.0,
dropout: float = 0.0,
add_to_edge_score: float = 0.5,
):
super().__init__()
Expand Down
Loading