Skip to content

Commit 1bd5ec6

Browse files
author
Vincent Moens
authored
[BugFix] Fix exploration in losses (#1898)
1 parent 2cfd9b6 commit 1bd5ec6

File tree

3 files changed

+26
-11
lines changed

3 files changed

+26
-11
lines changed

test/test_cost.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TensorDictSequential,
2626
TensorDictSequential as Seq,
2727
)
28+
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
2829

2930
from torchrl.modules.models import QMixer
3031

@@ -12391,6 +12392,22 @@ def __init__(self):
1239112392
assert p.device == dest
1239212393

1239312394

12395+
def test_loss_exploration():
12396+
class DummyLoss(LossModule):
12397+
def forward(self, td):
12398+
assert exploration_type() == InteractionType.MODE
12399+
with set_exploration_type(ExplorationType.RANDOM):
12400+
assert exploration_type() == ExplorationType.RANDOM
12401+
assert exploration_type() == ExplorationType.MODE
12402+
return td
12403+
12404+
loss_fn = DummyLoss()
12405+
with set_exploration_type(ExplorationType.RANDOM):
12406+
assert exploration_type() == ExplorationType.RANDOM
12407+
loss_fn(None)
12408+
assert exploration_type() == ExplorationType.RANDOM
12409+
12410+
1239412411
if __name__ == "__main__":
1239512412
args, unknown = argparse.ArgumentParser().parse_known_args()
1239612413
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_exploration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
class TestEGreedy:
5555
@pytest.mark.parametrize("eps_init", [0.0, 0.5, 1])
5656
@pytest.mark.parametrize("module", [True, False])
57+
@set_exploration_type(InteractionType.RANDOM)
5758
def test_egreedy(self, eps_init, module):
5859
torch.manual_seed(0)
5960
spec = BoundedTensorSpec(1, 1, torch.Size([4]))

torchrl/objectives/common.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import abc
89
import warnings
910
from copy import deepcopy
1011
from dataclasses import dataclass
@@ -31,7 +32,13 @@ def _updater_check_forward_prehook(module, *args, **kwargs):
3132
)
3233

3334

34-
class LossModule(TensorDictModuleBase):
35+
class _LossMeta(abc.ABCMeta):
36+
def __init__(cls, name, bases, attr_dict):
37+
super().__init__(name, bases, attr_dict)
38+
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)
39+
40+
41+
class LossModule(TensorDictModuleBase, metaclass=_LossMeta):
3542
"""A parent class for RL losses.
3643
3744
LossModule inherits from nn.Module. It is designed to read an input
@@ -109,16 +116,6 @@ def __init__(self):
109116
self.value_type = self.default_value_estimator
110117
self._tensor_keys = self._AcceptedKeys()
111118
self.register_forward_pre_hook(_updater_check_forward_prehook)
112-
expl_mode = set_exploration_type(ExplorationType.MODE)
113-
114-
def _pre_hook(*args, expl_mode=expl_mode, **kwargs):
115-
expl_mode.__enter__()
116-
117-
def _post_hook(*args, expl_mode=expl_mode, **kwargs):
118-
expl_mode.__exit__(exc_type=None, exc_value=None, traceback=None)
119-
120-
self.register_forward_pre_hook(_pre_hook)
121-
self.register_forward_hook(_post_hook)
122119

123120
def _set_deprecated_ctor_keys(self, **kwargs) -> None:
124121
for key, value in kwargs.items():

0 commit comments

Comments
 (0)