Skip to content

Commit 2cfd9b6

Browse files
author
Vincent Moens
authored
[BugFix] Solve recursion issue in losses hook (#1897)
1 parent 89213f9 commit 2cfd9b6

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

torchrl/objectives/common.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def tensor_keys(self) -> _AcceptedKeys:
9797
return self._tensor_keys
9898

9999
def __new__(cls, *args, **kwargs):
100-
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)
101100
self = super().__new__(cls)
102101
return self
103102

@@ -110,7 +109,16 @@ def __init__(self):
110109
self.value_type = self.default_value_estimator
111110
self._tensor_keys = self._AcceptedKeys()
112111
self.register_forward_pre_hook(_updater_check_forward_prehook)
113-
# self.register_forward_pre_hook(_parameters_to_tensordict)
112+
expl_mode = set_exploration_type(ExplorationType.MODE)
113+
114+
def _pre_hook(*args, expl_mode=expl_mode, **kwargs):
115+
expl_mode.__enter__()
116+
117+
def _post_hook(*args, expl_mode=expl_mode, **kwargs):
118+
expl_mode.__exit__(exc_type=None, exc_value=None, traceback=None)
119+
120+
self.register_forward_pre_hook(_pre_hook)
121+
self.register_forward_hook(_post_hook)
114122

115123
def _set_deprecated_ctor_keys(self, **kwargs) -> None:
116124
for key, value in kwargs.items():

0 commit comments

Comments
 (0)