Skip to content

Commit

Permalink
Rename all arguments named "unused" by "_"
Browse files Browse the repository at this point in the history
  • Loading branch information
maxencefaldor committed Sep 22, 2024
1 parent cdb9573 commit 83d4201
Show file tree
Hide file tree
Showing 12 changed files with 16 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@
"centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n",
"\n",
"@jax.jit\n",
"def update_scan_fn(carry: Any, unused: Any) -> Any:\n",
"def update_scan_fn(carry: Any, _: Any) -> Any:\n",
" \"\"\"Scan the update function.\"\"\"\n",
" repertoire, key, aurora_extra_info = carry\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/pga_aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@
"centroids = jnp.zeros(shape=(num_centroids, aurora_dims))\n",
"\n",
"@jax.jit\n",
"def update_scan_fn(carry: Any, unused: Any) -> Any:\n",
"def update_scan_fn(carry: Any, _: Any) -> Any:\n",
" \"\"\"Scan the update function.\"\"\"\n",
" (\n",
" repertoire,\n",
Expand Down
4 changes: 2 additions & 2 deletions qdax/baselines/genetic_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,15 @@ def update(
def scan_update(
self,
carry: Tuple[GARepertoire, Optional[EmitterState], RNGKey],
unused: Any,
_: Any,
) -> Tuple[Tuple[GARepertoire, Optional[EmitterState], RNGKey], Metrics]:
"""Rewrites the update function in a way that makes it compatible with the
jax.lax.scan primitive.
Args:
carry: a tuple containing the repertoire, the emitter state and a
random key.
unused: unused element, necessary to respect jax.lax.scan API.
_: unused element, necessary to respect jax.lax.scan API.
Returns:
The updated repertoire and emitter state, with a new random key and metrics.
Expand Down
2 changes: 1 addition & 1 deletion qdax/core/containers/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def top_1(data: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
return data, value, indice

def scannable_top_1(
carry: jnp.ndarray, unused: Any
carry: jnp.ndarray, _: Any
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
data = carry
data, value, indice = top_1(data)
Expand Down
2 changes: 1 addition & 1 deletion qdax/core/distributed_map_elites.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_distributed_update_fn(
@jax.jit
def _scan_update(
carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
unused: Any,
_: Any,
) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
"""Rewrites the update function in a way that makes it compatible with the
jax.lax.scan primitive."""
Expand Down
6 changes: 2 additions & 4 deletions qdax/core/emitters/cma_pool_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def batch_size(self) -> int:
Returns:
the batch size emitted by the emitter.
"""
return self._emitter.batch_size
return self._emitter.batch_size # type: ignore

@partial(jax.jit, static_argnames=("self",))
def init(
Expand All @@ -69,9 +69,7 @@ def init(
The initial state of the emitter.
"""

def scan_emitter_init(
carry: RNGKey, unused: Any
) -> Tuple[RNGKey, CMAEmitterState]:
def scan_emitter_init(carry: RNGKey, _: Any) -> Tuple[RNGKey, CMAEmitterState]:
key = carry
key, subkey = jax.random.split(key)
emitter_state = self._emitter.init(
Expand Down
4 changes: 2 additions & 2 deletions qdax/core/emitters/qpg_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def state_update(
emitter_state = emitter_state.replace(replay_buffer=replay_buffer)

def scan_train_critics(
carry: QualityPGEmitterState, unused: Any
carry: QualityPGEmitterState, _: Any
) -> Tuple[QualityPGEmitterState, Any]:
emitter_state = carry
new_emitter_state = self._train_critics(emitter_state)
Expand Down Expand Up @@ -501,7 +501,7 @@ def _mutation_function_pg(

def scan_train_policy(
carry: Tuple[QualityPGEmitterState, Genotype, optax.OptState],
unused: Any,
_: Any,
) -> Tuple[Tuple[QualityPGEmitterState, Genotype, optax.OptState], Any]:
emitter_state, policy_params, policy_optimizer_state = carry
(
Expand Down
4 changes: 2 additions & 2 deletions qdax/core/map_elites.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ def update(
def scan_update(
self,
carry: Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey],
unused: Any,
_: Any,
) -> Tuple[Tuple[MapElitesRepertoire, Optional[EmitterState], RNGKey], Metrics]:
"""Rewrites the update function in a way that makes it compatible with the
jax.lax.scan primitive.
Args:
carry: a tuple containing the repertoire, the emitter state and a
random key.
unused: unused element, necessary to respect jax.lax.scan API.
_: unused element, necessary to respect jax.lax.scan API.
Returns:
The updated repertoire and emitter state, with a new random key and metrics.
Expand Down
2 changes: 1 addition & 1 deletion qdax/utils/uncertainty_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def _perform_reevaluation(

def _sampling_scan(
key: RNGKey,
unused: Tuple[()],
_: Tuple[()],
) -> Tuple[Tuple[RNGKey], Tuple[Fitness, Descriptor, ExtraScores]]:
key, subkey = jax.random.split(key)
(
Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/mees_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict:
repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey)

@jax.jit
def update_scan_fn(carry: Any, unused: Any) -> Any:
def update_scan_fn(carry: Any, _: Any) -> Any:
# iterate over grid
repertoire, emitter_state, key = carry
key, subkey = jax.random.split(key)
Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/pgame_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict:
repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey)

@jax.jit
def update_scan_fn(carry: Any, unused: Any) -> Any:
def update_scan_fn(carry: Any, _: Any) -> Any:
# iterate over grid
repertoire, emitter_state, key = carry
key, subkey = jax.random.split(key)
Expand Down
2 changes: 1 addition & 1 deletion tests/baselines_test/qdpg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def metrics_function(repertoire: MapElitesRepertoire) -> Dict:
repertoire, emitter_state = map_elites.init(init_variables, centroids, subkey)

@jax.jit
def update_scan_fn(carry: Any, unused: Any) -> Any:
def update_scan_fn(carry: Any, _: Any) -> Any:
# iterate over grid
repertoire, emitter_state, key = carry
key, subkey = jax.random.split(key)
Expand Down

0 comments on commit 83d4201

Please sign in to comment.