Skip to content

Commit

Permalink
Implemented the Learnable Communitive Monoid Aggregation (#7976)
Browse files Browse the repository at this point in the history
Initial version of the LCM aggregation, requested in issue #7574. A
possible feature to incorporate is for `forward` to additionally compute
and return an associativity loss, as described in the paper.

---------

Co-authored-by: Rishi Puri <puririshi98@berkeley.edu>
Co-authored-by: Jintang Li <cnljt@outlook.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
4 people authored and JakubPietrakIntel committed Sep 27, 2023
1 parent 338baa6 commit 3adcf5b
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `LCMAggregation`, an implementation of Learnable Communitive Monoids ([#7976](https://github.com/pyg-team/pytorch_geometric/pull/7976))
- Added a warning for isolated/non-existing node types in `HeteroData.validate()` ([#7995](https://github.com/pyg-team/pytorch_geometric/pull/7995))
- Added `utils.cumsum` implementation ([#7994](https://github.com/pyg-team/pytorch_geometric/pull/7994))
- Added the `BrcaTcga` dataset ([#7905](https://github.com/pyg-team/pytorch_geometric/pull/7905))
Expand Down
31 changes: 31 additions & 0 deletions test/nn/aggr/test_lcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import torch

from torch_geometric.nn import LCMAggregation


def test_lcm_aggregation_with_project():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])

aggr = LCMAggregation(16, 32)
assert str(aggr) == 'LCMAggregation(16, 32, project=True)'

out = aggr(x, index)
assert out.size() == (3, 32)


def test_lcm_aggregation_without_project():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])

aggr = LCMAggregation(16, 16, project=False)
assert str(aggr) == 'LCMAggregation(16, 16, project=False)'

out = aggr(x, index)
assert out.size() == (3, 16)


def test_lcm_aggregation_error_handling():
with pytest.raises(ValueError, match="must be projected"):
LCMAggregation(16, 32, project=False)
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .mlp import MLPAggregation
from .deep_sets import DeepSetsAggregation
from .set_transformer import SetTransformerAggregation
from .lcm import LCMAggregation

__all__ = classes = [
'Aggregation',
Expand All @@ -49,4 +50,5 @@
'MLPAggregation',
'DeepSetsAggregation',
'SetTransformerAggregation',
'LCMAggregation',
]
101 changes: 101 additions & 0 deletions torch_geometric/nn/aggr/lcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from math import ceil, log2
from typing import Optional

from torch import Tensor
from torch.nn import GRUCell, Linear

from torch_geometric.experimental import disable_dynamic_shapes
from torch_geometric.nn.aggr import Aggregation


class LCMAggregation(Aggregation):
r"""The Learnable Commutative Monoid aggregation from the
`"Learnable Commutative Monoids for Graph Neural Networks"
<https://arxiv.org/abs/2212.08541>`_ paper, in which the elements are
aggregated using a binary tree reduction with
:math:`\mathcal{O}(\log |\mathcal{V}|)` depth.
.. note::
:class:`LCMAggregation` requires sorted indices :obj:`index` as input.
Specifically, if you use this aggregation as part of
:class:`~torch_geometric.nn.conv.MessagePassing`, ensure that
:obj:`edge_index` is sorted by destination nodes, either by manually
sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index`
or by calling :meth:`torch_geometric.data.Data.sort`.
.. warning::
:class:`LCMAggregation` is not a permutation-invariant operator.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
project (bool, optional): If set to :obj:`True`, the layer will apply a
linear transformation followed by an activation function before
aggregation. (default: :obj:`True`)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
project: bool = True,
):
super().__init__()

if in_channels != out_channels and not project:
raise ValueError(f"Inputs of '{self.__class__.__name__}' must be "
f"projected if `in_channels != out_channels`")

self.in_channels = in_channels
self.out_channels = out_channels
self.project = project

if self.project:
self.lin = Linear(in_channels, out_channels)
else:
self.lin = None

self.gru_cell = GRUCell(out_channels, out_channels)

def reset_parameters(self):
if self.project:
self.lin.reset_parameters()
self.gru_cell.reset_parameters()

def binary_op(self, left: Tensor, right: Tensor) -> Tensor:
return (self.gru_cell(left, right) + self.gru_cell(right, left)) / 2.0

@disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])
def forward(
self,
x: Tensor,
index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None,
dim: int = -2,
max_num_elements: Optional[int] = None,
) -> Tensor:

if self.project:
x = self.lin(x).relu()

x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,
max_num_elements=max_num_elements)

x = x.permute(1, 0, 2) # [num_neighbors, num_nodes, num_features]

depth = ceil(log2(x.size(0)))
for _ in range(depth):
x = [
self.binary_op(x[2 * i], x[2 * i + 1]) if
(2 * i + 1) < len(x) else x[2 * i]
for i in range(ceil(len(x) / 2))
]

assert len(x) == 1
return x[0]

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, project={self.project})')

0 comments on commit 3adcf5b

Please sign in to comment.