Skip to content

Commit

Permalink
fix bad imports in notebooks + remove unused functions from mdp utils
Browse files Browse the repository at this point in the history
  • Loading branch information
felixchalumeau committed Mar 9, 2023
1 parent 7b694fd commit 414cdc3
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 121 deletions.
2 changes: 1 addition & 1 deletion examples/dads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"from qdax import environments\n",
"from qdax.baselines.dads import DADS, DadsConfig, DadsTrainingState\n",
"from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer\n",
"from qdax.core.neuroevolution.sac_utils import do_iteration_fn, warmstart_buffer\n",
"from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer\n",
"\n",
"from qdax.utils.plotting import plot_skills_trajectory\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/diayn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"from qdax import environments\n",
"from qdax.baselines.diayn import DIAYN, DiaynConfig, DiaynTrainingState\n",
"from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer\n",
"from qdax.core.neuroevolution.sac_utils import do_iteration_fn, warmstart_buffer\n",
"from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer\n",
"\n",
"from qdax.utils.plotting import plot_skills_trajectory\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/smerl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
"from qdax.baselines.diayn_smerl import DIAYNSMERL, DiaynSmerlConfig, DiaynTrainingState\n",
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
"from qdax.core.neuroevolution.buffers.trajectory_buffer import TrajectoryBuffer\n",
"from qdax.core.neuroevolution.sac_utils import do_iteration_fn, warmstart_buffer\n",
"from qdax.core.neuroevolution.sac_td3_utils import do_iteration_fn, warmstart_buffer\n",
"\n",
"from qdax.utils.plotting import plot_skills_trajectory\n",
"\n",
Expand Down
119 changes: 2 additions & 117 deletions qdax/core/neuroevolution/mdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from brax.envs import State as EnvState
from flax.struct import PyTreeNode

from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition
from qdax.types import Genotype, Metrics, Params, RNGKey
from qdax.core.neuroevolution.buffers.buffer import Transition
from qdax.types import Genotype, Params, RNGKey


class TrainingState(PyTreeNode):
Expand All @@ -22,53 +22,6 @@ class TrainingState(PyTreeNode):
pass


@partial(
jax.jit,
static_argnames=(
"num_warmstart_steps",
"play_step_fn",
"env_batch_size",
),
)
def warmstart_buffer(
replay_buffer: ReplayBuffer,
policy_params: Params,
random_key: RNGKey,
env_state: EnvState,
play_step_fn: Callable[
[EnvState, Params, RNGKey],
Tuple[
EnvState,
Params,
RNGKey,
Transition,
],
],
num_warmstart_steps: int,
env_batch_size: int,
) -> Tuple[ReplayBuffer, EnvState]:
"""Pre-populates the buffer with transitions. Returns the warmstarted buffer
and the new state of the environment.
"""

def _scan_play_step_fn(
carry: Tuple[EnvState, Params, RNGKey], unused_arg: Any
) -> Tuple[Tuple[EnvState, Params, RNGKey], Transition]:
env_state, policy_params, random_key, transitions = play_step_fn(*carry)
return (env_state, policy_params, random_key), transitions

random_key, subkey = jax.random.split(random_key)
(state, _, _), transitions = jax.lax.scan(
_scan_play_step_fn,
(env_state, policy_params, subkey),
(),
length=num_warmstart_steps // env_batch_size,
)
replay_buffer = replay_buffer.insert(transitions)

return replay_buffer, env_state


@partial(jax.jit, static_argnames=("play_step_fn", "episode_length"))
def generate_unroll(
init_state: EnvState,
Expand Down Expand Up @@ -114,74 +67,6 @@ def _scan_play_step_fn(
return state, transitions


@partial(
jax.jit,
static_argnames=(
"env_batch_size",
"grad_updates_per_step",
"play_step_fn",
"update_fn",
),
)
def do_iteration_fn(
training_state: TrainingState,
env_state: EnvState,
replay_buffer: ReplayBuffer,
env_batch_size: int,
grad_updates_per_step: float,
play_step_fn: Callable[
[EnvState, Params, RNGKey],
Tuple[
EnvState,
Params,
RNGKey,
Transition,
],
],
update_fn: Callable[
[TrainingState, ReplayBuffer],
Tuple[
TrainingState,
ReplayBuffer,
Metrics,
],
],
) -> Tuple[TrainingState, EnvState, ReplayBuffer, Metrics]:
"""Performs one environment step (over all env simultaneously) followed by one
training step. The number of updates is controlled by the parameter
`grad_updates_per_step` (0 means no update while 1 means `env_batch_size`
updates). Returns the updated states, the updated buffer and the aggregated
metrics.
"""

def _scan_update_fn(
carry: Tuple[TrainingState, ReplayBuffer], unused_arg: Any
) -> Tuple[Tuple[TrainingState, ReplayBuffer], Metrics]:
training_state, replay_buffer, metrics = update_fn(*carry)
return (training_state, replay_buffer), metrics

# play steps in the environment
random_key = training_state.random_key
env_state, _, random_key, transitions = play_step_fn(
env_state,
training_state.policy_params,
random_key,
)

# insert transitions in replay buffer
replay_buffer = replay_buffer.insert(transitions)
num_updates = int(grad_updates_per_step * env_batch_size)

(training_state, replay_buffer), metrics = jax.lax.scan(
_scan_update_fn,
(training_state, replay_buffer),
(),
length=num_updates,
)

return training_state, env_state, replay_buffer, metrics


@jax.jit
def get_first_episode(transition: Transition) -> Transition:
"""Extracts the first episode from a batch of transitions, returns the batch of
Expand Down
4 changes: 3 additions & 1 deletion qdax/tasks/brax_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def reset_based_scoring_function_brax_envs(
"""

random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, jax.tree_leaves(policies_params)[0].shape[0])
keys = jax.random.split(
subkey, jax.tree_util.tree_leaves(policies_params)[0].shape[0]
)
reset_fn = jax.vmap(play_reset_fn)
init_states = reset_fn(keys)

Expand Down

0 comments on commit 414cdc3

Please sign in to comment.