Skip to content

Commit aa9a9f2

Browse files
authored
[Feature] custom info_dict reader methods (#234)
1 parent 61a00e8 commit aa9a9f2

File tree

2 files changed

+74
-5
lines changed

2 files changed

+74
-5
lines changed

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .vec_env import *
99
from .transforms import *
1010
from .env_creator import *
11+
from .gym_like import *

torchrl/envs/gym_like.py

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,43 @@
1010
from torchrl.envs.common import _EnvWrapper
1111
from torchrl.envs.utils import step_tensordict
1212

13+
__all__ = ["GymLikeEnv", "default_info_dict_reader"]
14+
15+
16+
class default_info_dict_reader:
17+
"""
18+
Default info-key reader.
19+
20+
In cases where keys can be directly written to a tensordict (mostly if they abide to the
21+
tensordict shape), one simply needs to indicate the keys to be registered during
22+
instantiation.
23+
24+
Examples:
25+
>>> from torchrl.envs import GymWrapper, default_info_dict_reader
26+
>>> reader = default_info_dict_reader(["my_info_key"])
27+
>>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
28+
>>> env = GymWrapper(gym.make("some_env-v0"))
29+
>>> env.set_info_dict_reader(info_dict_reader=reader)
30+
>>> tensordict = env.reset()
31+
>>> tensordict = env.rand_step(tensordict)
32+
>>> assert "my_info_key" in tensordict.keys()
33+
34+
"""
35+
36+
def __init__(self, keys=None):
37+
if keys is None:
38+
keys = []
39+
self.keys = keys
40+
41+
def __call__(self, info_dict: dict, tensordict: _TensorDict) -> _TensorDict:
42+
for key in self.keys:
43+
if key in info_dict:
44+
tensordict[key] = info_dict[key]
45+
return tensordict
46+
1347

1448
class GymLikeEnv(_EnvWrapper):
15-
info_keys = []
49+
_info_dict_reader: callable
1650

1751
"""
1852
A gym-like env is an environment whose behaviour is similar to gym environments in what
@@ -25,7 +59,7 @@ class GymLikeEnv(_EnvWrapper):
2559
2660
where the outputs are the observation, reward and done state respectively.
2761
In this implementation, the info output is discarded (but specific keys can be read
28-
by updating the `"info_keys"` class attribute).
62+
by updating info_dict_reader, see `set_info_dict_reader` class method).
2963
3064
By default, the first output is written at the "next_observation" key-value pair in the output tensordict, unless
3165
the first output is a dictionary. In that case, each observation output will be put at the corresponding
@@ -65,9 +99,7 @@ def _step(self, tensordict: _TensorDict) -> _TensorDict:
6599
)
66100
tensordict_out.set("reward", reward)
67101
tensordict_out.set("done", done)
68-
for key in self.info_keys:
69-
data = info[0][key]
70-
tensordict_out.set(key, data)
102+
self.info_dict_reader(info, tensordict_out)
71103

72104
self.current_tensordict = step_tensordict(tensordict_out)
73105
return tensordict_out
@@ -100,6 +132,42 @@ def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple:
100132
)
101133
return step_outputs_tuple
102134

135+
def set_info_dict_reader(self, info_dict_reader: callable) -> GymLikeEnv:
136+
"""
137+
Sets an info_dict_reader function. This function should take as input an
138+
info_dict dictionary and the tensordict returned by the step function, and
139+
write values in an ad-hoc manner from one to the other.
140+
141+
Args:
142+
info_dict_reader (callable): a callable taking a input dictionary and
143+
output tensordict as arguments. This function should modify the
144+
tensordict in-place.
145+
146+
Returns: the same environment with the dict_reader registered.
147+
148+
Examples:
149+
>>> from torchrl.envs import GymWrapper, default_info_dict_reader
150+
>>> reader = default_info_dict_reader(["my_info_key"])
151+
>>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
152+
>>> env = GymWrapper(gym.make("some_env-v0")).set_info_dict_reader(info_dict_reader=reader)
153+
>>> tensordict = env.reset()
154+
>>> tensordict = env.rand_step(tensordict)
155+
>>> assert "my_info_key" in tensordict.keys()
156+
157+
"""
158+
self.info_dict_reader = info_dict_reader
159+
return self
160+
161+
@property
162+
def info_dict_reader(self):
163+
if "_info_dict_reader" not in self.__dir__():
164+
self._info_dict_reader = default_info_dict_reader()
165+
return self._info_dict_reader
166+
167+
@info_dict_reader.setter
168+
def info_dict_reader(self, value: callable):
169+
self._info_dict_reader = value
170+
103171
def __repr__(self) -> str:
104172
return (
105173
f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"

0 commit comments

Comments
 (0)