HeteroConv aggregation cannot handle tuple outputs #9677
Open
Description
🐛 Describe the bug
I would like to use HeteroConv to wrap my custom message passing layer, which updates not just the node features but also edge attributes and returns a tuple of (x_updated, edge_attr_updated)
:
from torch_geometric.nn import HeteroConv, MessagePassing
from torch_scatter import scatter
class InteractionNetwork(MessagePassing):
def __init__(self, ...
def forward(self, x, edge_index, edge_attr):
r_dim = 1
if isinstance(edge_attr, tuple):
edge_attr = edge_attr[r_dim]
edge_attr_updated, aggr = self.propagate(
x=x, edge_index=edge_index, edge_attr=edge_attr
)
x_updated = self.node_fn(torch.cat((x[r_dim], aggr), dim=-1))
return x[r_dim] + x_updated, edge_attr + edge_attr_updated
def message(self, x_i, x_j, edge_attr): # receiver # sender
e_latent = torch.cat((x_i, x_j, edge_attr), dim=-1)
return self.edge_fn(e_latent)
def aggregate(
self, inputs: torch.Tensor, index: torch.Tensor, dim_size=None
):
out = scatter(
inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum"
)
return inputs, out
layer = HeteroConv(
{
EdgeType1: InteractionNetwork(**params),
EdgeType2: InteractionNetwork(**params),
},
aggr=None,
)
...
out = layer(data.x_dict, data.edge_index_dict, data.edge_attr_dict)
But the group
function raises following error:
def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:
if len(xs) == 0:
return None
elif aggr is None:
> return torch.stack(xs, dim=1)
E TypeError: expected Tensor as element 0 in argument 0, but got tuple
torch_geometric/nn/conv/hetero_conv.py:18: TypeError
Is it possible to make this group
function or even the aggregation process also customizable?
Versions
PyTorch version: 1.12.1+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.21.1
Libc version: glibc-2.31
Python version: 3.8.10 (default, Mar 15 2022, 12:22:08) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-71-generic-x86_64-with-glibc2.29
Is CUDA available: True
[pip3] torch_geometric==2.4.0