Skip to content

[Refactor] Pass all keys at reset (prototype) #2956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4083,7 +4083,7 @@ def test_reset_white_to_move(self, stateful, include_pgn, include_fen):
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 w - - 0 1"
td = env.reset(TensorDict({"fen": fen}))
td = env.reset(TensorDict({"fen_reset": fen}))
if include_fen:
assert td["fen"] == fen
assert env.board.fen() == fen
Expand All @@ -4097,7 +4097,7 @@ def test_reset_black_to_move(self, stateful, include_pgn, include_fen):
stateful=stateful, include_pgn=include_pgn, include_fen=include_fen
)
fen = "5k2/4r3/8/8/8/1Q6/2K5/8 b - - 0 1"
td = env.reset(TensorDict({"fen": fen}))
td = env.reset(TensorDict({"fen_reset": fen}))
assert td["fen"] == fen
assert env.board.fen() == fen
assert td["turn"] == env.lib.BLACK
Expand All @@ -4111,7 +4111,7 @@ def test_reset_done_error(self, stateful, include_pgn, include_fen):
)
fen = "1R3k2/2R5/8/8/8/8/2K5/8 b - - 0 1"
with pytest.raises(ValueError) as e_info:
env.reset(TensorDict({"fen": fen}))
env.reset(TensorDict({"fen_reset": fen}))

assert "Cannot reset to a fen that is a gameover state" in str(e_info)

Expand Down Expand Up @@ -4181,7 +4181,7 @@ def test_reward(
if reset_without_fen:
td = TensorDict({"fen": fen})
else:
td = env.reset(TensorDict({"fen": fen}))
td = env.reset(TensorDict({"fen_reset": fen}))
assert td["turn"] == expected_turn

td["action"] = env._san_moves.index(move)
Expand Down Expand Up @@ -4230,16 +4230,18 @@ def test_env_reset_with_hash(self, stateful, include_san):
]
for fen, num_legal_moves in cases:
# Load the state by fen.
td = env.reset(TensorDict({"fen": fen}))
td = env.reset(TensorDict({"fen_reset": fen}))
assert td["fen"] == fen
assert td["action_mask"].sum() == num_legal_moves

# Reset to initial state just to make sure that the next reset
# actually changes the state.
assert env.reset()["action_mask"].sum() == 20

# Load the state by fen hash and make sure it gives the same output
# as before.
td_check = env.reset(td.select("fen_hash"))
assert (td_check == td).all()
assert assert_allclose_td(td_check, td, intersection=True)

@pytest.mark.parametrize("include_fen", [False, True])
@pytest.mark.parametrize("include_pgn", [False, True])
Expand Down
5 changes: 3 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2782,9 +2782,10 @@ def reset(
# Therefore, maybe_reset tells reset to temporarily hide the non-reset keys.
# To make step_and_maybe_reset handle custom reset states, some version of TensorDictPrimer should be used.
tensordict_reset = self._reset(
tensordict.select(*self.reset_keys, strict=False), **kwargs
tensordict.exclude(*self.state_keys), **kwargs
)
else:
print('tensordict', tensordict)
tensordict_reset = self._reset(tensordict, **kwargs)
# We assume that this is done properly
# if reset.device != self.device:
Expand Down Expand Up @@ -3634,7 +3635,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
any_done = self.any_done(tensordict)
if any_done:
tensordict = self.reset(tensordict, select_reset_only=True)
tensordict = self.reset(tensordict)
return tensordict

def empty_cache(self):
Expand Down
9 changes: 6 additions & 3 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
The action space is structured as a categorical distribution over all possible SAN moves, with the legal moves
being a subset of this space. The environment uses a mask to ensure only legal moves are selected.

.. note:: You can reset the env at a given state by passing `"fen_reset"` or `"pgn_reset"` to the TensorDict passed
to the reset method.

Examples:
>>> import torch
>>> from torchrl.envs import ChessEnv
Expand Down Expand Up @@ -322,7 +325,7 @@ def __init__(
self.stateful = stateful

# state_spec is loosely defined as such - it's not really an issue that extra keys
# can go missing but it allows us to reset the env using fen passed to the reset
# can go missing, but it allows us to reset the env using fen passed to the reset
# method.
self.full_state_spec = self.full_observation_spec.clone()

Expand Down Expand Up @@ -374,11 +377,11 @@ def _reset(self, tensordict=None):
if tensordict is not None:
dest = tensordict.empty()
if self.include_fen:
fen = tensordict.get("fen", None)
fen = tensordict.get("fen_reset", None)
if fen is not None:
fen = fen.data
elif self.include_pgn:
pgn = tensordict.get("pgn", None)
pgn = tensordict.get("pgn_reset", None)
if pgn is not None:
pgn = pgn.data
else:
Expand Down
6 changes: 2 additions & 4 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,10 +1173,7 @@ def _set_seed(self, seed: int | None) -> None:
def _reset(self, tensordict: TensorDictBase | None = None, **kwargs):
if tensordict is not None:
# We must avoid modifying the original tensordict so a shallow copy is necessary.
# We just select the input data and reset signal, which is all we need.
tensordict = tensordict.select(
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
)
tensordict = tensordict.copy()
# We always call _reset_env_preprocess, even if tensordict is None - that way one can augment that
# method to do any pre-reset operation.
# By default, within _reset_env_preprocess we will skip the inv call when tensordict is None.
Expand Down Expand Up @@ -7225,6 +7222,7 @@ def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
"""Resets episode rewards."""
print(f'{tensordict=}')
for in_key, reset_key, out_key in _zip_strict(
self.in_keys, self.reset_keys, self.out_keys
):
Expand Down
Loading