Skip to content

Commit d882ea2

Browse files
jeguzziVincent Moens
and
Vincent Moens
authored
[Environment] Complete PettingZooWrapper state support (#2953)
Co-authored-by: Vincent Moens <vmoens@meta.com>
1 parent f61078d commit d882ea2

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

test/test_libs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3948,6 +3948,20 @@ def __call__(self, td):
39483948
td[-1]["next", "player", "reward"] == torch.tensor([[-1], [1]])
39493949
).all()
39503950

3951+
@pytest.mark.parametrize("task", ["simple_v3"])
3952+
def test_return_state(self, task):
3953+
env = PettingZooEnv(
3954+
task=task,
3955+
parallel=True,
3956+
seed=0,
3957+
use_mask=False,
3958+
return_state=True,
3959+
)
3960+
check_env_specs(env)
3961+
r = env.rollout(10)
3962+
assert (r["state"] != 0).any()
3963+
assert (r["next", "state"] != 0).any()
3964+
39513965
@pytest.mark.parametrize(
39523966
"task",
39533967
[

torchrl/envs/libs/pettingzoo.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,10 @@ def _reset(
584584
value, device=self.device
585585
)
586586

587+
if self.return_state:
588+
state = torch.as_tensor(self.state(), device=self.device)
589+
tensordict_out.set("state", state)
590+
587591
return tensordict_out
588592

589593
def _reset_aec(self, **kwargs) -> tuple[dict, dict]:
@@ -702,6 +706,11 @@ def _step(
702706
tensordict_out.set("done", done)
703707
tensordict_out.set("terminated", terminated)
704708
tensordict_out.set("truncated", truncated)
709+
710+
if self.return_state:
711+
state = torch.as_tensor(self.state(), device=self.device)
712+
tensordict_out.set("state", state)
713+
705714
return tensordict_out
706715

707716
def _aggregate_done(self, tensordict_out, use_any):

0 commit comments

Comments
 (0)