Skip to content

[Feature] custom info_dict reader methods #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .vec_env import *
from .transforms import *
from .env_creator import *
from .gym_like import *
78 changes: 73 additions & 5 deletions torchrl/envs/gym_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,43 @@
from torchrl.envs.common import _EnvWrapper
from torchrl.envs.utils import step_tensordict

__all__ = ["GymLikeEnv", "default_info_dict_reader"]


class default_info_dict_reader:
"""
Default info-key reader.

In cases where keys can be directly written to a tensordict (mostly if they abide to the
tensordict shape), one simply needs to indicate the keys to be registered during
instantiation.

Examples:
>>> from torchrl.envs import GymWrapper, default_info_dict_reader
>>> reader = default_info_dict_reader(["my_info_key"])
>>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
>>> env = GymWrapper(gym.make("some_env-v0"))
>>> env.set_info_dict_reader(info_dict_reader=reader)
>>> tensordict = env.reset()
>>> tensordict = env.rand_step(tensordict)
>>> assert "my_info_key" in tensordict.keys()

"""

def __init__(self, keys=None):
if keys is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to ensure that keys like 'observation', 'next_observation', 'action', 'reward', 'done' are not passed so that they arent read from info_dict

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point but on the other hand some crazy researcher may want to do that no? Like overriding the observation based on the info. Since the default is reading nothing we can assume that those inputting similar keys to read know what they're doing, what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - good to go from my side!

keys = []
self.keys = keys

def __call__(self, info_dict: dict, tensordict: _TensorDict) -> _TensorDict:
for key in self.keys:
if key in info_dict:
tensordict[key] = info_dict[key]
return tensordict


class GymLikeEnv(_EnvWrapper):
info_keys = []
_info_dict_reader: callable

"""
A gym-like env is an environment whose behaviour is similar to gym environments in what
Expand All @@ -25,7 +59,7 @@ class GymLikeEnv(_EnvWrapper):

where the outputs are the observation, reward and done state respectively.
In this implementation, the info output is discarded (but specific keys can be read
by updating the `"info_keys"` class attribute).
by updating info_dict_reader, see `set_info_dict_reader` class method).

By default, the first output is written at the "next_observation" key-value pair in the output tensordict, unless
the first output is a dictionary. In that case, each observation output will be put at the corresponding
Expand Down Expand Up @@ -65,9 +99,7 @@ def _step(self, tensordict: _TensorDict) -> _TensorDict:
)
tensordict_out.set("reward", reward)
tensordict_out.set("done", done)
for key in self.info_keys:
data = info[0][key]
tensordict_out.set(key, data)
self.info_dict_reader(info, tensordict_out)

self.current_tensordict = step_tensordict(tensordict_out)
return tensordict_out
Expand Down Expand Up @@ -100,6 +132,42 @@ def _output_transform(self, step_outputs_tuple: Tuple) -> Tuple:
)
return step_outputs_tuple

def set_info_dict_reader(self, info_dict_reader: callable) -> GymLikeEnv:
"""
Sets an info_dict_reader function. This function should take as input an
info_dict dictionary and the tensordict returned by the step function, and
write values in an ad-hoc manner from one to the other.

Args:
info_dict_reader (callable): a callable taking a input dictionary and
output tensordict as arguments. This function should modify the
tensordict in-place.

Returns: the same environment with the dict_reader registered.

Examples:
>>> from torchrl.envs import GymWrapper, default_info_dict_reader
>>> reader = default_info_dict_reader(["my_info_key"])
>>> # assuming "some_env-v0" returns a dict with a key "my_info_key"
>>> env = GymWrapper(gym.make("some_env-v0")).set_info_dict_reader(info_dict_reader=reader)
>>> tensordict = env.reset()
>>> tensordict = env.rand_step(tensordict)
>>> assert "my_info_key" in tensordict.keys()

"""
self.info_dict_reader = info_dict_reader
return self

@property
def info_dict_reader(self):
if "_info_dict_reader" not in self.__dir__():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that this is checking if _info_dict_reader attribute exists or not. I am not sure if checking the self.__dir__() is the recommended method. I have seen use of hasattr method lot more. So just flagging this, in case it turns out to be relevant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Problem with hasattr is that it'll look into env.env (ie the gym env) if it can't find it, and I can't promise what env.env has and hasn't.
With dir I'm sure it won't propagate to the wrapped env

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I didnt realise hashattr would do that but it makes sense!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's not a clever function
It just runs getattr and tells you ok if it did not return an error...

self._info_dict_reader = default_info_dict_reader()
return self._info_dict_reader

@info_dict_reader.setter
def info_dict_reader(self, value: callable):
self._info_dict_reader = value

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(env={self._env}, batch_size={self.batch_size})"
Expand Down