-
Notifications
You must be signed in to change notification settings - Fork 373
[Feature] custom info_dict reader methods #234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,4 @@ | |
from .vec_env import * | ||
from .transforms import * | ||
from .env_creator import * | ||
from .gym_like import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,9 +10,43 @@ | |
from torchrl.envs.common import _EnvWrapper | ||
from torchrl.envs.utils import step_tensordict | ||
|
||
__all__ = ["GymLikeEnv", "default_info_dict_reader"] | ||
|
||
|
||
class default_info_dict_reader: | ||
""" | ||
Default info-key reader. | ||
|
||
In cases where keys can be directly written to a tensordict (mostly if they abide to the | ||
tensordict shape), one simply needs to indicate the keys to be registered during | ||
instantiation. | ||
|
||
Examples: | ||
>>> from torchrl.envs import GymWrapper, default_info_dict_reader | ||
>>> reader = default_info_dict_reader(["my_info_key"]) | ||
>>> # assuming "some_env-v0" returns a dict with a key "my_info_key" | ||
>>> env = GymWrapper(gym.make("some_env-v0")) | ||
>>> env.set_info_dict_reader(info_dict_reader=reader) | ||
>>> tensordict = env.reset() | ||
>>> tensordict = env.rand_step(tensordict) | ||
>>> assert "my_info_key" in tensordict.keys() | ||
|
||
""" | ||
|
||
def __init__(self, keys=None): | ||
if keys is None: | ||
keys = [] | ||
self.keys = keys | ||
|
||
def __call__(self, info_dict: dict, tensordict: _TensorDict) -> _TensorDict: | ||
for key in self.keys: | ||
if key in info_dict: | ||
tensordict[key] = info_dict[key] | ||
return tensordict | ||
|
||
|
||
class GymLikeEnv(_EnvWrapper): | ||
info_keys = [] | ||
_info_dict_reader: callable | ||
|
||
""" | ||
A gym-like env is an environment whose behaviour is similar to gym environments in what | ||
|
@@ -25,7 +59,7 @@ class GymLikeEnv(_EnvWrapper): | |
|
||
where the outputs are the observation, reward and done state respectively. | ||
In this implementation, the info output is discarded (but specific keys can be read | ||
by updating the `"info_keys"` class attribute). | ||
by updating info_dict_reader, see `set_info_dict_reader` class method). | ||
|
||
By default, the first output is written at the "next_observation" key-value pair in the output tensordict, unless | ||
the first output is a dictionary. In that case, each observation output will be put at the corresponding | ||
|
@@ -65,9 +99,7 @@ def _step(self, tensordict: _TensorDict) -> _TensorDict: | |
) | ||
tensordict_out.set("reward", reward) | ||
tensordict_out.set("done", done) | ||
for key in self.info_keys: | ||
data = info[0][key] | ||
tensordict_out.set(key, data) | ||
self.info_dict_reader(info, tensordict_out) | ||
|
||
self.current_tensordict = step_tensordict(tensordict_out) | ||
return tensordict_out | ||
|
@@ -100,6 +132,42 @@ def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple: | |
) | ||
return step_outputs_tuple | ||
|
||
def set_info_dict_reader(self, info_dict_reader: callable) -> GymLikeEnv: | ||
""" | ||
Sets an info_dict_reader function. This function should take as input an | ||
info_dict dictionary and the tensordict returned by the step function, and | ||
write values in an ad-hoc manner from one to the other. | ||
|
||
Args: | ||
info_dict_reader (callable): a callable taking a input dictionary and | ||
output tensordict as arguments. This function should modify the | ||
tensordict in-place. | ||
|
||
Returns: the same environment with the dict_reader registered. | ||
|
||
Examples: | ||
>>> from torchrl.envs import GymWrapper, default_info_dict_reader | ||
>>> reader = default_info_dict_reader(["my_info_key"]) | ||
>>> # assuming "some_env-v0" returns a dict with a key "my_info_key" | ||
>>> env = GymWrapper(gym.make("some_env-v0")).set_info_dict_reader(info_dict_reader=reader) | ||
>>> tensordict = env.reset() | ||
>>> tensordict = env.rand_step(tensordict) | ||
>>> assert "my_info_key" in tensordict.keys() | ||
|
||
""" | ||
self.info_dict_reader = info_dict_reader | ||
return self | ||
|
||
@property | ||
def info_dict_reader(self): | ||
if "_info_dict_reader" not in self.__dir__(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume that this is checking if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Problem with hasattr is that it'll look into env.env (ie the gym env) if it can't find it, and I can't promise what env.env has and hasn't. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting. I didnt realise hashattr would do that but it makes sense! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it's not a clever function |
||
self._info_dict_reader = default_info_dict_reader() | ||
return self._info_dict_reader | ||
|
||
@info_dict_reader.setter | ||
def info_dict_reader(self, value: callable): | ||
self._info_dict_reader = value | ||
|
||
def __repr__(self) -> str: | ||
return ( | ||
f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to ensure that keys like 'observation', 'next_observation', 'action', 'reward', 'done' are not passed so that they arent read from info_dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point but on the other hand some crazy researcher may want to do that no? Like overriding the observation based on the info. Since the default is reading nothing we can assume that those inputting similar keys to read know what they're doing, what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense - good to go from my side!