Skip to content

Commit

Permalink
Add test for state equality
Browse files Browse the repository at this point in the history
  • Loading branch information
dtch1997 committed Oct 26, 2021
1 parent 2702fcc commit 2915913
Showing 1 changed file with 44 additions and 9 deletions.
53 changes: 44 additions & 9 deletions tests/pybullet/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,33 @@ def _freeze_pybullet_gym_env(env_name: str):
assert reward == pytest.approx(ref_reward)


def _is_eq(a, b) -> bool:
if not type(a) == type(b):
return False
if isinstance(a, np.ndarray):
return np.all(a == b)
elif isinstance(a, dict):
if not set(a.keys()) == set(b.keys()):
return False
for key in a.keys():
aval, bval = a[key], b[key]
if not _is_eq(aval, bval):
return False
return True
else:
return a == b


def _state_eq(state1, state2) -> bool:
if not len(state1) == len(state2):
return False
# skip the first element since that is a unique file name
for elem1, elem2 in zip(state1[1:], state2[1:]):
if not _is_eq(elem1, elem2):
return False
return True


def _get_and_set_state(env_name):
""" Test that state getter and setter can run without error """
handler = create_handler_from_str(env_name)
Expand All @@ -48,6 +75,7 @@ def _get_and_set_state(env_name):
handler.set_env_state(state, env)
# test if we can restore the state multiple times
handler.set_env_state(state, env)
assert _state_eq(state, handler.get_current_state(env))


def _transfer_state(env_name):
Expand All @@ -59,21 +87,28 @@ def _transfer_state(env_name):
env2 = handler.make_env_from_str(env_name)
env2.reset()
handler.set_env_state(state, env2)
assert _state_eq(state, handler.get_current_state(env2))


test_env_names = (
"pybulletgym___HalfCheetahPyBulletEnv-v0",
"pybulletgym___HopperPyBulletEnv-v0",
"pybulletgym___HumanoidPyBulletEnv-v0",
"pybulletgym___ReacherPyBulletEnv-v0",
"pybulletgym___InvertedPendulumPyBulletEnv-v0",
)


def test_freeze():
_freeze_pybullet_gym_env("pybulletgym___HalfCheetahPyBulletEnv-v0")
_freeze_pybullet_gym_env("pybulletgym___HopperPyBulletEnv-v0")
_freeze_pybullet_gym_env("pybulletgym___HumanoidPyBulletEnv-v0")
for env_name in test_env_names:
_freeze_pybullet_gym_env(env_name)


def test_get_and_set_state():
_get_and_set_state("pybulletgym___HalfCheetahPyBulletEnv-v0")
_get_and_set_state("pybulletgym___HopperPyBulletEnv-v0")
_get_and_set_state("pybulletgym___HumanoidPyBulletEnv-v0")
for env_name in test_env_names:
_get_and_set_state(env_name)


def test_transfer_state():
_transfer_state("pybulletgym___HalfCheetahPyBulletEnv-v0")
_transfer_state("pybulletgym___HopperPyBulletEnv-v0")
_transfer_state("pybulletgym___HumanoidPyBulletEnv-v0")
for env_name in test_env_names:
_transfer_state(env_name)

0 comments on commit 2915913

Please sign in to comment.