From abc0a3fae83bdc6093c89c9348fd63069cafd41a Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Tue, 6 Sep 2022 11:08:47 +0100 Subject: [PATCH 1/2] Doc: Add remark for installing QDax with GPU support in README (#77) * add remark for installing QDax with GPU support in README * Improve README about installation and GPUs - providing information about default installation behaviour in pip and containers --- README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 63eef09b..a5ceb3cc 100644 --- a/README.md +++ b/README.md @@ -14,13 +14,17 @@ QDax has been developed as a research framework: it is flexible and easy to exte ## Installation - -The latest stable release of QDax can be installed directly from source with: +QDax is available on PyPI and can be installed with: ```bash pip install qdax ``` +Alternatively, the latest commit of QDax can be installed directly from source with: +```bash +pip install git+https://github.com/adaptive-intelligent-robotics/QDax.git@main +``` +Installing QDax via ```pip``` installs a CPU-only version of JAX by default. To use QDax with NVidia GPUs, you must first install [CUDA, CuDNN, and JAX with GPU support](https://github.com/google/jax#installation). -However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/). +However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/). ## Basic API Usage For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./notebooks/mapelites_example.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default). From 2fb56198f176f0a9be90baa543b172d4915ba0f9 Mon Sep 17 00:00:00 2001 From: Luca Grillotti Date: Thu, 6 Oct 2022 14:14:02 +0100 Subject: [PATCH 2/2] chore: Update jax, brax and flax versions (fixes the jax.tree_util warnings) (#76) Update jax, brax and flax versions + tree-based functions imported from tree_util + update requirements.txt and setup.py + adds CompletedEvalWrapper that used to be in Brax --- dev.Dockerfile | 3 +- environment.yaml | 2 +- notebooks/cmamega_example.ipynb | 4 +- qdax/baselines/dads.py | 10 +-- qdax/baselines/diayn.py | 10 +-- qdax/baselines/sac.py | 10 +-- qdax/baselines/td3.py | 14 ++-- qdax/core/containers/ga_repertoire.py | 12 ++-- qdax/core/containers/mapelites_repertoire.py | 6 +- qdax/core/containers/mome_repertoire.py | 28 ++++---- qdax/core/containers/nsga2_repertoire.py | 6 +- qdax/core/containers/spea2_repertoire.py | 6 +- qdax/core/emitters/cma_mega_emitter.py | 6 +- qdax/core/emitters/mutation_operators.py | 20 +++--- qdax/core/emitters/omg_mega_emitter.py | 12 ++-- qdax/core/emitters/pga_me_emitter.py | 17 ++--- qdax/core/emitters/standard_emitters.py | 2 +- qdax/core/neuroevolution/mdp_utils.py | 6 +- qdax/environments/__init__.py | 5 +- qdax/environments/exploration_wrappers.py | 2 +- qdax/environments/wrappers.py | 73 ++++++++++++++++++++ requirements.txt | 10 +-- setup.py | 8 +-- tests/core_test/cmamega_test.py | 4 +- 24 files changed, 187 insertions(+), 89 deletions(-) create mode 100644 qdax/environments/wrappers.py diff --git a/dev.Dockerfile b/dev.Dockerfile index 17b1c945..6504f0bf 100644 --- a/dev.Dockerfile +++ b/dev.Dockerfile @@ -10,7 +10,6 @@ COPY requirements.txt /tmp/requirements.txt COPY requirements-dev.txt /tmp/requirements-dev.txt COPY environment.yaml /tmp/environment.yaml - RUN micromamba create -y --file /tmp/environment.yaml \ && micromamba clean --all --yes \ && find /opt/conda/ -follow -type f -name '*.pyc' -delete @@ -41,7 +40,7 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.0/targets/x86_64-linux/l ENV TZ=Europe/Paris RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone -RUN pip --no-cache-dir install jaxlib==0.3.10+cuda11.cudnn82 \ +RUN pip --no-cache-dir install jaxlib==0.3.15+cuda11.cudnn82 \ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \ && rm -rf /tmp/* diff --git a/environment.yaml b/environment.yaml index f4f6ebe7..78058b9d 100644 --- a/environment.yaml +++ b/environment.yaml @@ -8,6 +8,6 @@ dependencies: - conda>=4.9.2 - pip: - --find-links https://storage.googleapis.com/jax-releases/jax_releases.html - - jaxlib==0.3.10 + - jaxlib==0.3.15 - -r requirements.txt - -r requirements-dev.txt diff --git a/notebooks/cmamega_example.ipynb b/notebooks/cmamega_example.ipynb index 2af7103a..73920dc9 100644 --- a/notebooks/cmamega_example.ipynb +++ b/notebooks/cmamega_example.ipynb @@ -133,11 +133,11 @@ " gradients = jnp.nan_to_num(gradients)\n", "\n", " # Compute normalized gradients\n", - " norm_gradients = jax.tree_map(\n", + " norm_gradients = jax.tree_util.tree_map(\n", " lambda x: jnp.linalg.norm(x, axis=1, keepdims=True),\n", " gradients,\n", " )\n", - " grads = jax.tree_map(\n", + " grads = jax.tree_util.tree_map(\n", " lambda x, y: x / y, gradients, norm_gradients\n", " )\n", " grads = jnp.nan_to_num(grads)\n", diff --git a/qdax/baselines/dads.py b/qdax/baselines/dads.py index a70727e2..c2928981 100644 --- a/qdax/baselines/dads.py +++ b/qdax/baselines/dads.py @@ -25,6 +25,7 @@ update_running_mean_std, ) from qdax.core.neuroevolution.sac_utils import generate_unroll +from qdax.environments import CompletedEvalWrapper from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor @@ -144,7 +145,7 @@ def init( # type: ignore random_key, subkey = jax.random.split(random_key) critic_params = self._critic.init(subkey, dummy_obs, dummy_action) - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x: jnp.asarray(x.copy()), critic_params ) @@ -373,16 +374,17 @@ def eval_policy_fn( play_step_fn=play_step_fn, ) + eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY true_return = ( - state.info["eval_metrics"].completed_episodes_metrics["reward"] - / state.info["eval_metrics"].completed_episodes + state.info[eval_metrics_key].completed_episodes_metrics["reward"] + / state.info[eval_metrics_key].completed_episodes ) transitions = get_first_episode(transitions) true_returns = jnp.nansum(transitions.rewards, axis=0) - reshaped_transitions = jax.tree_map( + reshaped_transitions = jax.tree_util.tree_map( lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)), transitions, ) diff --git a/qdax/baselines/diayn.py b/qdax/baselines/diayn.py index 632b5e57..a6d335b3 100644 --- a/qdax/baselines/diayn.py +++ b/qdax/baselines/diayn.py @@ -20,6 +20,7 @@ from qdax.core.neuroevolution.mdp_utils import TrainingState, get_first_episode from qdax.core.neuroevolution.networks.diayn_networks import make_diayn_networks from qdax.core.neuroevolution.sac_utils import generate_unroll +from qdax.environments import CompletedEvalWrapper from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor @@ -141,7 +142,7 @@ def init( # type: ignore random_key, subkey = jax.random.split(random_key) critic_params = self._critic.init(subkey, dummy_obs, dummy_action) - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x: jnp.asarray(x.copy()), critic_params ) @@ -316,16 +317,17 @@ def eval_policy_fn( play_step_fn=play_step_fn, ) + eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY true_return = ( - state.info["eval_metrics"].completed_episodes_metrics["reward"] - / state.info["eval_metrics"].completed_episodes + state.info[eval_metrics_key].completed_episodes_metrics["reward"] + / state.info[eval_metrics_key].completed_episodes ) transitions = get_first_episode(transitions) true_return_per_env = jnp.nansum(transitions.rewards, axis=0) - reshaped_transitions = jax.tree_map( + reshaped_transitions = jax.tree_util.tree_map( lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)), transitions, ) diff --git a/qdax/baselines/sac.py b/qdax/baselines/sac.py index dec2bbf1..127bccb0 100644 --- a/qdax/baselines/sac.py +++ b/qdax/baselines/sac.py @@ -24,6 +24,7 @@ update_running_mean_std, ) from qdax.core.neuroevolution.sac_utils import generate_unroll +from qdax.environments import CompletedEvalWrapper from qdax.types import Action, Metrics, Observation, Params, Reward, RNGKey @@ -115,7 +116,7 @@ def init( random_key, subkey = jax.random.split(random_key) critic_params = self._critic.init(subkey, dummy_obs, dummy_action) - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x: jnp.asarray(x.copy()), critic_params ) @@ -298,9 +299,10 @@ def eval_policy_fn( play_step_fn=play_step_fn, ) + eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY true_return = ( - state.info["eval_metrics"].completed_episodes_metrics["reward"] - / state.info["eval_metrics"].completed_episodes + state.info[eval_metrics_key].completed_episodes_metrics["reward"] + / state.info[eval_metrics_key].completed_episodes ) transitions = get_first_episode(transitions) @@ -389,7 +391,7 @@ def _update_critic( critic_params = optax.apply_updates( training_state.critic_params, critic_updates ) - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.tau) * x1 + self._config.tau * x2, training_state.target_critic_params, critic_params, diff --git a/qdax/baselines/td3.py b/qdax/baselines/td3.py index 9ca5eadb..c4bfbc65 100644 --- a/qdax/baselines/td3.py +++ b/qdax/baselines/td3.py @@ -19,6 +19,7 @@ get_first_episode, ) from qdax.core.neuroevolution.networks.td3_networks import make_td3_networks +from qdax.environments import CompletedEvalWrapper from qdax.types import Action, Metrics, Observation, Params, Reward, RNGKey @@ -113,10 +114,10 @@ def init( policy_params = self._policy.init(subkey_2, fake_obs) # Initialize target networks - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x: jnp.asarray(x.copy()), critic_params ) - target_policy_params = jax.tree_map( + target_policy_params = jax.tree_util.tree_map( lambda x: jnp.asarray(x.copy()), policy_params ) @@ -251,9 +252,10 @@ def eval_policy_fn( play_step_fn=play_step_fn, ) + eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY true_return = ( - state.info["eval_metrics"].completed_episodes_metrics["reward"] - / state.info["eval_metrics"].completed_episodes + state.info[eval_metrics_key].completed_episodes_metrics["reward"] + / state.info[eval_metrics_key].completed_episodes ) transitions = get_first_episode(transitions) @@ -303,7 +305,7 @@ def update( training_state.critic_params, critic_updates ) # Soft update of target critic network - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, training_state.target_critic_params, @@ -325,7 +327,7 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]: training_state.policy_params, policy_updates ) # Soft update of target policy - target_policy_params = jax.tree_map( + target_policy_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, training_state.target_policy_params, diff --git a/qdax/core/containers/ga_repertoire.py b/qdax/core/containers/ga_repertoire.py index d95c4e10..87ade54f 100644 --- a/qdax/core/containers/ga_repertoire.py +++ b/qdax/core/containers/ga_repertoire.py @@ -34,7 +34,7 @@ class GARepertoire(Repertoire): @property def size(self) -> int: """Gives the size of the population.""" - first_leaf = jax.tree_leaves(self.genotypes)[0] + first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0] return int(first_leaf.shape[0]) def save(self, path: str = "./") -> None: @@ -95,7 +95,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey # sample random_key, subkey = jax.random.split(random_key) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice( subkey, x, shape=(num_samples,), p=p, replace=False ), @@ -122,7 +122,7 @@ def add( """ # gather individuals and fitnesses - candidates = jax.tree_map( + candidates = jax.tree_util.tree_map( lambda x, y: jnp.concatenate((x, y), axis=0), self.genotypes, batch_of_genotypes, @@ -138,7 +138,9 @@ def add( survivor_indices = indices[: self.size] # keep only the best ones - new_candidates = jax.tree_map(lambda x: x[survivor_indices], candidates) + new_candidates = jax.tree_util.tree_map( + lambda x: x[survivor_indices], candidates + ) new_repertoire = self.replace( genotypes=new_candidates, fitnesses=candidates_fitnesses[survivor_indices] @@ -172,7 +174,7 @@ def init( # type: ignore ) # create default genotypes - default_genotypes = jax.tree_map( + default_genotypes = jax.tree_util.tree_map( lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes ) diff --git a/qdax/core/containers/mapelites_repertoire.py b/qdax/core/containers/mapelites_repertoire.py index f9da638d..e2eb9433 100644 --- a/qdax/core/containers/mapelites_repertoire.py +++ b/qdax/core/containers/mapelites_repertoire.py @@ -220,7 +220,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty) random_key, subkey = jax.random.split(random_key) - samples = jax.tree_map( + samples = jax.tree_util.tree_map( lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p), self.genotypes, ) @@ -283,7 +283,7 @@ def add( ) # create new repertoire - new_repertoire_genotypes = jax.tree_map( + new_repertoire_genotypes = jax.tree_util.tree_map( lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[ batch_of_indices.squeeze(axis=-1) ].set(new_genotypes), @@ -337,7 +337,7 @@ def init( # Initialize repertoire with default values num_centroids = centroids.shape[0] default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) - default_genotypes = jax.tree_map( + default_genotypes = jax.tree_util.tree_map( lambda x: jnp.zeros(shape=(num_centroids,) + x.shape[1:]), genotypes, ) diff --git a/qdax/core/containers/mome_repertoire.py b/qdax/core/containers/mome_repertoire.py index 6fa5972b..04d4b680 100644 --- a/qdax/core/containers/mome_repertoire.py +++ b/qdax/core/containers/mome_repertoire.py @@ -54,7 +54,7 @@ def repertoire_capacity(self) -> int: Returns: The repertoire capacity. """ - first_leaf = jax.tree_leaves(self.genotypes)[0] + first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0] return int(first_leaf.shape[0] * first_leaf.shape[1]) @jax.jit @@ -80,7 +80,7 @@ def _sample_in_masked_pareto_front( """ p = (1.0 - mask) / jnp.sum(1.0 - mask) - genotype_sample = jax.tree_map( + genotype_sample = jax.tree_util.tree_map( lambda x: jax.random.choice(random_key, x, shape=(1,), p=p), pareto_front_genotypes, ) @@ -116,7 +116,9 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey cells_idx = jax.random.choice(subkey, indices, shape=(num_samples,), p=p) # get genotypes (front) from the chosen indices - pareto_front_genotypes = jax.tree_map(lambda x: x[cells_idx], self.genotypes) + pareto_front_genotypes = jax.tree_util.tree_map( + lambda x: x[cells_idx], self.genotypes + ) # prepare second sampling function sample_in_fronts = jax.vmap(self._sample_in_masked_pareto_front) @@ -131,7 +133,9 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey ) # remove the dim coming from pareto front - sampled_genotypes = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_genotypes) + sampled_genotypes = jax.tree_util.tree_map( + lambda x: x.squeeze(axis=1), sampled_genotypes + ) return sampled_genotypes, random_key @@ -172,7 +176,7 @@ def _update_masked_pareto_front( pareto_front_len = pareto_front_fitnesses.shape[0] # type: ignore - first_leaf = jax.tree_leaves(new_batch_of_genotypes)[0] + first_leaf = jax.tree_util.tree_leaves(new_batch_of_genotypes)[0] genotypes_dim = first_leaf.shape[1] descriptors_dim = new_batch_of_descriptors.shape[1] @@ -182,7 +186,7 @@ def _update_masked_pareto_front( cat_fitnesses = jnp.concatenate( [pareto_front_fitnesses, new_batch_of_fitnesses], axis=0 ) - cat_genotypes = jax.tree_map( + cat_genotypes = jax.tree_util.tree_map( lambda x, y: jnp.concatenate([x, y], axis=0), pareto_front_genotypes, new_batch_of_genotypes, @@ -205,7 +209,7 @@ def _update_masked_pareto_front( # get new fitness, genotypes and descriptors new_front_fitness = jnp.take(cat_fitnesses, indices, axis=0) - new_front_genotypes = jax.tree_map( + new_front_genotypes = jax.tree_util.tree_map( lambda x: jnp.take(x, indices, axis=0), cat_genotypes ) new_front_descriptors = jnp.take(cat_descriptors, indices, axis=0) @@ -232,10 +236,10 @@ def _update_masked_pareto_front( genotypes_mask = jnp.repeat( jnp.expand_dims(new_mask, axis=-1), genotypes_dim, axis=-1 ) - new_front_genotypes = jax.tree_map( + new_front_genotypes = jax.tree_util.tree_map( lambda x: x * genotypes_mask, new_front_genotypes ) - new_front_genotypes = jax.tree_map( + new_front_genotypes = jax.tree_util.tree_map( lambda x: x[:front_size, :], new_front_genotypes ) @@ -289,7 +293,7 @@ def _add_one( index = index.astype(jnp.int32) # get cell data - cell_genotype = jax.tree_map(lambda x: x[index], carry.genotypes) + cell_genotype = jax.tree_util.tree_map(lambda x: x[index], carry.genotypes) cell_fitness = carry.fitnesses[index] cell_descriptor = carry.descriptors[index] cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1) @@ -315,7 +319,7 @@ def _add_one( cell_fitness = cell_fitness - jnp.inf * jnp.expand_dims(cell_mask, axis=-1) # update grid - new_genotypes = jax.tree_map( + new_genotypes = jax.tree_util.tree_map( lambda x, y: x.at[index].set(y), carry.genotypes, cell_genotype ) new_fitnesses = carry.fitnesses.at[index].set(cell_fitness) @@ -383,7 +387,7 @@ def init( # type: ignore default_fitnesses = -jnp.inf * jnp.ones( shape=(num_centroids, pareto_front_max_length, num_criteria) ) - default_genotypes = jax.tree_map( + default_genotypes = jax.tree_util.tree_map( lambda x: jnp.zeros( shape=( num_centroids, diff --git a/qdax/core/containers/nsga2_repertoire.py b/qdax/core/containers/nsga2_repertoire.py index afce8915..74b0f454 100644 --- a/qdax/core/containers/nsga2_repertoire.py +++ b/qdax/core/containers/nsga2_repertoire.py @@ -106,7 +106,7 @@ def add( The updated repertoire. """ # All the candidates - candidates = jax.tree_map( + candidates = jax.tree_util.tree_map( lambda x, y: jnp.concatenate((x, y), axis=0), self.genotypes, batch_of_genotypes, @@ -114,7 +114,7 @@ def add( candidate_fitnesses = jnp.concatenate((self.fitnesses, batch_of_fitnesses)) - first_leaf = jax.tree_leaves(candidates)[0] + first_leaf = jax.tree_util.tree_leaves(candidates)[0] num_candidates = first_leaf.shape[0] def compute_current_front( @@ -237,7 +237,7 @@ def condition_fn_2(val: Tuple[jnp.ndarray, jnp.ndarray]) -> bool: indices = indices - 1 # keep only the survivors - new_candidates = jax.tree_map(lambda x: x[indices], candidates) + new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates) new_scores = candidate_fitnesses[indices] new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_scores) diff --git a/qdax/core/containers/spea2_repertoire.py b/qdax/core/containers/spea2_repertoire.py index 785cd8fa..54870db4 100644 --- a/qdax/core/containers/spea2_repertoire.py +++ b/qdax/core/containers/spea2_repertoire.py @@ -72,7 +72,7 @@ def add( Updated repertoire. """ # All the candidates - candidates = jax.tree_map( + candidates = jax.tree_util.tree_map( lambda x, y: jnp.concatenate((x, y), axis=0), self.genotypes, batch_of_genotypes, @@ -87,7 +87,7 @@ def add( indices = jnp.argsort(strength_scores)[: self.size] # keep the survivors - new_candidates = jax.tree_map(lambda x: x[indices], candidates) + new_candidates = jax.tree_util.tree_map(lambda x: x[indices], candidates) new_fitnesses = candidates_fitnesses[indices] new_repertoire = self.replace(genotypes=new_candidates, fitnesses=new_fitnesses) @@ -121,7 +121,7 @@ def init( # type: ignore ) # create default genotypes - default_genotypes = jax.tree_map( + default_genotypes = jax.tree_util.tree_map( lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes ) diff --git a/qdax/core/emitters/cma_mega_emitter.py b/qdax/core/emitters/cma_mega_emitter.py index 8d008e12..3b8d6a20 100644 --- a/qdax/core/emitters/cma_mega_emitter.py +++ b/qdax/core/emitters/cma_mega_emitter.py @@ -102,7 +102,7 @@ def init( """ # define init theta as 0 - theta = jax.tree_map( + theta = jax.tree_util.tree_map( lambda x: jnp.zeros_like(x[:1, ...]), init_genotypes, ) @@ -161,7 +161,7 @@ def emit( update_grad = coeffs @ grads.T # Compute new candidates - new_thetas = jax.tree_map(lambda x, y: x + y, theta, update_grad) + new_thetas = jax.tree_util.tree_map(lambda x, y: x + y, theta, update_grad) return new_thetas, random_key @@ -226,7 +226,7 @@ def state_update( gradient_step = jnp.sum(self._weights[sorted_indices] * update_grad, axis=0) # update theta - theta = jax.tree_map( + theta = jax.tree_util.tree_map( lambda x, y: x + self._learning_rate * y, theta, gradient_step ) diff --git a/qdax/core/emitters/mutation_operators.py b/qdax/core/emitters/mutation_operators.py index 696ae056..f39b8060 100644 --- a/qdax/core/emitters/mutation_operators.py +++ b/qdax/core/emitters/mutation_operators.py @@ -104,7 +104,7 @@ def polynomial_mutation( New genotypes - same shape as input and a new RNG key """ random_key, subkey = jax.random.split(random_key) - batch_size = jax.tree_leaves(x)[0].shape[0] + batch_size = jax.tree_util.tree_leaves(x)[0].shape[0] mutation_key = jax.random.split(subkey, num=batch_size) mutation_fn = partial( _polynomial_mutation, @@ -114,7 +114,7 @@ def polynomial_mutation( maxval=maxval, ) mutation_fn = jax.vmap(mutation_fn) - x = jax.tree_map(lambda x_: mutation_fn(x_, mutation_key), x) + x = jax.tree_util.tree_map(lambda x_: mutation_fn(x_, mutation_key), x) return x, random_key @@ -165,7 +165,7 @@ def polynomial_crossover( """ random_key, subkey = jax.random.split(random_key) - batch_size = jax.tree_leaves(x2)[0].shape[0] + batch_size = jax.tree_util.tree_leaves(x2)[0].shape[0] crossover_keys = jax.random.split(subkey, num=batch_size) crossover_fn = partial( _polynomial_crossover, @@ -173,7 +173,9 @@ def polynomial_crossover( ) crossover_fn = jax.vmap(crossover_fn) # TODO: check that key usage is correct - x = jax.tree_map(lambda x1_, x2_: crossover_fn(x1_, x2_, crossover_keys), x1, x2) + x = jax.tree_util.tree_map( + lambda x1_, x2_: crossover_fn(x1_, x2_, crossover_keys), x1, x2 + ) return x, random_key @@ -209,7 +211,7 @@ def isoline_variation( # Computing line_noise random_key, key_line_noise = jax.random.split(random_key) - batch_size = jax.tree_leaves(x1)[0].shape[0] + batch_size = jax.tree_util.tree_leaves(x1)[0].shape[0] line_noise = jax.random.normal(key_line_noise, shape=(batch_size,)) * line_sigma def _variation_fn( @@ -224,12 +226,14 @@ def _variation_fn( return x # create a tree with random keys - nb_leaves = len(jax.tree_leaves(x1)) + nb_leaves = len(jax.tree_util.tree_leaves(x1)) random_key, subkey = jax.random.split(random_key) subkeys = jax.random.split(subkey, num=nb_leaves) - keys_tree = jax.tree_unflatten(jax.tree_structure(x1), subkeys) + keys_tree = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(x1), subkeys) # apply isolinedd to each branch of the tree - x = jax.tree_map(lambda y1, y2, key: _variation_fn(y1, y2, key), x1, x2, keys_tree) + x = jax.tree_util.tree_map( + lambda y1, y2, key: _variation_fn(y1, y2, key), x1, x2, keys_tree + ) return x, random_key diff --git a/qdax/core/emitters/omg_mega_emitter.py b/qdax/core/emitters/omg_mega_emitter.py index a3afc891..34b8d9fd 100644 --- a/qdax/core/emitters/omg_mega_emitter.py +++ b/qdax/core/emitters/omg_mega_emitter.py @@ -93,7 +93,7 @@ def init( # Initialize grid with default values num_centroids = self._centroids.shape[0] default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids) - default_gradients = jax.tree_map( + default_gradients = jax.tree_util.tree_map( lambda x: jnp.zeros( shape=(num_centroids,) + x.shape[1:] + (self._num_descriptors + 1,) ), @@ -151,10 +151,10 @@ def emit( random_key, num_samples=self._batch_size ) - fitness_gradients = jax.tree_map( + fitness_gradients = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x[:, :, 0], axis=-1), gradients ) - descriptors_gradients = jax.tree_map(lambda x: x[:, :, 1:], gradients) + descriptors_gradients = jax.tree_util.tree_map(lambda x: x[:, :, 1:], gradients) # Normalize the gradients norm_fitness_gradients = jnp.linalg.norm( @@ -177,7 +177,7 @@ def emit( cov=self._sigma, ) coeffs = coeffs.at[:, 0].set(jnp.abs(coeffs[:, 0])) - grads = jax.tree_map( + grads = jax.tree_util.tree_map( lambda x, y: jnp.concatenate((x, y), axis=-1), fitness_gradients, descriptors_gradients, @@ -185,7 +185,9 @@ def emit( update_grad = jnp.sum(jax.vmap(lambda x, y: x * y)(coeffs, grads), axis=-1) # update the genotypes - new_genotypes = jax.tree_map(lambda x, y: x + y, genotypes, update_grad) + new_genotypes = jax.tree_util.tree_map( + lambda x, y: x + y, genotypes, update_grad + ) return new_genotypes, random_key diff --git a/qdax/core/emitters/pga_me_emitter.py b/qdax/core/emitters/pga_me_emitter.py index 1e3f84b3..0841ebb0 100644 --- a/qdax/core/emitters/pga_me_emitter.py +++ b/qdax/core/emitters/pga_me_emitter.py @@ -8,7 +8,6 @@ import jax import optax from jax import numpy as jnp -from jax.tree_util import tree_map from qdax.core.containers.repertoire import Repertoire from qdax.core.emitters.emitter import Emitter, EmitterState @@ -126,10 +125,12 @@ def init( critic_params = self._critic_network.init( subkey, obs=fake_obs, actions=fake_action ) - target_critic_params = tree_map(lambda x: x, critic_params) + target_critic_params = jax.tree_util.tree_map(lambda x: x, critic_params) - greedy_policy_params = tree_map(lambda x: x[0], init_genotypes) - target_greedy_policy_params = tree_map(lambda x: x[0], init_genotypes) + greedy_policy_params = jax.tree_util.tree_map(lambda x: x[0], init_genotypes) + target_greedy_policy_params = jax.tree_util.tree_map( + lambda x: x[0], init_genotypes + ) # Prepare init optimizer states critic_optimizer_state = self._critic_optimizer.init(critic_params) @@ -209,12 +210,12 @@ def emit( x_mutation_pg = jax.vmap(mutation_fn)(x1) # Add dimension for concatenation - greedy_policy_params = jax.tree_map( + greedy_policy_params = jax.tree_util.tree_map( lambda x: jnp.expand_dims(x, axis=0), emitter_state.greedy_policy_params ) # gather offspring - genotypes = jax.tree_map( + genotypes = jax.tree_util.tree_map( lambda x, y, z: jnp.concatenate([x, y, z], axis=0), x_mutation_ga, x_mutation_pg, @@ -316,7 +317,7 @@ def _train_critics(self, emitter_state: PGAMEEmitterState) -> PGAMEEmitterState: ) critic_params = optax.apply_updates(emitter_state.critic_params, critic_updates) # Soft update of target critic network - target_critic_params = jax.tree_map( + target_critic_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, emitter_state.target_critic_params, @@ -340,7 +341,7 @@ def _train_critics(self, emitter_state: PGAMEEmitterState) -> PGAMEEmitterState: emitter_state.greedy_policy_params, policy_updates ) # Soft update of target greedy policy - target_greedy_policy_params = jax.tree_map( + target_greedy_policy_params = jax.tree_util.tree_map( lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1 + self._config.soft_tau_update * x2, emitter_state.target_greedy_policy_params, diff --git a/qdax/core/emitters/standard_emitters.py b/qdax/core/emitters/standard_emitters.py index d15c87f5..bee4095f 100644 --- a/qdax/core/emitters/standard_emitters.py +++ b/qdax/core/emitters/standard_emitters.py @@ -69,7 +69,7 @@ def emit( elif n_mutation == 0: genotypes = x_variation else: - genotypes = jax.tree_map( + genotypes = jax.tree_util.tree_map( lambda x_1, x_2: jnp.concatenate([x_1, x_2], axis=0), x_variation, x_mutation, diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index 12b50528..012007d0 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -217,7 +217,9 @@ def reset_based_scoring_function( """ random_key, subkey = jax.random.split(random_key) - keys = jax.random.split(subkey, jax.tree_leaves(policies_params)[0].shape[0]) + keys = jax.random.split( + subkey, jax.tree_util.tree_leaves(policies_params)[0].shape[0] + ) reset_fn = jax.vmap(play_reset_fn) init_states = reset_fn(keys) @@ -314,4 +316,4 @@ def mask_episodes(x: jnp.ndarray) -> jnp.ndarray: # the double transpose trick is here to allow easy broadcasting return jnp.where(mask.T, x.T, jnp.nan * jnp.ones_like(x).T).T - return jax.tree_map(mask_episodes, transition) # type: ignore + return jax.tree_util.tree_map(mask_episodes, transition) # type: ignore diff --git a/qdax/environments/__init__.py b/qdax/environments/__init__.py index 6e980aca..0574006f 100644 --- a/qdax/environments/__init__.py +++ b/qdax/environments/__init__.py @@ -2,6 +2,7 @@ from typing import Any, Callable, List, Optional, Union import brax +import brax.envs from qdax.environments.base_wrappers import QDEnv, StateDescriptorResetWrapper from qdax.environments.bd_extractors import ( @@ -15,6 +16,7 @@ XYPositionWrapper, ) from qdax.environments.pointmaze import PointMaze +from qdax.environments.wrappers import CompletedEvalWrapper # experimentally determinated offset (except for antmaze) # should be sufficient to have only positive rewards but no guarantee @@ -109,7 +111,7 @@ def create( eval_metrics: bool = False, qdax_wrappers_kwargs: Optional[List] = None, **kwargs: Any, -) -> Union[brax.envs.Env, QDEnv]: +) -> Union[brax.envs.env.Env, QDEnv]: """Creates an Env with a specified brax system. Please use namespace to avoid confusion between this function and brax.envs.create. @@ -145,6 +147,7 @@ def create( env = StateDescriptorResetWrapper(env) if eval_metrics: env = brax.envs.wrappers.EvalWrapper(env) + env = CompletedEvalWrapper(env) return env diff --git a/qdax/environments/exploration_wrappers.py b/qdax/environments/exploration_wrappers.py index ea3e078d..b6754c28 100644 --- a/qdax/environments/exploration_wrappers.py +++ b/qdax/environments/exploration_wrappers.py @@ -89,7 +89,7 @@ # those are the configs from the official brax repo ENV_SYSTEM_CONFIG = { "ant": brax.envs.ant._SYSTEM_CONFIG, - "halfcheetah": brax.envs.halfcheetah._SYSTEM_CONFIG, + "halfcheetah": brax.envs.half_cheetah._SYSTEM_CONFIG, "walker2d": brax.envs.walker2d._SYSTEM_CONFIG, "hopper": brax.envs.hopper._SYSTEM_CONFIG, # "humanoid": brax.envs.humanoid._SYSTEM_CONFIG, diff --git a/qdax/environments/wrappers.py b/qdax/environments/wrappers.py new file mode 100644 index 00000000..45be0787 --- /dev/null +++ b/qdax/environments/wrappers.py @@ -0,0 +1,73 @@ +from typing import Dict + +import brax.envs +import flax.struct +import jax +from brax import jumpy as jp + + +class CompletedEvalMetrics(flax.struct.PyTreeNode): + current_episode_metrics: Dict[str, jp.ndarray] + completed_episodes_metrics: Dict[str, jp.ndarray] + completed_episodes: jp.ndarray + completed_episodes_steps: jp.ndarray + + +class CompletedEvalWrapper(brax.envs.env.Wrapper): + """Brax env with eval metrics for completed episodes.""" + + STATE_INFO_KEY = "completed_eval_metrics" + + def reset(self, rng: jp.ndarray) -> brax.envs.env.State: + reset_state = self.env.reset(rng) + reset_state.metrics["reward"] = reset_state.reward + eval_metrics = CompletedEvalMetrics( + current_episode_metrics=jax.tree_util.tree_map( + jp.zeros_like, reset_state.metrics + ), + completed_episodes_metrics=jax.tree_util.tree_map( + lambda x: jp.zeros_like(jp.sum(x)), reset_state.metrics + ), + completed_episodes=jp.zeros(()), + completed_episodes_steps=jp.zeros(()), + ) + reset_state.info[self.STATE_INFO_KEY] = eval_metrics + return reset_state + + def step( + self, state: brax.envs.env.State, action: jp.ndarray + ) -> brax.envs.env.State: + state_metrics = state.info[self.STATE_INFO_KEY] + if not isinstance(state_metrics, CompletedEvalMetrics): + raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}") + del state.info[self.STATE_INFO_KEY] + nstate = self.env.step(state, action) + nstate.metrics["reward"] = nstate.reward + # steps stores the highest step reached when done = True, and then + # the next steps becomes action_repeat + completed_episodes_steps = state_metrics.completed_episodes_steps + jp.sum( + nstate.info["steps"] * nstate.done + ) + current_episode_metrics = jax.tree_util.tree_map( + lambda a, b: a + b, state_metrics.current_episode_metrics, nstate.metrics + ) + completed_episodes = state_metrics.completed_episodes + jp.sum(nstate.done) + completed_episodes_metrics = jax.tree_util.tree_map( + lambda a, b: a + jp.sum(b * nstate.done), + state_metrics.completed_episodes_metrics, + current_episode_metrics, + ) + current_episode_metrics = jax.tree_util.tree_map( + lambda a, b: a * (1 - nstate.done) + b * nstate.done, + current_episode_metrics, + nstate.metrics, + ) + + eval_metrics = CompletedEvalMetrics( + current_episode_metrics=current_episode_metrics, + completed_episodes_metrics=completed_episodes_metrics, + completed_episodes=completed_episodes, + completed_episodes_steps=completed_episodes_steps, + ) + nstate.info[self.STATE_INFO_KEY] = eval_metrics + return nstate diff --git a/requirements.txt b/requirements.txt index ebd0d39a..907f1337 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ absl-py==1.0.0 -brax==0.0.12 -chex==0.1.3 +brax==0.0.15 +chex==0.1.4 dm-haiku==0.0.5 -flax==0.4.1 +flax==0.6.0 gym==0.23.1 ipython -jax==0.3.10 +jax==0.3.17 jupyter numpy==1.22.3 protobuf==3.19.4 @@ -14,4 +14,4 @@ scipy==1.8.0 seaborn==0.11.2 sklearn==0.0 tensorflow-probability==0.15.0 -typing-extensions == 3.10 +typing-extensions==4.3.0 diff --git a/setup.py b/setup.py index 6bef535e..d3059066 100644 --- a/setup.py +++ b/setup.py @@ -22,10 +22,10 @@ long_description_content_type="text/markdown", install_requires=[ "absl-py>=1.0.0", - "jax>=0.3.10", - "jaxlib>=0.3.10", # necessary to build the doc atm - "flax>=0.4.1", - "brax>=0.0.12", + "jax>=0.3.16", + "jaxlib>=0.3.15", # necessary to build the doc atm + "flax>=0.6", + "brax>=0.0.15", "gym>=0.23.1", "numpy>=1.22.3", "scikit-learn>=1.0.2", diff --git a/tests/core_test/cmamega_test.py b/tests/core_test/cmamega_test.py index ef4bbd8d..8f327a95 100644 --- a/tests/core_test/cmamega_test.py +++ b/tests/core_test/cmamega_test.py @@ -64,11 +64,11 @@ def scoring_function(x: jnp.ndarray) -> Tuple[Fitness, Descriptor, ExtraScores]: gradients = jnp.nan_to_num(gradients) # Compute normalized gradients - norm_gradients = jax.tree_map( + norm_gradients = jax.tree_util.tree_map( lambda x: jnp.linalg.norm(x, axis=1, keepdims=True), gradients, ) - grads = jax.tree_map(lambda x, y: x / y, gradients, norm_gradients) + grads = jax.tree_util.tree_map(lambda x, y: x / y, gradients, norm_gradients) grads = jnp.nan_to_num(grads) extra_scores = {"gradients": gradients, "normalized_grads": grads}