Skip to content

Commit

Permalink
Add R2D1 agents (#248)
Browse files Browse the repository at this point in the history
* Add R2D1 DQNAgent

Change grandiosely used functions

Fix zero-padding & torch contiguous

Fix zero-padding & Change indices sampling function

Change hyperparameters

Remove redundant codes

Add CNN compatibility to R2D1Agent

Remove redundant code

Implement rlpyt forward style

Add previous_action & previous_reward GRU input structure

Fix error

Fix prev_action bug & Use make_one_hot function

Fix error

Update descriptions & move leading_dims functions to helper_functions.py

Move valid_from_done from R2D1Loss to helper_functions.py

Fix parameters

r2d1_iqn loss & agent

Fix GRUBrain compatible with c51

Add R2D1C51Loss

Add r2d1_c51 configs

Fix priority > 0 assert error

Change parameters

* Fix descriptions & Rename R2D1 to R2D1DQN

Change parameters

Add total_step to wandb log

Add upndown env & configs

Fix test score

Fix test score sum to mean

Add total step to recurrent dqn_agent

Fix test log position

Add framestack argument

Remove upndown environment

Fix no_framestack argument

Add r2d1 resnet configs

Delete lunarlander iqn & Fix R2D1C51 lunarlander config description

Fix configs

Change total_step count startpoint after warmup

Chage test startpoint

Fix epsilon decay

Change r2d1 agent epsilon_decay

Fix several issues commented

* Rebase commit to a302479

* Rebase commit to 990c78a

* Fix several issues commented

* Delete PrioritizedRecurrentReplayBuffer

* Resolved issues commented

* Fix issues commented

* Change R2D1Learner parent class & Add descriptions

* Fix issues commented

* Change descriptions

* Delete unnecessary configs

* Change descriptions due to the length limit

* Fix several issues commented

* Fix contiguous issues

* Add __init__.py in recurrent

* Rebase commit to 815a1ca

* Use no grad tensor to select action

* Modify documentation

* Add torch.no_grad to select action in other class

* Add R2D1DQN ResNet config

* Fix R2D1DQN ResNet config

* Merge recurrent_replay_buffer into replay_bufer

* Add R2D1 on readme

* Remove off-framestack explanation

* Change r2d1 configs' framestack 1 to 4

Co-authored-by: khkim <kh.kim@medipixel.io>
  • Loading branch information
jinPrelude and khkim authored Jun 23, 2020
1 parent 815a1ca commit 07743f6
Show file tree
Hide file tree
Showing 32 changed files with 1,651 additions and 78 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ This project follows the [all-contributors](https://github.com/all-contributors/
7. [Rainbow DQN](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/dqn)
8. [Rainbow IQN (without DuelingNet)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/dqn) - DuelingNet [degrades performance](https://github.com/medipixel/rl_algorithms/pull/137)
9. Rainbow IQN (with [ResNet](https://github.com/medipixel/rl_algorithms/blob/master/rl_algorithms/common/networks/backbones/resnet.py))
10. [Recurrent Replay DQN (R2D1)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/recurrent/dqn_agent.py)


## Performance
Expand Down Expand Up @@ -204,6 +205,7 @@ python <run-file> -h
- Start rendering after the number of episodes.
- `--load-from <save-file-path>`
- Load the saved models and optimizers at the beginning.


#### Show feature map with Grad-CAM
You can show a feature map that the trained agent extract using **[Grad-CAM(Gradient-weighted Class Activation Mapping)](https://arxiv.org/pdf/1610.02391.pdf)**. Grad-CAM is a way of combining feature maps using the gradient signal, and produce a coarse localization map of the important regions in the image. You can use it by adding [Grad-CAM config](https://github.com/medipixel/rl_algorithms/blob/master/configs/pong_no_frameskip_v4/dqn.py#L39) and `--grad-cam` flag when you run. For example:
Expand Down Expand Up @@ -249,3 +251,4 @@ This won't be frequently updated.
16. [W. Dabney et al., "Implicit Quantile Networks for Distributional Reinforcement Learning." arXiv preprint arXiv:1806.06923, 2018.](https://arxiv.org/pdf/1806.06923.pdf)
17. [Ramprasaath R. Selvaraju et al., "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." arXiv preprint arXiv:1610.02391, 2016.](https://arxiv.org/pdf/1610.02391.pdf)
18. [Kaiming He et al., "Deep Residual Learning for Image Recognition." arXiv preprint arXiv:1512.03385, 2015.](https://arxiv.org/pdf/1512.03385)
19. [Steven Kapturowski et al., "Recurrent Experience Replay in Distributed Reinforcement Learning." in International Conference on Learning Representations https://openreview.net/forum?id=r1lyTjAqYX, 2019.](https://openreview.net/forum?id=r1lyTjAqYX)
52 changes: 52 additions & 0 deletions configs/lunarlander_v2/r2d1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Config for R2D1 on LunarLander-v2.
- Author: Euijin Jeong
- Contact: euijin.jeong@medipixel.io
"""
from rl_algorithms.common.helper_functions import identity

agent = dict(
type="R2D1Agent",
hyper_params=dict(
gamma=0.99,
tau=5e-3,
buffer_size=int(1e4), # openai baselines: int(1e4)
batch_size=64, # openai baselines: 32
update_starts_from=int(1e3), # openai baselines: int(1e4)
multiple_update=1, # multiple learning updates
train_freq=1, # in openai baselines, train_freq = 4
gradient_clip=10.0, # dueling: 10.0
n_step=3,
w_n_step=1.0,
w_q_reg=0.0,
per_alpha=0.6, # openai baselines: 0.6
per_beta=0.4,
per_eps=1e-6,
# R2D1
sequence_size=32,
overlap_size=16,
loss_type=dict(type="R2D1C51Loss"),
# Epsilon Greedy
max_epsilon=1.0,
min_epsilon=0.01, # openai baselines: 0.01
epsilon_decay=2e-5, # openai baselines: 1e-7 / 1e-1
),
learner_cfg=dict(
type="R2D1Learner",
backbone=dict(),
gru=dict(rnn_hidden_size=64, burn_in_step=16,),
head=dict(
type="C51DuelingMLP",
configs=dict(
hidden_sizes=[128, 64],
v_min=-300,
v_max=300,
atom_size=51,
output_activation=identity,
# NoisyNet
use_noisy_net=False,
),
),
optim_cfg=dict(lr_dqn=1e-4, weight_decay=1e-7, adam_eps=1e-8),
),
)
2 changes: 1 addition & 1 deletion configs/pong_no_frameskip_v4/dqn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Config for DQN on Pong-No_FrameSkip-v4.
"""Config for DQN(IQN) on Pong-No_FrameSkip-v4.
- Author: Kyunghwan Kim
- Contact: kh.kim@medipixel.io
Expand Down
2 changes: 1 addition & 1 deletion configs/pong_no_frameskip_v4/dqn_resnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Config for DQN on Pong-No_FrameSkip-v4.
"""Config for DQN(IQN with ResNet) on Pong-No_FrameSkip-v4.
- Author: Kyunghwan Kim
- Contact: kh.kim@medipixel.io
Expand Down
65 changes: 65 additions & 0 deletions configs/pong_no_frameskip_v4/r2d1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Config for R2D1 on PongNoFrameskip-v4.
- Author: Euijin Jeong
- Contact: euijin.jeong@medipixel.io
"""
from rl_algorithms.common.helper_functions import identity

agent = dict(
type="R2D1Agent",
hyper_params=dict(
gamma=0.99,
tau=5e-3,
buffer_size=int(4e3), # openai baselines: int(1e4)
batch_size=32, # openai baselines: 32
update_starts_from=int(4e3), # openai baselines: int(1e4)
multiple_update=1, # multiple learning updates
train_freq=4, # in openai baselines, train_freq = 4
gradient_clip=10.0, # dueling: 10.0
n_step=5,
w_n_step=1.0,
w_q_reg=0.0,
per_alpha=0.6, # openai baselines: 0.6
per_beta=0.4,
per_eps=1e-6,
# R2D1
sequence_size=20,
overlap_size=10,
loss_type=dict(type="R2D1DQNLoss"),
# Epsilon Greedy
max_epsilon=1.0,
min_epsilon=0.01, # openai baselines: 0.01
epsilon_decay=3e-6, # openai baselines: 1e-7 / 1e-1
# grad_cam
grad_cam_layer_list=[
"backbone.cnn.cnn_0.cnn",
"backbone.cnn.cnn_1.cnn",
"backbone.cnn.cnn_2.cnn",
],
),
learner_cfg=dict(
type="R2D1Learner",
backbone=dict(
type="CNN",
configs=dict(
input_sizes=[4, 32, 64],
output_sizes=[32, 64, 64],
kernel_sizes=[8, 4, 3],
strides=[4, 2, 1],
paddings=[1, 0, 0],
),
),
gru=dict(rnn_hidden_size=512, burn_in_step=10,),
head=dict(
type="DuelingMLP",
configs=dict(
hidden_sizes=[512], use_noisy_net=False, output_activation=identity,
),
),
optim_cfg=dict(
lr_dqn=1e-4, # dueling: 6.25e-5, openai baselines: 1e-4
weight_decay=0.0, # this makes saturation in cnn weights
adam_eps=1e-8, # rainbow: 1.5e-4, openai baselines: 1e-8
),
),
)
69 changes: 69 additions & 0 deletions configs/pong_no_frameskip_v4/r2d1_dqn_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Config for R2D1DQN on PongNoFrameskip-v4.
- Author: Kyunghwan Kim, Euijin Jeong
- Contact: kh.kim@medipixel.io, euijin.jeong@medipixel.io
"""
from rl_algorithms.common.helper_functions import identity

agent = dict(
type="R2D1Agent",
hyper_params=dict(
gamma=0.99,
tau=5e-3,
buffer_size=int(4e3), # openai baselines: int(1e4)
batch_size=16, # openai baselines: 32
update_starts_from=int(4e3), # openai baselines: int(1e4)
multiple_update=1, # multiple learning updates
train_freq=4, # in openai baselines, train_freq = 4
gradient_clip=10.0, # dueling: 10.0
n_step=5,
w_n_step=1.0,
w_q_reg=0.0,
per_alpha=0.6, # openai baselines: 0.6
per_beta=0.4,
per_eps=1e-6,
# R2D1
sequence_size=20,
overlap_size=10,
loss_type=dict(type="R2D1DQNLoss"),
# Epsilon Greedy
max_epsilon=1.0,
min_epsilon=0.01, # openai baselines: 0.01
epsilon_decay=3e-6, # openai baselines: 1e-7 / 1e-1
# grad_cam
grad_cam_layer_list=[
"backbone.layer1.0.conv2",
"backbone.layer2.0.shortcut.0",
"backbone.layer3.0.shortcut.0",
"backbone.layer4.0.shortcut.0",
"backbone.conv_out",
],
),
learner_cfg=dict(
type="R2D1Learner",
backbone=dict(
type="ResNet",
configs=dict(
use_bottleneck=False,
num_blocks=[1, 1, 1, 1],
block_output_sizes=[32, 32, 64, 64],
block_strides=[1, 2, 2, 2],
first_input_size=4,
first_output_size=32,
expansion=1,
channel_compression=4, # compression ratio
),
),
gru=dict(rnn_hidden_size=512, burn_in_step=10,),
head=dict(
type="DuelingMLP",
configs=dict(
hidden_sizes=[512], use_noisy_net=False, output_activation=identity,
),
),
optim_cfg=dict(
lr_dqn=1e-4, # dueling: 6.25e-5, openai baselines: 1e-4
weight_decay=0.0, # this makes saturation in cnn weights
adam_eps=1e-8, # rainbow: 1.5e-4, openai baselines: 1e-8
),
),
)
8 changes: 8 additions & 0 deletions rl_algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from .fd.sac_learner import SACfDLearner
from .ppo.agent import PPOAgent
from .ppo.learner import PPOLearner
from .recurrent.dqn_agent import R2D1Agent
from .recurrent.learner import R2D1Learner
from .recurrent.losses import R2D1C51Loss, R2D1DQNLoss, R2D1IQNLoss
from .registry import build_agent, build_her
from .sac.agent import SACAgent
from .sac.learner import SACLearner
Expand All @@ -33,6 +36,7 @@
"DQNAgent",
"DDPGfDAgent",
"DQfDAgent",
"R2D1Agent",
"SACfDAgent",
"PPOAgent",
"SACAgent",
Expand All @@ -48,6 +52,7 @@
"PPOLearner",
"SACLearner",
"TD3Learner",
"R2D1Learner",
"LunarLanderContinuousHER",
"ReacherHER",
"build_agent",
Expand All @@ -57,4 +62,7 @@
"IQNLoss",
"C51Loss",
"DQNLoss",
"R2D1IQNLoss",
"R2D1C51Loss",
"R2D1DQNLoss",
]
13 changes: 11 additions & 2 deletions rl_algorithms/common/abstract/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
self.log_cfg = log_cfg
self.log_cfg.env_name = env.spec.id if env.spec is not None else env.name

self.total_step = 0
self.learner = None

if isinstance(env.action_space, Discrete):
Expand Down Expand Up @@ -83,6 +84,7 @@ def set_wandb(self):
name=f"{self.log_cfg.agent}/{self.log_cfg.curr_time}",
)
wandb.config.update(vars(self.args))
wandb.config.update(self.hyper_params)
shutil.copy(self.args.cfg_path, os.path.join(wandb.run.dir, "config.py"))

def interim_test(self):
Expand Down Expand Up @@ -122,6 +124,7 @@ def _test(self, interim_test: bool = False):
else:
test_num = self.args.episode_num

score_list = []
for i_episode in range(test_num):
state = self.env.reset()
done = False
Expand All @@ -142,9 +145,15 @@ def _test(self, interim_test: bool = False):
print(
"[INFO] test %d\tstep: %d\ttotal score: %d" % (i_episode, step, score)
)
score_list.append(score)

if self.args.log:
wandb.log({"test score": score})
if self.args.log:
wandb.log(
{
"test score": round(sum(score_list) / len(score_list), 2),
"test total step": self.total_step,
}
)

def test_with_gradcam(self):
"""Test agent with Grad-CAM."""
Expand Down
Loading

0 comments on commit 07743f6

Please sign in to comment.