-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented the Learnable Communitive Monoid Aggregation (#7976)
Initial version of the LCM aggregation, requested in issue #7574. A possible feature to incorporate is for `forward` to additionally compute and return an associativity loss, as described in the paper. --------- Co-authored-by: Rishi Puri <puririshi98@berkeley.edu> Co-authored-by: Jintang Li <cnljt@outlook.com> Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
- Loading branch information
1 parent
338baa6
commit 3adcf5b
Showing
4 changed files
with
135 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.nn import LCMAggregation | ||
|
||
|
||
def test_lcm_aggregation_with_project(): | ||
x = torch.randn(6, 16) | ||
index = torch.tensor([0, 0, 1, 1, 1, 2]) | ||
|
||
aggr = LCMAggregation(16, 32) | ||
assert str(aggr) == 'LCMAggregation(16, 32, project=True)' | ||
|
||
out = aggr(x, index) | ||
assert out.size() == (3, 32) | ||
|
||
|
||
def test_lcm_aggregation_without_project(): | ||
x = torch.randn(6, 16) | ||
index = torch.tensor([0, 0, 1, 1, 1, 2]) | ||
|
||
aggr = LCMAggregation(16, 16, project=False) | ||
assert str(aggr) == 'LCMAggregation(16, 16, project=False)' | ||
|
||
out = aggr(x, index) | ||
assert out.size() == (3, 16) | ||
|
||
|
||
def test_lcm_aggregation_error_handling(): | ||
with pytest.raises(ValueError, match="must be projected"): | ||
LCMAggregation(16, 32, project=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from math import ceil, log2 | ||
from typing import Optional | ||
|
||
from torch import Tensor | ||
from torch.nn import GRUCell, Linear | ||
|
||
from torch_geometric.experimental import disable_dynamic_shapes | ||
from torch_geometric.nn.aggr import Aggregation | ||
|
||
|
||
class LCMAggregation(Aggregation): | ||
r"""The Learnable Commutative Monoid aggregation from the | ||
`"Learnable Commutative Monoids for Graph Neural Networks" | ||
<https://arxiv.org/abs/2212.08541>`_ paper, in which the elements are | ||
aggregated using a binary tree reduction with | ||
:math:`\mathcal{O}(\log |\mathcal{V}|)` depth. | ||
.. note:: | ||
:class:`LCMAggregation` requires sorted indices :obj:`index` as input. | ||
Specifically, if you use this aggregation as part of | ||
:class:`~torch_geometric.nn.conv.MessagePassing`, ensure that | ||
:obj:`edge_index` is sorted by destination nodes, either by manually | ||
sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` | ||
or by calling :meth:`torch_geometric.data.Data.sort`. | ||
.. warning:: | ||
:class:`LCMAggregation` is not a permutation-invariant operator. | ||
Args: | ||
in_channels (int): Size of each input sample. | ||
out_channels (int): Size of each output sample. | ||
project (bool, optional): If set to :obj:`True`, the layer will apply a | ||
linear transformation followed by an activation function before | ||
aggregation. (default: :obj:`True`) | ||
""" | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
out_channels: int, | ||
project: bool = True, | ||
): | ||
super().__init__() | ||
|
||
if in_channels != out_channels and not project: | ||
raise ValueError(f"Inputs of '{self.__class__.__name__}' must be " | ||
f"projected if `in_channels != out_channels`") | ||
|
||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.project = project | ||
|
||
if self.project: | ||
self.lin = Linear(in_channels, out_channels) | ||
else: | ||
self.lin = None | ||
|
||
self.gru_cell = GRUCell(out_channels, out_channels) | ||
|
||
def reset_parameters(self): | ||
if self.project: | ||
self.lin.reset_parameters() | ||
self.gru_cell.reset_parameters() | ||
|
||
def binary_op(self, left: Tensor, right: Tensor) -> Tensor: | ||
return (self.gru_cell(left, right) + self.gru_cell(right, left)) / 2.0 | ||
|
||
@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) | ||
def forward( | ||
self, | ||
x: Tensor, | ||
index: Optional[Tensor] = None, | ||
ptr: Optional[Tensor] = None, | ||
dim_size: Optional[int] = None, | ||
dim: int = -2, | ||
max_num_elements: Optional[int] = None, | ||
) -> Tensor: | ||
|
||
if self.project: | ||
x = self.lin(x).relu() | ||
|
||
x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, | ||
max_num_elements=max_num_elements) | ||
|
||
x = x.permute(1, 0, 2) # [num_neighbors, num_nodes, num_features] | ||
|
||
depth = ceil(log2(x.size(0))) | ||
for _ in range(depth): | ||
x = [ | ||
self.binary_op(x[2 * i], x[2 * i + 1]) if | ||
(2 * i + 1) < len(x) else x[2 * i] | ||
for i in range(ceil(len(x) / 2)) | ||
] | ||
|
||
assert len(x) == 1 | ||
return x[0] | ||
|
||
def __repr__(self) -> str: | ||
return (f'{self.__class__.__name__}({self.in_channels}, ' | ||
f'{self.out_channels}, project={self.project})') |