diff --git a/examples/aurora.ipynb b/examples/aurora.ipynb index fb955a98..09072715 100644 --- a/examples/aurora.ipynb +++ b/examples/aurora.ipynb @@ -31,7 +31,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -46,37 +64,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", - "\n", "from qdax.core.aurora import AURORA\n", "from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n", "from qdax import environments\n", diff --git a/examples/cmaes.ipynb b/examples/cmaes.ipynb index 8b87473d..ba059cdc 100644 --- a/examples/cmaes.ipynb +++ b/examples/cmaes.ipynb @@ -30,39 +30,28 @@ "metadata": {}, "outputs": [], "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", + "from IPython.display import clear_output\n", "\n", "try:\n", " import qdax\n", "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", " import qdax\n", "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", "import matplotlib.pyplot as plt\n", "from matplotlib.patches import Ellipse\n", "\n", @@ -71,7 +60,7 @@ }, { "cell_type": "markdown", - "id": "3", + "id": "4", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -80,7 +69,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -98,7 +87,7 @@ }, { "cell_type": "markdown", - "id": "5", + "id": "6", "metadata": { "pycharm": { "name": "#%% md\n" @@ -111,7 +100,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -133,7 +122,7 @@ }, { "cell_type": "markdown", - "id": "7", + "id": "8", "metadata": { "pycharm": { "name": "#%% md\n" @@ -146,7 +135,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "9", "metadata": { "pycharm": { "name": "#%%\n" @@ -167,7 +156,7 @@ }, { "cell_type": "markdown", - "id": "9", + "id": "10", "metadata": { "pycharm": { "name": "#%% md\n" @@ -180,7 +169,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "11", "metadata": { "pycharm": { "name": "#%%\n" @@ -194,7 +183,7 @@ }, { "cell_type": "markdown", - "id": "11", + "id": "12", "metadata": { "pycharm": { "name": "#%% md\n" @@ -207,7 +196,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "13", "metadata": { "pycharm": { "name": "#%%\n" @@ -245,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "13", + "id": "14", "metadata": {}, "source": [ "## Check final fitnesses and distribution mean" @@ -254,7 +243,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -272,7 +261,7 @@ }, { "cell_type": "markdown", - "id": "15", + "id": "16", "metadata": { "pycharm": { "name": "#%% md\n" @@ -285,7 +274,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "17", "metadata": { "pycharm": { "name": "#%%\n" diff --git a/examples/cmame.ipynb b/examples/cmame.ipynb index 7da832eb..f7aa235e 100644 --- a/examples/cmame.ipynb +++ b/examples/cmame.ipynb @@ -23,6 +23,24 @@ "- how to visualise the optimization process" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -38,36 +56,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "from qdax.core.map_elites import MAPElites\n", "from qdax.core.emitters.cma_opt_emitter import CMAOptimizingEmitter\n", "from qdax.core.emitters.cma_rnd_emitter import CMARndEmitter\n", diff --git a/examples/cmamega.ipynb b/examples/cmamega.ipynb index a90f8309..d37bf80e 100644 --- a/examples/cmamega.ipynb +++ b/examples/cmamega.ipynb @@ -29,39 +29,27 @@ "metadata": {}, "outputs": [], "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", + "from IPython.display import clear_output\n", "\n", "try:\n", " import qdax\n", "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", " import qdax\n", "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "\n", "from qdax.core.map_elites import MAPElites\n", "from qdax.core.emitters.cma_mega_emitter import CMAMEGAEmitter\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire\n", diff --git a/examples/dads.ipynb b/examples/dads.ipynb index 19348348..57d1df05 100644 --- a/examples/dads.ipynb +++ b/examples/dads.ipynb @@ -28,7 +28,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -42,37 +60,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", - "\n", "from qdax import environments\n", "from qdax.baselines.dads import DADS, DadsConfig, DadsTrainingState\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition, ReplayBuffer\n", diff --git a/examples/dcrlme.ipynb b/examples/dcrlme.ipynb index 5c367e37..eae0e6b3 100644 --- a/examples/dcrlme.ipynb +++ b/examples/dcrlme.ipynb @@ -13,7 +13,7 @@ "source": [ "# Optimizing with DCRL-ME in Jax\n", "\n", - "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [Descriptor-Conditioned Reinforcement Learning MAP-Elites (DCRL-ME)](https://arxiv.org/abs/2401.08632), also known as *Descriptor-Conditioned Gradients MAP-Elites with Actor Injection (DCG-ME-AI)*. \n", + "This notebook shows how to use QDax to find diverse and performing controllers in MDPs with [Descriptor-Conditioned Reinforcement Learning MAP-Elites (DCRL-ME)](https://arxiv.org/abs/2401.08632).\n", "This algorithm extends and improves upon [Descriptor-Conditioned Gradients MAP-Elites (DCG-ME)](https://dl.acm.org/doi/abs/10.1145/3583131.3590503)\n", "It can be run locally or on Google Colab. We recommand to use a GPU. This notebook will show:\n", "\n", @@ -31,7 +31,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -46,39 +64,6 @@ "\n", "import jax\n", "import jax.numpy as jnp\n", - "import pytest\n", - "from brax.envs import Env, State, Wrapper\n", - "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "\n", "from qdax import environments\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", @@ -137,9 +122,9 @@ "\n", "# DCRL emitter\n", "critic_hidden_layer_size = (256, 256) #@param {type:\"raw\"}\n", - "num_critic_training_steps = 3000\n", - "num_pg_training_steps = 150\n", - "replay_buffer_size = 1_000_000\n", + "num_critic_training_steps = 3000 #@param {type:\"integer\"}\n", + "num_pg_training_steps = 150 #@param {type:\"integer\"}\n", + "replay_buffer_size = 1_000_000 #@param {type:\"integer\"}\n", "discount = 0.99 #@param {type:\"number\"}\n", "reward_scaling = 1.0 #@param {type:\"number\"}\n", "critic_learning_rate = 3e-4 #@param {type:\"number\"}\n", diff --git a/examples/diayn.ipynb b/examples/diayn.ipynb index fba2055f..cdee8b4b 100644 --- a/examples/diayn.ipynb +++ b/examples/diayn.ipynb @@ -28,7 +28,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -42,36 +60,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "\n", "from qdax import environments\n", "from qdax.baselines.diayn import DIAYN, DiaynConfig, DiaynTrainingState\n", diff --git a/examples/distributed_mapelites.ipynb b/examples/distributed_mapelites.ipynb index 0fe1094c..7a7b5296 100644 --- a/examples/distributed_mapelites.ipynb +++ b/examples/distributed_mapelites.ipynb @@ -29,7 +29,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -51,37 +69,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", - "\n", "from qdax.core.distributed_map_elites import DistributedMAPElites\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax import environments\n", diff --git a/examples/jumanji_snake.ipynb b/examples/jumanji_snake.ipynb index dc915524..78ec01c8 100644 --- a/examples/jumanji_snake.ipynb +++ b/examples/jumanji_snake.ipynb @@ -19,44 +19,35 @@ "metadata": {}, "outputs": [], "source": [ - "from functools import partial\n", - "from typing import Tuple, Type\n", - "\n", - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", + "from IPython.display import clear_output\n", "\n", "try:\n", " import qdax\n", "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", " import qdax\n", "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "from typing import Tuple, Type\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "\n", "import functools\n", "\n", + "import jumanji\n", + "\n", "import numpy as np\n", "\n", "from qdax.baselines.genetic_algorithm import GeneticAlgorithm\n", @@ -78,7 +69,7 @@ }, { "cell_type": "markdown", - "id": "2", + "id": "3", "metadata": {}, "source": [ "## Define hyperparameters" @@ -87,7 +78,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -105,7 +96,7 @@ }, { "cell_type": "markdown", - "id": "4", + "id": "5", "metadata": {}, "source": [ "## Instantiate the snake environment" @@ -114,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -132,7 +123,7 @@ }, { "cell_type": "markdown", - "id": "6", + "id": "7", "metadata": {}, "source": [ "## Define the type of policy that will be used to solve the problem" @@ -141,7 +132,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -161,7 +152,7 @@ }, { "cell_type": "markdown", - "id": "8", + "id": "9", "metadata": {}, "source": [ "## Utils to interact with the environment\n", @@ -172,7 +163,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -219,7 +210,7 @@ }, { "cell_type": "markdown", - "id": "10", + "id": "11", "metadata": {}, "source": [ "## Init a population of policies\n", @@ -230,7 +221,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -255,7 +246,7 @@ }, { "cell_type": "markdown", - "id": "12", + "id": "13", "metadata": {}, "source": [ "## Define a method to extract behavior descriptor when relevant" @@ -264,7 +255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -311,7 +302,7 @@ }, { "cell_type": "markdown", - "id": "14", + "id": "15", "metadata": {}, "source": [ "## Define the scoring function" @@ -320,7 +311,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -333,7 +324,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "17", "metadata": {}, "source": [ "## Define the emitter used" @@ -342,7 +333,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -360,7 +351,7 @@ }, { "cell_type": "markdown", - "id": "18", + "id": "19", "metadata": {}, "source": [ "## Define the algorithm used and apply the initial step\n", @@ -371,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -415,7 +406,7 @@ }, { "cell_type": "markdown", - "id": "20", + "id": "21", "metadata": {}, "source": [ "## Run the optimization loop" @@ -424,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -442,7 +433,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -452,7 +443,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -462,7 +453,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -472,7 +463,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -489,7 +480,7 @@ }, { "cell_type": "markdown", - "id": "26", + "id": "27", "metadata": {}, "source": [ "## Play snake with the best policy\n", @@ -500,7 +491,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -511,7 +502,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -524,7 +515,7 @@ { "cell_type": "code", "execution_count": null, - "id": "29", + "id": "30", "metadata": {}, "outputs": [], "source": [ @@ -537,7 +528,7 @@ { "cell_type": "code", "execution_count": null, - "id": "30", + "id": "31", "metadata": {}, "outputs": [], "source": [ @@ -550,7 +541,7 @@ { "cell_type": "code", "execution_count": null, - "id": "31", + "id": "32", "metadata": {}, "outputs": [], "source": [ @@ -563,7 +554,7 @@ { "cell_type": "code", "execution_count": null, - "id": "32", + "id": "33", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/mapelites.ipynb b/examples/mapelites.ipynb index 575ee0c0..b7a0a256 100644 --- a/examples/mapelites.ipynb +++ b/examples/mapelites.ipynb @@ -31,7 +31,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -46,37 +64,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", - "\n", "from qdax.core.map_elites import MAPElites\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire\n", "from qdax import environments\n", diff --git a/examples/me_sac_pbt.ipynb b/examples/me_sac_pbt.ipynb index 778a7a5f..86deebc4 100644 --- a/examples/me_sac_pbt.ipynb +++ b/examples/me_sac_pbt.ipynb @@ -1,5 +1,23 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -13,36 +31,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "import optax\n", "from brax.v1.io import html\n", "from IPython.display import HTML\n", diff --git a/examples/me_td3_pbt.ipynb b/examples/me_td3_pbt.ipynb index da8d7311..f72ccda1 100644 --- a/examples/me_td3_pbt.ipynb +++ b/examples/me_td3_pbt.ipynb @@ -1,5 +1,23 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -14,36 +32,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "import matplotlib.pyplot as plt\n", "from brax.v1.io import html\n", "from IPython.display import HTML\n", diff --git a/examples/mees.ipynb b/examples/mees.ipynb index e9e37c2a..3cbf890f 100644 --- a/examples/mees.ipynb +++ b/examples/mees.ipynb @@ -28,6 +28,24 @@ "- how to visualize the results of the training process" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -36,7 +54,7 @@ }, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -51,36 +69,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "from qdax.core.map_elites import MAPElites\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax import environments\n", diff --git a/examples/mels.ipynb b/examples/mels.ipynb index 4f3fdc74..dae02a95 100644 --- a/examples/mels.ipynb +++ b/examples/mels.ipynb @@ -24,6 +24,24 @@ "- how to save/load a repertoire" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -32,7 +50,7 @@ }, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -47,37 +65,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", - "\n", "from qdax.core.mels import MELS\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax.core.containers.mels_repertoire import MELSRepertoire\n", diff --git a/examples/mome.ipynb b/examples/mome.ipynb index 4661d406..0d005dbe 100644 --- a/examples/mome.ipynb +++ b/examples/mome.ipynb @@ -32,42 +32,31 @@ "metadata": {}, "outputs": [], "source": [ - "import jax.numpy as jnp\n", - "import jax\n", - "from typing import Tuple\n", - "\n", - "from functools import partial\n", - "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", + "from IPython.display import clear_output\n", "\n", "try:\n", " import qdax\n", "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", " import qdax\n", "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import jax\n", + "from typing import Tuple\n", + "\n", + "from functools import partial\n", + "\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax.core.mome import MOME\n", "from qdax.core.emitters.mutation_operators import (\n", @@ -86,7 +75,7 @@ }, { "cell_type": "markdown", - "id": "3", + "id": "4", "metadata": {}, "source": [ "## Set the hyperparameters" @@ -95,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +108,7 @@ }, { "cell_type": "markdown", - "id": "5", + "id": "6", "metadata": {}, "source": [ "## Define the scoring function: rastrigin multi-objective\n", @@ -130,7 +119,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -165,7 +154,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -178,7 +167,7 @@ }, { "cell_type": "markdown", - "id": "8", + "id": "9", "metadata": {}, "source": [ "## Define the metrics function that will be used" @@ -187,7 +176,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "10", "metadata": {}, "outputs": [], "source": [ @@ -202,7 +191,7 @@ }, { "cell_type": "markdown", - "id": "10", + "id": "11", "metadata": {}, "source": [ "## Define the initial population and the emitter" @@ -211,7 +200,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -248,7 +237,7 @@ }, { "cell_type": "markdown", - "id": "12", + "id": "13", "metadata": {}, "source": [ "## Compute the centroids" @@ -257,7 +246,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "14", "metadata": {}, "outputs": [], "source": [ @@ -273,7 +262,7 @@ }, { "cell_type": "markdown", - "id": "14", + "id": "15", "metadata": {}, "source": [ "## Define a MOME instance" @@ -282,7 +271,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -295,7 +284,7 @@ }, { "cell_type": "markdown", - "id": "16", + "id": "17", "metadata": {}, "source": [ "## Init the algorithm" @@ -304,7 +293,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "18", "metadata": {}, "outputs": [], "source": [ @@ -318,7 +307,7 @@ }, { "cell_type": "markdown", - "id": "18", + "id": "19", "metadata": {}, "source": [ "## Run MOME iterations" @@ -327,7 +316,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +333,7 @@ }, { "cell_type": "markdown", - "id": "20", + "id": "21", "metadata": {}, "source": [ "## Plot the results" @@ -353,7 +342,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -363,7 +352,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +380,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "24", "metadata": {}, "outputs": [], "source": [ diff --git a/examples/nsga2_spea2.ipynb b/examples/nsga2_spea2.ipynb index d6385291..b418bd31 100644 --- a/examples/nsga2_spea2.ipynb +++ b/examples/nsga2_spea2.ipynb @@ -24,6 +24,24 @@ "- how to visualise the optimization process" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -39,36 +57,6 @@ "\n", "from functools import partial\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "from qdax.baselines.nsga2 import (\n", " NSGA2\n", ")\n", diff --git a/examples/omgmega.ipynb b/examples/omgmega.ipynb index c09deefe..bde9f5ed 100644 --- a/examples/omgmega.ipynb +++ b/examples/omgmega.ipynb @@ -30,40 +30,28 @@ "metadata": {}, "outputs": [], "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import math\n", - "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", + "from IPython.display import clear_output\n", "\n", "try:\n", " import qdax\n", "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", " import qdax\n", "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import math\n", + "\n", "from qdax.core.map_elites import MAPElites\n", "from qdax.core.emitters.omg_mega_emitter import OMGMEGAEmitter\n", "from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids, MapElitesRepertoire\n", diff --git a/examples/pga_aurora.ipynb b/examples/pga_aurora.ipynb index 29f4cc74..56ed4f01 100644 --- a/examples/pga_aurora.ipynb +++ b/examples/pga_aurora.ipynb @@ -31,7 +31,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -46,37 +64,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", - "\n", "from qdax.core.aurora import AURORA\n", "from qdax.core.containers.unstructured_repertoire import UnstructuredRepertoire\n", "from qdax import environments\n", diff --git a/examples/pgame.ipynb b/examples/pgame.ipynb index d60e246a..31e4f831 100644 --- a/examples/pgame.ipynb +++ b/examples/pgame.ipynb @@ -30,7 +30,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -45,37 +63,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", - "\n", "from qdax.core.map_elites import MAPElites\n", "from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n", "from qdax import environments\n", diff --git a/examples/qdpg.ipynb b/examples/qdpg.ipynb index 2082dcaa..8c47ffe6 100644 --- a/examples/qdpg.ipynb +++ b/examples/qdpg.ipynb @@ -30,7 +30,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -45,36 +63,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "from qdax.core.containers.archive import score_euclidean_novelty\n", "from qdax.core.emitters.dpg_emitter import DiversityPGConfig\n", "from qdax.core.emitters.qdpg_emitter import QDPGEmitter, QDPGEmitterConfig\n", diff --git a/examples/sac_pbt.ipynb b/examples/sac_pbt.ipynb index c71559f7..915cc272 100644 --- a/examples/sac_pbt.ipynb +++ b/examples/sac_pbt.ipynb @@ -4,6 +4,25 @@ "cell_type": "code", "execution_count": null, "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", "metadata": { "jupyter": { "outputs_hidden": false @@ -19,36 +38,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "from brax.v1.io import html\n", "from IPython.display import HTML\n", "from tqdm import tqdm\n", @@ -61,7 +50,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1", + "id": "2", "metadata": { "jupyter": { "outputs_hidden": false @@ -78,7 +67,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2", + "id": "3", "metadata": { "jupyter": { "outputs_hidden": false @@ -98,7 +87,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3", + "id": "4", "metadata": { "jupyter": { "outputs_hidden": false @@ -139,7 +128,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": { "jupyter": { "outputs_hidden": false @@ -170,7 +159,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "6", "metadata": { "jupyter": { "outputs_hidden": false @@ -209,7 +198,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": { "jupyter": { "outputs_hidden": false @@ -232,7 +221,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": { "jupyter": { "outputs_hidden": false @@ -261,7 +250,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "9", "metadata": { "jupyter": { "outputs_hidden": false @@ -288,7 +277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "10", "metadata": { "jupyter": { "outputs_hidden": false @@ -306,7 +295,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "11", "metadata": { "jupyter": { "outputs_hidden": false @@ -331,7 +320,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", "metadata": { "jupyter": { "outputs_hidden": false @@ -357,7 +346,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "13", "metadata": { "jupyter": { "outputs_hidden": false @@ -379,7 +368,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "14", "metadata": { "jupyter": { "outputs_hidden": false @@ -402,7 +391,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "15", "metadata": { "jupyter": { "outputs_hidden": false @@ -442,7 +431,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "16", "metadata": { "pycharm": { "name": "#%%\n" @@ -456,7 +445,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "17", "metadata": { "pycharm": { "name": "#%%\n" @@ -472,7 +461,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "18", "metadata": { "pycharm": { "name": "#%%\n" @@ -486,7 +475,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "19", "metadata": { "pycharm": { "name": "#%%\n" @@ -504,7 +493,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "20", "metadata": { "pycharm": { "name": "#%%\n" @@ -518,7 +507,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "21", "metadata": { "pycharm": { "name": "#%%\n" @@ -547,7 +536,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "22", "metadata": { "pycharm": { "name": "#%%\n" @@ -564,7 +553,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "23", "metadata": { "pycharm": { "name": "#%%\n" diff --git a/examples/smerl.ipynb b/examples/smerl.ipynb index 5f08e582..0e332192 100644 --- a/examples/smerl.ipynb +++ b/examples/smerl.ipynb @@ -28,7 +28,25 @@ "metadata": {}, "outputs": [], "source": [ - "#@title Installs and Imports\n", + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", "!pip install ipympl |tail -n 1\n", "# %matplotlib widget\n", "# from google.colab import output\n", @@ -42,37 +60,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", - "\n", "from qdax import environments\n", "from qdax.baselines.diayn_smerl import DIAYNSMERL, DiaynSmerlConfig, DiaynTrainingState\n", "from qdax.core.neuroevolution.buffers.buffer import QDTransition\n", diff --git a/examples/td3_pbt.ipynb b/examples/td3_pbt.ipynb index b3d2cbe1..d2d98f85 100644 --- a/examples/td3_pbt.ipynb +++ b/examples/td3_pbt.ipynb @@ -4,6 +4,25 @@ "cell_type": "code", "execution_count": null, "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output\n", + "\n", + "try:\n", + " import qdax\n", + "except:\n", + " print(\"QDax not found. Installing...\")\n", + " !pip install qdax[cuda12]\n", + " import qdax\n", + "\n", + "clear_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", "metadata": { "pycharm": { "name": "#%%\n" @@ -16,36 +35,6 @@ "import jax\n", "import jax.numpy as jnp\n", "\n", - "try:\n", - " import brax\n", - "except:\n", - " !pip install git+https://github.com/google/brax.git@v0.10.4 |tail -n 1\n", - " import brax\n", - "\n", - "try:\n", - " import flax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/google/flax.git@v0.8.5 |tail -n 1\n", - " import flax\n", - "\n", - "try:\n", - " import chex\n", - "except:\n", - " !pip install --no-deps git+https://github.com/deepmind/chex.git@v0.1.86 |tail -n 1\n", - " import chex\n", - "\n", - "try:\n", - " import jumanji\n", - "except:\n", - " !pip install \"jumanji==0.3.1\"\n", - " import jumanji\n", - "\n", - "try:\n", - " import qdax\n", - "except:\n", - " !pip install --no-deps git+https://github.com/adaptive-intelligent-robotics/QDax@main |tail -n 1\n", - " import qdax\n", - "\n", "from tqdm import tqdm\n", "\n", "from qdax import environments\n", @@ -56,7 +45,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1", + "id": "2", "metadata": { "pycharm": { "name": "#%%\n" @@ -70,7 +59,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2", + "id": "3", "metadata": { "pycharm": { "name": "#%%\n" @@ -87,7 +76,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3", + "id": "4", "metadata": { "pycharm": { "name": "#%%\n" @@ -124,7 +113,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": { "pycharm": { "name": "#%%\n" @@ -152,7 +141,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "6", "metadata": { "pycharm": { "name": "#%%\n" @@ -183,7 +172,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": { "pycharm": { "name": "#%%\n" @@ -203,7 +192,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "8", "metadata": { "pycharm": { "name": "#%%\n" @@ -227,7 +216,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "9", "metadata": { "pycharm": { "name": "#%%\n" @@ -251,7 +240,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "10", "metadata": { "pycharm": { "name": "#%%\n" @@ -266,7 +255,7 @@ { "cell_type": "code", "execution_count": null, - "id": "10", + "id": "11", "metadata": { "pycharm": { "name": "#%%\n" @@ -288,7 +277,7 @@ { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "12", "metadata": { "pycharm": { "name": "#%%\n" @@ -311,7 +300,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "13", "metadata": { "pycharm": { "name": "#%%\n" @@ -330,7 +319,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "14", "metadata": { "pycharm": { "name": "#%%\n" @@ -350,7 +339,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "15", "metadata": { "pycharm": { "name": "#%%\n" @@ -387,7 +376,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15", + "id": "16", "metadata": { "pycharm": { "name": "#%%\n"