Skip to content

Commit

Permalink
Merge pull request #58 from PettingZoo-Team/utilities_fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jkterry1 authored Feb 21, 2021
2 parents 808e6a4 + d00234b commit 79f2980
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 45 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pettingzoo[butterfly]>=1.4.0
pettingzoo[butterfly]>=1.6.0
opencv-python
cloudpickle
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_version():
keywords=["Reinforcement Learning", "gym"],
packages=setuptools.find_packages(),
python_requires=">=3.6",
install_requires=["pettingzoo>=1.4.0", "opencv-python~=3.4.0", "cloudpickle"],
install_requires=["pettingzoo>=1.6.0", "opencv-python~=3.4.0", "cloudpickle"],
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
Expand Down
4 changes: 2 additions & 2 deletions supersuit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import gym
from pettingzoo.utils.to_parallel import to_parallel, ParallelEnv, from_parallel
from pettingzoo.utils.conversions import to_parallel, ParallelEnv, from_parallel
from pettingzoo.utils.env import AECEnv
from . import aec_wrappers
from . import gym_wrappers
from . import parallel_wrappers
from . import vector_constructors
from . import aec_vector

__version__ = "2.5.1"
__version__ = "2.6.0"


class WrapperFactory:
Expand Down
2 changes: 2 additions & 0 deletions supersuit/aec_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _update_items(self):
self._cumulative_rewards[agent] = self.env._cumulative_rewards.get(agent, 0)

def step(self, action):
self._has_updated = True
if self.dones[self.agent_selection]:
return self._was_done_step(action)

Expand Down Expand Up @@ -327,6 +328,7 @@ def observe(self, agent):
return fin_observe if fin_observe is not None else super().observe(agent)

def step(self, action):
self._has_updated = True
if self.dones[self.agent_selection]:
if self.env.agents and self.agent_selection == self.env.agent_selection:
self.env.step(None)
Expand Down
2 changes: 1 addition & 1 deletion supersuit/parallel_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pettingzoo.utils.to_parallel import ParallelEnv
from pettingzoo.utils.conversions import ParallelEnv
import gym
from gym.spaces import Box, Discrete
from .adv_transforms.frame_stack import stack_obs_space, stack_init, stack_obs
Expand Down
2 changes: 1 addition & 1 deletion test/parallel_env_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pettingzoo.utils.to_parallel import ParallelEnv
from pettingzoo.utils import ParallelEnv
from gym.spaces import Box, Discrete
import numpy as np
import supersuit
Expand Down
60 changes: 30 additions & 30 deletions test/pettingzoo_api_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from pettingzoo.test import api_test, seed_test, parallel_test
from pettingzoo.mpe import simple_push_v2, simple_world_comm_v2
from pettingzoo.butterfly import knights_archers_zombies_v6, prison_v2
from pettingzoo.butterfly import knights_archers_zombies_v7, prison_v3

import supersuit
import pytest
Expand Down Expand Up @@ -42,30 +42,30 @@ def test_pettingzoo_pad_action_space():
def test_pettingzoo_parallel_env():
_env = simple_world_comm_v2.parallel_env()
wrapped_env = pad_action_space_v0(_env)
parallel_test.parallel_play_test(wrapped_env)
parallel_test.parallel_api_test(wrapped_env)


wrappers = [
supersuit.color_reduction_v0(knights_archers_zombies_v6.env(), "R"),
supersuit.resize_v0(dtype_v0(knights_archers_zombies_v6.env(), np.uint8), x_size=5, y_size=10),
supersuit.resize_v0(dtype_v0(knights_archers_zombies_v6.env(), np.uint8), x_size=5, y_size=10, linear_interp=True),
supersuit.dtype_v0(knights_archers_zombies_v6.env(), np.int32),
supersuit.flatten_v0(knights_archers_zombies_v6.env()),
supersuit.reshape_v0(knights_archers_zombies_v6.env(), (512 * 512, 3)),
supersuit.normalize_obs_v0(dtype_v0(knights_archers_zombies_v6.env(), np.float32), env_min=-1, env_max=5.0),
supersuit.frame_stack_v1(knights_archers_zombies_v6.env(), 8),
supersuit.pad_observations_v0(knights_archers_zombies_v6.env()),
supersuit.pad_action_space_v0(knights_archers_zombies_v6.env()),
supersuit.black_death_v1(knights_archers_zombies_v6.env()),
supersuit.agent_indicator_v0(knights_archers_zombies_v6.env(), True),
supersuit.agent_indicator_v0(knights_archers_zombies_v6.env(), False),
supersuit.reward_lambda_v0(knights_archers_zombies_v6.env(), lambda x: x / 10),
supersuit.clip_reward_v0(knights_archers_zombies_v6.env()),
supersuit.clip_actions_v0(prison_v2.env(continuous=True)),
supersuit.frame_skip_v0(knights_archers_zombies_v6.env(), 4),
supersuit.sticky_actions_v0(knights_archers_zombies_v6.env(), 0.75),
supersuit.delay_observations_v0(knights_archers_zombies_v6.env(), 3),
supersuit.max_observation_v0(knights_archers_zombies_v6.env(), 3),
supersuit.color_reduction_v0(knights_archers_zombies_v7.env(), "R"),
supersuit.resize_v0(dtype_v0(knights_archers_zombies_v7.env(), np.uint8), x_size=5, y_size=10),
supersuit.resize_v0(dtype_v0(knights_archers_zombies_v7.env(), np.uint8), x_size=5, y_size=10, linear_interp=True),
supersuit.dtype_v0(knights_archers_zombies_v7.env(), np.int32),
supersuit.flatten_v0(knights_archers_zombies_v7.env()),
supersuit.reshape_v0(knights_archers_zombies_v7.env(), (512 * 512, 3)),
supersuit.normalize_obs_v0(dtype_v0(knights_archers_zombies_v7.env(), np.float32), env_min=-1, env_max=5.0),
supersuit.frame_stack_v1(knights_archers_zombies_v7.env(), 8),
supersuit.pad_observations_v0(knights_archers_zombies_v7.env()),
supersuit.pad_action_space_v0(knights_archers_zombies_v7.env()),
supersuit.black_death_v1(knights_archers_zombies_v7.env()),
supersuit.agent_indicator_v0(knights_archers_zombies_v7.env(), True),
supersuit.agent_indicator_v0(knights_archers_zombies_v7.env(), False),
supersuit.reward_lambda_v0(knights_archers_zombies_v7.env(), lambda x: x / 10),
supersuit.clip_reward_v0(knights_archers_zombies_v7.env()),
supersuit.clip_actions_v0(prison_v3.env(continuous=True)),
supersuit.frame_skip_v0(knights_archers_zombies_v7.env(), 4),
supersuit.sticky_actions_v0(knights_archers_zombies_v7.env(), 0.75),
supersuit.delay_observations_v0(knights_archers_zombies_v7.env(), 3),
supersuit.max_observation_v0(knights_archers_zombies_v7.env(), 3),
]


Expand All @@ -75,16 +75,16 @@ def test_pettingzoo_aec_api(env):


parallel_wrappers = [
supersuit.frame_stack_v1(knights_archers_zombies_v6.parallel_env(), 8),
supersuit.reward_lambda_v0(knights_archers_zombies_v6.parallel_env(), lambda x: x / 10),
supersuit.delay_observations_v0(knights_archers_zombies_v6.parallel_env(), 3),
supersuit.color_reduction_v0(knights_archers_zombies_v6.parallel_env(), "R"),
supersuit.frame_skip_v0(knights_archers_zombies_v6.parallel_env(), 4),
supersuit.max_observation_v0(knights_archers_zombies_v6.parallel_env(), 4),
supersuit.black_death_v1(knights_archers_zombies_v6.parallel_env()),
supersuit.frame_stack_v1(knights_archers_zombies_v7.parallel_env(), 8),
supersuit.reward_lambda_v0(knights_archers_zombies_v7.parallel_env(), lambda x: x / 10),
supersuit.delay_observations_v0(knights_archers_zombies_v7.parallel_env(), 3),
supersuit.color_reduction_v0(knights_archers_zombies_v7.parallel_env(), "R"),
supersuit.frame_skip_v0(knights_archers_zombies_v7.parallel_env(), 4),
supersuit.max_observation_v0(knights_archers_zombies_v7.parallel_env(), 4),
supersuit.black_death_v1(knights_archers_zombies_v7.parallel_env()),
]


@pytest.mark.parametrize("env", parallel_wrappers)
def test_pettingzoo_parallel_api(env):
parallel_test.parallel_play_test(env)
parallel_test.parallel_api_test(env)
8 changes: 4 additions & 4 deletions test/test_vector/test_aec_vector_identity_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pettingzoo.butterfly import knights_archers_zombies_v6
from pettingzoo.butterfly import knights_archers_zombies_v7
from pettingzoo.mpe import simple_push_v2
from pettingzoo.sisl import waterworld_v3
from supersuit import frame_skip_v0, vectorize_aec_env_v0
Expand All @@ -7,12 +7,12 @@

def test_identical():
def env_fn():
return knights_archers_zombies_v6.env() # ,20)
return knights_archers_zombies_v7.env() # ,20)

n_envs = 2
# single threaded
env1 = vectorize_aec_env_v0(knights_archers_zombies_v6.env(), n_envs)
env2 = vectorize_aec_env_v0(knights_archers_zombies_v6.env(), n_envs, num_cpus=1)
env1 = vectorize_aec_env_v0(knights_archers_zombies_v7.env(), n_envs)
env2 = vectorize_aec_env_v0(knights_archers_zombies_v7.env(), n_envs, num_cpus=1)
env1.seed(42)
env2.seed(42)
env1.reset()
Expand Down
4 changes: 2 additions & 2 deletions test/test_vector/test_aec_vector_values.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from supersuit import vectorize_aec_env_v0
from pettingzoo.classic import rps_v1
from pettingzoo.classic import mahjong_v2, hanabi_v3
from pettingzoo.butterfly import knights_archers_zombies_v6
from pettingzoo.butterfly import knights_archers_zombies_v7
from pettingzoo.mpe import simple_world_comm_v2
from pettingzoo.classic import chess_v0
import numpy as np
Expand Down Expand Up @@ -67,5 +67,5 @@ def select_action(vec_env, passes, i):
test_vec_env(vectorize_aec_env_v0(mahjong_maker(), NUM_ENVS, num_cpus=num_cpus))
test_infos(vectorize_aec_env_v0(hanabi_maker(), NUM_ENVS, num_cpus=num_cpus))
test_some_done(vectorize_aec_env_v0(mahjong_maker(), NUM_ENVS, num_cpus=num_cpus))
test_vec_env(vectorize_aec_env_v0(knights_archers_zombies_v6.env(), NUM_ENVS, num_cpus=num_cpus))
test_vec_env(vectorize_aec_env_v0(knights_archers_zombies_v7.env(), NUM_ENVS, num_cpus=num_cpus))
test_vec_env(vectorize_aec_env_v0(simple_world_comm_v2.env(), NUM_ENVS, num_cpus=num_cpus))
6 changes: 3 additions & 3 deletions test/test_vector/test_pettingzoo_to_vec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pettingzoo.mpe import simple_spread_v2, simple_world_comm_v2
from pettingzoo.butterfly import knights_archers_zombies_v6
from pettingzoo.butterfly import knights_archers_zombies_v7
from supersuit import pettingzoo_env_to_vec_env_v0, black_death_v1
import pytest

Expand Down Expand Up @@ -30,7 +30,7 @@ def test_bad_action_spaces_env():


def test_env_black_death_assertion():
env = knights_archers_zombies_v6.parallel_env(spawn_rate=50, max_cycles=2000)
env = knights_archers_zombies_v7.parallel_env(spawn_rate=50, max_cycles=2000)
env = pettingzoo_env_to_vec_env_v0(env)
with pytest.raises(AssertionError):
for i in range(100):
Expand All @@ -41,7 +41,7 @@ def test_env_black_death_assertion():


def test_env_black_death_wrapper():
env = knights_archers_zombies_v6.parallel_env(spawn_rate=50, max_cycles=300)
env = knights_archers_zombies_v7.parallel_env(spawn_rate=50, max_cycles=300)
env = black_death_v1(env)
env = pettingzoo_env_to_vec_env_v0(env)
env.reset()
Expand Down

0 comments on commit 79f2980

Please sign in to comment.