Skip to content

Commit

Permalink
Added fine grained options for setting bias and dropout per layer…
Browse files Browse the repository at this point in the history
… in the `MLP` model (pyg-team#4981)

* Added more fine grained options for biases and dropouts

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update torch_geometric/nn/models/mlp.py

Co-authored-by: Padarn Wilson <padarn.wilson@grabtaxi.com>

* Update torch_geometric/nn/models/mlp.py

Co-authored-by: Padarn Wilson <padarn.wilson@grabtaxi.com>

* Updated CHANGELOG.md and fixed some CI errors

* reset

* update

* add test

Co-authored-by: Bram <b.t.ton@saxion.nl>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Padarn Wilson <padarn.wilson@grabtaxi.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
5 people authored Jul 17, 2022
1 parent ad357c5 commit 1acd235
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added fine grained options for setting `bias` and `dropout` per layer in the `MLP` model ([#4981](https://github.com/pyg-team/pytorch_geometric/pull/4981)
- Added `EdgeCNN` model ([#4991](https://github.com/pyg-team/pytorch_geometric/pull/4991))
- Added scalable `inference` mode in `BasicGNN` with layer-wise neighbor loading ([#4977](https://github.com/pyg-team/pytorch_geometric/pull/4977))
- Added inference benchmarks ([#4892](https://github.com/pyg-team/pytorch_geometric/pull/4892))
Expand Down
10 changes: 10 additions & 0 deletions test/nn/models/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,13 @@ def test_mlp(norm, act_first, plain_last):
plain_last=plain_last,
)
assert torch.allclose(mlp(x), out)


@pytest.mark.parametrize('plain_last', [False, True])
def test_fine_grained_mlp(plain_last):
mlp = MLP(
[16, 32, 32, 64],
dropout=[0.1, 0.2, 0.3],
bias=[False, True, False],
)
assert mlp(torch.randn(4, 16)).size() == (4, 64)
44 changes: 32 additions & 12 deletions torch_geometric/nn/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ class MLP(torch.nn.Module):
Will override :attr:`channel_list`. (default: :obj:`None`)
num_layers (int, optional): The number of layers.
Will override :attr:`channel_list`. (default: :obj:`None`)
dropout (float, optional): Dropout probability of each hidden
embedding. (default: :obj:`0.`)
dropout (float or List[float], optional): Dropout probability of each
hidden embedding. If a list is provided, sets the dropout value per
layer. (default: :obj:`0.`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_first (bool, optional): If set to :obj:`True`, activation is
Expand All @@ -65,8 +66,9 @@ class MLP(torch.nn.Module):
plain_last (bool, optional): If set to :obj:`False`, will apply
non-linearity, batch normalization and dropout to the last layer as
well. (default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the module will not
learn additive biases. (default: :obj:`True`)
bias (bool or List[bool], optional): If set to :obj:`False`, the module
will not learn additive biases. If a list is provided, sets the
bias per layer. (default: :obj:`True`)
**kwargs (optional): Additional deprecated arguments of the MLP layer.
"""
def __init__(
Expand All @@ -77,14 +79,14 @@ def __init__(
hidden_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: Optional[int] = None,
dropout: float = 0.,
dropout: Union[float, List[float]] = 0.,
act: Union[str, Callable, None] = "relu",
act_first: bool = False,
act_kwargs: Optional[Dict[str, Any]] = None,
norm: Union[str, Callable, None] = "batch_norm",
norm_kwargs: Optional[Dict[str, Any]] = None,
plain_last: bool = True,
bias: bool = True,
bias: Union[bool, List[bool]] = True,
**kwargs,
):
super().__init__()
Expand All @@ -111,15 +113,32 @@ def __init__(
assert len(channel_list) >= 2
self.channel_list = channel_list

self.dropout = dropout
self.act = activation_resolver(act, **(act_kwargs or {}))
self.act_first = act_first
self.plain_last = plain_last

if isinstance(dropout, float):
dropout = [dropout] * (len(channel_list) - 1)
if plain_last:
dropout[-1] = 0.
if len(dropout) != len(channel_list) - 1:
raise ValueError(
f"Number of dropout values provided ({len(dropout)} does not "
f"match the number of layers specified "
f"({len(channel_list)-1})")
self.dropout = dropout

if isinstance(bias, bool):
bias = [bias] * (len(channel_list) - 1)
if len(bias) != len(channel_list) - 1:
raise ValueError(
f"Number of bias values provided ({len(bias)}) does not match "
f"the number of layers specified ({len(channel_list)-1})")

self.lins = torch.nn.ModuleList()
iterator = zip(channel_list[:-1], channel_list[1:])
for in_channels, out_channels in iterator:
self.lins.append(Linear(in_channels, out_channels, bias=bias))
iterator = zip(channel_list[:-1], channel_list[1:], bias)
for in_channels, out_channels, _bias in iterator:
self.lins.append(Linear(in_channels, out_channels, bias=_bias))

self.norms = torch.nn.ModuleList()
iterator = channel_list[1:-1] if plain_last else channel_list[1:]
Expand Down Expand Up @@ -160,18 +179,19 @@ def reset_parameters(self):

def forward(self, x: Tensor, return_emb: NoneType = None) -> Tensor:
""""""
for lin, norm in zip(self.lins, self.norms):
for lin, norm, dropout in zip(self.lins, self.norms, self.dropout):
x = lin(x)
if self.act is not None and self.act_first:
x = self.act(x)
x = norm(x)
if self.act is not None and not self.act_first:
x = self.act(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.dropout(x, p=dropout, training=self.training)
emb = x

if self.plain_last:
x = self.lins[-1](x)
x = F.dropout(x, p=self.dropout[-1], training=self.training)

return (x, emb) if isinstance(return_emb, bool) else x

Expand Down

0 comments on commit 1acd235

Please sign in to comment.