Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ClusterPooling layer #9627

Merged
merged 13 commits into from
Sep 10, 2024
Prev Previous commit
Next Next commit
PR fixes
  • Loading branch information
thijssnelleman committed Aug 26, 2024
commit 37f9d34e7744f1bf285998ffee555eec5528690e
18 changes: 9 additions & 9 deletions torch_geometric/nn/pool/cluster_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class ClusterPooling(torch.nn.Module):
Pooling" <paper url>` paper.

In short, a score is computed for each edge.
Based on the selected edges, graph clusters are calculated and compressed to one
node using an injective aggregation function (sum). Edges are remapped based on
the node created by each cluster and the original edges.
Based on the selected edges, graph clusters are calculated and compressed
to one node using an injective aggregation function (sum). Edges are
remapped based on the node created by each cluster and the original edges.

Args:
in_channels (int): Size of each input sample.
Expand Down Expand Up @@ -104,23 +104,23 @@ def forward(
* **unpool_info** *(unpool_description)* - Information that is
consumed by :func:`ClusterPooling.unpool` for unpooling.
"""
#First we drop the self edges as those cannot be clustered
# First we drop the self edges as those cannot be clustered
msk = edge_index[0] != edge_index[1]
edge_index = edge_index[:, msk]
if not self.directed:
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1)
# We only evaluate each edge once, so we filter double edges from the list
# We only evaluate each edge once, remove double edges from the list
edge_index = coalesce(edge_index)

e = torch.cat(
[x[edge_index[0]], x[edge_index[1]]],
dim=-1) # Concatenates the source feature with the target features
dim=-1) # Concatenates source feature with target features
e = self.lin(e).view(
-1
) # Apply linear NN on the node pairs (edges) and reshape to 1 dimension
) # Apply linear NN on the node pairs (edges) and reshape
e = F.dropout(e, p=self.dropout, training=self.training)

e = self.compute_edge_score(e) #Non linear activation function
e = self.compute_edge_score(e) # Non linear activation function
x, edge_index, batch, unpool_info = self.__merge_edges__(
x, edge_index, batch, e)

Expand Down Expand Up @@ -182,7 +182,7 @@ def unpool(
* **edge_index** *(LongTensor)* - The new edge indices.
* **batch** *(LongTensor)* - The new batch vector.
"""
# We just copy the cluster feature into every node
# We copy the cluster features into every node
node_maps = unpool_info.cluster_map
n_nodes = 0
for c in node_maps:
Expand Down
Loading