diff --git a/README.md b/README.md index a5ceb3cc..e9093301 100644 --- a/README.md +++ b/README.md @@ -27,18 +27,75 @@ Installing QDax via ```pip``` installs a CPU-only version of JAX by default. To However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/). ## Basic API Usage -For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./notebooks/mapelites_example.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default). +For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./examples/notebooks/mapelites_example.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default). However, a summary of the main API usage is provided below: ```python -import qdax +import jax +import functools from qdax.core.map_elites import MAPElites +from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids +from qdax.tasks.arm import arm_scoring_function +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.utils.metrics import default_qd_metrics + +seed = 42 +num_param_dimensions = 100 # num DoF arm +init_batch_size = 100 +batch_size = 1024 +num_iterations = 50 +grid_shape = (100, 100) +min_param = 0.0 +max_param = 1.0 +min_bd = 0.0 +max_bd = 1.0 + +# Init a random key +random_key = jax.random.PRNGKey(seed) + +# Init population of controllers +random_key, subkey = jax.random.split(random_key) +init_variables = jax.random.uniform( + subkey, + shape=(init_batch_size, num_param_dimensions), + minval=min_param, + maxval=max_param, +) + +# Define emitter +variation_fn = functools.partial( + isoline_variation, + iso_sigma=0.05, + line_sigma=0.1, + minval=min_param, + maxval=max_param, +) +mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, +) + +# Define a metrics function +metrics_fn = functools.partial( + default_qd_metrics, + qd_offset=0.0, +) # Instantiate MAP-Elites map_elites = MAPElites( - scoring_function=scoring_fn, + scoring_function=arm_scoring_function, emitter=mixing_emitter, - metrics_function=metrics_function, + metrics_function=metrics_fn, +) + +# Compute the centroids +centroids = compute_euclidean_centroids( + grid_shape=grid_shape, + minval=min_bd, + maxval=max_bd, ) # Initializes repertoire and emitter state @@ -81,6 +138,11 @@ The QDax library also provides implementations for some useful baseline algorith | [NSGA2](https://ieeexplore.ieee.org/document/996017) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) | | [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) | +## QDax Tasks +The QDax library also provides numerous implementations for several standard Quality-Diversity tasks. + +All those implementations, and their descriptions are provided in the [tasks directory](./qdax/tasks). + ## Contributing Issues and contributions are welcome. Please refer to the [contribution guide](https://qdax.readthedocs.io/en/latest/guides/CONTRIBUTING/) in the documentation for more details. diff --git a/notebooks/cmamega_example.ipynb b/examples/notebooks/cmamega_example.ipynb similarity index 100% rename from notebooks/cmamega_example.ipynb rename to examples/notebooks/cmamega_example.ipynb diff --git a/notebooks/dads_example.ipynb b/examples/notebooks/dads_example.ipynb similarity index 100% rename from notebooks/dads_example.ipynb rename to examples/notebooks/dads_example.ipynb diff --git a/notebooks/diayn_example.ipynb b/examples/notebooks/diayn_example.ipynb similarity index 100% rename from notebooks/diayn_example.ipynb rename to examples/notebooks/diayn_example.ipynb diff --git a/notebooks/mapelites_example.ipynb b/examples/notebooks/mapelites_example.ipynb similarity index 99% rename from notebooks/mapelites_example.ipynb rename to examples/notebooks/mapelites_example.ipynb index 7713243f..c59f4e13 100644 --- a/notebooks/mapelites_example.ipynb +++ b/examples/notebooks/mapelites_example.ipynb @@ -62,7 +62,7 @@ "from qdax.core.map_elites import MAPElites\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire\n", "from qdax import environments\n", - "from qdax.core.neuroevolution.mdp_utils import scoring_function\n", + "from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", "from qdax.core.neuroevolution.networks.networks import MLP\n", "from qdax.core.emitters.mutation_operators import isoline_variation\n", diff --git a/notebooks/mome_example.ipynb b/examples/notebooks/mome_example.ipynb similarity index 100% rename from notebooks/mome_example.ipynb rename to examples/notebooks/mome_example.ipynb diff --git a/notebooks/nsga2_spea2_example.ipynb b/examples/notebooks/nsga2_spea2_example.ipynb similarity index 100% rename from notebooks/nsga2_spea2_example.ipynb rename to examples/notebooks/nsga2_spea2_example.ipynb diff --git a/notebooks/omgmega_example.ipynb b/examples/notebooks/omgmega_example.ipynb similarity index 100% rename from notebooks/omgmega_example.ipynb rename to examples/notebooks/omgmega_example.ipynb diff --git a/notebooks/pgame_example.ipynb b/examples/notebooks/pgame_example.ipynb similarity index 99% rename from notebooks/pgame_example.ipynb rename to examples/notebooks/pgame_example.ipynb index f35c7133..084bbe98 100644 --- a/notebooks/pgame_example.ipynb +++ b/examples/notebooks/pgame_example.ipynb @@ -61,7 +61,7 @@ "from qdax.core.map_elites import MAPElites\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax import environments\n", - "from qdax.core.neuroevolution.mdp_utils import scoring_function\n", + "from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", "from qdax.core.neuroevolution.networks.networks import MLP\n", "from qdax.core.emitters.mutation_operators import isoline_variation\n", diff --git a/notebooks/smerl_example.ipynb b/examples/notebooks/smerl_example.ipynb similarity index 100% rename from notebooks/smerl_example.ipynb rename to examples/notebooks/smerl_example.ipynb diff --git a/examples/scripts/me_example.py b/examples/scripts/me_example.py new file mode 100644 index 00000000..699c6aba --- /dev/null +++ b/examples/scripts/me_example.py @@ -0,0 +1,103 @@ +import functools + +import jax +import matplotlib.pyplot as plt + +from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.core.map_elites import MAPElites +from qdax.tasks.arm import arm_scoring_function +from qdax.utils.metrics import default_qd_metrics +from qdax.utils.plotting import plot_2d_map_elites_repertoire + + +def run_me() -> None: + seed = 42 + num_param_dimensions = 8 # num DoF arm + init_batch_size = 100 + batch_size = 2048 + num_evaluations = int(1e6) + num_iterations = num_evaluations // batch_size + grid_shape = (100, 100) + min_param = 0.0 + max_param = 1.0 + min_bd = 0.0 + max_bd = 1.0 + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init population of controllers + random_key, subkey = jax.random.split(random_key) + init_variables = jax.random.uniform( + subkey, + shape=(init_batch_size, num_param_dimensions), + minval=min_param, + maxval=max_param, + ) + + # Define emitter + variation_fn = functools.partial( + isoline_variation, + iso_sigma=0.005, + line_sigma=0, + minval=min_param, + maxval=max_param, + ) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + + # Define a metrics function + metrics_fn = functools.partial( + default_qd_metrics, + qd_offset=0.0, + ) + + # Instantiate MAP-Elites + map_elites = MAPElites( + scoring_function=arm_scoring_function, + emitter=mixing_emitter, + metrics_function=metrics_fn, + ) + + # Compute the centroids + centroids = compute_euclidean_centroids( + grid_shape=grid_shape, + minval=min_bd, + maxval=max_bd, + ) + + # Initializes repertoire and emitter state + repertoire, emitter_state, random_key = map_elites.init( + init_variables, centroids, random_key + ) + + # Run MAP-Elites loop + for _ in range(num_iterations): + (repertoire, emitter_state, metrics, random_key,) = map_elites.update( + repertoire, + emitter_state, + random_key, + ) + + # plot archive + fig, axes = plot_2d_map_elites_repertoire( + centroids=repertoire.centroids, + repertoire_fitnesses=repertoire.fitnesses, + minval=min_bd, + maxval=max_bd, + repertoire_descriptors=repertoire.descriptors, + # vmin=-0.2, + # vmax=0.0, + ) + + plt.show() + + +if __name__ == "__main__": + run_me() diff --git a/qdax/core/neuroevolution/mdp_utils.py b/qdax/core/neuroevolution/mdp_utils.py index 012007d0..80a4b6f4 100644 --- a/qdax/core/neuroevolution/mdp_utils.py +++ b/qdax/core/neuroevolution/mdp_utils.py @@ -1,26 +1,15 @@ from functools import partial from typing import Any, Callable, Tuple -import brax +import brax.envs +import flax.linen as nn import jax import jax.numpy as jnp from brax.envs import State as EnvState from flax.struct import PyTreeNode -from qdax.core.neuroevolution.buffers.buffer import ( - QDTransition, - ReplayBuffer, - Transition, -) -from qdax.types import ( - Descriptor, - ExtraScores, - Fitness, - Genotype, - Metrics, - Params, - RNGKey, -) +from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition +from qdax.types import Genotype, Metrics, Params, RNGKey class TrainingState(PyTreeNode): @@ -125,116 +114,6 @@ def _scan_play_step_fn( return state, transitions -@partial( - jax.jit, - static_argnames=( - "episode_length", - "play_step_fn", - "behavior_descriptor_extractor", - ), -) -def scoring_function( - policies_params: Genotype, - random_key: RNGKey, - init_states: brax.envs.State, - episode_length: int, - play_step_fn: Callable[ - [EnvState, Params, RNGKey, brax.envs.Env], - Tuple[EnvState, Params, RNGKey, QDTransition], - ], - behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: - """Evaluates policies contained in policies_params in parallel in - deterministic or pseudo-deterministic environments. - - This rollout is only deterministic when all the init states are the same. - If the init states are fixed but different, as a policy is not necessarly - evaluated with the same environment everytime, this won't be determinist. - When the init states are different, this is not purely stochastic. - """ - - # Perform rollouts with each policy - random_key, subkey = jax.random.split(random_key) - unroll_fn = partial( - generate_unroll, - episode_length=episode_length, - play_step_fn=play_step_fn, - random_key=subkey, - ) - - _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) - - # create a mask to extract data properly - is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) - mask = jnp.roll(is_done, 1, axis=1) - mask = mask.at[:, 0].set(0) - - # Scores - add offset to ensure positive fitness (through positive rewards) - fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) - descriptors = behavior_descriptor_extractor(data, mask) - - return ( - fitnesses, - descriptors, - { - "transitions": data, - }, - random_key, - ) - - -@partial( - jax.jit, - static_argnames=( - "episode_length", - "play_reset_fn", - "play_step_fn", - "behavior_descriptor_extractor", - ), -) -def reset_based_scoring_function( - policies_params: Genotype, - random_key: RNGKey, - episode_length: int, - play_reset_fn: Callable[[RNGKey], brax.envs.State], - play_step_fn: Callable[ - [brax.envs.State, Params, RNGKey, brax.envs.Env], - Tuple[brax.envs.State, Params, RNGKey, QDTransition], - ], - behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], -) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: - """Evaluates policies contained in policies_params in parallel. - The play_reset_fn function allows for a more general scoring_function that can be - called with different batch-size and not only with a batch-size of the same - dimension as init_states. - - To define purely stochastic environments, using the reset function from the - environment, use "play_reset_fn = env.reset". - - To define purely deterministic environments, as in "scoring_function", generate - a single init_state using "init_state = env.reset(random_key)", then use - "play_reset_fn = lambda random_key: init_state". - """ - - random_key, subkey = jax.random.split(random_key) - keys = jax.random.split( - subkey, jax.tree_util.tree_leaves(policies_params)[0].shape[0] - ) - reset_fn = jax.vmap(play_reset_fn) - init_states = reset_fn(keys) - - fitnesses, descriptors, extra_scores, random_key = scoring_function( - policies_params=policies_params, - random_key=random_key, - init_states=init_states, - episode_length=episode_length, - play_step_fn=play_step_fn, - behavior_descriptor_extractor=behavior_descriptor_extractor, - ) - - return (fitnesses, descriptors, extra_scores, random_key) - - @partial( jax.jit, static_argnames=( @@ -316,4 +195,32 @@ def mask_episodes(x: jnp.ndarray) -> jnp.ndarray: # the double transpose trick is here to allow easy broadcasting return jnp.where(mask.T, x.T, jnp.nan * jnp.ones_like(x).T).T - return jax.tree_util.tree_map(mask_episodes, transition) # type: ignore + return jax.tree_map(mask_episodes, transition) # type: ignore + + +def init_population_controllers( + policy_network: nn.Module, + env: brax.envs.Env, + batch_size: int, + random_key: RNGKey, +) -> Tuple[Genotype, RNGKey]: + """ + Initializes the population of controllers using a policy_network. + + Args: + policy_network: The policy network structure used for creating policy + controllers. + env: the BRAX environment. + batch_size: the number of environments we play simultaneously. + random_key: a JAX random key. + + Returns: + A tuple of the initial population and the new random key. + """ + random_key, subkey = jax.random.split(random_key) + + keys = jax.random.split(subkey, num=batch_size) + fake_batch = jnp.zeros(shape=(batch_size, env.observation_size)) + init_variables = jax.vmap(policy_network.init)(keys, fake_batch) + + return init_variables, random_key diff --git a/qdax/tasks/README.md b/qdax/tasks/README.md new file mode 100644 index 00000000..0aa418c3 --- /dev/null +++ b/qdax/tasks/README.md @@ -0,0 +1,178 @@ +# QD Tasks +The `tasks` directory provides default `scoring_function`'s to import easily to perform experiments without the boilerplate code so that the main script is kept simple and is not bloated. It provides a set of fixed tasks that is not meant to be modified. If you are developing and require the flexibility of modifying the task and the details that come along with it, we recommend copying and writing your own custom `scoring_function` in your main script instead of importing from `tasks`. + +The `tasks` directory also serves as a way to maintain a QD benchmark task suite that can be easily accesed. We implement several benchmark task across a range of domains. The tasks here are classical tasks from QD literature as well as more recent benchmarks tasks proposed at the [QD Benchmarks Workshop at GECCO 2022](https://quality-diversity.github.io/workshop). + +## Arm +| Task | Parameter Dimensions | Parameter Bounds | Descriptor Dimensions | Descriptor Bounds | Description | +|-----------|----------------------|------------------|-----------------------|-------------------|-------------| +| Arm | n | $[0,1]^n$ | 2 | $[0,1]^2$ | | +| Noisy Arm | n | $[0,1]^n$ | 2 | $[0,1]^2$ | | + +Notes: +- the parameter space is normalized between $[0,1]$ which corresponds to $[0,2\pi]$ +- the descriptor space (end-effector x-y position) is normalized between $[0,1]$ + +### Example Usage + +```python +import jax +from qdax.tasks.arm import arm_scoring_function + +random_key = jax.random.PRNGKey(0) + +# Get scoring function +scoring_fn = arm_scoring_function + +# Get Task Properties (parameter space, descriptor space, etc.) +min_param, max_param = 0., 1. +min_desc, max_desc = 0., 1. + +# Get initial batch of parameters +num_param_dimensions = ... +init_batch_size = ... +random_key, _subkey = jax.random.split(random_key) +initial_params = jax.random.uniform( + _subkey, + shape=(init_batch_size, num_param_dimensions), + minval=min_param, + maxval=max_param, +) + +# Get number of descriptor dimensions +desc_size = 2 +``` + +## Standard Functions +| Task | Parameter Dimensions | Parameter Bounds | Descriptor Dimensions | Descriptor Bounds | Description | +|----------------------|----------------------|------------------|-----------------------|-------------------|-------------| +| Sphere | n | $[0,1]^n$ | 2 | $[0,1]^n$ | | +| Rastrigin | n | $[0,1]^n$ | 2 | $[0,1]^n$ | | +| Rastrigin-Projection | n | $[0,1]^n$ | 2 | $[0,1]^n$ | | + +### Example Usage + +```python +import jax +from qdax.tasks.standard_functions import sphere_scoring_function + +random_key = jax.random.PRNGKey(0) + +# Get scoring function +scoring_fn = sphere_scoring_function + +# Get Task Properties (parameter space, descriptor space, etc.) +min_param, max_param = 0., 1. +min_desc, max_desc = 0., 1. + +# Get initial batch of parameters +num_param_dimensions = ... +init_batch_size = ... +random_key, _subkey = jax.random.split(random_key) +initial_params = jax.random.uniform( + _subkey, + shape=(init_batch_size, num_param_dimensions), + minval=min_param, + maxval=max_param, +) + +# Get number of descriptor dimensions +desc_size = 2 +``` + + +## Hyper-Volume Functions +"Hypervolume-based Benchmark Functions for Quality Diversity Algorithms" by Jean-Baptiste Mouret + +| Task | Parameter Dimensions | Parameter Bounds | Descriptor Dimensions | Descriptor Bounds | Description | +|-----------------------|----------------------|------------------|-----------------------|-------------------|-------------| +| Square | n | $[0,1]^n$ | n | $[0,1]^n$ | | +| Checkered | n | $[0,1]^n$ | n | $[0,1]^n$ | | +| Empty Circle | n | $[0,1]^n$ | n | $[0,1]^n$ | | +| Non-continous Islands | n | $[0,1]^n$ | n | $[0,1]^n$ | | +| Continous Islands | n | $[0,1]^n$ | n | $[0,1]^n$ | | + +### Example Usage + +```python +import jax +from qdax.tasks.hypervolume_functions import square_scoring_function + +random_key = jax.random.PRNGKey(0) + +# Get scoring function +scoring_fn = square_scoring_function + +# Get Task Properties (parameter space, descriptor space, etc.) +min_param, max_param = 0., 1. +min_desc, max_desc = 0., 1. + +# Get initial batch of parameters +num_param_dimensions = ... +init_batch_size = ... +random_key, _subkey = jax.random.split(random_key) +initial_params = jax.random.uniform( + _subkey, + shape=(init_batch_size, num_param_dimensions), + minval=min_param, + maxval=max_param, +) + +# Get number of descriptor dimensions +desc_size = num_param_dimensions +``` + +## QD Suite +"Towards QD-suite: developing a set of benchmarks for Quality-Diversity algorithms" by Achkan Salehi and Stephane Doncieux + +| Task | Parameter Dimensions | Parameter Bounds | Descriptor Dimensions | Descriptor Bounds | Description | +|--------------------------------|----------------------|--------------------------------------------------------------------------------|---------------------------------------|-----------------------------------------------------------------------------|-------------| +| archimedean-spiral-v0 | 1 | $[0,\alpha\pi]^n$ (angle param.)
$[0,max-arc-length]$ (arc length param.) | 1 (geodesic BD)
2 (euclidean BD) | $[0,max-arc-length]$ (geodesic BD)
$[-radius,radius]^2$ (euclidean BD) | | +| SSF-v0 | $n$ | Unbounded | 1 | $[ 0 ,$ ∞ $)$ | | +| deceptive-evolvability-v0
| $n$ (2 by default) | Bounded area including the two gaussian peaks | 1 | $[0,max-sum-gaussians]$ | | + +### Example Usage + +```python +import math +from qdax.tasks.qd_suite import archimedean_spiral_v0_angle_euclidean_task + +task = archimedean_spiral_v0_angle_euclidean_task + +# Get scoring function +scoring_fn = task.scoring_function + +# Get Task Properties (parameter space, descriptor space, etc.) +min_param, max_param = task.get_min_max_params() +min_desc, max_desc = task.get_bounded_min_max_descriptor() # To consider bounded Descriptor space +# If the task has a descriptor space that is not bounded, then the unbounded descriptor +# space can be obtained via the following: +# min_bd, max_bd = task.get_min_max_bd() + +# Get initial batch of parameters +initial_params = task.get_initial_parameters(batch_size=...) + +# Get number of descriptor dimensions +desc_size = task.get_descriptor_size() +``` + +## Brax-RL +| Task | Parameter Dimensions | Parameter Bounds | Descriptor Dimensions | Descriptor Bounds | Description | +|-----------------|----------------------|------------------|-----------------------|-------------------|-------------| +| pointmaze | NN params | Unbounded | 2 | $[-1,1]^2$ | | +| hopper_uni | NN params | Unbounded | 1 | $[0,1]$ | | +| walker2d_uni | NN params | Unbounded | 2 | $[0,1]^2$ | | +| halfcheetah_uni | NN params | Unbounded | 2 | $[0,1]^2$ | | +| ant_uni | NN params | Unbounded | 4 | $[0,1]^4$ | | +| humanoid_uni | NN params | Unbounded | 2 | $[0,1]^2$ | | +| ant_omni | NN params | Unbounded | 2 | $[-30,30]^2$ | | +| humanoid_omni | NN params | Unbounded | 2 | $[-30,30]^2$ | | +| anttrap | NN params | Unbounded | 2 | $[-8,8]\times[0,30]$ | | +| antmaze | NN params | Unbounded | 2 | $[-5,40]\times[-5,40]$ | | + +Notes: +- the parameter dimensions for default Brax-RL tasks depend on the size and architecture of the neural network used and can be customized and changed easily. If not set, a network size of two hidden layers of size 64 is used. + +### Example Usage + +See [Example in Notebook](../../examples/notebooks/mapelites_example.ipynb) diff --git a/qdax/tasks/__init__.py b/qdax/tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/qdax/tasks/arm.py b/qdax/tasks/arm.py new file mode 100644 index 00000000..7122ed63 --- /dev/null +++ b/qdax/tasks/arm.py @@ -0,0 +1,91 @@ +from typing import Tuple + +import jax +import jax.numpy as jnp + +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey + + +def arm(params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + Compute the fitness and BD of one individual in the Planar Arm task. + Based on the Planar Arm implementation in fast_map_elites + (https://github.com/hucebot/fast_map-elites). + + Args: + params: genotype of the individual to evaluate, corresponding to + the normalised angles for each DoF of the arm. + Params should be between [0, 1]. + + Returns: + f: the fitness of the individual, given as the variance of the angles. + bd: the bd of the individual, given as the [x, y] position of the + end-effector of the arm. + BD is normalized to [0, 1] regardless of the num of DoF. + Arm is centered at 0.5, 0.5. + """ + + x = jnp.clip(params, 0, 1) + size = params.shape[0] + + f = jnp.sqrt(jnp.mean(jnp.square(x - jnp.mean(x)))) + + # Compute the end-effector position - forward kinemateics + cum_angles = jnp.cumsum(2 * jnp.pi * x - jnp.pi) + x_pos = jnp.sum(jnp.cos(cum_angles)) / (2 * size) + 0.5 + y_pos = jnp.sum(jnp.sin(cum_angles)) / (2 * size) + 0.5 + + return -f, jnp.array([x_pos, y_pos]) + + +def arm_scoring_function( + params: Genotype, + random_key: RNGKey, +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Evaluate policies contained in params in parallel. + """ + fitnesses, descriptors = jax.vmap(arm)(params) + + return ( + fitnesses, + descriptors, + {}, + random_key, + ) + + +def noisy_arm_scoring_function( + params: Genotype, + random_key: RNGKey, + fit_variance: float, + desc_variance: float, + params_variance: float, +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Evaluate policies contained in params in parallel. + """ + + random_key, f_subkey, d_subkey, p_subkey = jax.random.split(random_key, num=4) + + # Add noise to the parameters + params = params + jax.random.normal(p_subkey, shape=params.shape) * params_variance + + # Evaluate + fitnesses, descriptors = jax.vmap(arm)(params) + + # Add noise to the fitnesses and descriptors + fitnesses = ( + fitnesses + jax.random.normal(f_subkey, shape=fitnesses.shape) * fit_variance + ) + descriptors = ( + descriptors + + jax.random.normal(d_subkey, shape=descriptors.shape) * desc_variance + ) + + return ( + fitnesses, + descriptors, + {}, + random_key, + ) diff --git a/qdax/tasks/brax_envs.py b/qdax/tasks/brax_envs.py new file mode 100644 index 00000000..693ec1a4 --- /dev/null +++ b/qdax/tasks/brax_envs.py @@ -0,0 +1,346 @@ +import functools +from functools import partial +from typing import Callable, Optional, Tuple + +import brax.envs +import flax.linen as nn +import jax +import jax.numpy as jnp + +import qdax.environments +from qdax import environments +from qdax.core.neuroevolution.buffers.buffer import QDTransition +from qdax.core.neuroevolution.mdp_utils import generate_unroll +from qdax.core.neuroevolution.networks.networks import MLP +from qdax.types import ( + Descriptor, + EnvState, + ExtraScores, + Fitness, + Genotype, + Params, + RNGKey, +) + + +def create_policy_network_play_step_fn( + env: brax.envs.Env, + policy_network: nn.Module, +) -> Callable[ + [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition] +]: + """ + Creates a function that when called, plays a step of the environment. + + Args: + env: The BRAX environment. + policy_network: The policy network structure used for creating and evaluating + policy controllers. + + Returns: + default_play_step_fn: A function that plays a step of the environment. + """ + # Define the function to play a step with the policy in the environment + def default_play_step_fn( + env_state: EnvState, + policy_params: Params, + random_key: RNGKey, + ) -> Tuple[EnvState, Params, RNGKey, QDTransition]: + """ + Play an environment step and return the updated EnvState and the transition. + + Args: env_state: The state of the environment (containing for instance the + actor joint positions and velocities, the reward...). policy_params: The + parameters of policies/controllers. random_key: JAX random key. + + Returns: + next_state: The updated environment state. + policy_params: The parameters of policies/controllers (unchanged). + random_key: The updated random key. + transition: containing some information about the transition: observation, + reward, next observation, policy action... + """ + + actions = policy_network.apply(policy_params, env_state.obs) + + state_desc = env_state.info["state_descriptor"] + next_state = env.step(env_state, actions) + + transition = QDTransition( + obs=env_state.obs, + next_obs=next_state.obs, + rewards=next_state.reward, + dones=next_state.done, + actions=actions, + truncations=next_state.info["truncation"], + state_desc=state_desc, + next_state_desc=next_state.info["state_descriptor"], + ) + + return next_state, policy_params, random_key, transition + + return default_play_step_fn + + +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_step_fn", + "behavior_descriptor_extractor", + ), +) +def scoring_function_brax_envs( + policies_params: Genotype, + random_key: RNGKey, + init_states: EnvState, + episode_length: int, + play_step_fn: Callable[ + [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition] + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policies_params in parallel in + deterministic or pseudo-deterministic environments. + + This rollout is only deterministic when all the init states are the same. + If the init states are fixed but different, as a policy is not necessarily + evaluated with the same environment everytime, this won't be determinist. + When the init states are different, this is not purely stochastic. + + Args: + policies_params: The parameters of closed-loop controllers/policies to evaluate. + random_key: A jax random key + episode_length: The maximal rollout length. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from evaluation + random_key: The updated random key. + """ + + # Perform rollouts with each policy + random_key, subkey = jax.random.split(random_key) + unroll_fn = partial( + generate_unroll, + episode_length=episode_length, + play_step_fn=play_step_fn, + random_key=subkey, + ) + + _final_state, data = jax.vmap(unroll_fn)(init_states, policies_params) + + # create a mask to extract data properly + is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1) + mask = jnp.roll(is_done, 1, axis=1) + mask = mask.at[:, 0].set(0) + + # Scores - add offset to ensure positive fitness (through positive rewards) + fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1) + descriptors = behavior_descriptor_extractor(data, mask) + + return ( + fitnesses, + descriptors, + { + "transitions": data, + }, + random_key, + ) + + +@partial( + jax.jit, + static_argnames=( + "episode_length", + "play_reset_fn", + "play_step_fn", + "behavior_descriptor_extractor", + ), +) +def reset_based_scoring_function_brax_envs( + policies_params: Genotype, + random_key: RNGKey, + episode_length: int, + play_reset_fn: Callable[[RNGKey], EnvState], + play_step_fn: Callable[ + [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition] + ], + behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor], +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """Evaluates policies contained in policies_params in parallel. + The play_reset_fn function allows for a more general scoring_function that can be + called with different batch-size and not only with a batch-size of the same + dimension as init_states. + + To define purely stochastic environments, using the reset function from the + environment, use "play_reset_fn = env.reset". + + To define purely deterministic environments, as in "scoring_function", generate + a single init_state using "init_state = env.reset(random_key)", then use + "play_reset_fn = lambda random_key: init_state". + + Args: + policies_params: The parameters of closed-loop controllers/policies to evaluate. + random_key: A jax random key + episode_length: The maximal rollout length. + play_reset_fn: The function to reset the environment and obtain initial states. + play_step_fn: The function to play a step of the environment. + behavior_descriptor_extractor: The function to extract the behavior descriptor. + + Returns: + fitness: Array of fitnesses of all evaluated policies + descriptor: Behavioural descriptors of all evaluated policies + extra_scores: Additional information resulting from the evaluation + random_key: The updated random key. + """ + + random_key, subkey = jax.random.split(random_key) + keys = jax.random.split(subkey, jax.tree_leaves(policies_params)[0].shape[0]) + reset_fn = jax.vmap(play_reset_fn) + init_states = reset_fn(keys) + + fitnesses, descriptors, extra_scores, random_key = scoring_function_brax_envs( + policies_params=policies_params, + random_key=random_key, + init_states=init_states, + episode_length=episode_length, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=behavior_descriptor_extractor, + ) + + return fitnesses, descriptors, extra_scores, random_key + + +def create_brax_scoring_fn( + env: brax.envs.Env, + policy_network: nn.Module, + batch_size: int, + bd_extraction_fn: Callable[[QDTransition, jnp.ndarray], Descriptor], + random_key: RNGKey, + play_step_fn: Optional[ + Callable[ + [EnvState, Params, RNGKey], Tuple[EnvState, Params, RNGKey, QDTransition] + ] + ] = None, + episode_length: int = 100, + is_reset_based: bool = False, + play_reset_fn: Optional[Callable[[RNGKey], EnvState]] = None, +) -> Tuple[ + Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]], + RNGKey, +]: + """ + Creates a scoring function to evaluate a policy in a BRAX task. + + Args: + env: The BRAX environment. + policy_network: The policy network controller. + batch_size: the number of environments we play simultaneously. + bd_extraction_fn: The behaviour descriptor extraction function. + random_key: a random key used for stochastic operations. + play_step_fn: the function used to perform environment rollouts and collect + evaluation episodes. If None, we use create_policy_network_play_step_fn + to generate it. + episode_length: The maximal episode length. + is_reset_based: Whether we reset the initial state of the robot before each + evaluation or not. + play_reset_fn: the function used to reset the environment to an initial state. + Only used if is_reset_based is True. If None, we take env.reset as + default reset function. + + Returns: + The scoring function: a function that takes a batch of genotypes and compute + their fitnesses and descriptors + The updated random key. + """ + if play_step_fn is None: + play_step_fn = create_policy_network_play_step_fn(env, policy_network) + if play_reset_fn is None: + play_reset_fn = env.reset + + if not is_reset_based: + # Create the initial environment states + random_key, subkey = jax.random.split(random_key) + keys = jnp.repeat(jnp.expand_dims(subkey, axis=0), repeats=batch_size, axis=0) + reset_fn = jax.jit(jax.vmap(env.reset)) + init_states = reset_fn(keys) + + scoring_fn = functools.partial( + scoring_function_brax_envs, + init_states=init_states, + episode_length=episode_length, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + else: + scoring_fn = functools.partial( + reset_based_scoring_function_brax_envs, + episode_length=episode_length, + play_reset_fn=play_reset_fn, + play_step_fn=play_step_fn, + behavior_descriptor_extractor=bd_extraction_fn, + ) + + return scoring_fn, random_key + + +def create_default_brax_task_components( + env_name: str, + batch_size: int, + random_key: RNGKey, + episode_length: int = 100, + mlp_policy_hidden_layer_sizes: Tuple[int, ...] = (64, 64), + is_reset_based: bool = False, +) -> Tuple[ + brax.envs.Env, + MLP, + Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]], + RNGKey, +]: + """ + Creates default environment, policy network and scoring function for a BRAX task. + + Args: + env_name: Name of the BRAX environment (e.g. "ant_omni", "walker2d_uni"...). + batch_size: The number of environments we play simultaneously. + random_key: Jax random key + episode_length: The maximal rollout length. + mlp_policy_hidden_layer_sizes: Hidden layer sizes of the policy network. + is_reset_based: Whether we reset the initial state of the robot before each + evaluation or not. + + Returns: + env: The BRAX environment. + policy_network: The policy network structure used for creating and evaluating + policy controllers. + scoring_fn: a function that takes a batch of genotypes and compute + their fitnesses and descriptors. + random_key: The updated random key. + """ + env = environments.create(env_name, episode_length=episode_length) + + # Init policy network + policy_layer_sizes = mlp_policy_hidden_layer_sizes + (env.action_size,) + policy_network = MLP( + layer_sizes=policy_layer_sizes, + kernel_init=jax.nn.initializers.lecun_uniform(), + final_activation=jnp.tanh, + ) + + bd_extraction_fn = qdax.environments.behavior_descriptor_extractor[env_name] + + scoring_fn, random_key = create_brax_scoring_fn( + env, + policy_network, + batch_size, + bd_extraction_fn, + random_key, + episode_length=episode_length, + is_reset_based=is_reset_based, + ) + + return env, policy_network, scoring_fn, random_key diff --git a/qdax/tasks/hypervolume_functions.py b/qdax/tasks/hypervolume_functions.py new file mode 100644 index 00000000..f4936574 --- /dev/null +++ b/qdax/tasks/hypervolume_functions.py @@ -0,0 +1,95 @@ +""" +Hypervolume Benchmark Functions in the paper by +J.B. Mouret, "Hypervolume-based Benchmark Functions for Quality Diversity Algorithms" +""" + +from typing import Callable, Tuple + +import jax +import jax.numpy as jnp + +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey + + +def square(params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + Seach space should be [0,1]^n + BD space should be [0,1]^n + """ + freq = 5 + f = 1 - jnp.prod(params) + bd = jnp.sin(freq * params) + return f, bd + + +def checkered(params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + Seach space should be [0,1]^n + BD space should be [0,1]^n + """ + freq = 5 + f = jnp.prod(jnp.sin(params * 50)) + bd = jnp.sin(params * freq) + return f, bd + + +def empty_circle(params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + Seach space should be [0,1]^n + BD space should be [0,1]^n + """ + + def _gaussian(x: jnp.ndarray, mu: float, sig: float) -> jnp.ndarray: + return jnp.exp(-jnp.power(x - mu, 2.0) / (2 * jnp.power(sig, 2.0))) + + freq = 40 + centre = jnp.ones_like(params) * 0.5 + distance_from_centre = jnp.linalg.norm(params - centre) + f = _gaussian(distance_from_centre, mu=0.5, sig=0.3) + bd = jnp.sin(freq * params) + return f, bd + + +def non_continous_islands(params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + Seach space should be [0,1]^n + BD space should be [0,1]^n + """ + f = jnp.prod(params) + bd = jnp.round(10 * params) / 10 + return f, bd + + +def continous_islands(params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + Seach space should be [0,1]^n + BD space should be [0,1]^n + """ + coeff = 20 + f = jnp.prod(params) + bd = params - jnp.sin(coeff * jnp.pi * params) / (coeff * jnp.pi) + return f, bd + + +def get_scoring_function( + task_fn: Callable[[Genotype], Tuple[Fitness, Descriptor]] +) -> Callable[[Genotype, RNGKey], Tuple[Fitness, Descriptor, ExtraScores, RNGKey]]: + def scoring_function( + params: Genotype, + random_key: RNGKey, + ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Evaluate params in parallel + """ + fitnesses, descriptors = jax.vmap(task_fn)(params) + + return (fitnesses, descriptors, {}, random_key) + + return scoring_function + + +square_scoring_function = get_scoring_function(square) +checkered_scoring_function = get_scoring_function(checkered) +empty_circle_scoring_function = get_scoring_function(empty_circle) +non_continous_islands_scoring_function = get_scoring_function(non_continous_islands) +continous_islands_scoring_function = get_scoring_function(continous_islands) diff --git a/qdax/tasks/qd_suite/__init__.py b/qdax/tasks/qd_suite/__init__.py new file mode 100644 index 00000000..ae3de3d6 --- /dev/null +++ b/qdax/tasks/qd_suite/__init__.py @@ -0,0 +1,27 @@ +from qdax.tasks.qd_suite.archimedean_spiral import ( + ArchimedeanBD, + ArchimedeanSpiralV0, + ParameterizationGenotype, +) +from qdax.tasks.qd_suite.deceptive_evolvability import DeceptiveEvolvabilityV0 +from qdax.tasks.qd_suite.ssf import SsfV0 + +archimedean_spiral_v0_angle_euclidean_task = ArchimedeanSpiralV0( + ParameterizationGenotype.angle, + ArchimedeanBD.euclidean, +) +archimedean_spiral_v0_angle_geodesic_task = ArchimedeanSpiralV0( + ParameterizationGenotype.angle, + ArchimedeanBD.geodesic, +) +archimedean_spiral_v0_arc_length_euclidean_task = ArchimedeanSpiralV0( + ParameterizationGenotype.arc_length, + ArchimedeanBD.euclidean, +) +archimedean_spiral_v0_arc_length_geodesic_task = ArchimedeanSpiralV0( + ParameterizationGenotype.arc_length, + ArchimedeanBD.geodesic, +) +deceptive_evolvability_v0_task = DeceptiveEvolvabilityV0() +ssf_v0_param_size_1_task = SsfV0(param_size=1) +ssf_v0_param_size_2_task = SsfV0(param_size=2) diff --git a/qdax/tasks/qd_suite/archimedean_spiral.py b/qdax/tasks/qd_suite/archimedean_spiral.py new file mode 100644 index 00000000..5784f596 --- /dev/null +++ b/qdax/tasks/qd_suite/archimedean_spiral.py @@ -0,0 +1,264 @@ +from enum import Enum +from typing import Optional, Tuple, Union + +import jax.lax +import jax.numpy as jnp + +from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask +from qdax.types import Descriptor, Fitness, Genotype + + +class ParameterizationGenotype(Enum): + angle = "angle" + arc_length = "arc_length" + + +class ArchimedeanBD(Enum): + euclidean = "euclidean" + geodesic = "geodesic" + + +class ArchimedeanSpiralV0(QDSuiteTask): + def __init__( + self, + parameterization: ParameterizationGenotype, + archimedean_bd: ArchimedeanBD, + amplitude: float = 0.01, + precision: Optional[float] = None, + alpha: float = 40.0, + ): + """ + Implements the Archimedean spiral task from Salehi et al. (2022): + https://arxiv.org/abs/2205.03162 + + Args: + parameterization: The parameterization of the genotype, + can be either angle or arc length. + archimedean_bd: The Archimedean BD, can be either euclidean or + geodesic. + amplitude: The amplitude of the Archimedean spiral. + precision: The precision of the approximation of the angle from the + arc length. + alpha: controls the length/maximal angle of the Archimedean spiral. + """ + self.parameterization = parameterization + self.archimedean_bd = archimedean_bd + self.amplitude = amplitude + if precision is None: + self.precision = alpha * jnp.pi / 1e7 + else: + self.precision = precision + self.alpha = alpha + + def _gamma(self, angle: Union[float, jnp.ndarray]) -> jnp.ndarray: + """ + The function gamma is the function that maps the angle to the euclidean + coordinates of the Archimedean spiral. + + Args: + angle: The angle of the Archimedean spiral. + + Returns: + The euclidean coordinates of the Archimedean spiral. + """ + return jnp.hstack( + [ + self.amplitude * angle * jnp.cos(angle), + self.amplitude * angle * jnp.sin(angle), + ] + ) + + def get_arc_length(self, angle: Union[float, jnp.ndarray]) -> jnp.ndarray: + """ + The function arc_length is the function that maps the angle to the arc + length of the Archimedean spiral. + + Args: + angle: The angle of the Archimedean spiral. + + Returns: + The arc length of the Archimedean spiral. + """ + return (self.amplitude / 2) * ( + angle * jnp.sqrt(1 + jnp.power(angle, 2)) + + jnp.log(angle + jnp.sqrt(1 + jnp.power(angle, 2))) + ) + + def _cond_fun(self, elem: Tuple[float, float, float]) -> jnp.bool_: + """ + The function cond_fun is the function that checks if the precision has + been reached. + + Args: + elem: The tuple containing the lower bound, the upper bound and the + target arc length. + + Returns: + True if the precision has been reached, False otherwise. + """ + inf, sup, target = elem + return (sup - inf) > self.precision + + def _body_fun(self, elem: Tuple[float, float, float]) -> Tuple[float, float, float]: + """ + The function body_fun is the function that computes the next iteration + of the while loop. + + Args: + elem: The tuple containing the lower bound, the upper bound and the + target arc length. + + Returns: + The tuple containing the lower bound, the upper bound and the + target arc length. + """ + inf, sup, target_angle_length = elem + middle = (sup + inf) / 2.0 + arc_length_middle = self.get_arc_length(middle) + new_inf, new_sup = jax.lax.cond( + target_angle_length < arc_length_middle, + lambda: (inf, middle), + lambda: (middle, sup), + ) + return new_inf, new_sup, target_angle_length + + def _approximate_angle_from_arc_length( + self, target_arc_length: float + ) -> jnp.ndarray: + """ + The function approximate_angle_from_arc_length is the function that + approximates the angle from the arc length. + + Args: + target_arc_length: The target arc length. + + Returns: + The angle. + """ + inf, sup, _ = jax.lax.while_loop( + self._cond_fun, + self._body_fun, + init_val=(0.0, self.alpha * jnp.pi, target_arc_length), + ) + middle = (sup + inf) / 2.0 + return jnp.asarray(middle) + + def evaluation(self, params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + The function evaluation computes the fitness and the descriptor of the + parameters passed as input. The fitness is always 1.0 as no elitism is + considered. + + Args: + params: The parameters of the Archimedean spiral. + + Returns: + The fitness and the descriptor of the parameters. + """ + constant_fitness = jnp.asarray(1.0) + + if ( + self.archimedean_bd == ArchimedeanBD.geodesic + and self.parameterization == ParameterizationGenotype.arc_length + ): + arc_length = params + return constant_fitness, arc_length + elif ( + self.archimedean_bd == ArchimedeanBD.geodesic + and self.parameterization == ParameterizationGenotype.angle + ): + angle = params + arc_length = self.get_arc_length(angle) + return constant_fitness, arc_length + elif ( + self.archimedean_bd == ArchimedeanBD.euclidean + and self.parameterization == ParameterizationGenotype.arc_length + ): + arc_length = params + angle = self._approximate_angle_from_arc_length(arc_length[0]) + euclidean_bd = self._gamma(angle) + return constant_fitness, euclidean_bd + elif ( + self.archimedean_bd == ArchimedeanBD.euclidean + and self.parameterization == ParameterizationGenotype.angle + ): + angle = params + return constant_fitness, self._gamma(angle) + else: + raise ValueError("Invalid parameterization and/or BD") + + def get_descriptor_size(self) -> int: + """ + The function get_descriptor_size returns the size of the descriptor. + + Returns: + The size of the descriptor. + """ + if self.archimedean_bd == ArchimedeanBD.euclidean: + return 2 + elif self.archimedean_bd == ArchimedeanBD.geodesic: + return 1 + else: + raise ValueError("Invalid BD") + + def get_min_max_descriptor(self) -> Tuple[float, float]: + """ + The function get_min_max_descriptor returns the minimum and maximum + bounds of the descriptor space. + + Returns: + The minimum and maximum value of the descriptor. + """ + max_angle = self.alpha * jnp.pi + max_norm = jnp.linalg.norm(self._gamma(max_angle)) + + if self.archimedean_bd == ArchimedeanBD.euclidean: + return -max_norm, max_norm + elif self.archimedean_bd == ArchimedeanBD.geodesic: + max_arc_length = self.get_arc_length(max_angle) + return 0.0, max_arc_length.item() + else: + raise ValueError("Invalid BD") + + def get_min_max_params(self) -> Tuple[float, float]: + """ + The function get_min_max_params returns the minimum and maximum value + of the parameter space. + + Returns: + The minimum and maximum value of the parameters. + """ + if self.parameterization == ParameterizationGenotype.angle: + max_angle = self.alpha * jnp.pi + return 0.0, max_angle + elif self.parameterization == ParameterizationGenotype.arc_length: + max_angle = self.alpha * jnp.pi + max_arc_length = self.get_arc_length(max_angle) + return 0, max_arc_length.item() + else: + raise ValueError("Invalid parameterization") + + def get_initial_parameters(self, batch_size: int) -> Genotype: + """ + The function get_initial_parameters returns the initial parameters. + + Args: + batch_size: The batch size. + + Returns: + The initial parameters (of size batch_size). + """ + max_angle = self.alpha * jnp.pi + mid_angle = max_angle / 2.0 + mid_number_turns = 1 + int(mid_angle / (2.0 * jnp.pi)) + horizontal_left_mid_angle = mid_number_turns * jnp.pi * 2 + + if self.parameterization == ParameterizationGenotype.angle: + angle_array = jnp.asarray(horizontal_left_mid_angle).reshape((1, 1)) + return jnp.repeat(angle_array, batch_size, axis=0) + elif self.parameterization == ParameterizationGenotype.arc_length: + arc_length = self.get_arc_length(horizontal_left_mid_angle) + length_array = jnp.asarray(arc_length).reshape((1, 1)) + return jnp.repeat(length_array, batch_size, axis=0) + else: + raise ValueError("Invalid parameterization") diff --git a/qdax/tasks/qd_suite/deceptive_evolvability.py b/qdax/tasks/qd_suite/deceptive_evolvability.py new file mode 100644 index 00000000..d5be0688 --- /dev/null +++ b/qdax/tasks/qd_suite/deceptive_evolvability.py @@ -0,0 +1,158 @@ +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp + +from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask +from qdax.types import Descriptor, Fitness, Genotype + + +def multivariate_normal( + params: Genotype, + mu: jnp.ndarray, + sigma: float, +) -> jnp.ndarray: + """ + Compute the un-normalised multivariate normal density. + """ + params = params.reshape((-1, 1)) + mu = mu.reshape((-1, 1)) + + x = params - mu + return jnp.exp(-0.5 * x.T.dot(x).ravel() / (sigma * sigma)) * ( + 2 * jnp.pi * sigma * sigma + ) ** (-params.shape[0] / 2) + + +class DeceptiveEvolvabilityV0(QDSuiteTask): + default_mu_1 = jnp.array([50.0, 125.0]) + default_sigma_1 = jnp.sqrt(70.0) + default_beta = 20.0 + default_mu_2 = jnp.array([150.0, 125.0]) + default_sigma_2 = jnp.sqrt(1e3) + + def __init__( + self, + mu_1: Optional[Genotype] = None, + sigma_1: Optional[float] = None, + beta: Optional[float] = None, + mu_2: Optional[Genotype] = None, + sigma_2: Optional[float] = None, + ): + """ + Initialize the deceptive evolvability task + from Achkan Salehi and Stephane Doncieux + + Args: + mu_1: The mean of the first Gaussian. + sigma_1: The standard deviation of the first Gaussian. + beta: The weight of the second Gaussian. + mu_2: The mean of the second Gaussian. + sigma_2: The standard deviation of the second Gaussian. + """ + if mu_1 is None: + mu_1 = self.default_mu_1 + if sigma_1 is None: + sigma_1 = self.default_sigma_1 + if beta is None: + beta = self.default_beta + if mu_2 is None: + mu_2 = self.default_mu_2 + if sigma_2 is None: + sigma_2 = self.default_sigma_2 + + self.mu_1 = mu_1 + self.sigma_1 = sigma_1 + self.beta = beta + self.mu_2 = mu_2 + self.sigma_2 = sigma_2 + + def evaluation(self, params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + Compute the fitness and descriptor of the deceptive evolvability task. + + The fitness is always 1.0, as no elitism is considered. + + Args: + params: The parameters to evaluate. + + Returns: + The fitness and descriptor. + """ + bd = multivariate_normal( + params, self.mu_1, self.sigma_1 + ) + self.beta * multivariate_normal(params, self.mu_2, self.sigma_2) + constant_fitness = jnp.asarray(1.0) + return constant_fitness, bd + + def get_saddle_point(self) -> Genotype: + """ + Compute the saddle point of the deceptive evolvability task. + + Returns: + The saddle point. + """ + + def _func_to_minimize(theta: Genotype) -> Descriptor: + return self.evaluation(theta)[1] + + t = jnp.linspace(0.0, 1.0, 1000) + + considered_points = jax.vmap( + lambda _t: _t * (self.mu_2 - self.mu_1) + self.mu_1 + )(t) + + results = jax.vmap(_func_to_minimize)(considered_points) + + index_min = jnp.argmin(results) + + return considered_points[index_min] + + def get_descriptor_size(self) -> int: + """ + Get the size of the descriptor. + + Returns: + The size of the descriptor. + """ + return 1 + + def get_min_max_descriptor(self) -> Tuple[float, float]: + """ + Get the minimum and maximum descriptor values. + + Returns: + The minimum and maximum descriptor values. + """ + potential_max_1 = self.evaluation(self.mu_1)[1] + potential_max_2 = self.evaluation(self.mu_2)[1] + return 0.0, jnp.maximum(potential_max_1, potential_max_2)[0] + + def get_min_max_params(self) -> Tuple[float, float]: + """ + Get the minimum and maximum parameter values. + + Returns: + The minimum and maximum parameter values. + """ + potential_max_1 = jnp.max(self.mu_1 + 3 * self.sigma_1) + potential_max_2 = jnp.max(self.mu_2 + 3 * self.sigma_2) + max_final = jnp.maximum(potential_max_1, potential_max_2) + + potential_min_1 = jnp.min(self.mu_1 - 3 * self.sigma_1) + potential_min_2 = jnp.min(self.mu_2 - 3 * self.sigma_2) + min_final = jnp.minimum(potential_min_1, potential_min_2) + return min_final, max_final + + def get_initial_parameters(self, batch_size: int) -> Genotype: + """ + Get the initial parameters. + + Args: + batch_size: The batch size. + + Returns: + The initial parameters. + """ + saddle_point = self.get_saddle_point() + return jnp.repeat(jnp.expand_dims(saddle_point, axis=0), batch_size, axis=0) diff --git a/qdax/tasks/qd_suite/qd_suite_task.py b/qdax/tasks/qd_suite/qd_suite_task.py new file mode 100644 index 00000000..6f1af76f --- /dev/null +++ b/qdax/tasks/qd_suite/qd_suite_task.py @@ -0,0 +1,98 @@ +import abc +from typing import Tuple, Union + +import jax +from jax import numpy as jnp + +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey + + +class QDSuiteTask(abc.ABC): + @abc.abstractmethod + def evaluation(self, params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + The function evaluation computes the fitness and the descriptor of the + parameters passed as input. + + Args: + params: The batch of parameters to evaluate + + Returns: + The fitnesses and the descriptors of the parameters. + """ + ... + + def scoring_function( + self, + params: Genotype, + random_key: RNGKey, + ) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Evaluate params in parallel + """ + fitnesses, descriptors = jax.vmap(self.evaluation)(params) + + return fitnesses, descriptors, {}, random_key + + @abc.abstractmethod + def get_descriptor_size(self) -> int: + """ + The function get_descriptor_size returns the size of the descriptor. + + Returns: + The size of the descriptor. + """ + ... + + @abc.abstractmethod + def get_min_max_descriptor( + self, + ) -> Tuple[Union[float, jnp.ndarray], Union[float, jnp.ndarray]]: + """ + Get the minimum and maximum descriptor values. + + Returns: + The minimum and maximum descriptor values. + """ + ... + + def get_bounded_min_max_descriptor( + self, + ) -> Tuple[Union[float, jnp.ndarray], Union[float, jnp.ndarray]]: + """ + Returns: + The minimum and maximum descriptor assuming that + the descriptor space is bounded. + """ + min_bd, max_bd = self.get_min_max_descriptor() + if jnp.isinf(max_bd) or jnp.isinf(min_bd): + raise NotImplementedError( + "Boundedness has not been implemented " "for this unbounded task" + ) + else: + return min_bd, max_bd + + @abc.abstractmethod + def get_min_max_params( + self, + ) -> Tuple[Union[float, jnp.ndarray], Union[float, jnp.ndarray]]: + """ + Get the minimum and maximum parameter values. + + Returns: + The minimum and maximum parameter values. + """ + ... + + @abc.abstractmethod + def get_initial_parameters(self, batch_size: int) -> Genotype: + """ + Get the initial parameters. + + Args: + batch_size: The batch size. + + Returns: + The initial parameters. + """ + ... diff --git a/qdax/tasks/qd_suite/ssf.py b/qdax/tasks/qd_suite/ssf.py new file mode 100644 index 00000000..547bee8d --- /dev/null +++ b/qdax/tasks/qd_suite/ssf.py @@ -0,0 +1,120 @@ +from typing import Tuple + +import jax +import jax.numpy as jnp + +from qdax.tasks.qd_suite.qd_suite_task import QDSuiteTask +from qdax.types import Descriptor, Fitness, Genotype + + +class SsfV0(QDSuiteTask): + def __init__( + self, + param_size: int, + ): + """ + Implements the Self-Similar Function (SSF) task + from Achkan Salehi and Stephane Doncieux. + + Args: + param_size: The number of parameters in the genotype. + """ + self.param_size = param_size + + def evaluation( + self, + params: Genotype, + ) -> Tuple[Fitness, Descriptor]: + """ + The function evaluation computes the fitness and the descriptor of the + parameters passed as input. The fitness is always 1.0 as the task does + not consider elitism. + + Args: + params: The batch of parameters to evaluate + + Returns: + The fitnesses and the descriptors of the parameters. + """ + norm = jnp.linalg.norm(params, ord=2) + r_2k_plus_1, _, k = self._get_k(params) + index = jnp.floor(norm / r_2k_plus_1) + bd = jax.lax.cond(index == 0.0, lambda: norm, lambda: r_2k_plus_1) + constant_fitness = jnp.asarray(1.0) + bd = jnp.asarray(bd).reshape((self.get_descriptor_size(),)) + return constant_fitness, bd + + def get_descriptor_size(self) -> int: + """ + Returns: + The descriptor size. + """ + return 1 + + def get_min_max_descriptor(self) -> Tuple[float, float]: + """ + Returns: + The minimum and maximum descriptor. + """ + return 0.0, jnp.inf + + def get_bounded_min_max_descriptor(self) -> Tuple[float, float]: + """ + Returns: + The minimum and maximum descriptor assuming that + the descriptor space is bounded. + """ + return 0.0, 1000.0 + + def get_min_max_params(self) -> Tuple[float, float]: + """ + Returns: + The minimum and maximum parameters (here + the parameter space is unbounded). + """ + return -jnp.inf, jnp.inf + + def get_initial_parameters(self, batch_size: int) -> Genotype: + """ + Returns: + The initial parameters (of size batch_size x param_size). + """ + return jnp.zeros(shape=(batch_size, self.param_size)) + + def _get_k(self, params: Genotype) -> Tuple[float, float, int]: + """ + Computes the k-th level of the SSF. + + Args: + params: The parameters to evaluate. + + Returns: + (R_2k_plus_1, norm of params, k) + """ + norm_params = jnp.linalg.norm(params, ord=2) + init_k = 0 + r_0 = 0.0 + r_1 = r_0 + self._get_r_add_odd(init_k) + (r_2k, r_2k_plus_1), norm, k = jax.lax.while_loop( + self._cond_fun, self._body_fun, ((0.0, r_1), norm_params, init_k) + ) + return r_2k_plus_1, norm, k + + def _cond_fun(self, elem: Tuple[Tuple[float, float], float, int]) -> jnp.bool_: + (r_2k, r_2k_plus_1), norm, k = elem + return r_2k_plus_1 + self._get_r_add_even(k + 1) < norm + + def _body_fun( + self, elem: Tuple[Tuple[float, float], float, int] + ) -> Tuple[Tuple[float, float], float, int]: + (r_2k, r_2k_plus_1), norm, k = elem + k_plus_1 = k + 1 + r_2k_plus_2 = r_2k_plus_1 + self._get_r_add_even(k_plus_1) + r_2k_plus_3 = r_2k_plus_2 + self._get_r_add_odd(k_plus_1) + return (r_2k_plus_2, r_2k_plus_3), norm, k + 1 + + def _get_r_add_odd(self, k: int) -> float: + return 2 * (k**3) + 1 + + def _get_r_add_even(self, k: int) -> float: + return 2 * ((k - 1) ** 3) + 1 diff --git a/qdax/tasks/standard_functions.py b/qdax/tasks/standard_functions.py new file mode 100644 index 00000000..53d5b492 --- /dev/null +++ b/qdax/tasks/standard_functions.py @@ -0,0 +1,120 @@ +from typing import Tuple + +import jax +import jax.numpy as jnp + +from qdax.types import Descriptor, ExtraScores, Fitness, Genotype, RNGKey + + +def rastrigin(params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + 2-D BD + """ + x = params * 10 - 5 # scaling to [-5, 5] + f = jnp.asarray(10.0 * x.shape[0]) + jnp.sum(x * x - 10 * jnp.cos(2 * jnp.pi * x)) + return -f, jnp.asarray([params[0], params[1]]) + + +def sphere(params: Genotype) -> Tuple[Fitness, Descriptor]: + """ + 2-D BD + """ + x = params * 10 - 5 # scaling to [-5, 5] + f = (x * x).sum() + return -f, jnp.array([params[0], params[1]]) + + +def rastrigin_scoring_function( + params: Genotype, + random_key: RNGKey, +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Scoring function for the rastrigin function + """ + fitnesses, descriptors = jax.vmap(rastrigin)(params) + + return fitnesses, descriptors, {}, random_key + + +def sphere_scoring_function( + params: Genotype, + random_key: RNGKey, +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Scoring function for the sphere function + """ + fitnesses, descriptors = jax.vmap(sphere)(params) + + return fitnesses, descriptors, {}, random_key + + +def _rastrigin_proj_scoring( + params: Genotype, minval: float, maxval: float +) -> Tuple[Fitness, Descriptor, ExtraScores]: + """ + Rastrigin function with a folding of the behaviour space. + + Args: + params: Genotype + minval: minimum value of the parameters + maxval: maximum value of the parameters + + Returns: + fitnesses + descriptors + extra_scores (containing the gradients of the + fitnesses and descriptors) + """ + + def rastrigin_scoring(x: jnp.ndarray) -> jnp.ndarray: + return -( + jnp.asarray(10 * x.shape[-1]) + + jnp.sum( + (x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)) + ) + ) + + def clip(x: jnp.ndarray) -> jnp.ndarray: + return x * (x <= maxval) * (x >= +minval) + maxval / x * ( + (x > maxval) + (x < +minval) + ) + + def _rastrigin_descriptor_1(x: jnp.ndarray) -> jnp.ndarray: + return jnp.mean(clip(x[: x.shape[0] // 2])) + + def _rastrigin_descriptor_2(x: jnp.ndarray) -> jnp.ndarray: + return jnp.mean(clip(x[x.shape[0] // 2 :])) + + def rastrigin_descriptors(x: jnp.ndarray) -> jnp.ndarray: + return jnp.array([_rastrigin_descriptor_1(x), _rastrigin_descriptor_2(x)]) + + # gradient function + rastrigin_grad_scores = jax.grad(rastrigin_scoring) + + fitnesses, descriptors = rastrigin_scoring(params), rastrigin_descriptors(params) + gradients = jnp.array( + [ + rastrigin_grad_scores(params), + jax.grad(_rastrigin_descriptor_1)(params), + jax.grad(_rastrigin_descriptor_2)(params), + ] + ).T + gradients = jnp.nan_to_num(gradients) + + return fitnesses, descriptors, {"gradients": gradients} + + +def rastrigin_proj_scoring_function( + params: Genotype, random_key: RNGKey, minval: float = -5.12, maxval: float = 5.12 +) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]: + """ + Scoring function for the rastrigin function with + a folding of the behaviour space. + """ + + # vmap only over the Genotypes + fitnesses, descriptors, extra_scores = jax.vmap( + _rastrigin_proj_scoring, in_axes=(0, None, None) + )(params, minval, maxval) + + return fitnesses, descriptors, extra_scores, random_key diff --git a/qdax/types.py b/qdax/types.py index be4b3d99..0699fb56 100644 --- a/qdax/types.py +++ b/qdax/types.py @@ -2,6 +2,7 @@ from typing import Dict, Generic, TypeVar, Union +import brax.envs import jax import jax.numpy as jnp from chex import ArrayTree @@ -12,7 +13,7 @@ Action: TypeAlias = jnp.ndarray Reward: TypeAlias = jnp.ndarray Done: TypeAlias = jnp.ndarray -EnvState: TypeAlias = jnp.ndarray +EnvState: TypeAlias = brax.envs.State Params: TypeAlias = ArrayTree # Evolution types diff --git a/qdax/utils/plotting.py b/qdax/utils/plotting.py index 5df4b48f..9b107c7e 100644 --- a/qdax/utils/plotting.py +++ b/qdax/utils/plotting.py @@ -603,7 +603,6 @@ def plot_multidimensional_map_elites_grid( ) -> Tuple[Optional[Figure], Axes]: """Plot a visual 2D representation of a multidimensional MAP-Elites repertoire (where the dimensionality of descriptors can be greater than 2). - Args: repertoire: the MAP-Elites repertoire to plot. minval: minimum values for the descriptors @@ -612,10 +611,8 @@ def plot_multidimensional_map_elites_grid( ax: a matplotlib axe for the figure to plot. Defaults to None. vmin: minimum value for the fitness. Defaults to None. vmax: maximum value for the fitness. Defaults to None. - Raises: ValueError: the resolution should be an int or a tuple - Returns: A figure and axes object, corresponding to the visualisation of the repertoire. @@ -661,16 +658,16 @@ def plot_multidimensional_map_elites_grid( size_grid_y = np.prod(np.array(grid_shape[1::2]), dtype=int) # initialise the grid - grid_2d = jnp.full( - (size_grid_x, size_grid_y), + grid_2d = np.full( + (size_grid_x.item(), size_grid_y.item()), fill_value=jnp.nan, ) # put solutions in the grid according to their projected 2-dimensional coordinates - for _, (desc, fit) in enumerate(zip(descriptors_integers, non_empty_fitnesses)): + for desc, fit in zip(descriptors_integers, non_empty_fitnesses): projection_2d = _get_projection_in_2d(desc, grid_shape) if jnp.isnan(grid_2d[projection_2d]) or fit.item() > grid_2d[projection_2d]: - grid_2d = grid_2d.at[projection_2d].set(fit.item()) + grid_2d[projection_2d] = fit.item() # set plot parameters font_size = 12 @@ -724,11 +721,9 @@ def _get_ticks_positions( ) -> jnp.ndarray: """ Get the positions of the ticks on the grid axis. - Args: total_size_grid_axis: total size of the grid axis step_ticks_on_axis: step of the ticks - Returns: The positions of the ticks on the plot. """ @@ -777,23 +772,44 @@ def _get_ticks_positions( ) ax.grid(which="minor", alpha=1.0, color="#000000", linewidth=0.5) - ax.grid(which="major", alpha=1.0, color="#000000", linewidth=2.5) - - ax.set_xticklabels( - [ - f"{x:.2}" - for x in jnp.around( - jnp.linspace(minval[0], maxval[0], num=len(major_ticks_x)), decimals=2 + if len(grid_shape) > 2: + ax.grid(which="major", alpha=1.0, color="#000000", linewidth=2.5) + + def _get_positions_labels( + _minval: float, _maxval: float, _number_ticks: int, _step_labels_ticks: int + ) -> List[str]: + positions = jnp.linspace(_minval, _maxval, num=_number_ticks) + + list_str_positions = [] + for index_tick, position in enumerate(positions): + if index_tick % _step_labels_ticks != 0: + character = "" + else: + character = f"{position:.2E}" + list_str_positions.append(character) + # forcing the last tick label + list_str_positions[-1] = f"{positions[-1]:.2E}" + return list_str_positions + + number_label_ticks = 4 + + if len(major_ticks_x) // number_label_ticks > 0: + ax.set_xticklabels( + _get_positions_labels( + _minval=minval[0], + _maxval=maxval[0], + _number_ticks=len(major_ticks_x), + _step_labels_ticks=len(major_ticks_x) // number_label_ticks, ) - ] - ) - ax.set_yticklabels( - [ - f"{y:.2}" - for y in jnp.around( - jnp.linspace(minval[1], maxval[1], num=len(major_ticks_y)), decimals=2 + ) + if len(major_ticks_y) // number_label_ticks > 0: + ax.set_yticklabels( + _get_positions_labels( + _minval=minval[1], + _maxval=maxval[1], + _number_ticks=len(major_ticks_y), + _step_labels_ticks=len(major_ticks_y) // number_label_ticks, ) - ] - ) + ) return fig, ax diff --git a/tests/baselines_test/pgame_test.py b/tests/baselines_test/pgame_test.py index 88791dda..17402847 100644 --- a/tests/baselines_test/pgame_test.py +++ b/tests/baselines_test/pgame_test.py @@ -14,8 +14,8 @@ from qdax.core.emitters.pga_me_emitter import PGAMEConfig, PGAMEEmitter from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.core.neuroevolution.mdp_utils import scoring_function from qdax.core.neuroevolution.networks.networks import MLP +from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.types import EnvState, Params, RNGKey @@ -151,7 +151,7 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict: # Prepare the scoring function bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] scoring_fn = functools.partial( - scoring_function, + scoring_function_brax_envs, init_states=init_states, episode_length=episode_length, play_step_fn=play_step_fn, diff --git a/tests/baselines_test/sac_test.py b/tests/baselines_test/sac_test.py index 115d7d64..6d171460 100644 --- a/tests/baselines_test/sac_test.py +++ b/tests/baselines_test/sac_test.py @@ -3,7 +3,6 @@ from functools import partial from typing import Any, Tuple -import brax import jax import pytest @@ -11,6 +10,7 @@ from qdax.baselines.sac import SAC, SacConfig, TrainingState from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition from qdax.core.neuroevolution.sac_utils import do_iteration_fn, warmstart_buffer +from qdax.types import EnvState def test_sac() -> None: @@ -129,9 +129,9 @@ def test_sac() -> None: @jax.jit def _scan_do_iteration( - carry: Tuple[TrainingState, brax.envs.State, ReplayBuffer], + carry: Tuple[TrainingState, EnvState, ReplayBuffer], unused_arg: Any, - ) -> Tuple[Tuple[TrainingState, brax.envs.State, ReplayBuffer], Any]: + ) -> Tuple[Tuple[TrainingState, EnvState, ReplayBuffer], Any]: ( training_state, env_state, diff --git a/tests/core_test/map_elites_test.py b/tests/core_test/map_elites_test.py index 6b2de7bb..66748079 100644 --- a/tests/core_test/map_elites_test.py +++ b/tests/core_test/map_elites_test.py @@ -16,8 +16,8 @@ from qdax.core.emitters.standard_emitters import MixingEmitter from qdax.core.map_elites import MAPElites from qdax.core.neuroevolution.buffers.buffer import QDTransition -from qdax.core.neuroevolution.mdp_utils import scoring_function from qdax.core.neuroevolution.networks.networks import MLP +from qdax.tasks.brax_envs import scoring_function_brax_envs from qdax.types import EnvState, Params, RNGKey @@ -94,7 +94,7 @@ def play_step_fn( # Prepare the scoring function bd_extraction_fn = environments.behavior_descriptor_extractor[env_name] scoring_fn = functools.partial( - scoring_function, + scoring_function_brax_envs, init_states=init_states, episode_length=episode_length, play_step_fn=play_step_fn, diff --git a/tests/default_tasks_test/arm_test.py b/tests/default_tasks_test/arm_test.py new file mode 100644 index 00000000..e71e761c --- /dev/null +++ b/tests/default_tasks_test/arm_test.py @@ -0,0 +1,178 @@ +"""Test default rastrigin using MAP Elites""" + +import functools + +import jax +import jax.numpy as jnp +import pytest + +from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.core.map_elites import MAPElites +from qdax.tasks.arm import arm_scoring_function, noisy_arm_scoring_function +from qdax.utils.metrics import default_qd_metrics + +scoring_functions = { + "arm": functools.partial(arm_scoring_function), + "noisy_arm": functools.partial( + noisy_arm_scoring_function, + fit_variance=0.1, + desc_variance=0.1, + params_variance=0.05, + ), +} + + +@pytest.mark.parametrize( + "task_name, batch_size", + [("arm", 1), ("noisy_arm", 10)], +) +def test_arm(task_name: str, batch_size: int) -> None: + seed = 42 + num_param_dimensions = 100 # num DoF arm + init_batch_size = 100 + batch_size = batch_size + num_iterations = 5 + grid_shape = (100, 100) + min_param = 0.0 + max_param = 1.0 + min_bd = 0.0 + max_bd = 1.0 + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init population of controllers + random_key, subkey = jax.random.split(random_key) + init_variables = jax.random.uniform( + subkey, + shape=(init_batch_size, num_param_dimensions), + minval=min_param, + maxval=max_param, + ) + + # Prepare the scoring function + scoring_fn = scoring_functions[task_name] + + # Define emitter + variation_fn = functools.partial( + isoline_variation, + iso_sigma=0.05, + line_sigma=0.1, + minval=min_param, + maxval=max_param, + ) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + + # Define a metrics function + metrics_fn = functools.partial( + default_qd_metrics, + qd_offset=0.0, + ) + + # Instantiate MAP-Elites + map_elites = MAPElites( + scoring_function=scoring_fn, + emitter=mixing_emitter, + metrics_function=metrics_fn, + ) + + # Compute the centroids + centroids = compute_euclidean_centroids( + grid_shape=grid_shape, + minval=min_bd, + maxval=max_bd, + ) + + # Compute initial repertoire + repertoire, emitter_state, random_key = map_elites.init( + init_variables, centroids, random_key + ) + + # Run the algorithm + (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + map_elites.scan_update, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None) + + +def test_arm_scoring_function() -> None: + + # Init a random key + seed = 42 + random_key = jax.random.PRNGKey(seed) + + # arm has xy BD centered at 0.5 0.5 and min max range is [0,1] + # 0 params of first genotype is horizontal and points towards negative x axis + # angles move in anticlockwise direction + genotypes_1 = jnp.ones(shape=(1, 4)) * 0.5 # 0.5 + genotypes_2 = jnp.zeros( + shape=(1, 6) + ) # zeros - this folds upon itself (if even number ends up at origin) + genotypes_3 = jnp.ones( + shape=(1, 10) + ) # ones - this also folds upon itself (if even number ends up at origin) + genotypes_4 = jnp.array([[0, 0.5]]) + genotypes_5 = jnp.array([[0.25, 0.5]]) + genotypes_6 = jnp.array([[0.5, 0.5]]) + genotypes_7 = jnp.array([[0.75, 0.5]]) + + fitness_1, descriptors_1, _, random_key = arm_scoring_function( + genotypes_1, random_key + ) + fitness_2, descriptors_2, _, random_key = arm_scoring_function( + genotypes_2, random_key + ) + fitness_3, descriptors_3, _, random_key = arm_scoring_function( + genotypes_3, random_key + ) + fitness_4, descriptors_4, _, random_key = arm_scoring_function( + genotypes_4, random_key + ) + fitness_5, descriptors_5, _, random_key = arm_scoring_function( + genotypes_5, random_key + ) + fitness_6, descriptors_6, _, random_key = arm_scoring_function( + genotypes_6, random_key + ) + fitness_7, descriptors_7, _, random_key = arm_scoring_function( + genotypes_7, random_key + ) + + # use rounding to avoid some numerical floating point errors + pytest.assume( + jnp.array_equal(jnp.around(descriptors_1, decimals=1), jnp.array([[1.0, 0.5]])) + ) + pytest.assume( + jnp.array_equal(jnp.around(descriptors_2, decimals=1), jnp.array([[0.5, 0.5]])) + ) + pytest.assume( + jnp.array_equal(jnp.around(descriptors_3, decimals=1), jnp.array([[0.5, 0.5]])) + ) + pytest.assume( + jnp.array_equal(jnp.around(descriptors_4, decimals=1), jnp.array([[0.0, 0.5]])) + ) + pytest.assume( + jnp.array_equal(jnp.around(descriptors_5, decimals=1), jnp.array([[0.5, 0.0]])) + ) + pytest.assume( + jnp.array_equal(jnp.around(descriptors_6, decimals=1), jnp.array([[1.0, 0.5]])) + ) + pytest.assume( + jnp.array_equal(jnp.around(descriptors_7, decimals=1), jnp.array([[0.5, 1.0]])) + ) + + +if __name__ == "__main__": + test_arm(task_name="arm", batch_size=128) + test_arm_scoring_function() diff --git a/tests/default_tasks_test/brax_task_test.py b/tests/default_tasks_test/brax_task_test.py new file mode 100644 index 00000000..992efaeb --- /dev/null +++ b/tests/default_tasks_test/brax_task_test.py @@ -0,0 +1,99 @@ +"""Tests MAP Elites implementation""" + +import functools + +import jax +import pytest + +import qdax.environments +from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.core.map_elites import MAPElites +from qdax.core.neuroevolution.mdp_utils import init_population_controllers +from qdax.tasks.brax_envs import create_default_brax_task_components +from qdax.utils.metrics import default_qd_metrics + + +@pytest.mark.parametrize( + "env_name, batch_size, is_task_reset_based", + [ + ("walker2d_uni", 5, False), + ("walker2d_uni", 5, True), + ], +) +def test_map_elites(env_name: str, batch_size: int, is_task_reset_based: bool) -> None: + batch_size = batch_size + env_name = env_name + episode_length = 100 + num_iterations = 5 + seed = 42 + num_init_cvt_samples = 1000 + num_centroids = 50 + min_bd = 0.0 + max_bd = 1.0 + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + env, policy_network, scoring_fn, random_key = create_default_brax_task_components( + env_name=env_name, + batch_size=batch_size, + random_key=random_key, + ) + + # Define emitter + variation_fn = functools.partial(isoline_variation, iso_sigma=0.05, line_sigma=0.1) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + + # Define metrics function + metrics_fn = functools.partial( + default_qd_metrics, + qd_offset=qdax.environments.reward_offset[env_name] * episode_length, + ) + + # Instantiate MAP-Elites + map_elites = MAPElites( + scoring_function=scoring_fn, + emitter=mixing_emitter, + metrics_function=metrics_fn, + ) + + # Compute the centroids + centroids, random_key = compute_cvt_centroids( + num_descriptors=env.behavior_descriptor_length, + num_init_cvt_samples=num_init_cvt_samples, + num_centroids=num_centroids, + minval=min_bd, + maxval=max_bd, + random_key=random_key, + ) + + # Init population of controllers + init_variables, random_key = init_population_controllers( + policy_network, env, batch_size, random_key + ) + + # Compute initial repertoire + repertoire, emitter_state, random_key = map_elites.init( + init_variables, centroids, random_key + ) + + # Run the algorithm + (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + map_elites.scan_update, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None) + + +if __name__ == "__main__": + test_map_elites(env_name="walker2d_uni", batch_size=10, is_task_reset_based=False) diff --git a/tests/default_tasks_test/hypervolume_functions_test.py b/tests/default_tasks_test/hypervolume_functions_test.py new file mode 100644 index 00000000..a390f709 --- /dev/null +++ b/tests/default_tasks_test/hypervolume_functions_test.py @@ -0,0 +1,116 @@ +"""Test default rastrigin using MAP Elites""" + +import functools + +import jax +import pytest + +from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.core.map_elites import MAPElites +from qdax.tasks.hypervolume_functions import ( + checkered_scoring_function, + continous_islands_scoring_function, + empty_circle_scoring_function, + non_continous_islands_scoring_function, + square_scoring_function, +) +from qdax.utils.metrics import default_qd_metrics + +scoring_functions = { + "square": square_scoring_function, + "checkered": checkered_scoring_function, + "empty_circle": empty_circle_scoring_function, + "non_continous_islands": non_continous_islands_scoring_function, + "continous_islands": continous_islands_scoring_function, +} + + +@pytest.mark.parametrize( + "task_name, batch_size", + [ + ("square", 1), + ("checkered", 10), + ("empty_circle", 20), + ("non_continous_islands", 30), + ("continous_islands", 40), + ], +) +def test_standard_functions(task_name: str, batch_size: int) -> None: + seed = 42 + num_param_dimensions = 2 + init_batch_size = 100 + batch_size = batch_size + num_iterations = 5 + grid_shape = (100, 100) + min_param = 0.0 + max_param = 1.0 + min_bd = 0.0 + max_bd = 1.0 + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init population of controllers + random_key, subkey = jax.random.split(random_key) + init_variables = jax.random.uniform( + subkey, shape=(init_batch_size, num_param_dimensions) + ) + + # Prepare the scoring function + scoring_fn = scoring_functions[task_name] + + # Define emitter + variation_fn = functools.partial( + isoline_variation, + iso_sigma=0.05, + line_sigma=0.1, + minval=min_param, + maxval=max_param, + ) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + + # Define a metrics function + metrics_fn = functools.partial( + default_qd_metrics, + qd_offset=0.0, + ) + + # Instantiate MAP-Elites + map_elites = MAPElites( + scoring_function=scoring_fn, + emitter=mixing_emitter, + metrics_function=metrics_fn, + ) + + # Compute the centroids + centroids = compute_euclidean_centroids( + grid_shape=grid_shape, + minval=min_bd, + maxval=max_bd, + ) + + # Compute initial repertoire + repertoire, emitter_state, random_key = map_elites.init( + init_variables, centroids, random_key + ) + + # Run the algorithm + (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + map_elites.scan_update, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None) + + +if __name__ == "__main__": + test_standard_functions(task_name="rastrigin", batch_size=128) diff --git a/tests/default_tasks_test/qd_suite_test.py b/tests/default_tasks_test/qd_suite_test.py new file mode 100644 index 00000000..a0542e9b --- /dev/null +++ b/tests/default_tasks_test/qd_suite_test.py @@ -0,0 +1,131 @@ +"""Test qd suite tasks using MAP Elites""" + +import functools +import math +from typing import Tuple + +import jax +import pytest + +from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.core.map_elites import MAPElites +from qdax.tasks.qd_suite import ( + archimedean_spiral_v0_angle_euclidean_task, + archimedean_spiral_v0_angle_geodesic_task, + archimedean_spiral_v0_arc_length_euclidean_task, + archimedean_spiral_v0_arc_length_geodesic_task, + deceptive_evolvability_v0_task, + ssf_v0_param_size_1_task, + ssf_v0_param_size_2_task, +) +from qdax.utils.metrics import default_qd_metrics + +task_dict = { + "archimedean_spiral_v0_angle_euclidean": archimedean_spiral_v0_angle_euclidean_task, + "archimedean_spiral_v0_angle_geodesic": archimedean_spiral_v0_angle_geodesic_task, + "archimedean_spiral_v0_arc_length_euclidean": archimedean_spiral_v0_arc_length_euclidean_task, # noqa: E501 + "archimedean_spiral_v0_arc_length_geodesic": archimedean_spiral_v0_arc_length_geodesic_task, # noqa: E501 + "deceptive_evolvability_v0": deceptive_evolvability_v0_task, + "ssf_v0_param_size_1": ssf_v0_param_size_1_task, + "ssf_v0_param_size_2": ssf_v0_param_size_2_task, +} + + +@pytest.mark.parametrize( + "task_name, batch_size", + [ + ("archimedean_spiral_v0_angle_euclidean", 1), + ("archimedean_spiral_v0_angle_geodesic", 10), + ("archimedean_spiral_v0_arc_length_euclidean", 128), + ("archimedean_spiral_v0_arc_length_geodesic", 30), + ("deceptive_evolvability_v0", 64), + ("ssf_v0_param_size_1", 256), + ("ssf_v0_param_size_2", 1), + ], +) +def test_qd_suite(task_name: str, batch_size: int) -> None: + seed = 42 + + # get task from parameterization for test + task = task_dict[task_name] + + init_batch_size = 100 + batch_size = batch_size + num_iterations = 5 + min_param, max_param = task.get_min_max_params() + min_bd, max_bd = task.get_bounded_min_max_descriptor() + bd_size = task.get_descriptor_size() + + grid_shape: Tuple[int, ...] + if bd_size == 1: + grid_shape = (100,) + elif bd_size == 2: + grid_shape = (100, 100) + else: + resolution_per_axis = math.floor(math.pow(10000.0, 1.0 / bd_size)) + grid_shape = tuple([resolution_per_axis for _ in range(bd_size)]) + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init population of parameters + init_variables = task.get_initial_parameters(init_batch_size) + + # Define scoring function + scoring_fn = task.scoring_function + + # Define emitter + variation_fn = functools.partial( + isoline_variation, + iso_sigma=0.05, + line_sigma=0.1, + minval=min_param, + maxval=max_param, + ) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + + # Define a metrics function + metrics_fn = functools.partial( + default_qd_metrics, + qd_offset=0.0, + ) + + # Instantiate MAP-Elites + map_elites = MAPElites( + scoring_function=scoring_fn, + emitter=mixing_emitter, + metrics_function=metrics_fn, + ) + + # Compute the centroids + centroids = compute_euclidean_centroids( + grid_shape=grid_shape, + minval=min_bd, + maxval=max_bd, + ) + + # Compute initial repertoire + repertoire, emitter_state, random_key = map_elites.init( + init_variables, centroids, random_key + ) + + # Run the algorithm + (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + map_elites.scan_update, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None) + + +if __name__ == "__main__": + test_qd_suite(task_name="archimedean_spiral_v0_angle_geodesic", batch_size=128) diff --git a/tests/default_tasks_test/standard_functions_test.py b/tests/default_tasks_test/standard_functions_test.py new file mode 100644 index 00000000..7b310389 --- /dev/null +++ b/tests/default_tasks_test/standard_functions_test.py @@ -0,0 +1,106 @@ +"""Test default rastrigin using MAP Elites""" + +import functools + +import jax +import pytest + +from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids +from qdax.core.emitters.mutation_operators import isoline_variation +from qdax.core.emitters.standard_emitters import MixingEmitter +from qdax.core.map_elites import MAPElites +from qdax.tasks.standard_functions import ( + rastrigin_proj_scoring_function, + rastrigin_scoring_function, + sphere_scoring_function, +) +from qdax.utils.metrics import default_qd_metrics + +scoring_functions = { + "rastrigin": functools.partial(rastrigin_scoring_function), + "sphere": functools.partial(sphere_scoring_function), + "rastrigin_proj": functools.partial(rastrigin_proj_scoring_function), +} + + +@pytest.mark.parametrize( + "task_name, batch_size", + [("rastrigin", 1), ("sphere", 10), ("rastrigin_proj", 20)], +) +def test_standard_functions(task_name: str, batch_size: int) -> None: + seed = 42 + num_param_dimensions = 100 + init_batch_size = 100 + batch_size = batch_size + num_iterations = 5 + grid_shape = (100, 100) + min_param = 0.0 + max_param = 1.0 + min_bd = min_param + max_bd = max_param + + # Init a random key + random_key = jax.random.PRNGKey(seed) + + # Init population of controllers + random_key, subkey = jax.random.split(random_key) + init_variables = jax.random.uniform( + subkey, shape=(init_batch_size, num_param_dimensions) + ) + + # Prepare the scoring function + scoring_fn = scoring_functions[task_name] + + # Define emitter + variation_fn = functools.partial( + isoline_variation, + iso_sigma=0.05, + line_sigma=0.1, + minval=min_param, + maxval=max_param, + ) + mixing_emitter = MixingEmitter( + mutation_fn=lambda x, y: (x, y), + variation_fn=variation_fn, + variation_percentage=1.0, + batch_size=batch_size, + ) + + # Define a metrics function + metrics_fn = functools.partial( + default_qd_metrics, + qd_offset=0.0, + ) + + # Instantiate MAP-Elites + map_elites = MAPElites( + scoring_function=scoring_fn, + emitter=mixing_emitter, + metrics_function=metrics_fn, + ) + + # Compute the centroids + centroids = compute_euclidean_centroids( + grid_shape=grid_shape, + minval=min_bd, + maxval=max_bd, + ) + + # Compute initial repertoire + repertoire, emitter_state, random_key = map_elites.init( + init_variables, centroids, random_key + ) + + # Run the algorithm + (repertoire, emitter_state, random_key,), metrics = jax.lax.scan( + map_elites.scan_update, + (repertoire, emitter_state, random_key), + (), + length=num_iterations, + ) + + pytest.assume(repertoire is not None) + + +if __name__ == "__main__": + test_standard_functions(task_name="rastrigin", batch_size=128) diff --git a/tests/environments_test/pointmaze_test.py b/tests/environments_test/pointmaze_test.py index a09d6177..76c96451 100644 --- a/tests/environments_test/pointmaze_test.py +++ b/tests/environments_test/pointmaze_test.py @@ -1,12 +1,14 @@ from typing import Any, Tuple import brax +import brax.envs import jax import pytest from brax import jumpy as jp import qdax from qdax.environments.pointmaze import PointMaze +from qdax.types import EnvState def test_pointmaze() -> None: @@ -61,8 +63,8 @@ def test_pointmaze() -> None: state = qd_env.reset(rng=jp.random_prngkey(seed=0)) @jax.jit - def run_n_steps(state: brax.envs.State) -> brax.envs.State: - def run_step(carry: brax.envs.State, _: Any) -> Tuple[brax.envs.State, Any]: + def run_n_steps(state: EnvState) -> EnvState: + def run_step(carry: Tuple[EnvState], _: Any) -> Tuple[Tuple[EnvState], Any]: (state,) = carry action = jp.zeros((qd_env.action_size,)) state = qd_env.step(state, action)