Skip to content

Commit

Permalink
feat: Default Scoring Functions for Sphere, Rastrigin, Arm, Brax envi…
Browse files Browse the repository at this point in the history
…ronments, Hypervolume functions and QD Suite (#73)

*  adding sphere, rastrigin, arm, and noisy_arm scoring functions

*  adding possibility to create default scoring_function for brax environments.

*  changing type of EnvState to brax.envs.State

*  fix some quick typing inconsistencies throughout the code.

*  update README to include the default functions and make it a usable 

*  create an examples directory to include scripts and notebooks
  • Loading branch information
Lookatator authored Oct 13, 2022
1 parent ab4d4ca commit 13272b0
Show file tree
Hide file tree
Showing 34 changed files with 2,385 additions and 167 deletions.
70 changes: 66 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,75 @@ Installing QDax via ```pip``` installs a CPU-only version of JAX by default. To
However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/).

## Basic API Usage
For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./notebooks/mapelites_example.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default).
For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./examples/notebooks/mapelites_example.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default).

However, a summary of the main API usage is provided below:
```python
import qdax
import jax
import functools
from qdax.core.map_elites import MAPElites
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from qdax.tasks.arm import arm_scoring_function
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.metrics import default_qd_metrics

seed = 42
num_param_dimensions = 100 # num DoF arm
init_batch_size = 100
batch_size = 1024
num_iterations = 50
grid_shape = (100, 100)
min_param = 0.0
max_param = 1.0
min_bd = 0.0
max_bd = 1.0

# Init a random key
random_key = jax.random.PRNGKey(seed)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
init_variables = jax.random.uniform(
subkey,
shape=(init_batch_size, num_param_dimensions),
minval=min_param,
maxval=max_param,
)

# Define emitter
variation_fn = functools.partial(
isoline_variation,
iso_sigma=0.05,
line_sigma=0.1,
minval=min_param,
maxval=max_param,
)
mixing_emitter = MixingEmitter(
mutation_fn=lambda x, y: (x, y),
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=batch_size,
)

# Define a metrics function
metrics_fn = functools.partial(
default_qd_metrics,
qd_offset=0.0,
)

# Instantiate MAP-Elites
map_elites = MAPElites(
scoring_function=scoring_fn,
scoring_function=arm_scoring_function,
emitter=mixing_emitter,
metrics_function=metrics_function,
metrics_function=metrics_fn,
)

# Compute the centroids
centroids = compute_euclidean_centroids(
grid_shape=grid_shape,
minval=min_bd,
maxval=max_bd,
)

# Initializes repertoire and emitter state
Expand Down Expand Up @@ -81,6 +138,11 @@ The QDax library also provides implementations for some useful baseline algorith
| [NSGA2](https://ieeexplore.ieee.org/document/996017) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) |
| [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) |

## QDax Tasks
The QDax library also provides numerous implementations for several standard Quality-Diversity tasks.

All those implementations, and their descriptions are provided in the [tasks directory](./qdax/tasks).

## Contributing
Issues and contributions are welcome. Please refer to the [contribution guide](https://qdax.readthedocs.io/en/latest/guides/CONTRIBUTING/) in the documentation for more details.

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"from qdax.core.map_elites import MAPElites\n",
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire\n",
"from qdax import environments\n",
"from qdax.core.neuroevolution.mdp_utils import scoring_function\n",
"from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function\n",
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
"from qdax.core.neuroevolution.networks.networks import MLP\n",
"from qdax.core.emitters.mutation_operators import isoline_variation\n",
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"from qdax.core.map_elites import MAPElites\n",
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n",
"from qdax import environments\n",
"from qdax.core.neuroevolution.mdp_utils import scoring_function\n",
"from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function\n",
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
"from qdax.core.neuroevolution.networks.networks import MLP\n",
"from qdax.core.emitters.mutation_operators import isoline_variation\n",
Expand Down
File renamed without changes.
103 changes: 103 additions & 0 deletions examples/scripts/me_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import functools

import jax
import matplotlib.pyplot as plt

from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.core.map_elites import MAPElites
from qdax.tasks.arm import arm_scoring_function
from qdax.utils.metrics import default_qd_metrics
from qdax.utils.plotting import plot_2d_map_elites_repertoire


def run_me() -> None:
seed = 42
num_param_dimensions = 8 # num DoF arm
init_batch_size = 100
batch_size = 2048
num_evaluations = int(1e6)
num_iterations = num_evaluations // batch_size
grid_shape = (100, 100)
min_param = 0.0
max_param = 1.0
min_bd = 0.0
max_bd = 1.0

# Init a random key
random_key = jax.random.PRNGKey(seed)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
init_variables = jax.random.uniform(
subkey,
shape=(init_batch_size, num_param_dimensions),
minval=min_param,
maxval=max_param,
)

# Define emitter
variation_fn = functools.partial(
isoline_variation,
iso_sigma=0.005,
line_sigma=0,
minval=min_param,
maxval=max_param,
)
mixing_emitter = MixingEmitter(
mutation_fn=lambda x, y: (x, y),
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=batch_size,
)

# Define a metrics function
metrics_fn = functools.partial(
default_qd_metrics,
qd_offset=0.0,
)

# Instantiate MAP-Elites
map_elites = MAPElites(
scoring_function=arm_scoring_function,
emitter=mixing_emitter,
metrics_function=metrics_fn,
)

# Compute the centroids
centroids = compute_euclidean_centroids(
grid_shape=grid_shape,
minval=min_bd,
maxval=max_bd,
)

# Initializes repertoire and emitter state
repertoire, emitter_state, random_key = map_elites.init(
init_variables, centroids, random_key
)

# Run MAP-Elites loop
for _ in range(num_iterations):
(repertoire, emitter_state, metrics, random_key,) = map_elites.update(
repertoire,
emitter_state,
random_key,
)

# plot archive
fig, axes = plot_2d_map_elites_repertoire(
centroids=repertoire.centroids,
repertoire_fitnesses=repertoire.fitnesses,
minval=min_bd,
maxval=max_bd,
repertoire_descriptors=repertoire.descriptors,
# vmin=-0.2,
# vmax=0.0,
)

plt.show()


if __name__ == "__main__":
run_me()
159 changes: 33 additions & 126 deletions qdax/core/neuroevolution/mdp_utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,15 @@
from functools import partial
from typing import Any, Callable, Tuple

import brax
import brax.envs
import flax.linen as nn
import jax
import jax.numpy as jnp
from brax.envs import State as EnvState
from flax.struct import PyTreeNode

from qdax.core.neuroevolution.buffers.buffer import (
QDTransition,
ReplayBuffer,
Transition,
)
from qdax.types import (
Descriptor,
ExtraScores,
Fitness,
Genotype,
Metrics,
Params,
RNGKey,
)
from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition
from qdax.types import Genotype, Metrics, Params, RNGKey


class TrainingState(PyTreeNode):
Expand Down Expand Up @@ -125,116 +114,6 @@ def _scan_play_step_fn(
return state, transitions


@partial(
jax.jit,
static_argnames=(
"episode_length",
"play_step_fn",
"behavior_descriptor_extractor",
),
)
def scoring_function(
policies_params: Genotype,
random_key: RNGKey,
init_states: brax.envs.State,
episode_length: int,
play_step_fn: Callable[
[EnvState, Params, RNGKey, brax.envs.Env],
Tuple[EnvState, Params, RNGKey, QDTransition],
],
behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor],
) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
"""Evaluates policies contained in policies_params in parallel in
deterministic or pseudo-deterministic environments.
This rollout is only deterministic when all the init states are the same.
If the init states are fixed but different, as a policy is not necessarly
evaluated with the same environment everytime, this won't be determinist.
When the init states are different, this is not purely stochastic.
"""

# Perform rollouts with each policy
random_key, subkey = jax.random.split(random_key)
unroll_fn = partial(
generate_unroll,
episode_length=episode_length,
play_step_fn=play_step_fn,
random_key=subkey,
)

_final_state, data = jax.vmap(unroll_fn)(init_states, policies_params)

# create a mask to extract data properly
is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1)
mask = jnp.roll(is_done, 1, axis=1)
mask = mask.at[:, 0].set(0)

# Scores - add offset to ensure positive fitness (through positive rewards)
fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1)
descriptors = behavior_descriptor_extractor(data, mask)

return (
fitnesses,
descriptors,
{
"transitions": data,
},
random_key,
)


@partial(
jax.jit,
static_argnames=(
"episode_length",
"play_reset_fn",
"play_step_fn",
"behavior_descriptor_extractor",
),
)
def reset_based_scoring_function(
policies_params: Genotype,
random_key: RNGKey,
episode_length: int,
play_reset_fn: Callable[[RNGKey], brax.envs.State],
play_step_fn: Callable[
[brax.envs.State, Params, RNGKey, brax.envs.Env],
Tuple[brax.envs.State, Params, RNGKey, QDTransition],
],
behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor],
) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
"""Evaluates policies contained in policies_params in parallel.
The play_reset_fn function allows for a more general scoring_function that can be
called with different batch-size and not only with a batch-size of the same
dimension as init_states.
To define purely stochastic environments, using the reset function from the
environment, use "play_reset_fn = env.reset".
To define purely deterministic environments, as in "scoring_function", generate
a single init_state using "init_state = env.reset(random_key)", then use
"play_reset_fn = lambda random_key: init_state".
"""

random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(
subkey, jax.tree_util.tree_leaves(policies_params)[0].shape[0]
)
reset_fn = jax.vmap(play_reset_fn)
init_states = reset_fn(keys)

fitnesses, descriptors, extra_scores, random_key = scoring_function(
policies_params=policies_params,
random_key=random_key,
init_states=init_states,
episode_length=episode_length,
play_step_fn=play_step_fn,
behavior_descriptor_extractor=behavior_descriptor_extractor,
)

return (fitnesses, descriptors, extra_scores, random_key)


@partial(
jax.jit,
static_argnames=(
Expand Down Expand Up @@ -316,4 +195,32 @@ def mask_episodes(x: jnp.ndarray) -> jnp.ndarray:
# the double transpose trick is here to allow easy broadcasting
return jnp.where(mask.T, x.T, jnp.nan * jnp.ones_like(x).T).T

return jax.tree_util.tree_map(mask_episodes, transition) # type: ignore
return jax.tree_map(mask_episodes, transition) # type: ignore


def init_population_controllers(
policy_network: nn.Module,
env: brax.envs.Env,
batch_size: int,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Initializes the population of controllers using a policy_network.
Args:
policy_network: The policy network structure used for creating policy
controllers.
env: the BRAX environment.
batch_size: the number of environments we play simultaneously.
random_key: a JAX random key.
Returns:
A tuple of the initial population and the new random key.
"""
random_key, subkey = jax.random.split(random_key)

keys = jax.random.split(subkey, num=batch_size)
fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

return init_variables, random_key
Loading

0 comments on commit 13272b0

Please sign in to comment.