Skip to content

Commit

Permalink
Fix graphgym activation function dictionary, now not including inst…
Browse files Browse the repository at this point in the history
…ances anymore (pyg-team#4978)

* Individual activation function instances

* Removed unnecessary import

* Updated changelog

* Revert "Removed unnecessary import"

This reverts commit f48bc2c.

* fix test

* fix test
  • Loading branch information
fjulian authored Jul 13, 2022
1 parent 9b129b8 commit 9eaa8c0
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Changed `act_dict` (part of `graphgym`) to create individual instances instead of reusing the same ones everywhere ([4978](https://github.com/pyg-team/pytorch_geometric/pull/4978))
- Fixed issue where one-hot tensors were passed to `F.one_hot` ([4970](https://github.com/pyg-team/pytorch_geometric/pull/4970))
- Fixed `bool` arugments in `argparse` in `benchmark/` ([#4967](https://github.com/pyg-team/pytorch_geometric/pull/4967))
- Fixed `BasicGNN` for `num_layers=1`, which now respects a desired number of `out_channels` ([#4943](https://github.com/pyg-team/pytorch_geometric/pull/4943))
Expand Down
4 changes: 2 additions & 2 deletions test/graphgym/test_graphgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_run_single_graphgym(auto_resume, skip_train_eval, use_trivial_metric):
assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR)

cfg.params = params_count(model)
assert cfg.params == 23880
assert cfg.params == 23883

train(model, datamodule, logger=True,
trainer_config={"enable_progress_bar": False})
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_graphgym_module(tmpdir):
assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR)

cfg.params = params_count(model)
assert cfg.params == 23880
assert cfg.params == 23883

keys = {"loss", "true", "pred_score", "step_end_time"}
# test training step
Expand Down
2 changes: 1 addition & 1 deletion test/graphgym/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_register():
'relu', 'selu', 'prelu', 'elu', 'lrelu_01', 'lrelu_025', 'lrelu_05',
'identity'
]
assert str(register.act_dict['relu']) == 'ReLU()'
assert str(register.act_dict['relu']()) == 'ReLU()'

register.register_act('lrelu_03', torch.nn.LeakyReLU(0.3))
assert len(register.act_dict) == 9
Expand Down
43 changes: 36 additions & 7 deletions torch_geometric/graphgym/models/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,40 @@
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_act


def relu():
return nn.ReLU(inplace=cfg.mem.inplace)


def selu():
return nn.SELU(inplace=cfg.mem.inplace)


def prelu():
return nn.PReLU()


def elu():
return nn.ELU(inplace=cfg.mem.inplace)


def lrelu_01():
return nn.LeakyReLU(0.1, inplace=cfg.mem.inplace)


def lrelu_025():
return nn.LeakyReLU(0.25, inplace=cfg.mem.inplace)


def lrelu_05():
return nn.LeakyReLU(0.5, inplace=cfg.mem.inplace)


if cfg is not None:
register_act('relu', nn.ReLU(inplace=cfg.mem.inplace))
register_act('selu', nn.SELU(inplace=cfg.mem.inplace))
register_act('prelu', nn.PReLU())
register_act('elu', nn.ELU(inplace=cfg.mem.inplace))
register_act('lrelu_01', nn.LeakyReLU(0.1, inplace=cfg.mem.inplace))
register_act('lrelu_025', nn.LeakyReLU(0.25, inplace=cfg.mem.inplace))
register_act('lrelu_05', nn.LeakyReLU(0.5, inplace=cfg.mem.inplace))
register_act('relu', relu)
register_act('selu', selu)
register_act('prelu', prelu)
register_act('elu', elu)
register_act('lrelu_01', lrelu_01)
register_act('lrelu_025', lrelu_025)
register_act('lrelu_05', lrelu_05)
2 changes: 1 addition & 1 deletion torch_geometric/graphgym/models/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, name, layer_config: LayerConfig, **kwargs):
nn.Dropout(p=layer_config.dropout,
inplace=layer_config.mem_inplace))
if layer_config.has_act:
layer_wrapper.append(register.act_dict[layer_config.act])
layer_wrapper.append(register.act_dict[layer_config.act]())
self.post_layer = nn.Sequential(*layer_wrapper)

def forward(self, batch):
Expand Down

0 comments on commit 9eaa8c0

Please sign in to comment.