diff --git a/cleanrl/ppo_atari_envpool_xla_jax_scan.py b/cleanrl/ppo_atari_envpool_xla_jax_scan.py index 9cbbe67f..7f374cef 100644 --- a/cleanrl/ppo_atari_envpool_xla_jax_scan.py +++ b/cleanrl/ppo_atari_envpool_xla_jax_scan.py @@ -323,8 +323,10 @@ def compute_gae( _, advantages = jax.lax.scan( compute_gae_once, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True ) - storage = storage.replace(advantages=advantages) - storage = storage.replace(returns=storage.advantages + storage.values) + storage = storage.replace( + advantages=advantages, + returns=storage.advantages + storage.values, + ) return storage def ppo_loss(params, x, a, logp, mb_advantages, mb_returns): diff --git a/tests/test_jax_compute_gae.py b/tests/test_jax_compute_gae.py index de0763fa..9f4f0f73 100644 --- a/tests/test_jax_compute_gae.py +++ b/tests/test_jax_compute_gae.py @@ -39,8 +39,10 @@ def compute_gae_scan( _, advantages = jax.lax.scan( compute_gae_once_fn, advantages, (dones[1:], values[1:], values[:-1], storage.rewards), reverse=True ) - storage = storage.replace(advantages=advantages) - storage = storage.replace(returns=storage.advantages + storage.values) + storage = storage.replace( + advantages=advantages, + returns=storage.advantages + storage.values, + ) return storage def compute_gae_python_loop(