From b221c682c8d3e01951239e4869e58e0544eed8a7 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Sun, 10 Dec 2023 20:11:08 +0900 Subject: [PATCH] Fix imports, CI and formatting issues (#163) * fix imports jumpy * fix formatting issues * update ci to handle internal and external PRs * adapt Jumanji test to new Jumanji API --- README.md | 1 + examples/jumanji_snake.ipynb | 2 +- qdax/environments/__init__.py | 10 +++++++--- qdax/environments/wrappers.py | 6 ++---- tests/default_tasks_test/jumanji_envs_test.py | 12 ++++++++---- tests/environments_test/pointmaze_test.py | 9 ++++----- tests/environments_test/wrapper_test.py | 8 ++++---- 7 files changed, 27 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index a2e4ce83..304fe843 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ qdax_logo + # QDax: Accelerated Quality-Diversity [![Documentation Status](https://readthedocs.org/projects/qdax/badge/?version=latest)](https://qdax.readthedocs.io/en/latest/?badge=latest) diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index 20aaac14..0b206f34 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -91,7 +91,7 @@ "outputs": [], "source": [ "# Instantiate a Jumanji environment using the registry\n", - "env = jumanji.make('Snake-6x6-v0')\n", + "env = jumanji.make('Snake-v1')\n", "\n", "# Reset your (jit-able) environment\n", "key = jax.random.PRNGKey(0)\n", diff --git a/qdax/environments/__init__.py b/qdax/environments/__init__.py index be336a14..054c75f7 100644 --- a/qdax/environments/__init__.py +++ b/qdax/environments/__init__.py @@ -1,9 +1,13 @@ import functools from typing import Any, Callable, List, Optional, Union -from brax.v1.envs import Env -from brax.v1.envs import _envs -from brax.v1.envs.wrappers import EpisodeWrapper, AutoResetWrapper, EvalWrapper, VectorWrapper +from brax.v1.envs import Env, _envs +from brax.v1.envs.wrappers import ( + AutoResetWrapper, + EpisodeWrapper, + EvalWrapper, + VectorWrapper, +) from qdax.environments.base_wrappers import QDEnv, StateDescriptorResetWrapper from qdax.environments.bd_extractors import ( diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py index 274b9073..720f662a 100644 --- a/qdax/environments/wrappers.py +++ b/qdax/environments/wrappers.py @@ -1,9 +1,9 @@ from typing import Dict -from brax.v1.envs import State, Wrapper import flax.struct import jax from brax.v1 import jumpy as jp +from brax.v1.envs import State, Wrapper class CompletedEvalMetrics(flax.struct.PyTreeNode): @@ -34,9 +34,7 @@ def reset(self, rng: jp.ndarray) -> State: reset_state.info[self.STATE_INFO_KEY] = eval_metrics return reset_state - def step( - self, state: State, action: jp.ndarray - ) -> State: + def step(self, state: State, action: jp.ndarray) -> State: state_metrics = state.info[self.STATE_INFO_KEY] if not isinstance(state_metrics, CompletedEvalMetrics): raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}") diff --git a/tests/default_tasks_test/jumanji_envs_test.py b/tests/default_tasks_test/jumanji_envs_test.py index bb41c555..eed90127 100644 --- a/tests/default_tasks_test/jumanji_envs_test.py +++ b/tests/default_tasks_test/jumanji_envs_test.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp import jumanji +import jumanji.environments.routing.snake import numpy as np import pytest @@ -26,7 +27,7 @@ def test_jumanji_utils() -> None: batch_size = population_size # Instantiate a Jumanji environment using the registry - env = jumanji.make("Snake-6x6-v0") + env = jumanji.make("Snake-v1") # Reset your (jit-able) environment key = jax.random.PRNGKey(0) @@ -49,8 +50,10 @@ def test_jumanji_utils() -> None: final_activation=jax.nn.softmax, ) - def observation_processing(observation: jumanji.types.Observation) -> Observation: - network_input = jnp.ravel(observation) + def observation_processing( + observation: jumanji.environments.routing.snake.types.Observation, + ) -> Observation: + network_input = jnp.ravel(observation.grid) return network_input play_step_fn = make_policy_network_play_step_fn_jumanji( @@ -64,7 +67,7 @@ def observation_processing(observation: jumanji.types.Observation) -> Observatio keys = jax.random.split(subkey, num=batch_size) # compute observation size from observation spec - observation_size = np.prod(np.array(env.observation_spec().shape)) + observation_size = np.prod(np.array(env.observation_spec().grid.shape)) fake_batch = jnp.zeros(shape=(batch_size, observation_size)) init_variables = jax.vmap(policy_network.init)(keys, fake_batch) @@ -136,4 +139,5 @@ def bd_extraction( if __name__ == "__main__": + pytest.assume test_jumanji_utils() diff --git a/tests/environments_test/pointmaze_test.py b/tests/environments_test/pointmaze_test.py index 76c96451..a13f41cc 100644 --- a/tests/environments_test/pointmaze_test.py +++ b/tests/environments_test/pointmaze_test.py @@ -1,10 +1,9 @@ from typing import Any, Tuple -import brax -import brax.envs import jax import pytest -from brax import jumpy as jp +from brax.v1 import jumpy as jp +from brax.v1.envs import Env import qdax from qdax.environments.pointmaze import PointMaze @@ -15,7 +14,7 @@ def test_pointmaze() -> None: # create env with class qd_env = PointMaze() # verify class - pytest.assume(isinstance(qd_env, brax.envs.Env)) + pytest.assume(isinstance(qd_env, Env)) # check state_descriptor_length pytest.assume(qd_env.state_descriptor_length == 2) @@ -25,7 +24,7 @@ def test_pointmaze() -> None: qd_env = qdax.environments.create(env_name="pointmaze") # type: ignore # verify class - pytest.assume(isinstance(qd_env, brax.envs.Env)) + pytest.assume(isinstance(qd_env, Env)) # check state_descriptor_length pytest.assume(qd_env.state_descriptor_length == 2) diff --git a/tests/environments_test/wrapper_test.py b/tests/environments_test/wrapper_test.py index 8c8155c2..f5e035ea 100644 --- a/tests/environments_test/wrapper_test.py +++ b/tests/environments_test/wrapper_test.py @@ -1,12 +1,12 @@ from typing import Dict, List, Union -import brax.envs +import brax import jax import jax.numpy as jnp import pytest -from brax import jumpy as jp -from brax.physics.base import vec_to_arr -from brax.physics.config_pb2 import Joint +from brax.v1 import jumpy as jp +from brax.v1.physics.base import vec_to_arr +from brax.v1.physics.config_pb2 import Joint from qdax import environments