Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BasicGNN.jittable() support #7865

Merged
merged 4 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
update
  • Loading branch information
rusty1s committed Aug 10, 2023
commit 896d163ebeeb3df2c621a49d7d807213989b8f70
10 changes: 10 additions & 0 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,16 @@ def test_edge_cnn(out_dim, dropout, act, norm, jk):
assert model(x, edge_index).size() == (3, out_channels)


def test_jittable():
x = torch.randn(3, 8)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])

model = GCN(8, 16, num_layers=2).jittable()
model = torch.jit.script(model)

assert model(x, edge_index).size() == (3, 16)


@pytest.mark.parametrize('out_dim', out_dims)
@pytest.mark.parametrize('jk', jks)
def test_one_layer_gnn(out_dim, jk):
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,8 @@ def register_edge_update_forward_hook(self,
@torch.jit.unused
def jittable(self, typing: Optional[str] = None) -> 'MessagePassing':
r"""Analyzes the :class:`MessagePassing` instance and produces a new
jittable module.
jittable module that can be used in combination with
:meth:`torch.jit.script`.

Args:
typing (str, optional): If given, will generate a concrete instance
Expand Down
209 changes: 159 additions & 50 deletions torch_geometric/nn/models/basic_gnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand All @@ -23,7 +23,7 @@
activation_resolver,
normalization_resolver,
)
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from torch_geometric.utils.trim_to_layer import TrimToLayer


Expand Down Expand Up @@ -61,6 +61,9 @@ class BasicGNN(torch.nn.Module):
**kwargs (optional): Additional arguments of the underlying
:class:`torch_geometric.nn.conv.MessagePassing` layers.
"""
supports_edge_weight: Final[bool]
supports_edge_attr: Final[bool]

def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -117,18 +120,21 @@ def __init__(
self.convs.append(
self.init_conv(in_channels, hidden_channels, **kwargs))

self.norms = None
if norm is not None:
norm_layer = normalization_resolver(
norm,
hidden_channels,
**(norm_kwargs or {}),
)
self.norms = ModuleList()
for _ in range(num_layers - 1):
self.norms.append(copy.deepcopy(norm_layer))
if jk is not None:
self.norms.append(copy.deepcopy(norm_layer))
self.norms = ModuleList()
norm_layer = normalization_resolver(
norm,
hidden_channels,
**(norm_kwargs or {}),
)
if norm_layer is None:
norm_layer = torch.nn.Identity()
for _ in range(num_layers - 1):
self.norms.append(copy.deepcopy(norm_layer))

if jk is not None:
self.norms.append(copy.deepcopy(norm_layer))
else:
self.norms.append(torch.nn.Identity())

if jk is not None and jk != 'last':
self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)
Expand All @@ -152,18 +158,42 @@ def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
for conv in self.convs:
conv.reset_parameters()
for norm in self.norms or []:
norm.reset_parameters()
for norm in self.norms:
if hasattr(norm, 'reset_parameters'):
norm.reset_parameters()
if hasattr(self, 'jk'):
self.jk.reset_parameters()
if hasattr(self, 'lin'):
self.lin.reset_parameters()

def forward(
@torch.jit._overload_method
def forward( # noqa
x,
edge_index,
edge_weight=None,
edge_attr=None,
num_sampled_nodes_per_hop=None,
num_sampled_edges_per_hop=None,
):
# type: (Tensor, Tensor, OptTensor, OptTensor, Optional[List[int]], Optional[List[int]]) -> Tensor # noqa
pass

@torch.jit._overload_method
def forward( # noqa
x,
edge_index,
edge_weight=None,
edge_attr=None,
num_sampled_nodes_per_hop=None,
num_sampled_edges_per_hop=None,
):
# type: (Tensor, SparseTensor, OptTensor, OptTensor, Optional[List[int]], Optional[List[int]]) -> Tensor # noqa
pass

def forward( # noqa
self,
x: Tensor,
edge_index: Adj,
*,
edge_weight: OptTensor = None,
edge_attr: OptTensor = None,
num_sampled_nodes_per_hop: Optional[List[int]] = None,
Expand All @@ -172,7 +202,7 @@ def forward(
r"""
Args:
x (torch.Tensor): The input node features.
edge_index (torch.Tensor): The edge indices.
edge_index (torch.Tensor or SparseTensor): The edge indices.
edge_weight (torch.Tensor, optional): The edge weights (if
supported by the underlying GNN layer). (default: :obj:`None`)
edge_attr (torch.Tensor, optional): The edge features (if supported
Expand All @@ -196,8 +226,10 @@ def forward(
"'edge_weight' and 'edge_attr'")

xs: List[Tensor] = []
for i in range(self.num_layers):
if num_sampled_nodes_per_hop is not None:
assert len(self.convs) == len(self.norms)
for i, (conv, norm) in enumerate(zip(self.convs, self.norms)):
if (num_sampled_nodes_per_hop is not None
and not torch.jit.is_scripting()):
x, edge_index, value = self._trim(
i,
num_sampled_nodes_per_hop,
Expand All @@ -215,28 +247,28 @@ def forward(
# As such, we rely on a static solution to pass optional edge
# weights and edge attributes to the module.
if self.supports_edge_weight and self.supports_edge_attr:
x = self.convs[i](x, edge_index, edge_weight=edge_weight,
edge_attr=edge_attr)
x = conv(x, edge_index, edge_weight=edge_weight,
edge_attr=edge_attr)
elif self.supports_edge_weight:
x = self.convs[i](x, edge_index, edge_weight=edge_weight)
x = conv(x, edge_index, edge_weight=edge_weight)
elif self.supports_edge_attr:
x = self.convs[i](x, edge_index, edge_attr=edge_attr)
x = conv(x, edge_index, edge_attr=edge_attr)
else:
x = self.convs[i](x, edge_index)
if i == self.num_layers - 1 and self.jk_mode is None:
break
if self.act is not None and self.act_first:
x = self.act(x)
if self.norms is not None:
x = self.norms[i](x)
if self.act is not None and not self.act_first:
x = self.act(x)
x = self.dropout(x)
if hasattr(self, 'jk'):
xs.append(x)
x = conv(x, edge_index)

if i < self.num_layers - 1 or self.jk_mode is not None:
if self.act is not None and self.act_first:
x = self.act(x)
x = norm(x)
if self.act is not None and not self.act_first:
x = self.act(x)
x = self.dropout(x)
if hasattr(self, 'jk'):
xs.append(x)

x = self.jk(xs) if hasattr(self, 'jk') else x
x = self.lin(x) if hasattr(self, 'lin') else x

return x

@torch.no_grad()
Expand Down Expand Up @@ -328,6 +360,76 @@ def inference(

return x_all

def jittable(self, use_sparse_tensor: bool = False) -> 'BasicGNN':
r"""Produces a new jittable instance module that can be used in
combination with :meth:`torch.jit.script`."""
class EdgeIndexJittable(torch.nn.Module):
def __init__(self, child: BasicGNN):
super().__init__()
self.child = child

def reset_parameters(self):
self.child.reset_parameters()

def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_weight: OptTensor = None,
edge_attr: OptTensor = None,
num_sampled_nodes_per_hop: Optional[List[int]] = None,
num_sampled_edges_per_hop: Optional[List[int]] = None,
) -> Tensor:
return self.child(
x,
edge_index,
edge_weight,
edge_attr,
num_sampled_nodes_per_hop,
num_sampled_edges_per_hop,
)

def __repr__(self) -> str:
return str(self.child)

class SparseTensorJittable(torch.nn.Module):
def __init__(self, child: BasicGNN):
super().__init__()
self.child = child

def reset_parameters(self):
self.child.reset_parameters()

def forward(
self,
x: Tensor,
edge_index: SparseTensor,
edge_weight: OptTensor = None,
edge_attr: OptTensor = None,
num_sampled_nodes_per_hop: Optional[List[int]] = None,
num_sampled_edges_per_hop: Optional[List[int]] = None,
) -> Tensor:
return self.child(
x,
edge_index,
edge_weight,
edge_attr,
num_sampled_nodes_per_hop,
num_sampled_edges_per_hop,
)

def __repr__(self) -> str:
return str(self.child)

out = copy.deepcopy(self)
convs = [conv.jittable() for conv in out.convs]
out.convs = torch.nn.ModuleList(convs)
out._trim = None # TODO Trimming is currently not support in JIT mode.

if use_sparse_tensor:
return SparseTensorJittable(out)
return EdgeIndexJittable(out)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, num_layers={self.num_layers})')
Expand Down Expand Up @@ -368,8 +470,8 @@ class GCN(BasicGNN):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GCNConv`.
"""
supports_edge_weight = True
supports_edge_attr = False
supports_edge_weight: Final[bool] = True
supports_edge_attr: Final[bool] = False

def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
Expand Down Expand Up @@ -412,8 +514,8 @@ class GraphSAGE(BasicGNN):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.SAGEConv`.
"""
supports_edge_weight = False
supports_edge_attr = False
supports_edge_weight: Final[bool] = False
supports_edge_attr: Final[bool] = False

def init_conv(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, **kwargs) -> MessagePassing:
Expand Down Expand Up @@ -453,8 +555,8 @@ class GIN(BasicGNN):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.GINConv`.
"""
supports_edge_weight = False
supports_edge_attr = False
supports_edge_weight: Final[bool] = False
supports_edge_attr: Final[bool] = False

def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
Expand Down Expand Up @@ -511,8 +613,8 @@ class GAT(BasicGNN):
:class:`torch_geometric.nn.conv.GATConv` or
:class:`torch_geometric.nn.conv.GATv2Conv`.
"""
supports_edge_weight = False
supports_edge_attr = True
supports_edge_weight: Final[bool] = False
supports_edge_attr: Final[bool] = True

def init_conv(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int, **kwargs) -> MessagePassing:
Expand Down Expand Up @@ -573,8 +675,8 @@ class PNA(BasicGNN):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.PNAConv`.
"""
supports_edge_weight = False
supports_edge_attr = True
supports_edge_weight: Final[bool] = False
supports_edge_attr: Final[bool] = True

def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
Expand Down Expand Up @@ -614,8 +716,8 @@ class EdgeCNN(BasicGNN):
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.EdgeConv`.
"""
supports_edge_weight = False
supports_edge_attr = False
supports_edge_weight: Final[bool] = False
supports_edge_attr: Final[bool] = False

def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
Expand All @@ -629,4 +731,11 @@ def init_conv(self, in_channels: int, out_channels: int,
return EdgeConv(mlp, **kwargs)


__all__ = ['GCN', 'GraphSAGE', 'GIN', 'GAT', 'PNA', 'EdgeCNN']
__all__ = [
'GCN',
'GraphSAGE',
'GIN',
'GAT',
'PNA',
'EdgeCNN',
]
Loading