Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax, brax and flax versions (fixes the jax.tree_util warnings) #76

Merged
merged 15 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions dev.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ COPY requirements.txt /tmp/requirements.txt
COPY requirements-dev.txt /tmp/requirements-dev.txt
COPY environment.yaml /tmp/environment.yaml


RUN micromamba create -y --file /tmp/environment.yaml \
&& micromamba clean --all --yes \
&& find /opt/conda/ -follow -type f -name '*.pyc' -delete
Expand Down Expand Up @@ -41,7 +40,7 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.0/targets/x86_64-linux/l

ENV TZ=Europe/Paris
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
RUN pip --no-cache-dir install jaxlib==0.3.10+cuda11.cudnn82 \
RUN pip --no-cache-dir install jaxlib==0.3.15+cuda11.cudnn82 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
&& rm -rf /tmp/*

Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ dependencies:
- conda>=4.9.2
- pip:
- --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
- jaxlib==0.3.10
- jaxlib==0.3.15
- -r requirements.txt
- -r requirements-dev.txt
4 changes: 2 additions & 2 deletions notebooks/cmamega_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@
" gradients = jnp.nan_to_num(gradients)\n",
"\n",
" # Compute normalized gradients\n",
" norm_gradients = jax.tree_map(\n",
" norm_gradients = jax.tree_util.tree_map(\n",
" lambda x: jnp.linalg.norm(x, axis=1, keepdims=True),\n",
" gradients,\n",
" )\n",
" grads = jax.tree_map(\n",
" grads = jax.tree_util.tree_map(\n",
" lambda x, y: x / y, gradients, norm_gradients\n",
" )\n",
" grads = jnp.nan_to_num(grads)\n",
Expand Down
10 changes: 6 additions & 4 deletions qdax/baselines/dads.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
update_running_mean_std,
)
from qdax.core.neuroevolution.sac_utils import generate_unroll
from qdax.environments import CompletedEvalWrapper
from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor


Expand Down Expand Up @@ -144,7 +145,7 @@ def init( # type: ignore
random_key, subkey = jax.random.split(random_key)
critic_params = self._critic.init(subkey, dummy_obs, dummy_action)

target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), critic_params
)

Expand Down Expand Up @@ -373,16 +374,17 @@ def eval_policy_fn(
play_step_fn=play_step_fn,
)

eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY
true_return = (
state.info["eval_metrics"].completed_episodes_metrics["reward"]
/ state.info["eval_metrics"].completed_episodes
state.info[eval_metrics_key].completed_episodes_metrics["reward"]
/ state.info[eval_metrics_key].completed_episodes
)

transitions = get_first_episode(transitions)

true_returns = jnp.nansum(transitions.rewards, axis=0)

reshaped_transitions = jax.tree_map(
reshaped_transitions = jax.tree_util.tree_map(
lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)),
transitions,
)
Expand Down
10 changes: 6 additions & 4 deletions qdax/baselines/diayn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from qdax.core.neuroevolution.mdp_utils import TrainingState, get_first_episode
from qdax.core.neuroevolution.networks.diayn_networks import make_diayn_networks
from qdax.core.neuroevolution.sac_utils import generate_unroll
from qdax.environments import CompletedEvalWrapper
from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor


Expand Down Expand Up @@ -141,7 +142,7 @@ def init( # type: ignore
random_key, subkey = jax.random.split(random_key)
critic_params = self._critic.init(subkey, dummy_obs, dummy_action)

target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), critic_params
)

Expand Down Expand Up @@ -316,16 +317,17 @@ def eval_policy_fn(
play_step_fn=play_step_fn,
)

eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY
true_return = (
state.info["eval_metrics"].completed_episodes_metrics["reward"]
/ state.info["eval_metrics"].completed_episodes
state.info[eval_metrics_key].completed_episodes_metrics["reward"]
/ state.info[eval_metrics_key].completed_episodes
)

transitions = get_first_episode(transitions)

true_return_per_env = jnp.nansum(transitions.rewards, axis=0)

reshaped_transitions = jax.tree_map(
reshaped_transitions = jax.tree_util.tree_map(
lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)),
transitions,
)
Expand Down
10 changes: 6 additions & 4 deletions qdax/baselines/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
update_running_mean_std,
)
from qdax.core.neuroevolution.sac_utils import generate_unroll
from qdax.environments import CompletedEvalWrapper
from qdax.types import Action, Metrics, Observation, Params, Reward, RNGKey


Expand Down Expand Up @@ -115,7 +116,7 @@ def init(
random_key, subkey = jax.random.split(random_key)
critic_params = self._critic.init(subkey, dummy_obs, dummy_action)

target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), critic_params
)

Expand Down Expand Up @@ -298,9 +299,10 @@ def eval_policy_fn(
play_step_fn=play_step_fn,
)

eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY
true_return = (
state.info["eval_metrics"].completed_episodes_metrics["reward"]
/ state.info["eval_metrics"].completed_episodes
state.info[eval_metrics_key].completed_episodes_metrics["reward"]
/ state.info[eval_metrics_key].completed_episodes
)

transitions = get_first_episode(transitions)
Expand Down Expand Up @@ -389,7 +391,7 @@ def _update_critic(
critic_params = optax.apply_updates(
training_state.critic_params, critic_updates
)
target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x1, x2: (1.0 - self._config.tau) * x1 + self._config.tau * x2,
training_state.target_critic_params,
critic_params,
Expand Down
14 changes: 8 additions & 6 deletions qdax/baselines/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_first_episode,
)
from qdax.core.neuroevolution.networks.td3_networks import make_td3_networks
from qdax.environments import CompletedEvalWrapper
from qdax.types import Action, Metrics, Observation, Params, Reward, RNGKey


Expand Down Expand Up @@ -113,10 +114,10 @@ def init(
policy_params = self._policy.init(subkey_2, fake_obs)

# Initialize target networks
target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), critic_params
)
target_policy_params = jax.tree_map(
target_policy_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), policy_params
)

Expand Down Expand Up @@ -251,9 +252,10 @@ def eval_policy_fn(
play_step_fn=play_step_fn,
)

eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY
true_return = (
state.info["eval_metrics"].completed_episodes_metrics["reward"]
/ state.info["eval_metrics"].completed_episodes
state.info[eval_metrics_key].completed_episodes_metrics["reward"]
/ state.info[eval_metrics_key].completed_episodes
)

transitions = get_first_episode(transitions)
Expand Down Expand Up @@ -303,7 +305,7 @@ def update(
training_state.critic_params, critic_updates
)
# Soft update of target critic network
target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
+ self._config.soft_tau_update * x2,
training_state.target_critic_params,
Expand All @@ -325,7 +327,7 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]:
training_state.policy_params, policy_updates
)
# Soft update of target policy
target_policy_params = jax.tree_map(
target_policy_params = jax.tree_util.tree_map(
lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
+ self._config.soft_tau_update * x2,
training_state.target_policy_params,
Expand Down
12 changes: 7 additions & 5 deletions qdax/core/containers/ga_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class GARepertoire(Repertoire):
@property
def size(self) -> int:
"""Gives the size of the population."""
first_leaf = jax.tree_leaves(self.genotypes)[0]
first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0]
return int(first_leaf.shape[0])

def save(self, path: str = "./") -> None:
Expand Down Expand Up @@ -95,7 +95,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey

# sample
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_map(
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(
subkey, x, shape=(num_samples,), p=p, replace=False
),
Expand All @@ -122,7 +122,7 @@ def add(
"""

# gather individuals and fitnesses
candidates = jax.tree_map(
candidates = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=0),
self.genotypes,
batch_of_genotypes,
Expand All @@ -138,7 +138,9 @@ def add(
survivor_indices = indices[: self.size]

# keep only the best ones
new_candidates = jax.tree_map(lambda x: x[survivor_indices], candidates)
new_candidates = jax.tree_util.tree_map(
lambda x: x[survivor_indices], candidates
)

new_repertoire = self.replace(
genotypes=new_candidates, fitnesses=candidates_fitnesses[survivor_indices]
Expand Down Expand Up @@ -172,7 +174,7 @@ def init( # type: ignore
)

# create default genotypes
default_genotypes = jax.tree_map(
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
)

Expand Down
6 changes: 3 additions & 3 deletions qdax/core/containers/mapelites_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey
p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)

random_key, subkey = jax.random.split(random_key)
samples = jax.tree_map(
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.genotypes,
)
Expand Down Expand Up @@ -283,7 +283,7 @@ def add(
)

# create new repertoire
new_repertoire_genotypes = jax.tree_map(
new_repertoire_genotypes = jax.tree_util.tree_map(
lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
batch_of_indices.squeeze(axis=-1)
].set(new_genotypes),
Expand Down Expand Up @@ -337,7 +337,7 @@ def init(
# Initialize repertoire with default values
num_centroids = centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
default_genotypes = jax.tree_map(
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(num_centroids,) + x.shape[1:]),
genotypes,
)
Expand Down
28 changes: 16 additions & 12 deletions qdax/core/containers/mome_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def repertoire_capacity(self) -> int:
Returns:
The repertoire capacity.
"""
first_leaf = jax.tree_leaves(self.genotypes)[0]
first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0]
return int(first_leaf.shape[0] * first_leaf.shape[1])

@jax.jit
Expand All @@ -80,7 +80,7 @@ def _sample_in_masked_pareto_front(
"""
p = (1.0 - mask) / jnp.sum(1.0 - mask)

genotype_sample = jax.tree_map(
genotype_sample = jax.tree_util.tree_map(
lambda x: jax.random.choice(random_key, x, shape=(1,), p=p),
pareto_front_genotypes,
)
Expand Down Expand Up @@ -116,7 +116,9 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey
cells_idx = jax.random.choice(subkey, indices, shape=(num_samples,), p=p)

# get genotypes (front) from the chosen indices
pareto_front_genotypes = jax.tree_map(lambda x: x[cells_idx], self.genotypes)
pareto_front_genotypes = jax.tree_util.tree_map(
lambda x: x[cells_idx], self.genotypes
)

# prepare second sampling function
sample_in_fronts = jax.vmap(self._sample_in_masked_pareto_front)
Expand All @@ -131,7 +133,9 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey
)

# remove the dim coming from pareto front
sampled_genotypes = jax.tree_map(lambda x: x.squeeze(axis=1), sampled_genotypes)
sampled_genotypes = jax.tree_util.tree_map(
lambda x: x.squeeze(axis=1), sampled_genotypes
)

return sampled_genotypes, random_key

Expand Down Expand Up @@ -172,7 +176,7 @@ def _update_masked_pareto_front(

pareto_front_len = pareto_front_fitnesses.shape[0] # type: ignore

first_leaf = jax.tree_leaves(new_batch_of_genotypes)[0]
first_leaf = jax.tree_util.tree_leaves(new_batch_of_genotypes)[0]
genotypes_dim = first_leaf.shape[1]

descriptors_dim = new_batch_of_descriptors.shape[1]
Expand All @@ -182,7 +186,7 @@ def _update_masked_pareto_front(
cat_fitnesses = jnp.concatenate(
[pareto_front_fitnesses, new_batch_of_fitnesses], axis=0
)
cat_genotypes = jax.tree_map(
cat_genotypes = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate([x, y], axis=0),
pareto_front_genotypes,
new_batch_of_genotypes,
Expand All @@ -205,7 +209,7 @@ def _update_masked_pareto_front(

# get new fitness, genotypes and descriptors
new_front_fitness = jnp.take(cat_fitnesses, indices, axis=0)
new_front_genotypes = jax.tree_map(
new_front_genotypes = jax.tree_util.tree_map(
lambda x: jnp.take(x, indices, axis=0), cat_genotypes
)
new_front_descriptors = jnp.take(cat_descriptors, indices, axis=0)
Expand All @@ -232,10 +236,10 @@ def _update_masked_pareto_front(
genotypes_mask = jnp.repeat(
jnp.expand_dims(new_mask, axis=-1), genotypes_dim, axis=-1
)
new_front_genotypes = jax.tree_map(
new_front_genotypes = jax.tree_util.tree_map(
lambda x: x * genotypes_mask, new_front_genotypes
)
new_front_genotypes = jax.tree_map(
new_front_genotypes = jax.tree_util.tree_map(
lambda x: x[:front_size, :], new_front_genotypes
)

Expand Down Expand Up @@ -289,7 +293,7 @@ def _add_one(
index = index.astype(jnp.int32)

# get cell data
cell_genotype = jax.tree_map(lambda x: x[index], carry.genotypes)
cell_genotype = jax.tree_util.tree_map(lambda x: x[index], carry.genotypes)
cell_fitness = carry.fitnesses[index]
cell_descriptor = carry.descriptors[index]
cell_mask = jnp.any(cell_fitness == -jnp.inf, axis=-1)
Expand All @@ -315,7 +319,7 @@ def _add_one(
cell_fitness = cell_fitness - jnp.inf * jnp.expand_dims(cell_mask, axis=-1)

# update grid
new_genotypes = jax.tree_map(
new_genotypes = jax.tree_util.tree_map(
lambda x, y: x.at[index].set(y), carry.genotypes, cell_genotype
)
new_fitnesses = carry.fitnesses.at[index].set(cell_fitness)
Expand Down Expand Up @@ -383,7 +387,7 @@ def init( # type: ignore
default_fitnesses = -jnp.inf * jnp.ones(
shape=(num_centroids, pareto_front_max_length, num_criteria)
)
default_genotypes = jax.tree_map(
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(
shape=(
num_centroids,
Expand Down
Loading