From 9eaa8c074f77fbaf34ec6457051d6df6c1c89d8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Julian=20F=C3=B6rster?= Date: Wed, 13 Jul 2022 20:12:24 +0200 Subject: [PATCH] Fix `graphgym` activation function dictionary, now not including instances anymore (#4978) * Individual activation function instances * Removed unnecessary import * Updated changelog * Revert "Removed unnecessary import" This reverts commit f48bc2c8f06cb54ff50fccebbcc1cbc28750a7f9. * fix test * fix test --- CHANGELOG.md | 1 + test/graphgym/test_graphgym.py | 4 +-- test/graphgym/test_register.py | 2 +- torch_geometric/graphgym/models/act.py | 43 ++++++++++++++++++++---- torch_geometric/graphgym/models/layer.py | 2 +- 5 files changed, 41 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6d4425f422b..3d5f4f31e5ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Changed `act_dict` (part of `graphgym`) to create individual instances instead of reusing the same ones everywhere ([4978](https://github.com/pyg-team/pytorch_geometric/pull/4978)) - Fixed issue where one-hot tensors were passed to `F.one_hot` ([4970](https://github.com/pyg-team/pytorch_geometric/pull/4970)) - Fixed `bool` arugments in `argparse` in `benchmark/` ([#4967](https://github.com/pyg-team/pytorch_geometric/pull/4967)) - Fixed `BasicGNN` for `num_layers=1`, which now respects a desired number of `out_channels` ([#4943](https://github.com/pyg-team/pytorch_geometric/pull/4943)) diff --git a/test/graphgym/test_graphgym.py b/test/graphgym/test_graphgym.py index aa518443add0..d8fcb3e09b31 100644 --- a/test/graphgym/test_graphgym.py +++ b/test/graphgym/test_graphgym.py @@ -95,7 +95,7 @@ def test_run_single_graphgym(auto_resume, skip_train_eval, use_trivial_metric): assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR) cfg.params = params_count(model) - assert cfg.params == 23880 + assert cfg.params == 23883 train(model, datamodule, logger=True, trainer_config={"enable_progress_bar": False}) @@ -136,7 +136,7 @@ def test_graphgym_module(tmpdir): assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR) cfg.params = params_count(model) - assert cfg.params == 23880 + assert cfg.params == 23883 keys = {"loss", "true", "pred_score", "step_end_time"} # test training step diff --git a/test/graphgym/test_register.py b/test/graphgym/test_register.py index 4d046c01775d..fd445b71f81e 100644 --- a/test/graphgym/test_register.py +++ b/test/graphgym/test_register.py @@ -16,7 +16,7 @@ def test_register(): 'relu', 'selu', 'prelu', 'elu', 'lrelu_01', 'lrelu_025', 'lrelu_05', 'identity' ] - assert str(register.act_dict['relu']) == 'ReLU()' + assert str(register.act_dict['relu']()) == 'ReLU()' register.register_act('lrelu_03', torch.nn.LeakyReLU(0.3)) assert len(register.act_dict) == 9 diff --git a/torch_geometric/graphgym/models/act.py b/torch_geometric/graphgym/models/act.py index 088b22e1141a..4f9bc4d6811c 100644 --- a/torch_geometric/graphgym/models/act.py +++ b/torch_geometric/graphgym/models/act.py @@ -3,11 +3,40 @@ from torch_geometric.graphgym.config import cfg from torch_geometric.graphgym.register import register_act + +def relu(): + return nn.ReLU(inplace=cfg.mem.inplace) + + +def selu(): + return nn.SELU(inplace=cfg.mem.inplace) + + +def prelu(): + return nn.PReLU() + + +def elu(): + return nn.ELU(inplace=cfg.mem.inplace) + + +def lrelu_01(): + return nn.LeakyReLU(0.1, inplace=cfg.mem.inplace) + + +def lrelu_025(): + return nn.LeakyReLU(0.25, inplace=cfg.mem.inplace) + + +def lrelu_05(): + return nn.LeakyReLU(0.5, inplace=cfg.mem.inplace) + + if cfg is not None: - register_act('relu', nn.ReLU(inplace=cfg.mem.inplace)) - register_act('selu', nn.SELU(inplace=cfg.mem.inplace)) - register_act('prelu', nn.PReLU()) - register_act('elu', nn.ELU(inplace=cfg.mem.inplace)) - register_act('lrelu_01', nn.LeakyReLU(0.1, inplace=cfg.mem.inplace)) - register_act('lrelu_025', nn.LeakyReLU(0.25, inplace=cfg.mem.inplace)) - register_act('lrelu_05', nn.LeakyReLU(0.5, inplace=cfg.mem.inplace)) + register_act('relu', relu) + register_act('selu', selu) + register_act('prelu', prelu) + register_act('elu', elu) + register_act('lrelu_01', lrelu_01) + register_act('lrelu_025', lrelu_025) + register_act('lrelu_05', lrelu_05) diff --git a/torch_geometric/graphgym/models/layer.py b/torch_geometric/graphgym/models/layer.py index c92de97f31af..a4cd2f138e41 100644 --- a/torch_geometric/graphgym/models/layer.py +++ b/torch_geometric/graphgym/models/layer.py @@ -96,7 +96,7 @@ def __init__(self, name, layer_config: LayerConfig, **kwargs): nn.Dropout(p=layer_config.dropout, inplace=layer_config.mem_inplace)) if layer_config.has_act: - layer_wrapper.append(register.act_dict[layer_config.act]) + layer_wrapper.append(register.act_dict[layer_config.act]()) self.post_layer = nn.Sequential(*layer_wrapper) def forward(self, batch):