10
10
from torchrl .envs .common import _EnvWrapper
11
11
from torchrl .envs .utils import step_tensordict
12
12
13
+ __all__ = ["GymLikeEnv" , "default_info_dict_reader" ]
14
+
15
+
16
+ class default_info_dict_reader :
17
+ """
18
+ Default info-key reader.
19
+
20
+ In cases where keys can be directly written to a tensordict (mostly if they abide to the
21
+ tensordict shape), one simply needs to indicate the keys to be registered during
22
+ instantiation.
23
+
24
+ Examples:
25
+ >>> from torchrl.envs import GymWrapper, default_info_dict_reader
26
+ >>> reader = default_info_dict_reader(["my_info_key"])
27
+ >>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
28
+ >>> env = GymWrapper(gym.make("some_env-v0"))
29
+ >>> env.set_info_dict_reader(info_dict_reader=reader)
30
+ >>> tensordict = env.reset()
31
+ >>> tensordict = env.rand_step(tensordict)
32
+ >>> assert "my_info_key" in tensordict.keys()
33
+
34
+ """
35
+
36
+ def __init__ (self , keys = None ):
37
+ if keys is None :
38
+ keys = []
39
+ self .keys = keys
40
+
41
+ def __call__ (self , info_dict : dict , tensordict : _TensorDict ) -> _TensorDict :
42
+ for key in self .keys :
43
+ if key in info_dict :
44
+ tensordict [key ] = info_dict [key ]
45
+ return tensordict
46
+
13
47
14
48
class GymLikeEnv (_EnvWrapper ):
15
- info_keys = []
49
+ _info_dict_reader : callable
16
50
17
51
"""
18
52
A gym-like env is an environment whose behaviour is similar to gym environments in what
@@ -25,7 +59,7 @@ class GymLikeEnv(_EnvWrapper):
25
59
26
60
where the outputs are the observation, reward and done state respectively.
27
61
In this implementation, the info output is discarded (but specific keys can be read
28
- by updating the `"info_keys" ` class attribute ).
62
+ by updating info_dict_reader, see `set_info_dict_reader ` class method ).
29
63
30
64
By default, the first output is written at the "next_observation" key-value pair in the output tensordict, unless
31
65
the first output is a dictionary. In that case, each observation output will be put at the corresponding
@@ -65,9 +99,7 @@ def _step(self, tensordict: _TensorDict) -> _TensorDict:
65
99
)
66
100
tensordict_out .set ("reward" , reward )
67
101
tensordict_out .set ("done" , done )
68
- for key in self .info_keys :
69
- data = info [0 ][key ]
70
- tensordict_out .set (key , data )
102
+ self .info_dict_reader (info , tensordict_out )
71
103
72
104
self .current_tensordict = step_tensordict (tensordict_out )
73
105
return tensordict_out
@@ -100,6 +132,42 @@ def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple:
100
132
)
101
133
return step_outputs_tuple
102
134
135
+ def set_info_dict_reader (self , info_dict_reader : callable ) -> GymLikeEnv :
136
+ """
137
+ Sets an info_dict_reader function. This function should take as input an
138
+ info_dict dictionary and the tensordict returned by the step function, and
139
+ write values in an ad-hoc manner from one to the other.
140
+
141
+ Args:
142
+ info_dict_reader (callable): a callable taking a input dictionary and
143
+ output tensordict as arguments. This function should modify the
144
+ tensordict in-place.
145
+
146
+ Returns: the same environment with the dict_reader registered.
147
+
148
+ Examples:
149
+ >>> from torchrl.envs import GymWrapper, default_info_dict_reader
150
+ >>> reader = default_info_dict_reader(["my_info_key"])
151
+ >>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
152
+ >>> env = GymWrapper(gym.make("some_env-v0")).set_info_dict_reader(info_dict_reader=reader)
153
+ >>> tensordict = env.reset()
154
+ >>> tensordict = env.rand_step(tensordict)
155
+ >>> assert "my_info_key" in tensordict.keys()
156
+
157
+ """
158
+ self .info_dict_reader = info_dict_reader
159
+ return self
160
+
161
+ @property
162
+ def info_dict_reader (self ):
163
+ if "_info_dict_reader" not in self .__dir__ ():
164
+ self ._info_dict_reader = default_info_dict_reader ()
165
+ return self ._info_dict_reader
166
+
167
+ @info_dict_reader .setter
168
+ def info_dict_reader (self , value : callable ):
169
+ self ._info_dict_reader = value
170
+
103
171
def __repr__ (self ) -> str :
104
172
return (
105
173
f"{ self .__class__ .__name__ } (env={ self ._env } , batch_size={ self .batch_size } )"
0 commit comments