diff --git a/test/nn/test_to_hetero_module.py b/test/nn/test_to_hetero_module.py index 9838236b6344..bc65d4d729c8 100644 --- a/test/nn/test_to_hetero_module.py +++ b/test/nn/test_to_hetero_module.py @@ -1,72 +1,59 @@ import pytest # noqa import torch -from torch_geometric.data import HeteroData -from torch_geometric.nn.conv import GCNConv +from torch_geometric.nn.conv import SAGEConv from torch_geometric.nn.dense import Linear -from torch_geometric.nn.to_hetero_module import ToHeteroModule +from torch_geometric.nn.to_hetero_module import ( + ToHeteroLinear, + ToHeteroMessagePassing, +) -heterodata = HeteroData() -heterodata['v0'].x = torch.randn(20, 10) -heterodata['v1'].x = torch.randn(20, 10) -heterodata[('v0', 'r0', - 'v0')].edge_index = torch.randint(high=20, - size=(2, 30)).to(torch.long) -heterodata[('v0', 'r2', - 'v1')].edge_index = torch.randint(high=20, - size=(2, 30)).to(torch.long) -heterodata[('v1', 'r3', - 'v0')].edge_index = torch.randint(high=20, - size=(2, 30)).to(torch.long) -heterodata[('v1', 'r4', 'v1')] = torch.randn(2, 50) +@pytest.mark.parametrize('LinearCls', [torch.nn.Linear, Linear]) +def test_to_hetero_linear(LinearCls): + x_dict = {'1': torch.randn(5, 16), '2': torch.randn(4, 16)} + x = torch.cat([x_dict['1'], x_dict['2']], dim=0) + type_vec = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1]) -def test_to_hetero_linear(): - lin = Linear(10, 5) - heterolin = ToHeteroModule(lin, heterodata.metadata()) - # test dict input - x_dict = heterodata.collect('x') - out = heterolin(x_dict) - assert out['v0'].shape == (20, 5) - assert out['v1'].shape == (20, 5) - # test fused input - x = torch.cat([x_j for x_j in x_dict.values()]) - node_type = torch.cat([(j * torch.ones(x_j.shape[0])).long() - for j, x_j in enumerate(x_dict.values())]) - out = heterolin(x=x, node_type=node_type) - assert out.shape == (40, 5) + module = ToHeteroLinear(LinearCls(16, 32), list(x_dict.keys())) + out_dict = module(x_dict) + assert len(out_dict) == 2 + assert out_dict['1'].size() == (5, 32) + assert out_dict['2'].size() == (4, 32) -def test_to_hetero_gcn(): - gcnconv = GCNConv(10, 5) - rgcnconv = ToHeteroModule(gcnconv, heterodata.metadata()) - # test dict input - x_dict = heterodata.collect('x') - e_idx_dict = heterodata.collect('edge_index') - out = rgcnconv(x_dict, edge_index=e_idx_dict) - assert out['v0'].shape == (20, 5) - assert out['v1'].shape == (20, 5) + out = module(x, type_vec) + assert out.size() == (9, 32) - x = torch.cat(list(x_dict.values()), dim=0) + assert torch.allclose(out_dict['1'], out[0:5]) + assert torch.allclose(out_dict['2'], out[5:9]) - num_node_dict = heterodata.collect('num_nodes') - increment_dict = {} - ctr = 0 - for node_type in num_node_dict: - increment_dict[node_type] = ctr - ctr += num_node_dict[node_type] - etypes_list = [] - for i, e_type in enumerate(e_idx_dict.keys()): - src_type, dst_type = e_type[0], e_type[-1] - if torch.numel(e_idx_dict[e_type]) != 0: - e_idx_dict[e_type][ - 0, :] = e_idx_dict[e_type][0, :] + increment_dict[src_type] - e_idx_dict[e_type][ - 1, :] = e_idx_dict[e_type][1, :] + increment_dict[dst_type] - etypes_list.append(torch.ones(e_idx_dict[e_type].shape[-1]) * i) - edge_type = torch.cat(etypes_list).to(torch.long) - edge_index = torch.cat(list(e_idx_dict.values()), dim=1) - # test fused input - out = rgcnconv(x, edge_index=edge_index, edge_type=edge_type) - assert out.shape == (40, 5) +def test_to_hetero_message_passing(): + x_dict = {'1': torch.randn(5, 16), '2': torch.randn(4, 16)} + x = torch.cat([x_dict['1'], x_dict['2']], dim=0) + node_type = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1]) + + edge_index_dict = { + ('1', 'to', '2'): torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 2, 3]]), + ('2', 'to', '1'): torch.tensor([[0, 0, 1, 2, 3], [0, 1, 2, 3, 4]]), + } + edge_index = torch.tensor([ + [0, 1, 2, 3, 4, 5, 5, 6, 7, 8], + [5, 5, 6, 7, 8, 0, 1, 2, 3, 4], + ]) + edge_type = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + + module = ToHeteroMessagePassing(SAGEConv(16, 32), list(x_dict.keys()), + list(edge_index_dict.keys())) + + out_dict = module(x_dict, edge_index_dict) + assert len(out_dict) == 2 + assert out_dict['1'].size() == (5, 32) + assert out_dict['2'].size() == (4, 32) + + out = module(x, edge_index, node_type, edge_type) + assert out.size() == (9, 32) + + assert torch.allclose(out_dict['1'], out[0:5]) + assert torch.allclose(out_dict['2'], out[5:9]) diff --git a/torch_geometric/nn/to_hetero_module.py b/torch_geometric/nn/to_hetero_module.py index 88c70676189a..b7ac18db5095 100644 --- a/torch_geometric/nn/to_hetero_module.py +++ b/torch_geometric/nn/to_hetero_module.py @@ -1,199 +1,176 @@ import copy import warnings -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union import torch from torch import Tensor import torch_geometric -from torch_geometric.typing import EdgeType, Metadata, NodeType, OptTensor -from torch_geometric.utils.hetero import get_unused_node_types +from torch_geometric.typing import EdgeType, NodeType, OptTensor +from torch_geometric.utils import scatter -class ToHeteroModule(torch.nn.Module): - aggrs = { - 'sum': torch.add, - # For 'mean' aggregation, we first sum up all feature matrices, and - # divide by the number of matrices in a later step. - 'mean': torch.add, - 'max': torch.max, - 'min': torch.min, - 'mul': torch.mul, - } +class ToHeteroLinear(torch.nn.Module): + def __init__( + self, + module: torch.nn.Module, + types: Union[List[NodeType], List[EdgeType]], + ): + from torch_geometric.nn import HeteroLinear, Linear + + super().__init__() + + self.types = types + + if isinstance(module, Linear): + in_channels = module.in_channels + out_channels = module.out_channels + bias = module.bias is not None + + elif isinstance(module, torch.nn.Linear): + in_channels = module.in_features + out_channels = module.out_features + bias = module.bias is not None + + else: + raise ValueError(f"Expected 'Linear' module (got '{type(module)}'") + + # TODO: Need to handle `in_channels=-1` case. + # TODO We currently assume that `x` is sorted according to `type`. + self.hetero_module = HeteroLinear( + in_channels, + out_channels, + num_types=len(types), + is_sorted=True, + bias=bias, + ) + + def fused_forward(self, x: Tensor, type_vec: Tensor) -> Tensor: + return self.hetero_module(x, type_vec) + + def dict_forward( + self, + x_dict: Dict[Union[NodeType, EdgeType], Tensor], + ) -> Dict[Union[NodeType, EdgeType], Tensor]: + + if not torch_geometric.typing.WITH_PYG_LIB: + return { + key: self.heteromodule.lins[i](x_dict[key]) + for i, key in enumerate(self.types) + } + + x = torch.cat([x_dict[key] for key in self.types], dim=0) + sizes = [x_dict[key].size(0) for key in self.types] + type_vec = torch.arange(len(self.types), device=x.device) + size = torch.tensor(sizes, device=x.device) + type_vec = type_vec.repeat_interleave(size) + outs = self.hetero_module(x, type_vec).split(sizes) + return {key: out for key, out in zip(self.types, outs)} + + def forward( + self, + x: Union[Tensor, Dict[Union[NodeType, EdgeType], Tensor]], + type_vec: Optional[Tensor] = None, + ) -> Union[Tensor, Dict[Union[NodeType, EdgeType], Tensor]]: + + if isinstance(x, dict): + return self.dict_forward(x) + + elif isinstance(x, Tensor) and type_vec is not None: + return self.fused_forward(x, type_vec) + raise ValueError(f"Encountered invalid forward types in " + f"'{self.__class__.__name__}'") + + +class ToHeteroMessagePassing(torch.nn.Module): def __init__( self, module: torch.nn.Module, - metadata: Metadata, + node_types: List[NodeType], + edge_types: List[NodeType], aggr: str = 'sum', ): + from torch_geometric.nn import HeteroConv, MessagePassing + super().__init__() - self.metadata = metadata - self.node_types = metadata[0] - self.edge_types = metadata[1] - self.aggr = aggr - assert len(metadata) == 2 - assert aggr in self.aggrs.keys() - # check wether module is linear - self.is_lin = isinstance(module, torch.nn.Linear) or isinstance( - module, torch_geometric.nn.dense.Linear) - # check metadata[0] has node types - # check metadata[1] has edge types if module is MessagePassing - assert len(metadata[0]) > 0 and (len(metadata[1]) > 0 - or not self.is_lin) - if self.is_lin: - # make HeteroLinear layer based on metadata - if isinstance(module, torch.nn.Linear): - in_ft = module.in_features - out_ft = module.out_features - else: - in_ft = module.in_channels - out_ft = module.out_channels - heteromodule = torch_geometric.nn.dense.HeteroLinear( - in_ft, out_ft, - len(self.node_types)).to(list(module.parameters())[0].device) - heteromodule.reset_parameters() - else: - # copy MessagePassing module for each edge type - unused_node_types = get_unused_node_types(*metadata) - if len(unused_node_types) > 0: - warnings.warn( - f"There exist node types ({unused_node_types}) whose " - f"representations do not get updated during message " - f"passing as they do not occur as destination type in any " - f"edge type. This may lead to unexpected behaviour.") - heteromodule = {} - for edge_type in self.edge_types: - heteromodule[edge_type] = copy.deepcopy(module) - if hasattr(module, 'reset_parameters'): - module.reset_parameters() - elif sum([p.numel() for p in module.parameters()]) > 0: - warnings.warn( - f"'{module}' will be duplicated, but its parameters" - f"cannot be reset. To suppress this warning, add a" - f"'reset_parameters()' method to '{module}'") - - self.heteromodule = heteromodule - - def fused_forward(self, x: Tensor, edge_index: OptTensor = None, - node_type: OptTensor = None, - edge_type: OptTensor = None) -> Tensor: - r""" - Args: - x: The input node features. :obj:`[num_nodes, in_channels]` - node feature matrix. - edge_index (LongTensor): The edge indices. - node_type: The one-dimensional node type/index for each node in - :obj:`x`. - edge_type: The one-dimensional edge type/index for each edge in - :obj:`edge_index`. - """ - # (TODO) Add Sparse Tensor support - if self.is_lin: - # call HeteroLinear layer - out = self.heteromodule(x, node_type) - else: - # iterate over each edge type - for j, module in enumerate(self.heteromodule.values()): - e_idx_type_j = edge_index[:, edge_type == j] - o_j = module(x, e_idx_type_j) - if j == 0: - out = o_j - else: - out += o_j - return out + + self.node_types = node_types + self.node_type_to_index = {key: i for i, key in enumerate(node_types)} + self.edge_types = edge_types + + if not isinstance(module, MessagePassing): + raise ValueError(f"Expected 'MessagePassing' module " + f"(got '{type(module)}'") + + if (not hasattr(module, 'reset_parameters') + and sum([p.numel() for p in module.parameters()]) > 0): + warnings.warn(f"'{module}' will be duplicated, but its parameters " + f"cannot be reset. To suppress this warning, add a " + f"'reset_parameters()' method to '{module}'") + + convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types} + self.hetero_module = HeteroConv(convs, aggr) + self.hetero_module.reset_parameters() + + def fused_forward(self, x: Tensor, edge_index: Tensor, node_type: Tensor, + edge_type: Tensor) -> Tensor: + # TODO This currently does not fuse at all :( + # TODO We currently assume that `x` and `edge_index` are both sorted + # according to `type`. + + node_sizes = scatter(torch.ones_like(node_type), node_type, dim=0, + dim_size=len(self.node_types), reduce='sum') + edge_sizes = scatter(torch.ones_like(edge_type), edge_type, dim=0, + dim_size=len(self.edge_types), reduce='sum') + + cumsum = torch.cat([node_type.new_zeros(1), node_sizes.cumsum(0)[:1]]) + + xs = x.split(node_sizes.tolist()) + x_dict = {node_type: x for node_type, x in zip(self.node_types, xs)} + + # TODO Consider out-sourcing to its own function. + edge_indices = edge_index.clone().split(edge_sizes.tolist(), dim=1) + for (src, _, dst), index in zip(self.edge_types, edge_indices): + index[0] -= cumsum[self.node_type_to_index[src]] + index[1] -= cumsum[self.node_type_to_index[dst]] + + edge_index_dict = { + edge_type: edge_index + for edge_type, edge_index in zip(self.edge_types, edge_indices) + } + + out_dict = self.hetero_module(x_dict, edge_index_dict) + return torch.cat([out_dict[key] for key in self.node_types], dim=0) def dict_forward( self, x_dict: Dict[NodeType, Tensor], - edge_index_dict: Optional[Dict[EdgeType, Tensor]] = None, + edge_index_dict: Dict[EdgeType, Tensor], + **kwargs, ) -> Dict[NodeType, Tensor]: - r""" - Args: - x_dict (Dict[str, Tensor]): A dictionary holding node feature - information for each individual node type. - edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A dictionary - holding graph connectivity information for each individual - edge type. - """ - # (TODO) Add Sparse Tensor support - if self.is_lin: - # fuse inputs - x = torch.cat([x_j for x_j in x_dict.values()]) - size_list = [feat.shape[0] for feat in x_dict.values()] - sizes = torch.tensor(size_list, dtype=torch.long, device=x.device) - node_type = torch.arange(len(sizes), device=x.device) - node_type = node_type.repeat_interleave(sizes) - # HeteroLinear layer - o = self.heteromodule(x, node_type) - o_dict = { - key: o_i.squeeze() - for key, o_i in zip(x_dict.keys(), o.split(size_list)) - } - else: - o_dict = {} - # iterate over each edge_type - for j, (etype_j, module) in enumerate(self.heteromodule.items()): - e_idx_type_j = edge_index_dict[etype_j] - src_node_type_j = etype_j[0] - dst_node_type_j = etype_j[-1] - o_j = module(x_dict[src_node_type_j], e_idx_type_j) - if dst_node_type_j not in o_dict.keys(): - o_dict[dst_node_type_j] = o_j - else: - o_dict[dst_node_type_j] += o_j - return o_dict + return self.hetero_module(x_dict, edge_index_dict, **kwargs) def forward( self, - x: Union[Dict[NodeType, Tensor], Tensor], - edge_index: Optional[Union[Dict[EdgeType, Tensor], Tensor]] = None, + x: Union[Tensor, Dict[NodeType, Tensor]], + edge_index: Union[Tensor, Dict[EdgeType, Tensor]], node_type: OptTensor = None, edge_type: OptTensor = None, - ) -> Union[Dict[NodeType, Tensor], Tensor]: - r""" - Args: - x (Dict[str, Tensor] or Tensor): A dictionary holding node feature - information for each individual node type or the same - features combined into one tensor. - edge_index (Dict[Tuple[str, str, str], Tensor] or Tensor): - A dictionary holding graph connectivity information for - each individual edge type or the same values combined - into one tensor. - node_type: The one-dimensional relation type/index for each node in - :obj:`x` if it is provided as a single tensor. - Should be only :obj:`None` in case :obj:`x` is of type - Dict[str, Tensor]. - (default: :obj:`None`) - edge_type: The one-dimensional relation type/index for each edge in - :obj:`edge_index` if it is provided as a single tensor. - Should be only :obj:`None` in case :obj:`edge_index` is of type - Dict[Tuple[str, str, str], Tensor]. - (default: :obj:`None`) - """ - # check if x is passed as a dict or fused - if isinstance(x, dict): - # check what inputs to pass - if self.is_lin: - return self.dict_forward(x) - else: - if not isinstance(edge_index, dict): - raise TypeError("If x is provided as a dictionary, \ - edge_index must be as well") - return self.dict_forward(x, edge_index_dict=edge_index) - else: - if self.is_lin: - if node_type is None: - raise ValueError('If x is a single tensor, \ - node_type argument must be provided.') - return self.fused_forward(x, node_type=node_type) - else: - if not isinstance(edge_index, Tensor): - raise TypeError("If x is provided as a Tensor, \ - edge_index must be as well") - if edge_type is None: - raise ValueError( - 'If x and edge_indices are single tensors, \ - node_type and edge_type arguments must be provided.') - return self.fused_forward(x, edge_index=edge_index, - edge_type=edge_type) + **kwargs, + ) -> Union[Tensor, Dict[NodeType, Tensor]]: + + if isinstance(x, dict) and isinstance(edge_index, dict): + return self.dict_forward(x, edge_index, **kwargs) + + elif (isinstance(x, Tensor) and isinstance(edge_index, Tensor) + and node_type is not None and edge_type is not None): + + if len(kwargs) > 0: + raise ValueError("Additional forward arguments not yet " + "supported in fused mode") + + return self.fused_forward(x, edge_index, node_type, edge_type) + + raise ValueError(f"Encountered invalid forward types in " + f"'{self.__class__.__name__}'")