File tree 1 file changed +10
-2
lines changed 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -97,7 +97,6 @@ def tensor_keys(self) -> _AcceptedKeys:
97
97
return self ._tensor_keys
98
98
99
99
def __new__ (cls , * args , ** kwargs ):
100
- cls .forward = set_exploration_type (ExplorationType .MODE )(cls .forward )
101
100
self = super ().__new__ (cls )
102
101
return self
103
102
@@ -110,7 +109,16 @@ def __init__(self):
110
109
self .value_type = self .default_value_estimator
111
110
self ._tensor_keys = self ._AcceptedKeys ()
112
111
self .register_forward_pre_hook (_updater_check_forward_prehook )
113
- # self.register_forward_pre_hook(_parameters_to_tensordict)
112
+ expl_mode = set_exploration_type (ExplorationType .MODE )
113
+
114
+ def _pre_hook (* args , expl_mode = expl_mode , ** kwargs ):
115
+ expl_mode .__enter__ ()
116
+
117
+ def _post_hook (* args , expl_mode = expl_mode , ** kwargs ):
118
+ expl_mode .__exit__ (exc_type = None , exc_value = None , traceback = None )
119
+
120
+ self .register_forward_pre_hook (_pre_hook )
121
+ self .register_forward_hook (_post_hook )
114
122
115
123
def _set_deprecated_ctor_keys (self , ** kwargs ) -> None :
116
124
for key , value in kwargs .items ():
You can’t perform that action at this time.
0 commit comments