Skip to content

HeteroConv aggregation cannot handle tuple outputs #9677

Open
@jongyaoY

Description

🐛 Describe the bug

I would like to use HeteroConv to wrap my custom message passing layer, which updates not just the node features but also edge attributes and returns a tuple of (x_updated, edge_attr_updated):

from torch_geometric.nn import HeteroConv, MessagePassing
from torch_scatter import scatter

class InteractionNetwork(MessagePassing):
    def __init__(self, ...

    def forward(self, x, edge_index, edge_attr):
        r_dim = 1
        if isinstance(edge_attr, tuple):
            edge_attr = edge_attr[r_dim]
        edge_attr_updated, aggr = self.propagate(
            x=x, edge_index=edge_index, edge_attr=edge_attr
        )
        x_updated = self.node_fn(torch.cat((x[r_dim], aggr), dim=-1))
        return x[r_dim] + x_updated, edge_attr + edge_attr_updated

    def message(self, x_i, x_j, edge_attr):  # receiver  # sender
        e_latent = torch.cat((x_i, x_j, edge_attr), dim=-1)
        return self.edge_fn(e_latent)

    def aggregate(
        self, inputs: torch.Tensor, index: torch.Tensor, dim_size=None
    ):
        out = scatter(
            inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum"
        )
        return inputs, out

layer = HeteroConv(
    {
        EdgeType1: InteractionNetwork(**params),
        EdgeType2: InteractionNetwork(**params),
    },
    aggr=None,
)

...

out = layer(data.x_dict, data.edge_index_dict, data.edge_attr_dict)

But the group function raises following error:

    def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
        if len(xs) == 0:
            return None
        elif aggr is None:
>           return torch.stack(xs, dim=1)
E           TypeError: expected Tensor as element 0 in argument 0, but got tuple

torch_geometric/nn/conv/hetero_conv.py:18: TypeError

Is it possible to make this group function or even the aggregation process also customizable?

Versions

PyTorch version: 1.12.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.21.1
Libc version: glibc-2.31

Python version: 3.8.10 (default, Mar 15 2022, 12:22:08)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-71-generic-x86_64-with-glibc2.29
Is CUDA available: True

[pip3] torch_geometric==2.4.0

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions