From 3adcf5b9fa227ea26d62d12bdad991302c91b5cc Mon Sep 17 00:00:00 2001 From: ArchieGertsman Date: Mon, 11 Sep 2023 02:56:45 -0500 Subject: [PATCH] Implemented the Learnable Communitive Monoid Aggregation (#7976) Initial version of the LCM aggregation, requested in issue #7574. A possible feature to incorporate is for `forward` to additionally compute and return an associativity loss, as described in the paper. --------- Co-authored-by: Rishi Puri Co-authored-by: Jintang Li Co-authored-by: rusty1s --- CHANGELOG.md | 1 + test/nn/aggr/test_lcm.py | 31 +++++++++ torch_geometric/nn/aggr/__init__.py | 2 + torch_geometric/nn/aggr/lcm.py | 101 ++++++++++++++++++++++++++++ 4 files changed, 135 insertions(+) create mode 100644 test/nn/aggr/test_lcm.py create mode 100644 torch_geometric/nn/aggr/lcm.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b0caa744a64..9d6cf049029e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `LCMAggregation`, an implementation of Learnable Communitive Monoids ([#7976](https://github.com/pyg-team/pytorch_geometric/pull/7976)) - Added a warning for isolated/non-existing node types in `HeteroData.validate()` ([#7995](https://github.com/pyg-team/pytorch_geometric/pull/7995)) - Added `utils.cumsum` implementation ([#7994](https://github.com/pyg-team/pytorch_geometric/pull/7994)) - Added the `BrcaTcga` dataset ([#7905](https://github.com/pyg-team/pytorch_geometric/pull/7905)) diff --git a/test/nn/aggr/test_lcm.py b/test/nn/aggr/test_lcm.py new file mode 100644 index 000000000000..60cb2335d527 --- /dev/null +++ b/test/nn/aggr/test_lcm.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from torch_geometric.nn import LCMAggregation + + +def test_lcm_aggregation_with_project(): + x = torch.randn(6, 16) + index = torch.tensor([0, 0, 1, 1, 1, 2]) + + aggr = LCMAggregation(16, 32) + assert str(aggr) == 'LCMAggregation(16, 32, project=True)' + + out = aggr(x, index) + assert out.size() == (3, 32) + + +def test_lcm_aggregation_without_project(): + x = torch.randn(6, 16) + index = torch.tensor([0, 0, 1, 1, 1, 2]) + + aggr = LCMAggregation(16, 16, project=False) + assert str(aggr) == 'LCMAggregation(16, 16, project=False)' + + out = aggr(x, index) + assert out.size() == (3, 16) + + +def test_lcm_aggregation_error_handling(): + with pytest.raises(ValueError, match="must be projected"): + LCMAggregation(16, 32, project=False) diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index bbb67089f62e..c41d038b8ba2 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -23,6 +23,7 @@ from .mlp import MLPAggregation from .deep_sets import DeepSetsAggregation from .set_transformer import SetTransformerAggregation +from .lcm import LCMAggregation __all__ = classes = [ 'Aggregation', @@ -49,4 +50,5 @@ 'MLPAggregation', 'DeepSetsAggregation', 'SetTransformerAggregation', + 'LCMAggregation', ] diff --git a/torch_geometric/nn/aggr/lcm.py b/torch_geometric/nn/aggr/lcm.py new file mode 100644 index 000000000000..ac96a6d14d2d --- /dev/null +++ b/torch_geometric/nn/aggr/lcm.py @@ -0,0 +1,101 @@ +from math import ceil, log2 +from typing import Optional + +from torch import Tensor +from torch.nn import GRUCell, Linear + +from torch_geometric.experimental import disable_dynamic_shapes +from torch_geometric.nn.aggr import Aggregation + + +class LCMAggregation(Aggregation): + r"""The Learnable Commutative Monoid aggregation from the + `"Learnable Commutative Monoids for Graph Neural Networks" + `_ paper, in which the elements are + aggregated using a binary tree reduction with + :math:`\mathcal{O}(\log |\mathcal{V}|)` depth. + + .. note:: + + :class:`LCMAggregation` requires sorted indices :obj:`index` as input. + Specifically, if you use this aggregation as part of + :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that + :obj:`edge_index` is sorted by destination nodes, either by manually + sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index` + or by calling :meth:`torch_geometric.data.Data.sort`. + + .. warning:: + + :class:`LCMAggregation` is not a permutation-invariant operator. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + project (bool, optional): If set to :obj:`True`, the layer will apply a + linear transformation followed by an activation function before + aggregation. (default: :obj:`True`) + """ + def __init__( + self, + in_channels: int, + out_channels: int, + project: bool = True, + ): + super().__init__() + + if in_channels != out_channels and not project: + raise ValueError(f"Inputs of '{self.__class__.__name__}' must be " + f"projected if `in_channels != out_channels`") + + self.in_channels = in_channels + self.out_channels = out_channels + self.project = project + + if self.project: + self.lin = Linear(in_channels, out_channels) + else: + self.lin = None + + self.gru_cell = GRUCell(out_channels, out_channels) + + def reset_parameters(self): + if self.project: + self.lin.reset_parameters() + self.gru_cell.reset_parameters() + + def binary_op(self, left: Tensor, right: Tensor) -> Tensor: + return (self.gru_cell(left, right) + self.gru_cell(right, left)) / 2.0 + + @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements']) + def forward( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + max_num_elements: Optional[int] = None, + ) -> Tensor: + + if self.project: + x = self.lin(x).relu() + + x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim, + max_num_elements=max_num_elements) + + x = x.permute(1, 0, 2) # [num_neighbors, num_nodes, num_features] + + depth = ceil(log2(x.size(0))) + for _ in range(depth): + x = [ + self.binary_op(x[2 * i], x[2 * i + 1]) if + (2 * i + 1) < len(x) else x[2 * i] + for i in range(ceil(len(x) / 2)) + ] + + assert len(x) == 1 + return x[0] + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, project={self.project})')