Skip to content

HeteroConv aggregation cannot handle tuple outputs #9677

Open
@jongyaoY

Description

@jongyaoY

🐛 Describe the bug

I would like to use HeteroConv to wrap my custom message passing layer, which updates not just the node features but also edge attributes and returns a tuple of (x_updated, edge_attr_updated):

from torch_geometric.nn import HeteroConv, MessagePassing
from torch_scatter import scatter

class InteractionNetwork(MessagePassing):
    def __init__(self, ...

    def forward(self, x, edge_index, edge_attr):
        r_dim = 1
        if isinstance(edge_attr, tuple):
            edge_attr = edge_attr[r_dim]
        edge_attr_updated, aggr = self.propagate(
            x=x, edge_index=edge_index, edge_attr=edge_attr
        )
        x_updated = self.node_fn(torch.cat((x[r_dim], aggr), dim=-1))
        return x[r_dim] + x_updated, edge_attr + edge_attr_updated

    def message(self, x_i, x_j, edge_attr):  # receiver  # sender
        e_latent = torch.cat((x_i, x_j, edge_attr), dim=-1)
        return self.edge_fn(e_latent)

    def aggregate(
        self, inputs: torch.Tensor, index: torch.Tensor, dim_size=None
    ):
        out = scatter(
            inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum"
        )
        return inputs, out

layer = HeteroConv(
    {
        EdgeType1: InteractionNetwork(**params),
        EdgeType2: InteractionNetwork(**params),
    },
    aggr=None,
)

...

out = layer(data.x_dict, data.edge_index_dict, data.edge_attr_dict)

But the group function raises following error:

    def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
        if len(xs) == 0:
            return None
        elif aggr is None:
>           return torch.stack(xs, dim=1)
E           TypeError: expected Tensor as element 0 in argument 0, but got tuple

torch_geometric/nn/conv/hetero_conv.py:18: TypeError

Is it possible to make this group function or even the aggregation process also customizable?

Versions

PyTorch version: 1.12.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.21.1
Libc version: glibc-2.31

Python version: 3.8.10 (default, Mar 15 2022, 12:22:08)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-71-generic-x86_64-with-glibc2.29
Is CUDA available: True

[pip3] torch_geometric==2.4.0

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions