Skip to content

Commit

Permalink
feat: extend functionality of compute_euclidean_centroids (#40)
Browse files Browse the repository at this point in the history
Extend the functionality of compute_euclidean_centroids to be more general to any dimension grid and any number of centroids per dimension

Co-authored-by: Felix Chalumeau <f.chalumeau@instadeep.com>
  • Loading branch information
limbryan and felixchalumeau authored Jul 5, 2022
1 parent fc78101 commit 1a434a1
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 22 deletions.
7 changes: 5 additions & 2 deletions notebooks/omgmega_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"source": [
"import jax \n",
"import jax.numpy as jnp\n",
"import math\n",
"\n",
"try:\n",
" import qdax\n",
Expand Down Expand Up @@ -178,9 +179,11 @@
"# defines the population\n",
"random_key, subkey = jax.random.split(random_key)\n",
"initial_population = jax.random.uniform(subkey, shape=(init_population_size, num_dimensions))\n",
"\n",
"sqrt_centroids = int(math.sqrt(num_centroids)) # 2-D grid \n",
"grid_shape = (sqrt_centroids, sqrt_centroids)\n",
"centroids = compute_euclidean_centroids(\n",
" num_descriptors = num_descriptors,\n",
" num_centroids = num_centroids,\n",
" grid_shape = grid_shape,\n",
" minval = minval,\n",
" maxval = maxval\n",
") \n",
Expand Down
30 changes: 15 additions & 15 deletions qdax/core/containers/repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from __future__ import annotations

import math
from functools import partial
from typing import Callable, List, Tuple, Union

Expand Down Expand Up @@ -66,36 +65,37 @@ def compute_cvt_centroids(


def compute_euclidean_centroids(
num_descriptors: int,
num_centroids: int,
grid_shape: Tuple[int, ...],
minval: Union[float, List[float]],
maxval: Union[float, List[float]],
) -> jnp.ndarray:
"""
Compute centroids for square Euclidean tesselation.
Args:
num_descriptors: number od scalar descriptors
num_centroids: number of centroids
grid_shape: number of centroids per BD dimension
minval: minimum descriptors value
maxval: maximum descriptors value
Returns:
the centroids with shape (num_centroids, num_descriptors)
"""
if num_descriptors != 2:
raise NotImplementedError("This function supports 2 descriptors only for now.")

sqrt_centroids = math.sqrt(num_centroids)
# get number of descriptors
num_descriptors = len(grid_shape)

if math.floor(sqrt_centroids) != sqrt_centroids:
raise ValueError("Num centroids should be a squared number.")
# prepare list of linspaces
linspace_list = []
for num_centroids_in_dim in grid_shape:
offset = 1 / (2 * num_centroids_in_dim)
linspace = jnp.linspace(offset, 1.0 - offset, num_centroids_in_dim)
linspace_list.append(linspace)

offset = 1 / (2 * int(sqrt_centroids))
meshes = jnp.meshgrid(*linspace_list, sparse=False)

linspace = jnp.linspace(offset, 1.0 - offset, int(sqrt_centroids))
meshes = jnp.meshgrid(linspace, linspace, sparse=False)
centroids = jnp.stack([jnp.ravel(meshes[0]), jnp.ravel(meshes[1])], axis=-1)
# create centroids
centroids = jnp.stack(
[jnp.ravel(meshes[i]) for i in range(num_descriptors)], axis=-1
)
minval = jnp.array(minval)
maxval = jnp.array(maxval)
return jnp.asarray(centroids) * (maxval - minval) + minval
Expand Down
8 changes: 5 additions & 3 deletions tests/core_test/containers_test/repertoire_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ def test_repertoire() -> None:
batch_size = 2
genotype_size = 12
num_centroids = 4
num_descriptors = 2
grid_shape = (2, 2)

# get num descriptors from grid shape
num_descriptors = len(grid_shape)

centroids = compute_euclidean_centroids(
num_descriptors=num_descriptors,
num_centroids=num_centroids,
grid_shape=grid_shape,
minval=0.0,
maxval=1.0,
)
Expand Down
7 changes: 5 additions & 2 deletions tests/core_test/omgmega_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Dict, Tuple

import jax
Expand Down Expand Up @@ -86,9 +87,11 @@ def metrics_fn(repertoire: MapElitesRepertoire) -> Dict[str, jnp.ndarray]:
# defines the population
random_key, subkey = jax.random.split(random_key)
initial_population = jax.random.uniform(subkey, shape=(100, num_dimensions))

sqrt_centroids = int(math.sqrt(num_centroids)) # 2-D grid
grid_shape = (sqrt_centroids, sqrt_centroids)
centroids = compute_euclidean_centroids(
num_descriptors=num_descriptors,
num_centroids=num_centroids,
grid_shape=grid_shape,
minval=minval,
maxval=maxval,
)
Expand Down

0 comments on commit 1a434a1

Please sign in to comment.