Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minigrid doorkey environment #251

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
test passed
  • Loading branch information
wangsiping97 committed Jan 19, 2023
commit 1787f21ddf97c704a48b591aee8452077566ff04
1 change: 1 addition & 0 deletions envpool/minigrid/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,6 @@ py_test(
requirement("dm_env"),
requirement("gym"),
requirement("numpy"),
requirement("minigrid"),
],
)
9 changes: 7 additions & 2 deletions envpool/minigrid/impl/minigrid_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace minigrid {
void MiniGridEnv::MiniGridReset() {
GenGrid();
step_count_ = 0;
done_ = false;
CHECK(agent_pos_.first >= 0 && agent_pos_.second >= 0);
CHECK(agent_dir_ >= 0);
CHECK(grid_[agent_pos_.second][agent_pos_.first].GetType() == kEmpty);
Expand Down Expand Up @@ -54,8 +55,9 @@ float MiniGridEnv::MiniGridStep(Act act) {
// Get the forward cell object
if (act == kLeft) {
agent_dir_ -= 1;
if (agent_dir_ < 0)
if (agent_dir_ < 0) {
agent_dir_ += 4;
}
} else if (act == kRight) {
agent_dir_ = (agent_dir_ + 1) % 4;
} else if (act == kForward) {
Expand Down Expand Up @@ -182,8 +184,11 @@ void MiniGridEnv::GenImage(Array& obs) {
memset(vis_mask, 1, sizeof(vis_mask));
}
// Let the agent see what it's carrying
if (carrying_.GetType() != kEmpty)
if (carrying_.GetType() != kEmpty) {
agent_view_grid[agent_pos_y][agent_pos_x] = carrying_;
} else {
agent_view_grid[agent_pos_y][agent_pos_x] = WorldObj(kEmpty);
}
for (int y = 0; y < agent_view_size_; ++y) {
for (int x = 0; x < agent_view_size_; ++x) {
if (vis_mask[y * agent_view_size_ + x] == true) {
Expand Down
33 changes: 30 additions & 3 deletions envpool/minigrid/minigrid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

from typing import Any

import gymnasium as gym
import minigrid
import numpy as np
from absl.testing import absltest
from absl import logging

import envpool.minigrid.registration # noqa: F401
from envpool.registration import make_gym
Expand All @@ -28,11 +31,35 @@ def test_deterministic_check(
self,
task_id: str = "MiniGrid-Empty-5x5-v0",
num_envs: int = 1,
total: int = 100000,
**kwargs: Any,
) -> None:
env = make_gym(task_id, num_envs=num_envs, seed=0, **kwargs)
obs, info = env.reset()
print(obs)
env0 = gym.make(task_id)
env1 = make_gym(task_id, num_envs=num_envs, seed=0, **kwargs)
obs0, info0 = env0.reset()
obs1, info1 = env1.reset()
np.testing.assert_allclose(obs0["image"], obs1["image"][0])
done0 = False
acts = []
for i in range(total):
act = env0.action_space.sample()
acts.append(act)
if done0:
obs0, info0 = env0.reset()
auto_reset = True
term0 = trunc0 = False
else:
obs0, rew0, term0, trunc0, info0 = env0.step(act)
auto_reset = False
obs1, rew1, term1, trunc1, info1 = env1.step(np.array([act]))
self.assertEqual(obs0["image"].shape, (7, 7, 3))
self.assertEqual(obs1["image"].shape, (num_envs, 7, 7, 3))
done0 = term0 | trunc0
done1 = term1 | trunc1
if not auto_reset:
np.testing.assert_allclose(rew0, rew1[0], rtol=1e-6)
np.testing.assert_allclose(done0, done1[0])
np.testing.assert_allclose(obs0["image"], obs1["image"][0])


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions third_party/pip_requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ dm-control>=1.0.5
mujoco>=2.2.1,<2.3
mujoco_py>=2.1.2.14
pygame
minigrid