Skip to content

Commit

Permalink
feat(algorithms): add CMA-ME, fix CMA-ES and CMA-MEGA (#86)
Browse files Browse the repository at this point in the history
Add emitters introduced in CMA-ME, add tests, notebook, update the documentation, fix CMA-ES and fix CMA-MEGA
  • Loading branch information
felixchalumeau authored Nov 24, 2022
1 parent d7c6dc7 commit a5c19a2
Show file tree
Hide file tree
Showing 29 changed files with 2,089 additions and 128 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:
rev: 0.3.9
hooks:
- id: nbstripout
args: ["notebooks/"]
args: ["examples/"]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
Expand Down
29 changes: 15 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

QDax is a tool to accelerate Quality-Diversity (QD) and neuro-evolution algorithms through hardware accelerators and massive parallelization. QD algorithms usually take days/weeks to run on large CPU clusters. With QDax, QD algorithms can now be run in minutes! ⏩ ⏩ 🕛

QDax has been developed as a research framework: it is flexible and easy to extend and build on and can be used for any problem setting. Get started with simple example and run a QD algorithm in minutes here! [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mapelites_example.ipynb)
QDax has been developed as a research framework: it is flexible and easy to extend and build on and can be used for any problem setting. Get started with simple example and run a QD algorithm in minutes here! [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb)

- QDax [paper](https://arxiv.org/abs/2202.01258)
- QDax [documentation](https://qdax.readthedocs.io/en/latest/)
Expand All @@ -32,7 +32,7 @@ Installing QDax via ```pip``` installs a CPU-only version of JAX by default. To
However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/).

## Basic API Usage
For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./examples/notebooks/mapelites_example.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default).
For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./examples/mapelites.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default).

However, a summary of the main API usage is provided below:
```python
Expand Down Expand Up @@ -124,25 +124,26 @@ QDax currently supports the following algorithms:

| Algorithm | Example |
| --- | --- |
| [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mapelites_example.ipynb) |
| [CVT MAP-Elites](https://arxiv.org/abs/1610.05729) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mapelites_example.ipynb) |
| [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/pgame_example.ipynb) |
| [OMG-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/omgmega_example.ipynb) |
| [CMA-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/cmamega_example.ipynb) |
| [Multi-Objective MAP-Elites (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mome_example.ipynb) |
| [MAP-Elites Evolution Strategies (MEES)](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/mees_example.ipynb) |
| [MAP-Elites](https://arxiv.org/abs/1504.04909) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) |
| [CVT MAP-Elites](https://arxiv.org/abs/1610.05729) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mapelites.ipynb) |
| [Policy Gradient Assisted MAP-Elites (PGA-ME)](https://hal.archives-ouvertes.fr/hal-03135723v2/file/PGA_MAP_Elites_GECCO.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/pgame.ipynb) |
| [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmame.ipynb) |
| [OMG-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/omgmega.ipynb) |
| [CMA-MEGA](https://arxiv.org/abs/2106.03894) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/cmamega.ipynb) |
| [Multi-Objective MAP-Elites (MOME)](https://arxiv.org/abs/2202.03057) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mome.ipynb) |
| [MAP-Elites Evolution Strategies (MEES)](https://dl.acm.org/doi/pdf/10.1145/3377930.3390217) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/mees.ipynb) |


## QDax baseline algorithms
The QDax library also provides implementations for some useful baseline algorithms:

| Algorithm | Example |
| --- | --- |
| [DIAYN](https://arxiv.org/abs/1802.06070) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/diayn_example.ipynb) |
| [DADS](https://arxiv.org/abs/1907.01657) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/dads_example.ipynb) |
| [SMERL](https://arxiv.org/abs/2010.14484) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/smerl_example.ipynb) |
| [NSGA2](https://ieeexplore.ieee.org/document/996017) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) |
| [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) |
| [DIAYN](https://arxiv.org/abs/1802.06070) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/diayn.ipynb) |
| [DADS](https://arxiv.org/abs/1907.01657) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/dads.ipynb) |
| [SMERL](https://arxiv.org/abs/2010.14484) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/smerl.ipynb) |
| [NSGA2](https://ieeexplore.ieee.org/document/996017) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/nsga2_spea2.ipynb) |
| [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/examples/nsga2_spea2.ipynb) |

## QDax Tasks
The QDax library also provides numerous implementations for several standard Quality-Diversity tasks.
Expand Down
13 changes: 13 additions & 0 deletions docs/api_documentation/core/cmame.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Covariance Matrix Adaptation MAP Elites (CMAME)

To create an instance of CMAME, one need to use an instance of [MAP-Elites](map_elites.md) with the desired CMA Emitter - optimizing, random direction, improvement - detailed below.To use the pool of emitter mechanism, use the CMAPoolEmitter.

Three emitter types:

::: qdax.core.emitters.cma_emitter.CMAEmitter
::: qdax.core.emitters.cma_rnd_emitter.CMARndEmitter
::: qdax.core.emitters.cma_opt_emitter.CMAOptimizingEmitter

Pool of homogeneous emitters:

::: qdax.core.emitters.cma_pool_emitter.CMAPoolEmitter
1 change: 1 addition & 0 deletions docs/examples
1 change: 0 additions & 1 deletion docs/notebooks

This file was deleted.

308 changes: 308 additions & 0 deletions examples/cmaes.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "222bbe00",
"metadata": {},
"source": [
"# Optimizing with CMA-ES in Jax\n",
"\n",
"This notebook shows how to use QDax to find performing parameters on Rastrigin and Sphere problems with [CMA-ES](https://arxiv.org/pdf/1604.00772.pdf). It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n",
"\n",
"- how to define the problem\n",
"- how to create a CMA-ES optimizer\n",
"- how to launch a certain number of optimizing steps\n",
"- how to visualise the optimization process"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d731f067",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Ellipse\n",
"\n",
"from qdax.core.cmaes import CMAES"
]
},
{
"cell_type": "markdown",
"id": "7b6e910b",
"metadata": {},
"source": [
"## Set the hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "404fb0dc",
"metadata": {},
"outputs": [],
"source": [
"#@title Hyperparameters\n",
"#@markdown ---\n",
"num_iterations = 1000 #@param {type:\"integer\"}\n",
"num_dimensions = 100 #@param {type:\"integer\"}\n",
"batch_size = 36 #@param {type:\"integer\"}\n",
"num_best = 18 #@param {type:\"integer\"}\n",
"sigma_g = 0.5 # 0.5 #@param {type:\"number\"}\n",
"minval = -5.12 #@param {type:\"number\"}\n",
"optim_problem = \"sphere\" #@param[\"rastrigin\", \"sphere\"]\n",
"#@markdown ---"
]
},
{
"cell_type": "markdown",
"id": "ccc7cbeb",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Define the fitness function - choose rastrigin or sphere"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "436dccbb",
"metadata": {},
"outputs": [],
"source": [
"def rastrigin_scoring(x: jnp.ndarray):\n",
" first_term = 10 * x.shape[-1]\n",
" second_term = jnp.sum((x + minval * 0.4) ** 2 - 10 * jnp.cos(2 * jnp.pi * (x + minval * 0.4)))\n",
" return -(first_term + second_term)\n",
"\n",
"def sphere_scoring(x: jnp.ndarray):\n",
" return -jnp.sum((x + minval * 0.4) * (x + minval * 0.4), axis=-1)\n",
"\n",
"if optim_problem == \"sphere\":\n",
" fitness_fn = sphere_scoring\n",
"elif optim_problem == \"rastrigin\":\n",
" fitness_fn = jax.vmap(rastrigin_scoring)\n",
"else:\n",
" raise Exception(\"Invalid opt function name given\")"
]
},
{
"cell_type": "markdown",
"id": "62bdd2a4",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Define a CMA-ES optimizer instance"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cf03f55",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"cmaes = CMAES(\n",
" population_size=batch_size,\n",
" num_best=num_best,\n",
" search_dim=num_dimensions,\n",
" fitness_function=fitness_fn,\n",
" mean_init=jnp.zeros((num_dimensions,)),\n",
" init_sigma=sigma_g,\n",
" delay_eigen_decomposition=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f1f69f50",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Init the CMA-ES optimizer state"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a95b74d",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"state = cmaes.init()\n",
"random_key = jax.random.PRNGKey(0)"
]
},
{
"cell_type": "markdown",
"id": "ac2d5c0d",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Run optimization iterations"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "363198ca",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"%%time\n",
"\n",
"means = [state.mean]\n",
"covs = [(state.sigma**2) * state.cov_matrix]\n",
"\n",
"iteration_count = 0\n",
"for _ in range(num_iterations):\n",
" iteration_count += 1\n",
" \n",
" # sample\n",
" samples, random_key = cmaes.sample(state, random_key)\n",
" \n",
" # udpate\n",
" state = cmaes.update(state, samples)\n",
" \n",
" # check stop condition\n",
" stop_condition = cmaes.stop_condition(state)\n",
"\n",
" if stop_condition:\n",
" break\n",
" \n",
" # store data for plotting\n",
" means.append(state.mean)\n",
" covs.append((state.sigma**2) * state.cov_matrix)\n",
" \n",
"print(\"Num iterations before stop condition: \", iteration_count)"
]
},
{
"cell_type": "markdown",
"id": "0e5820b8",
"metadata": {},
"source": [
"## Check final fitnesses and distribution mean"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e4a2c7b",
"metadata": {},
"outputs": [],
"source": [
"# checking final fitness values\n",
"fitnesses = fitness_fn(samples)\n",
"\n",
"print(\"Min fitness in the final population: \", jnp.min(fitnesses))\n",
"print(\"Mean fitness in the final population: \", jnp.mean(fitnesses))\n",
"print(\"Max fitness in the final population: \", jnp.max(fitnesses))\n",
"\n",
"# checking mean of the final distribution\n",
"print(\"Final mean of the distribution: \\n\", means[-1])\n",
"# print(\"Final covariance matrix of the distribution: \", covs[-1])"
]
},
{
"cell_type": "markdown",
"id": "f3bd2b0f",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## Visualization of the optimization trajectory"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad85551c",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"fig, ax = plt.subplots(figsize=(12, 6))\n",
"\n",
"# sample points to show fitness landscape\n",
"random_key, subkey = jax.random.split(random_key)\n",
"x = jax.random.uniform(subkey, minval=-4, maxval=8, shape=(100000, 2))\n",
"f_x = fitness_fn(x)\n",
"\n",
"# plot fitness landscape\n",
"points = ax.scatter(x[:, 0], x[:, 1], c=f_x, s=0.1)\n",
"fig.colorbar(points)\n",
"\n",
"# plot cma-es trajectory\n",
"traj_min = 0\n",
"traj_max = iteration_count\n",
"for mean, cov in zip(means[traj_min:traj_max], covs[traj_min:traj_max]):\n",
" ellipse = Ellipse((mean[0], mean[1]), cov[0, 0], cov[1, 1], fill=False, color='k', ls='--')\n",
" ax.add_patch(ellipse)\n",
" ax.plot(mean[0], mean[1], color='k', marker='x')\n",
" \n",
"ax.set_title(f\"Optimization trajectory of CMA-ES between step {traj_min} and step {traj_max}\")\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"vscode": {
"interpreter": {
"hash": "9ae46cf6a59eb5e192bc4f27fbb5c33d8a30eb9acb43edbb510eeaf7c819ab64"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit a5c19a2

Please sign in to comment.