Skip to content

Commit 07743f6

Browse files
jinPreludekhkim
andauthored
Add R2D1 agents (#248)
* Add R2D1 DQNAgent Change grandiosely used functions Fix zero-padding & torch contiguous Fix zero-padding & Change indices sampling function Change hyperparameters Remove redundant codes Add CNN compatibility to R2D1Agent Remove redundant code Implement rlpyt forward style Add previous_action & previous_reward GRU input structure Fix error Fix prev_action bug & Use make_one_hot function Fix error Update descriptions & move leading_dims functions to helper_functions.py Move valid_from_done from R2D1Loss to helper_functions.py Fix parameters r2d1_iqn loss & agent Fix GRUBrain compatible with c51 Add R2D1C51Loss Add r2d1_c51 configs Fix priority > 0 assert error Change parameters * Fix descriptions & Rename R2D1 to R2D1DQN Change parameters Add total_step to wandb log Add upndown env & configs Fix test score Fix test score sum to mean Add total step to recurrent dqn_agent Fix test log position Add framestack argument Remove upndown environment Fix no_framestack argument Add r2d1 resnet configs Delete lunarlander iqn & Fix R2D1C51 lunarlander config description Fix configs Change total_step count startpoint after warmup Chage test startpoint Fix epsilon decay Change r2d1 agent epsilon_decay Fix several issues commented * Rebase commit to a302479 * Rebase commit to 990c78a * Fix several issues commented * Delete PrioritizedRecurrentReplayBuffer * Resolved issues commented * Fix issues commented * Change R2D1Learner parent class & Add descriptions * Fix issues commented * Change descriptions * Delete unnecessary configs * Change descriptions due to the length limit * Fix several issues commented * Fix contiguous issues * Add __init__.py in recurrent * Rebase commit to 815a1ca * Use no grad tensor to select action * Modify documentation * Add torch.no_grad to select action in other class * Add R2D1DQN ResNet config * Fix R2D1DQN ResNet config * Merge recurrent_replay_buffer into replay_bufer * Add R2D1 on readme * Remove off-framestack explanation * Change r2d1 configs' framestack 1 to 4 Co-authored-by: khkim <kh.kim@medipixel.io>
1 parent 815a1ca commit 07743f6

32 files changed

+1651
-78
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ This project follows the [all-contributors](https://github.com/all-contributors/
6363
7. [Rainbow DQN](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/dqn)
6464
8. [Rainbow IQN (without DuelingNet)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/dqn) - DuelingNet [degrades performance](https://github.com/medipixel/rl_algorithms/pull/137)
6565
9. Rainbow IQN (with [ResNet](https://github.com/medipixel/rl_algorithms/blob/master/rl_algorithms/common/networks/backbones/resnet.py))
66+
10. [Recurrent Replay DQN (R2D1)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/recurrent/dqn_agent.py)
6667

6768

6869
## Performance
@@ -204,6 +205,7 @@ python <run-file> -h
204205
- Start rendering after the number of episodes.
205206
- `--load-from <save-file-path>`
206207
- Load the saved models and optimizers at the beginning.
208+
207209
208210
#### Show feature map with Grad-CAM
209211
You can show a feature map that the trained agent extract using **[Grad-CAM(Gradient-weighted Class Activation Mapping)](https://arxiv.org/pdf/1610.02391.pdf)**. Grad-CAM is a way of combining feature maps using the gradient signal, and produce a coarse localization map of the important regions in the image. You can use it by adding [Grad-CAM config](https://github.com/medipixel/rl_algorithms/blob/master/configs/pong_no_frameskip_v4/dqn.py#L39) and `--grad-cam` flag when you run. For example:
@@ -249,3 +251,4 @@ This won't be frequently updated.
249251
16. [W. Dabney et al., "Implicit Quantile Networks for Distributional Reinforcement Learning." arXiv preprint arXiv:1806.06923, 2018.](https://arxiv.org/pdf/1806.06923.pdf)
250252
17. [Ramprasaath R. Selvaraju et al., "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." arXiv preprint arXiv:1610.02391, 2016.](https://arxiv.org/pdf/1610.02391.pdf)
251253
18. [Kaiming He et al., "Deep Residual Learning for Image Recognition." arXiv preprint arXiv:1512.03385, 2015.](https://arxiv.org/pdf/1512.03385)
254+
19. [Steven Kapturowski et al., "Recurrent Experience Replay in Distributed Reinforcement Learning." in International Conference on Learning Representations https://openreview.net/forum?id=r1lyTjAqYX, 2019.](https://openreview.net/forum?id=r1lyTjAqYX)

configs/lunarlander_v2/r2d1.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Config for R2D1 on LunarLander-v2.
2+
3+
- Author: Euijin Jeong
4+
- Contact: euijin.jeong@medipixel.io
5+
"""
6+
from rl_algorithms.common.helper_functions import identity
7+
8+
agent = dict(
9+
type="R2D1Agent",
10+
hyper_params=dict(
11+
gamma=0.99,
12+
tau=5e-3,
13+
buffer_size=int(1e4), # openai baselines: int(1e4)
14+
batch_size=64, # openai baselines: 32
15+
update_starts_from=int(1e3), # openai baselines: int(1e4)
16+
multiple_update=1, # multiple learning updates
17+
train_freq=1, # in openai baselines, train_freq = 4
18+
gradient_clip=10.0, # dueling: 10.0
19+
n_step=3,
20+
w_n_step=1.0,
21+
w_q_reg=0.0,
22+
per_alpha=0.6, # openai baselines: 0.6
23+
per_beta=0.4,
24+
per_eps=1e-6,
25+
# R2D1
26+
sequence_size=32,
27+
overlap_size=16,
28+
loss_type=dict(type="R2D1C51Loss"),
29+
# Epsilon Greedy
30+
max_epsilon=1.0,
31+
min_epsilon=0.01, # openai baselines: 0.01
32+
epsilon_decay=2e-5, # openai baselines: 1e-7 / 1e-1
33+
),
34+
learner_cfg=dict(
35+
type="R2D1Learner",
36+
backbone=dict(),
37+
gru=dict(rnn_hidden_size=64, burn_in_step=16,),
38+
head=dict(
39+
type="C51DuelingMLP",
40+
configs=dict(
41+
hidden_sizes=[128, 64],
42+
v_min=-300,
43+
v_max=300,
44+
atom_size=51,
45+
output_activation=identity,
46+
# NoisyNet
47+
use_noisy_net=False,
48+
),
49+
),
50+
optim_cfg=dict(lr_dqn=1e-4, weight_decay=1e-7, adam_eps=1e-8),
51+
),
52+
)

configs/pong_no_frameskip_v4/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Config for DQN on Pong-No_FrameSkip-v4.
1+
"""Config for DQN(IQN) on Pong-No_FrameSkip-v4.
22
33
- Author: Kyunghwan Kim
44
- Contact: kh.kim@medipixel.io

configs/pong_no_frameskip_v4/dqn_resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Config for DQN on Pong-No_FrameSkip-v4.
1+
"""Config for DQN(IQN with ResNet) on Pong-No_FrameSkip-v4.
22
33
- Author: Kyunghwan Kim
44
- Contact: kh.kim@medipixel.io

configs/pong_no_frameskip_v4/r2d1.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Config for R2D1 on PongNoFrameskip-v4.
2+
3+
- Author: Euijin Jeong
4+
- Contact: euijin.jeong@medipixel.io
5+
"""
6+
from rl_algorithms.common.helper_functions import identity
7+
8+
agent = dict(
9+
type="R2D1Agent",
10+
hyper_params=dict(
11+
gamma=0.99,
12+
tau=5e-3,
13+
buffer_size=int(4e3), # openai baselines: int(1e4)
14+
batch_size=32, # openai baselines: 32
15+
update_starts_from=int(4e3), # openai baselines: int(1e4)
16+
multiple_update=1, # multiple learning updates
17+
train_freq=4, # in openai baselines, train_freq = 4
18+
gradient_clip=10.0, # dueling: 10.0
19+
n_step=5,
20+
w_n_step=1.0,
21+
w_q_reg=0.0,
22+
per_alpha=0.6, # openai baselines: 0.6
23+
per_beta=0.4,
24+
per_eps=1e-6,
25+
# R2D1
26+
sequence_size=20,
27+
overlap_size=10,
28+
loss_type=dict(type="R2D1DQNLoss"),
29+
# Epsilon Greedy
30+
max_epsilon=1.0,
31+
min_epsilon=0.01, # openai baselines: 0.01
32+
epsilon_decay=3e-6, # openai baselines: 1e-7 / 1e-1
33+
# grad_cam
34+
grad_cam_layer_list=[
35+
"backbone.cnn.cnn_0.cnn",
36+
"backbone.cnn.cnn_1.cnn",
37+
"backbone.cnn.cnn_2.cnn",
38+
],
39+
),
40+
learner_cfg=dict(
41+
type="R2D1Learner",
42+
backbone=dict(
43+
type="CNN",
44+
configs=dict(
45+
input_sizes=[4, 32, 64],
46+
output_sizes=[32, 64, 64],
47+
kernel_sizes=[8, 4, 3],
48+
strides=[4, 2, 1],
49+
paddings=[1, 0, 0],
50+
),
51+
),
52+
gru=dict(rnn_hidden_size=512, burn_in_step=10,),
53+
head=dict(
54+
type="DuelingMLP",
55+
configs=dict(
56+
hidden_sizes=[512], use_noisy_net=False, output_activation=identity,
57+
),
58+
),
59+
optim_cfg=dict(
60+
lr_dqn=1e-4, # dueling: 6.25e-5, openai baselines: 1e-4
61+
weight_decay=0.0, # this makes saturation in cnn weights
62+
adam_eps=1e-8, # rainbow: 1.5e-4, openai baselines: 1e-8
63+
),
64+
),
65+
)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Config for R2D1DQN on PongNoFrameskip-v4.
2+
- Author: Kyunghwan Kim, Euijin Jeong
3+
- Contact: kh.kim@medipixel.io, euijin.jeong@medipixel.io
4+
"""
5+
from rl_algorithms.common.helper_functions import identity
6+
7+
agent = dict(
8+
type="R2D1Agent",
9+
hyper_params=dict(
10+
gamma=0.99,
11+
tau=5e-3,
12+
buffer_size=int(4e3), # openai baselines: int(1e4)
13+
batch_size=16, # openai baselines: 32
14+
update_starts_from=int(4e3), # openai baselines: int(1e4)
15+
multiple_update=1, # multiple learning updates
16+
train_freq=4, # in openai baselines, train_freq = 4
17+
gradient_clip=10.0, # dueling: 10.0
18+
n_step=5,
19+
w_n_step=1.0,
20+
w_q_reg=0.0,
21+
per_alpha=0.6, # openai baselines: 0.6
22+
per_beta=0.4,
23+
per_eps=1e-6,
24+
# R2D1
25+
sequence_size=20,
26+
overlap_size=10,
27+
loss_type=dict(type="R2D1DQNLoss"),
28+
# Epsilon Greedy
29+
max_epsilon=1.0,
30+
min_epsilon=0.01, # openai baselines: 0.01
31+
epsilon_decay=3e-6, # openai baselines: 1e-7 / 1e-1
32+
# grad_cam
33+
grad_cam_layer_list=[
34+
"backbone.layer1.0.conv2",
35+
"backbone.layer2.0.shortcut.0",
36+
"backbone.layer3.0.shortcut.0",
37+
"backbone.layer4.0.shortcut.0",
38+
"backbone.conv_out",
39+
],
40+
),
41+
learner_cfg=dict(
42+
type="R2D1Learner",
43+
backbone=dict(
44+
type="ResNet",
45+
configs=dict(
46+
use_bottleneck=False,
47+
num_blocks=[1, 1, 1, 1],
48+
block_output_sizes=[32, 32, 64, 64],
49+
block_strides=[1, 2, 2, 2],
50+
first_input_size=4,
51+
first_output_size=32,
52+
expansion=1,
53+
channel_compression=4, # compression ratio
54+
),
55+
),
56+
gru=dict(rnn_hidden_size=512, burn_in_step=10,),
57+
head=dict(
58+
type="DuelingMLP",
59+
configs=dict(
60+
hidden_sizes=[512], use_noisy_net=False, output_activation=identity,
61+
),
62+
),
63+
optim_cfg=dict(
64+
lr_dqn=1e-4, # dueling: 6.25e-5, openai baselines: 1e-4
65+
weight_decay=0.0, # this makes saturation in cnn weights
66+
adam_eps=1e-8, # rainbow: 1.5e-4, openai baselines: 1e-8
67+
),
68+
),
69+
)

rl_algorithms/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from .fd.sac_learner import SACfDLearner
2020
from .ppo.agent import PPOAgent
2121
from .ppo.learner import PPOLearner
22+
from .recurrent.dqn_agent import R2D1Agent
23+
from .recurrent.learner import R2D1Learner
24+
from .recurrent.losses import R2D1C51Loss, R2D1DQNLoss, R2D1IQNLoss
2225
from .registry import build_agent, build_her
2326
from .sac.agent import SACAgent
2427
from .sac.learner import SACLearner
@@ -33,6 +36,7 @@
3336
"DQNAgent",
3437
"DDPGfDAgent",
3538
"DQfDAgent",
39+
"R2D1Agent",
3640
"SACfDAgent",
3741
"PPOAgent",
3842
"SACAgent",
@@ -48,6 +52,7 @@
4852
"PPOLearner",
4953
"SACLearner",
5054
"TD3Learner",
55+
"R2D1Learner",
5156
"LunarLanderContinuousHER",
5257
"ReacherHER",
5358
"build_agent",
@@ -57,4 +62,7 @@
5762
"IQNLoss",
5863
"C51Loss",
5964
"DQNLoss",
65+
"R2D1IQNLoss",
66+
"R2D1C51Loss",
67+
"R2D1DQNLoss",
6068
]

rl_algorithms/common/abstract/agent.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
self.log_cfg = log_cfg
5252
self.log_cfg.env_name = env.spec.id if env.spec is not None else env.name
5353

54+
self.total_step = 0
5455
self.learner = None
5556

5657
if isinstance(env.action_space, Discrete):
@@ -83,6 +84,7 @@ def set_wandb(self):
8384
name=f"{self.log_cfg.agent}/{self.log_cfg.curr_time}",
8485
)
8586
wandb.config.update(vars(self.args))
87+
wandb.config.update(self.hyper_params)
8688
shutil.copy(self.args.cfg_path, os.path.join(wandb.run.dir, "config.py"))
8789

8890
def interim_test(self):
@@ -122,6 +124,7 @@ def _test(self, interim_test: bool = False):
122124
else:
123125
test_num = self.args.episode_num
124126

127+
score_list = []
125128
for i_episode in range(test_num):
126129
state = self.env.reset()
127130
done = False
@@ -142,9 +145,15 @@ def _test(self, interim_test: bool = False):
142145
print(
143146
"[INFO] test %d\tstep: %d\ttotal score: %d" % (i_episode, step, score)
144147
)
148+
score_list.append(score)
145149

146-
if self.args.log:
147-
wandb.log({"test score": score})
150+
if self.args.log:
151+
wandb.log(
152+
{
153+
"test score": round(sum(score_list) / len(score_list), 2),
154+
"test total step": self.total_step,
155+
}
156+
)
148157

149158
def test_with_gradcam(self):
150159
"""Test agent with Grad-CAM."""

0 commit comments

Comments
 (0)