Skip to content

Commit

Permalink
TD3: fixed dimension of clipped_noise for target actions, added noise…
Browse files Browse the repository at this point in the history
… … (vwxyzjn#281)

* TD3: fixed dimension of clippednoise for target actions, added noise scaling based on action_scale

* minor refactor

* update benchmark script

* update results

Co-authored-by: Costa Huang <costa.huang@outlook.com>
  • Loading branch information
dosssman and vwxyzjn authored Oct 19, 2022
1 parent 423650a commit 331cb39
Show file tree
Hide file tree
Showing 9 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions benchmark/td3.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
poetry install --with mujoco,pybullet
python -c "import mujoco_py"
OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
WANDB_TAGS=$(git describe --tags) OMP_NUM_THREADS=1 xvfb-run -a python -m cleanrl_utils.benchmark \
--env-ids HalfCheetah-v2 Walker2d-v2 Hopper-v2 InvertedPendulum-v2 Humanoid-v2 Pusher-v2 \
--command "poetry run python cleanrl/td3_continuous_action.py --track --capture-video" \
--num-seeds 3 \
--workers 3
--workers 1

poetry install --with mujoco,jax
poetry run pip install --upgrade "jax[cuda]==0.3.14" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand Down
6 changes: 3 additions & 3 deletions cleanrl/td3_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ def forward(self, x):
if global_step > args.learning_starts:
data = rb.sample(args.batch_size)
with torch.no_grad():
clipped_noise = (torch.randn_like(torch.Tensor(actions[0])) * args.policy_noise).clamp(
clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp(
-args.noise_clip, args.noise_clip
)
) * target_actor.action_scale

next_state_actions = (target_actor(data.next_observations) + clipped_noise.to(device)).clamp(
next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp(
envs.single_action_space.low[0], envs.single_action_space.high[0]
)
qf1_next_target = qf1_target(data.next_observations, next_state_actions)
Expand Down
12 changes: 6 additions & 6 deletions docs/rl-algorithms/td3.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ Below are the average episodic returns for [`td3_continuous_action.py`](https://

| Environment | [`td3_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/td3_continuous_action.py) | [`TD3.py`](https://github.com/sfujim/TD3/blob/master/TD3.py) (Fujimoto et al., 2018, Table 1)[^2] |
| ----------- | ----------- | ----------- |
| HalfCheetah | 9018.31 ± 1078.31 |9636.95 ± 859.065 |
| Walker2d | 4246.07 ± 1210.84 | 4682.82 ± 539.64 |
| Hopper | 3391.78 ± 232.21 | 3564.07 ± 114.74 |
| Humanoid | 4822.64 ± 321.85 | not available |
| Pusher | -42.24 ± 6.74 | not available |
| InvertedPendulum | 964.59 ± 43.91 | 1000.00 ± 0.00 |
| HalfCheetah | 9449.94 ± 1586.49 |9636.95 ± 859.065 |
| Walker2d | 3851.55 ± 335.29 | 4682.82 ± 539.64 |
| Hopper | 3162.21 ± 261.08 | 3564.07 ± 114.74 |
| Humanoid | 5011.05 ± 254.89 | not available |
| Pusher | -37.49 ± 10.22 | not available |
| InvertedPendulum | 996.81 ± 4.50 | 1000.00 ± 0.00 |



Expand Down
Binary file modified docs/rl-algorithms/td3/HalfCheetah-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/rl-algorithms/td3/Hopper-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/rl-algorithms/td3/Humanoid-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/rl-algorithms/td3/InvertedPendulum-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/rl-algorithms/td3/Pusher-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/rl-algorithms/td3/Walker2d-v2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 331cb39

Please sign in to comment.