Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.2.0 of QDax #126

Merged
merged 28 commits into from
Nov 30, 2022
Merged
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
be464f2
Merge branch 'main' into develop
Lookatator Aug 9, 2022
a04fadc
fix: correct irrelevant factor 0.25 in td3 loss (#78)
felixchalumeau Sep 6, 2022
82262e5
Fix the replay buffer overflow issue (#75)
felixchalumeau Sep 6, 2022
ab4d4ca
Merge branch 'main' into develop
Lookatator Oct 6, 2022
13272b0
feat: Default Scoring Functions for Sphere, Rastrigin, Arm, Brax envi…
Lookatator Oct 13, 2022
b691c7b
docs: add caveats and logo (#99)
felixchalumeau Nov 1, 2022
e58472e
hotfix(images): re-add deleted logos to the repo
felixchalumeau Nov 1, 2022
7d094ef
fix(pointmaze): scale after the clip of actions (#101)
felixchalumeau Nov 2, 2022
9d35be7
fix: add update policy delay in PG emitter
limbryan Nov 4, 2022
0180665
fix: optimizer state reinitialization for PG variations (#104)
maxencefaldor Nov 4, 2022
7607290
fix(style): mypy issue in controller training
felixchalumeau Nov 7, 2022
aee7e51
feat(envs): wrapper for fixed initial state of environments (#92)
limbryan Nov 7, 2022
46e82b6
fix(docs): avoid using flax 0.6.2 in setup (#112)
felixchalumeau Nov 22, 2022
df67142
fix(examples): brax version in colab examples (#108)
limbryan Nov 22, 2022
335a0ef
feat(algorithms): add ME-ES to QDax (#81)
manon-but-yes Nov 23, 2022
f20a32a
fix: reset_based scoring in brax_env default task (#109)
manon-but-yes Nov 23, 2022
211d176
feat(algorithms): Add Multi-Emitter (#90)
Lookatator Nov 23, 2022
d7c6dc7
fix(mees): add batch size property (#114)
felixchalumeau Nov 23, 2022
a5c19a2
feat(algorithms): add CMA-ME, fix CMA-ES and CMA-MEGA (#86)
felixchalumeau Nov 24, 2022
6370def
feat(algorithms): add QDPG emitter + refactor PGAME (#110)
felixchalumeau Nov 28, 2022
2afc4ea
feat(github): add GitHub template for PR (#120)
felixchalumeau Nov 29, 2022
db730e0
fix(test): inverse fitness and desc names in sampling test (#119)
manon-but-yes Nov 29, 2022
0f07770
feat(algorithms): add MAP-Elites distributed on multiple devices (#117)
felixchalumeau Nov 29, 2022
062b52f
fix(docker): fix run-image docker stage (#121)
Egiob Nov 29, 2022
ab13dbd
fix(jit): avoid consecutive jits of same method in for loops (#122)
Lookatator Nov 29, 2022
f5b5d94
feat(repertoire): optional extra-scores for repertoire addition (#118)
manon-but-yes Nov 29, 2022
f84f887
fix(doc): add colab links, missing doc, update version (#125)
felixchalumeau Nov 30, 2022
cfaf3c8
fix(envs): order of wrappers to ensure update of state descriptor whe…
felixchalumeau Nov 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix(style): mypy issue in controller training
  • Loading branch information
felixchalumeau committed Nov 7, 2022
commit 7607290cd0ab34d54c2cf3925a7c2ab6ea618bd5
36 changes: 24 additions & 12 deletions qdax/core/emitters/pga_me_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,22 +395,32 @@ def _mutation_function_pg(
"""

# Define new controller optimizer state
controller_optimizer_state = self._controllers_optimizer.init(
controller_params
)
controller_optimizer_state = self._controllers_optimizer.init(controller_params)

def scan_train_controller(
carry: Tuple[PGAMEEmitterState, Genotype], unused: Any
) -> Tuple[Tuple[PGAMEEmitterState, Genotype], Any]:
carry: Tuple[PGAMEEmitterState, Genotype, optax.OptState], unused: Any
) -> Tuple[Tuple[PGAMEEmitterState, Genotype, optax.OptState], Any]:
emitter_state, controller_params, controller_optimizer_state = carry
(
new_emitter_state,
new_controller_params,
new_controller_optimizer_state
) = self._train_controller(emitter_state, controller_params, controller_optimizer_state,)
return (new_emitter_state, new_controller_params, new_controller_optimizer_state), ()
new_controller_optimizer_state,
) = self._train_controller(
emitter_state,
controller_params,
controller_optimizer_state,
)
return (
new_emitter_state,
new_controller_params,
new_controller_optimizer_state,
), ()

(emitter_state, controller_params, controller_optimizer_state), _ = jax.lax.scan(
(
emitter_state,
controller_params,
controller_optimizer_state,
), _ = jax.lax.scan(
scan_train_controller,
(emitter_state, controller_params, controller_optimizer_state),
(),
Expand All @@ -425,7 +435,7 @@ def _train_controller(
emitter_state: PGAMEEmitterState,
controller_params: Params,
controller_optimizer_state: optax.OptState,
) -> Tuple[PGAMEEmitterState, Params]:
) -> Tuple[PGAMEEmitterState, Params, optax.OptState]:
"""Apply one gradient step to a policy (called controllers_params).

Args:
Expand All @@ -449,7 +459,10 @@ def _train_controller(
samples,
)
# Compute gradient and update policies
(policy_updates, controller_optimizer_state,) = self._controllers_optimizer.update(
(
policy_updates,
controller_optimizer_state,
) = self._controllers_optimizer.update(
policy_gradient, controller_optimizer_state
)
controller_params = optax.apply_updates(controller_params, policy_updates)
Expand All @@ -468,4 +481,3 @@ def _train_controller(
)

return new_emitter_state, controller_params, controller_optimizer_state