Skip to content

Commit

Permalink
Add extra_scores to mome repertoire as well
Browse files Browse the repository at this point in the history
  • Loading branch information
manon-but-yes committed Nov 28, 2022
1 parent edc9060 commit 3c8285e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
10 changes: 8 additions & 2 deletions qdax/core/containers/mome_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

from functools import partial
from typing import Any, Tuple
from typing import Any, Optional, Tuple

import jax
import jax.numpy as jnp
Expand All @@ -17,6 +17,7 @@
from qdax.types import (
Centroid,
Descriptor,
ExtraScores,
Fitness,
Genotype,
Mask,
Expand Down Expand Up @@ -259,6 +260,7 @@ def add(
batch_of_genotypes: Genotype,
batch_of_descriptors: Descriptor,
batch_of_fitnesses: Fitness,
batch_of_extra_scores: Optional[ExtraScores] = None,
) -> MOMERepertoire:
"""Insert a batch of elements in the repertoire.
Expand All @@ -274,6 +276,8 @@ def add(
trying to add to the repertoire.
batch_of_fitnesses: the fitnesses of the genotypes we are trying
to add to the repertoire.
batch_of_extra_scores: unused tree that contains the extra_scores of
aforementioned genotypes.
Returns:
The updated repertoire with potential new individuals.
Expand Down Expand Up @@ -355,6 +359,7 @@ def init( # type: ignore
descriptors: Descriptor,
centroids: Centroid,
pareto_front_max_length: int,
extra_scores: Optional[ExtraScores] = None,
) -> MOMERepertoire:
"""
Initialize a Multi Objective Map-Elites repertoire with an initial population
Expand All @@ -373,6 +378,7 @@ def init( # type: ignore
of shape (batch_size, num_descriptors)
centroids: tesselation centroids of shape (batch_size, num_descriptors)
pareto_front_max_length: maximum size of the pareto fronts
extra_scores: unused extra_scores of the initial genotypes
Returns:
An initialized MAP-Elite repertoire
Expand Down Expand Up @@ -410,7 +416,7 @@ def init( # type: ignore
)

# add first batch of individuals in the repertoire
new_repertoire = repertoire.add(genotypes, descriptors, fitnesses)
new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)

return new_repertoire # type: ignore

Expand Down
1 change: 1 addition & 0 deletions qdax/core/mome.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def init(
descriptors=descriptors,
centroids=centroids,
pareto_front_max_length=pareto_front_max_length,
extra_scores=extra_scores,
)

# get initial state of the emitter
Expand Down

0 comments on commit 3c8285e

Please sign in to comment.