Skip to content

Commit

Permalink
Support normalization_resolver in BasicGNN (pyg-team#4958)
Browse files Browse the repository at this point in the history
* Support normalization_resolver in basic_gnn

* changelog

* update

* update

* update

* fix

* update

* reset

Co-authored-by: Guohao Li <lighaime@gmail.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Jul 12, 2022
1 parent 5ce04c5 commit bac7021
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 22 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `unbatch_edge_index` functionality for splitting an `edge_index` tensor according to a `batch` vector ([#4903](https://github.com/pyg-team/pytorch_geometric/pull/4903))
- Added node-wise normalization mode in `LayerNorm` ([#4944](https://github.com/pyg-team/pytorch_geometric/pull/4944))
- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926), [#4951](https://github.com/pyg-team/pytorch_geometric/pull/4951))
- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926), [#4951](https://github.com/pyg-team/pytorch_geometric/pull/4951), [#4958](https://github.com/pyg-team/pytorch_geometric/pull/4958))
- Added notebook tutorial for `torch_geometric.nn.aggr` package to documentation ([#4927](https://github.com/pyg-team/pytorch_geometric/pull/4927))
- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))
- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))
Expand Down
3 changes: 2 additions & 1 deletion test/graphgym/test_config_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_config_store():
assert cfg.dataset.transform.AddSelfLoops.fill_value is None

# Check `cfg.model`:
assert len(cfg.model) == 11
assert len(cfg.model) == 12
assert cfg.model._target_.split('.')[-1] == 'GCN'
assert cfg.model.in_channels == 34
assert cfg.model.out_channels == 4
Expand All @@ -50,6 +50,7 @@ def test_config_store():
assert cfg.model.dropout == 0.0
assert cfg.model.act == 'relu'
assert cfg.model.norm is None
assert cfg.model.norm_kwargs is None
assert cfg.model.jk is None
assert not cfg.model.act_first
assert cfg.model.act_kwargs is None
Expand Down
61 changes: 44 additions & 17 deletions torch_geometric/nn/models/basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
)
from torch_geometric.nn.models import MLP
from torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.nn.resolver import (
activation_resolver,
normalization_resolver,
)
from torch_geometric.typing import Adj


Expand All @@ -34,18 +37,21 @@ class BasicGNN(torch.nn.Module):
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
node embeddings to the expected output feature dimensionality.
(:obj:`None`, :obj:`"last"`, :obj:`"cat"`, :obj:`"max"`,
:obj:`"lstm"`). (default: :obj:`None`)
act_first (bool, optional): If set to :obj:`True`, activation is
applied before normalization. (default: :obj:`False`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
**kwargs (optional): Additional arguments of the underlying
:class:`torch_geometric.nn.conv.MessagePassing` layers.
"""
Expand All @@ -57,10 +63,11 @@ def __init__(
out_channels: Optional[int] = None,
dropout: float = 0.0,
act: Union[str, Callable, None] = "relu",
norm: Optional[torch.nn.Module] = None,
jk: Optional[str] = None,
act_first: bool = False,
act_kwargs: Optional[Dict[str, Any]] = None,
norm: Union[str, Callable, None] = None,
norm_kwargs: Optional[Dict[str, Any]] = None,
jk: Optional[str] = None,
**kwargs,
):
super().__init__()
Expand Down Expand Up @@ -98,11 +105,16 @@ def __init__(

self.norms = None
if norm is not None:
norm_layer = normalization_resolver(
norm,
hidden_channels,
**(norm_kwargs or {}),
)
self.norms = ModuleList()
for _ in range(num_layers - 1):
self.norms.append(copy.deepcopy(norm))
self.norms.append(copy.deepcopy(norm_layer))
if jk is not None:
self.norms.append(copy.deepcopy(norm))
self.norms.append(copy.deepcopy(norm_layer))

if jk is not None and jk != 'last':
self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)
Expand Down Expand Up @@ -170,7 +182,10 @@ class GCN(BasicGNN):
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
Expand Down Expand Up @@ -205,7 +220,10 @@ class GraphSAGE(BasicGNN):
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
Expand Down Expand Up @@ -240,7 +258,10 @@ class GIN(BasicGNN):
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (Callable, optional): The non-linear activation function to use.
(default: :obj:`torch.nn.ReLU(inplace=True)`)
norm (torch.nn.Module, optional): The normalization operator to use.
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
Expand All @@ -257,7 +278,7 @@ class GIN(BasicGNN):
"""
def init_conv(self, in_channels: int, out_channels: int,
**kwargs) -> MessagePassing:
mlp = MLP([in_channels, out_channels, out_channels], batch_norm=True)
mlp = MLP([in_channels, out_channels, out_channels], norm="batch_norm")
return GINConv(mlp, **kwargs)


Expand All @@ -282,7 +303,10 @@ class GAT(BasicGNN):
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
Expand Down Expand Up @@ -338,7 +362,10 @@ class PNA(BasicGNN):
dropout (float, optional): Dropout probability. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
norm (torch.nn.Module, optional): The normalization operator to use.
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`None`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
jk (str, optional): The Jumping Knowledge mode. If specified, the model
will additionally apply a final linear transformation to transform
Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -78,10 +78,10 @@ def __init__(
out_channels: Optional[int] = None,
num_layers: Optional[int] = None,
dropout: float = 0.,
act: str = "relu",
act: Union[str, Callable, None] = "relu",
act_first: bool = False,
act_kwargs: Optional[Dict[str, Any]] = None,
norm: Optional[str] = 'batch_norm',
norm: Union[str, Callable, None] = "batch_norm",
norm_kwargs: Optional[Dict[str, Any]] = None,
plain_last: bool = True,
bias: bool = True,
Expand Down

0 comments on commit bac7021

Please sign in to comment.