Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default Scoring Functions for Sphere, Rastrigin, Arm and Brax environments #73

Merged
merged 56 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
a75970b
first commit - push a draft of the strcuture and arm and rastrigin tasks
limbryan Aug 1, 2022
fb5c5a2
complete arm and standard functions - rastrigin, sphere, rastrigin-proj
limbryan Aug 1, 2022
e314310
adding create_brax_scoring_function_fn to make scoring functions for …
Lookatator Aug 1, 2022
0aa9d55
Merge branch 'feat/default_tasks' of github.com:adaptive-intelligent-…
Lookatator Aug 1, 2022
418d2bc
add test for default standard functions
limbryan Aug 2, 2022
c28ec64
add test for default arm task and test make better docstring for arm
limbryan Aug 2, 2022
ab7edd0
add rastrigin proj test as well
limbryan Aug 3, 2022
4f38cb4
fix weird styling issues to do with types for the tests script
limbryan Aug 3, 2022
ff4847c
adding function for init controllers, and a function for creating def…
Lookatator Aug 4, 2022
34ac363
simpler metrics_fn + fix type of bd_extractor
Lookatator Aug 4, 2022
15b931c
fix catch of return of init_population_controllers
Lookatator Aug 4, 2022
2ce365e
add Docstrings
Lookatator Aug 8, 2022
a24ff3c
fix metadata notebook
Lookatator Aug 9, 2022
7323ce7
fix import scoring function in notebook
Lookatator Aug 9, 2022
3c83225
add init file into tasks directory, create an examples folder for scr…
limbryan Aug 11, 2022
5d5ded3
update README to be a usable readme that uses the arm function instea…
limbryan Aug 11, 2022
d62baa6
upload new benchmark functions from QDbenchmark workshop with corresp…
limbryan Aug 19, 2022
c750ae1
add README for summary of tasks in the directory along wiht some desc…
limbryan Aug 19, 2022
c7f5f64
fix implementation hypervolume functions
Lookatator Aug 21, 2022
b1fb1f1
adding run_me() example function
Lookatator Aug 21, 2022
3a04dc4
adding basis functions qd_suite benchmarking
Lookatator Aug 22, 2022
632ff41
fix keymanagement in noisy arm and add option to add noise on params …
limbryan Aug 22, 2022
e44c7b8
Merge branch 'feat/default_tasks' of https://github.com/adaptive-inte…
limbryan Aug 22, 2022
93fdadd
add draft of tasks summary README
Aug 22, 2022
338d7c2
refactoring benchmarking functions
Lookatator Aug 22, 2022
fa1acf0
fix key splitting in noisy arm task and references to hypervolume fun…
Aug 22, 2022
03a7702
improve plotting plot_multidimensional_map_elites_grid when using hig…
Lookatator Aug 22, 2022
b4a5a31
add default tasks qd_suite
Lookatator Aug 22, 2022
53c206a
Merge branch 'feat/default_tasks' of github.com:adaptive-intelligent-…
Lookatator Aug 22, 2022
460b5a8
completing description of QD Suite functions
Lookatator Aug 22, 2022
22ffba8
move all qd_suite tasks
Lookatator Aug 22, 2022
3639982
fix README latex
Lookatator Aug 22, 2022
3e38f11
fix README latex
Lookatator Aug 23, 2022
c40b939
move default tasks to qd_suite __init__
Lookatator Aug 23, 2022
b865988
add example usage qd_suite tasks
Lookatator Aug 23, 2022
6dab609
add example usage hypervolume functions
Lookatator Aug 23, 2022
57d0d51
add examples for standard function
Lookatator Aug 23, 2022
3c1419c
example BRAX usage
Lookatator Aug 23, 2022
41831a5
add test for qd suite tasks
Aug 23, 2022
7371d35
Merge branch 'feat/default_tasks' of https://github.com/adaptive-inte…
Aug 23, 2022
5be8b0d
add test for qd suite tasks
Aug 23, 2022
f33ea17
remove type aliases
Lookatator Sep 7, 2022
4b1f333
specify type of grid_shape
Lookatator Sep 7, 2022
6e0baa7
Merge remote-tracking branch 'origin/develop' into feat/default_tasks
Lookatator Sep 7, 2022
73d32ce
fix styling issue
Lookatator Sep 7, 2022
75a456c
add Docstrings default scores
Lookatator Sep 8, 2022
c299e4f
Mention QDax tasks doc in the main README
Lookatator Sep 8, 2022
1908362
stochastic arm -> noisy arm
Lookatator Sep 9, 2022
169d1ac
add task specific test for arm
limbryan Oct 11, 2022
f5b88fb
Merge branch 'develop' into feat/default_tasks
Lookatator Oct 11, 2022
ed52ea6
Merge branch 'feat/default_tasks' of github.com:adaptive-intelligent-…
Lookatator Oct 11, 2022
57e0caa
reformat tests/default_tasks_test/arm_test.py
Lookatator Oct 11, 2022
1f08215
complete missing descriptor bounds
felixchalumeau Oct 12, 2022
fd98af0
QDSuiteTask inherits from abc.ABC instead of using ABCMeta
Lookatator Oct 12, 2022
76cc8a5
Merge branch 'feat/default_tasks' of github.com:adaptive-intelligent-…
Lookatator Oct 12, 2022
1de5311
create examples folder with notebooks and scripts
Lookatator Oct 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactoring benchmarking functions
  • Loading branch information
Lookatator committed Aug 22, 2022
commit 338d7c2ae1fe91bda4f6f3e9f1cfc63f024eab86
218 changes: 151 additions & 67 deletions qdax/tasks/archimedean_spiral.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import abc
import time
from enum import Enum
from functools import partial
from typing import Tuple
from typing import Tuple, Optional

import jax.lax
import jax.numpy as jnp
Expand All @@ -20,97 +21,180 @@ class ArchimedeanBD(Enum):
geodesic = "geodesic"


def archimedean_spiral(params: Genotype,
parameterization: ParameterizationGenotype,
archimedean_bd: ArchimedeanBD,
parameter=0.01,
precision=None,
alpha=40,
) -> Tuple[Fitness, Descriptor]:
"""
Seach space should be [0,alpha * pi]^n
BD space should be [0,1]^n
"""
if precision is None:
precision = alpha * jnp.pi / 1e7

constant_fitness = jnp.asarray(1.)

def _gamma(angle):
return jnp.asarray([parameter * angle * jnp.cos(angle),
parameter * angle * jnp.sin(angle)])

def _arc_length(angle):
return (parameter / 2) * \
class QDBenchmarkTask(metaclass=abc.ABCMeta):
@abc.abstractmethod
def scoring_function(self, params: Genotype) -> Tuple[Fitness, Descriptor]:
...

@abc.abstractmethod
def get_bd_size(self):
...

@abc.abstractmethod
def get_min_max_bd(self):
...

def get_bounded_min_max_bd(self):
min_bd, max_bd = self.get_min_max_bd()
if jnp.isinf(max_bd) or jnp.isinf(min_bd):
raise NotImplementedError("Boundedness has not been implemented "
"for this unbounded task")
else:
return min_bd, max_bd

@abc.abstractmethod
def get_min_max_params(self):
...

@abc.abstractmethod
def get_initial_parameters(self, batch_size: int) -> Genotype:
...


class ArchimedeanSpiralV0(QDBenchmarkTask):
def __init__(self,
parameterization: ParameterizationGenotype,
archimedean_bd: ArchimedeanBD,
parameter: float = 0.01,
precision: Optional[float] = None,
alpha: float = 40.,
):
self.parameterization = parameterization
self.archimedean_bd = archimedean_bd
self.parameter = parameter
if precision is None:
self.precision = alpha * jnp.pi / 1e7
else:
self.precision = precision
self.alpha = alpha

def _gamma(self, angle):
return jnp.asarray([self.parameter * angle * jnp.cos(angle),
self.parameter * angle * jnp.sin(angle)])

def get_arc_length(self, angle):
return (self.parameter / 2) * \
(angle * jnp.sqrt(1 + jnp.power(angle,2))
+ jnp.log(angle + jnp.sqrt(1 + jnp.power(angle, 2))))

def _cond_fun(elem: Tuple[float, float, float]) -> jnp.bool_:
def _cond_fun(self, elem: Tuple[float, float, float]) -> jnp.bool_:
inf, sup, target = elem
return (sup - inf) > precision
return (sup - inf) > self.precision

def _body_fun(elem: Tuple[float, float, float]) -> Tuple[float, float, float]:
def _body_fun(self, elem: Tuple[float, float, float]) -> Tuple[float, float, float]:
inf, sup, target_angle_length = elem
middle = (sup + inf) / 2.
arc_length_middle = _arc_length(middle)
arc_length_middle = self.get_arc_length(middle)
new_inf, new_sup = jax.lax.cond(target_angle_length < arc_length_middle,
lambda: (inf, middle),
lambda: (middle, sup))
return new_inf, new_sup, target_angle_length

def _approximate_angle_from_arc_length(target_arc_length: float) -> jnp.ndarray:
inf, sup, _ = jax.lax.while_loop(_cond_fun,
_body_fun,
def _approximate_angle_from_arc_length(self,
target_arc_length: float
) -> jnp.ndarray:
inf, sup, _ = jax.lax.while_loop(self._cond_fun,
self._body_fun,
init_val=(0.,
alpha * jnp.pi,
target_arc_length))
self.alpha * jnp.pi,
target_arc_length)
)
middle = (sup + inf) / 2.
return jnp.asarray(middle)

if archimedean_bd == ArchimedeanBD.geodesic and parameterization == ParameterizationGenotype.arc_length:
arc_length = params
return constant_fitness, arc_length
elif archimedean_bd == ArchimedeanBD.geodesic and parameterization == ParameterizationGenotype.angle:
arc_length = _arc_length(params)
return constant_fitness, arc_length
elif archimedean_bd == ArchimedeanBD.euclidean and parameterization == ParameterizationGenotype.arc_length:
arc_length = params
angle = _approximate_angle_from_arc_length(arc_length[0])
euclidean_bd = _gamma(angle)
return constant_fitness, euclidean_bd
elif archimedean_bd == ArchimedeanBD.euclidean and parameterization == ParameterizationGenotype.angle:
return constant_fitness, _gamma(params)




def get_arc_length(angle,
a
):
return (a / 2) * (angle * jnp.sqrt(1 + jnp.power(angle,
2)) + jnp.log(angle + jnp.sqrt(1 + jnp.power(angle, 2))))
def scoring_function(self, params: Genotype) -> Tuple[Fitness, Descriptor]:
constant_fitness = jnp.asarray(1.)

if self.archimedean_bd == ArchimedeanBD.geodesic \
and self.parameterization == ParameterizationGenotype.arc_length:
arc_length = params
return constant_fitness, arc_length
elif self.archimedean_bd == ArchimedeanBD.geodesic \
and self.parameterization == ParameterizationGenotype.angle:
angle = params
arc_length = self.get_arc_length(angle)
return constant_fitness, arc_length
elif self.archimedean_bd == ArchimedeanBD.euclidean \
and self.parameterization == ParameterizationGenotype.arc_length:
arc_length = params
angle = self._approximate_angle_from_arc_length(arc_length[0])
euclidean_bd = self._gamma(angle)
return constant_fitness, euclidean_bd
elif self.archimedean_bd == ArchimedeanBD.euclidean \
and self.parameterization == ParameterizationGenotype.angle:
angle = params
return constant_fitness, self._gamma(angle)
else:
raise ValueError("Invalid parameterization and/or BD")

def get_bd_size(self) -> int:
if self.archimedean_bd == ArchimedeanBD.euclidean:
return 2
elif self.archimedean_bd == ArchimedeanBD.geodesic:
return 1
else:
raise ValueError("Invalid BD")

def get_min_max_bd(self) -> Tuple[Optional[float], Optional[float]]:
max_angle = self.alpha * jnp.pi
max_norm = jnp.linalg.norm(self._gamma(max_angle))

if self.archimedean_bd == ArchimedeanBD.euclidean:
return -max_norm, max_norm
elif self.archimedean_bd == ArchimedeanBD.geodesic:
max_arc_length = self.get_arc_length(max_angle)
return 0., max_arc_length
else:
raise ValueError("Invalid BD")

def get_min_max_params(self) -> Tuple[Optional[float], Optional[float]]:
if self.parameterization == ParameterizationGenotype.angle:
max_angle = self.alpha * jnp.pi
return 0., max_angle
elif self.parameterization == ParameterizationGenotype.arc_length:
max_angle = self.alpha * jnp.pi
max_arc_length = self.get_arc_length(max_angle)
return 0, max_arc_length
else:
raise ValueError("Invalid parameterization")

def get_initial_parameters(self, batch_size: int) -> Genotype:
max_angle = self.alpha * jnp.pi
mid_angle = max_angle / 2.
mid_number_turns = 1 + int(mid_angle / (2. * jnp.pi))
horizontal_left_mid_angle = mid_number_turns * jnp.pi * 2

if self.parameterization == ParameterizationGenotype.angle:
angle_array = jnp.asarray(horizontal_left_mid_angle).reshape((1, 1))
return jnp.repeat(angle_array, batch_size, axis=0)
elif self.parameterization == ParameterizationGenotype.arc_length:
arc_length = self.get_arc_length(horizontal_left_mid_angle)
length_array = jnp.asarray(arc_length).reshape((1, 1))
return jnp.repeat(length_array, batch_size, axis=0)
else:
raise ValueError("Invalid parameterization")


if __name__ == '__main__':
parameter = 200
alpha = 10

archimedean_spiral_fn = partial(archimedean_spiral,
parameterization=ParameterizationGenotype.arc_length,
archimedean_bd=ArchimedeanBD.euclidean,
parameter=parameter,)

max_length = get_arc_length(jnp.pi * alpha,
parameter)
x = jnp.linspace(0,
max_length,
16000).reshape((-1, 1))
# parameter = 200
# alpha = 10
task = ArchimedeanSpiralV0(parameterization=ParameterizationGenotype.arc_length,
archimedean_bd=ArchimedeanBD.euclidean)
archimedean_spiral_fn = task.scoring_function

# max_length = task.get_arc_length(jnp.pi * task.alpha)
x = jnp.linspace(*task.get_min_max_params(), num=16000).reshape((-1, 1))
f = jax.jit(jax.vmap(archimedean_spiral_fn))
f(x)
start = time.time()
res = f(x)

point = task.get_initial_parameters(126)
red_point = f(point)

print(res)
print("time taken: {}".format(time.time() - start))
plt.plot(res[1][:, 0],
res[1][:, 1])
plt.scatter(red_point[1][0, 0], red_point[1][0, 1], c='r')
plt.show()
Loading