Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DCG-MAP-Elites #167

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/distributed_mapelites.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@
"repertoire, emitter_state, random_key = map_elites.get_distributed_init_fn(\n",
" centroids=centroids,\n",
" devices=devices,\n",
")(init_genotypes=init_variables, random_key=random_key)"
")(genotypes=init_variables, random_key=random_key)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/me_sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@
"# initialize map-elites\n",
"repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n",
" devices=devices, centroids=centroids\n",
")(init_genotypes=training_states, random_key=keys)"
")(genotypes=training_states, random_key=keys)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/me_td3_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@
"# initialize map-elites\n",
"repertoire, emitter_state, keys = map_elites.get_distributed_init_fn(\n",
" devices=devices, centroids=centroids\n",
")(init_genotypes=training_states, random_key=keys)"
")(genotypes=training_states, random_key=keys)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions examples/mome.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@
"# initial population\n",
"random_key = jax.random.PRNGKey(42)\n",
"random_key, subkey = jax.random.split(random_key)\n",
"init_genotypes = jax.random.uniform(\n",
"genotypes = jax.random.uniform(\n",
" random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n",
")\n",
"\n",
Expand Down Expand Up @@ -303,7 +303,7 @@
"outputs": [],
"source": [
"repertoire, emitter_state, random_key = mome.init(\n",
" init_genotypes,\n",
" genotypes,\n",
" centroids,\n",
" pareto_front_max_length,\n",
" random_key\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/nsga2_spea2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
"# Initial population\n",
"random_key = jax.random.PRNGKey(0)\n",
"random_key, subkey = jax.random.split(random_key)\n",
"init_genotypes = jax.random.uniform(\n",
"genotypes = jax.random.uniform(\n",
" subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n",
")\n",
"\n",
Expand Down Expand Up @@ -238,7 +238,7 @@
"\n",
"# init nsga2\n",
"repertoire, emitter_state, random_key = nsga2.init(\n",
" init_genotypes,\n",
" genotypes,\n",
" population_size,\n",
" random_key\n",
")"
Expand Down Expand Up @@ -303,7 +303,7 @@
"\n",
"# init spea2\n",
"repertoire, emitter_state, random_key = spea2.init(\n",
" init_genotypes,\n",
" genotypes,\n",
" population_size,\n",
" num_neighbours,\n",
" random_key\n",
Expand Down
21 changes: 8 additions & 13 deletions qdax/baselines/genetic_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def __init__(

@partial(jax.jit, static_argnames=("self", "population_size"))
def init(
self, init_genotypes: Genotype, population_size: int, random_key: RNGKey
self, genotypes: Genotype, population_size: int, random_key: RNGKey
) -> Tuple[GARepertoire, Optional[EmitterState], RNGKey]:
"""Initialize a GARepertoire with an initial population of genotypes.

Args:
init_genotypes: the initial population of genotypes
genotypes: the initial population of genotypes
population_size: the maximal size of the repertoire
random_key: a random key to handle stochastic operations

Expand All @@ -54,26 +54,21 @@ def init(

# score initial genotypes
fitnesses, extra_scores, random_key = self._scoring_function(
init_genotypes, random_key
genotypes, random_key
)

# init the repertoire
repertoire = GARepertoire.init(
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
population_size=population_size,
)

# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
)

# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
random_key=random_key,
repertoire=repertoire,
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=None,
extra_scores=extra_scores,
Expand Down Expand Up @@ -108,7 +103,7 @@ def update(
"""

# generate offsprings
genotypes, random_key = self._emitter.emit(
genotypes, extra_info, random_key = self._emitter.emit(
repertoire, emitter_state, random_key
)

Expand All @@ -127,7 +122,7 @@ def update(
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=None,
extra_scores=extra_scores,
extra_scores={**extra_scores, **extra_info},
)

# update the metrics
Expand Down
15 changes: 10 additions & 5 deletions qdax/baselines/nsga2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,36 @@ class NSGA2(GeneticAlgorithm):

@partial(jax.jit, static_argnames=("self", "population_size"))
def init(
self, init_genotypes: Genotype, population_size: int, random_key: RNGKey
self, genotypes: Genotype, population_size: int, random_key: RNGKey
) -> Tuple[NSGA2Repertoire, Optional[EmitterState], RNGKey]:

# score initial genotypes
fitnesses, extra_scores, random_key = self._scoring_function(
init_genotypes, random_key
genotypes, random_key
)

# init the repertoire
repertoire = NSGA2Repertoire.init(
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
population_size=population_size,
)

# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
random_key=random_key,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=None,
extra_scores=extra_scores,
)

# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
extra_scores=extra_scores,
)
Expand Down
15 changes: 10 additions & 5 deletions qdax/baselines/spea2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,40 @@ class SPEA2(GeneticAlgorithm):
)
def init(
self,
init_genotypes: Genotype,
genotypes: Genotype,
population_size: int,
num_neighbours: int,
random_key: RNGKey,
) -> Tuple[SPEA2Repertoire, Optional[EmitterState], RNGKey]:

# score initial genotypes
fitnesses, extra_scores, random_key = self._scoring_function(
init_genotypes, random_key
genotypes, random_key
)

# init the repertoire
repertoire = SPEA2Repertoire.init(
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
population_size=population_size,
num_neighbours=num_neighbours,
)

# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
random_key=random_key,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=None,
extra_scores=extra_scores,
)

# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
extra_scores=extra_scores,
)
Expand Down
24 changes: 11 additions & 13 deletions qdax/core/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def container_size_control(

def init(
self,
init_genotypes: Genotype,
genotypes: Genotype,
aurora_extra_info: AuroraExtraInfo,
l_value: jnp.ndarray,
max_size: int,
Expand All @@ -128,7 +128,7 @@ def init(
genotypes. Also performs the first training of the AURORA encoder.

Args:
init_genotypes: initial genotypes, pytree in which leaves
genotypes: initial genotypes, pytree in which leaves
have shape (batch_size, num_features)
aurora_extra_info: information to perform AURORA encodings,
such as the encoder parameters
Expand All @@ -141,7 +141,7 @@ def init(
the emitter, and the updated information to perform AURORA encodings
"""
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
init_genotypes,
genotypes,
random_key,
)

Expand All @@ -150,7 +150,7 @@ def init(
descriptors = self._encoder_fn(observations, aurora_extra_info)

repertoire = UnstructuredRepertoire.init(
genotypes=init_genotypes,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
observations=observations,
Expand All @@ -160,13 +160,9 @@ def init(

# get initial state of the emitter
emitter_state, random_key = self._emitter.init(
init_genotypes=init_genotypes, random_key=random_key
)

# update emitter state
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
genotypes=init_genotypes,
random_key=random_key,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
Expand Down Expand Up @@ -208,9 +204,10 @@ def update(
a new key
"""
# generate offsprings with the emitter
genotypes, random_key = self._emitter.emit(
genotypes, extra_info, random_key = self._emitter.emit(
repertoire, emitter_state, random_key
)

# scores the offsprings
fitnesses, descriptors, extra_scores, random_key = self._scoring_function(
genotypes,
Expand All @@ -232,10 +229,11 @@ def update(
# update emitter state after scoring is made
emitter_state = self._emitter.state_update(
emitter_state=emitter_state,
repertoire=repertoire,
genotypes=genotypes,
fitnesses=fitnesses,
descriptors=descriptors,
extra_scores=extra_scores,
extra_scores=extra_scores | extra_info,
)

# update the metrics
Expand Down
32 changes: 32 additions & 0 deletions qdax/core/containers/mapelites_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,38 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey

return samples, random_key

@partial(jax.jit, static_argnames=("num_samples",))
def sample_with_descs(
self,
random_key: RNGKey,
num_samples: int,
) -> Tuple[Genotype, Descriptor, RNGKey]:
"""Sample elements in the repertoire.

Args:
random_key: a jax PRNG random key
num_samples: the number of elements to be sampled

Returns:
samples: a batch of genotypes sampled in the repertoire
random_key: an updated jax PRNG random key
"""

repertoire_empty = self.fitnesses == -jnp.inf
p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)

random_key, subkey = jax.random.split(random_key)
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.genotypes,
)
descs = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.descriptors,
)

return samples, descs, random_key

@jax.jit
def add(
self,
Expand Down
Loading
Loading