Skip to content

Commit 9e897ad

Browse files
cyoon1729khkimMrSyee
authored
Incorporate distributed RL framework, Ape-X and Ape-X DQN (#246)
* Take context as init_communication input; all processes share the same context. * implement abstract classes for distributed and ApeX Learner wrapper * Implement params2numpy method that loads torch state_dict as array of np.ndarray. * add __init__ * implement worker as abstract class, not wrapper base class * Change apex_learner file name to learner. * Implement Ape-X worker and learner base classes * implement Ape-X DQN worker * Create base class for distributed architectures * Implement and test Ape-X DQN working on Pong * Accept current change (master) for PongNoFrameskip-v4 dqn config * Make env_info more explicit in run_pong script (accept incoming change) * Make learner return cpu state_dict (accept incoming change) * Fix minor errors * Implement ApeXWorker as a wrapper ApeXWorkerWrapper Implement Logger and test wandb functionality Add worker and logger render in argparse Implement load_param() method in logger and worker * Move num_workers to hyperparams, and add logger_interval to hyperparams. * Implement safe exit condition for all ray actors. * Change _init_communication -> init_communication and call outside of __init__ for all ApeX actors Implement test() in distributed architectures (load from checkpoint and run logger test()) * * Add documentation * Move collect_data from worker class to ApeX Wrapper * Change hyperparameters around * Add worker-verbose as argparse flag * * Move num_worker to hyper_param cfg * * Add author * Add separate integration test for ApeX * Add integration test flag to pong * argparse integration test flag store_false->store_true * Change default config to dqn. * * Log worker scores per update step on Wandb. * Modify integration test * Modify apex buffer config for integration test * Change distributed directory structure * Add documentation * Modify readme.md * Modify readme.md * Add Ape-X to README. * Add description about args flags for distributed training. Co-authored-by: khkim <kh.kim@medipixel.io> Co-authored-by: Kyunghwan Kim <khsyee@gmail.com>
1 parent 07743f6 commit 9e897ad

28 files changed

+1538
-49
lines changed

LICENSE.md

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# Our repository
2-
MIT License
1+
The MIT License (MIT)
32

43
Copyright (c) 2019 Medipixel
54

@@ -20,16 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
2019
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2120
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2221
SOFTWARE.
23-
24-
# Mujoco models
25-
This work is derived from [MuJuCo models](http://www.mujoco.org/forum/index.php?resources/) used under the following license:
26-
```
27-
This file is part of MuJoCo.
28-
Copyright 2009-2015 Roboti LLC.
29-
Mujoco :: Advanced physics simulation engine
30-
Source : www.roboti.us
31-
Version : 1.31
32-
Released : 23Apr16
33-
Author :: Vikash Kumar
34-
Contacts : kumar@roboti.us
35-
```

README.md

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
<p align="center">
22
<img src="https://user-images.githubusercontent.com/17582508/52845370-4a930200-314a-11e9-9889-e00007043872.jpg" align="center">
33

4-
[![CircleCI](https://circleci.com/gh/circleci/circleci-docs.svg?style=shield)](https://circleci.com/gh/medipixel)
54
[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/medipixel/rl_algorithms.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/medipixel/rl_algorithms/context:python)
6-
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
5+
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
6+
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
77
[![All Contributors](https://img.shields.io/badge/all_contributors-7-orange.svg?style=flat-square)](#contributors-)
88
<!-- ALL-CONTRIBUTORS-BADGE:END -->
99

@@ -63,8 +63,8 @@ This project follows the [all-contributors](https://github.com/all-contributors/
6363
7. [Rainbow DQN](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/dqn)
6464
8. [Rainbow IQN (without DuelingNet)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/dqn) - DuelingNet [degrades performance](https://github.com/medipixel/rl_algorithms/pull/137)
6565
9. Rainbow IQN (with [ResNet](https://github.com/medipixel/rl_algorithms/blob/master/rl_algorithms/common/networks/backbones/resnet.py))
66-
10. [Recurrent Replay DQN (R2D1)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/recurrent/dqn_agent.py)
67-
66+
10. [Recurrent Replay DQN (R2D1)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/recurrent)
67+
11. [Distributed Pioritized Experience Replay (Ape-X)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/common/distributed)
6868

6969
## Performance
7070

@@ -205,6 +205,16 @@ python <run-file> -h
205205
- Start rendering after the number of episodes.
206206
- `--load-from <save-file-path>`
207207
- Load the saved models and optimizers at the beginning.
208+
209+
#### Arguments for distributed training in run-files
210+
- `--max-episode-steps <int>`
211+
- Set maximum update step for learner as a stopping criterion for training loop. If the number is less than or equal to 0, it uses the default maximum step number of the environment.
212+
- `--off-worker-render`
213+
- Turn off rendering of individual workers.
214+
- `--off-logger-render`
215+
- Turn off rendering of logger tests.
216+
- `--worker-verbose`
217+
- Turn on printing episode run info for individual workers
208218
209219
210220
#### Show feature map with Grad-CAM
@@ -252,3 +262,4 @@ This won't be frequently updated.
252262
17. [Ramprasaath R. Selvaraju et al., "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." arXiv preprint arXiv:1610.02391, 2016.](https://arxiv.org/pdf/1610.02391.pdf)
253263
18. [Kaiming He et al., "Deep Residual Learning for Image Recognition." arXiv preprint arXiv:1512.03385, 2015.](https://arxiv.org/pdf/1512.03385)
254264
19. [Steven Kapturowski et al., "Recurrent Experience Replay in Distributed Reinforcement Learning." in International Conference on Learning Representations https://openreview.net/forum?id=r1lyTjAqYX, 2019.](https://openreview.net/forum?id=r1lyTjAqYX)
265+
20. [Horgan et al., "Distributed Prioritized Experience Replay." in International Conference on Learning Representations, 2018](https://arxiv.org/pdf/1803.00933.pdf)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Config for ApeX-DQN on Pong-No_FrameSkip-v4.
2+
3+
- Author: Chris Yoon
4+
- Contact: chris.yoon@medipixel.io
5+
"""
6+
7+
from rl_algorithms.common.helper_functions import identity
8+
9+
agent = dict(
10+
type="ApeX",
11+
hyper_params=dict(
12+
gamma=0.99,
13+
tau=5e-3,
14+
buffer_size=int(2.5e5), # openai baselines: int(1e4)
15+
batch_size=512, # openai baselines: 32
16+
update_starts_from=int(1e5), # openai baselines: int(1e4)
17+
multiple_update=1, # multiple learning updates
18+
train_freq=1, # in openai baselines, train_freq = 4
19+
gradient_clip=10.0, # dueling: 10.0
20+
n_step=5,
21+
w_n_step=1.0,
22+
w_q_reg=0.0,
23+
per_alpha=0.6, # openai baselines: 0.6
24+
per_beta=0.4,
25+
per_eps=1e-6,
26+
loss_type=dict(type="DQNLoss"),
27+
# Epsilon Greedy
28+
max_epsilon=1.0,
29+
min_epsilon=0.1, # openai baselines: 0.01
30+
epsilon_decay=1e-6, # openai baselines: 1e-7 / 1e-1
31+
# grad_cam
32+
grad_cam_layer_list=[
33+
"backbone.cnn.cnn_0.cnn",
34+
"backbone.cnn.cnn_1.cnn",
35+
"backbone.cnn.cnn_2.cnn",
36+
],
37+
num_workers=4,
38+
local_buffer_max_size=1000,
39+
worker_update_interval=50,
40+
logger_interval=2000,
41+
),
42+
learner_cfg=dict(
43+
type="DQNLearner",
44+
device="cuda",
45+
backbone=dict(
46+
type="CNN",
47+
configs=dict(
48+
input_sizes=[4, 32, 64],
49+
output_sizes=[32, 64, 64],
50+
kernel_sizes=[8, 4, 3],
51+
strides=[4, 2, 1],
52+
paddings=[1, 0, 0],
53+
),
54+
),
55+
head=dict(
56+
type="DuelingMLP",
57+
configs=dict(
58+
use_noisy_net=False, hidden_sizes=[512], output_activation=identity
59+
),
60+
),
61+
optim_cfg=dict(
62+
lr_dqn=0.0003, # dueling: 6.25e-5, openai baselines: 1e-4
63+
weight_decay=0.0, # this makes saturation in cnn weights
64+
adam_eps=1e-8, # rainbow: 1.5e-4, openai baselines: 1e-8
65+
),
66+
),
67+
worker_cfg=dict(type="DQNWorker", device="cpu",),
68+
logger_cfg=dict(type="DQNLogger",),
69+
comm_cfg=dict(
70+
learner_buffer_port=6554,
71+
learner_worker_port=6555,
72+
worker_buffer_port=6556,
73+
learner_logger_port=6557,
74+
send_batch_port=6558,
75+
priorities_port=6559,
76+
),
77+
)

requirements.txt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,17 @@ cloudpickle
99
opencv-python
1010
wandb
1111
addict
12-
1312
# mujoco
1413

14+
# for distributed learning
15+
ray
16+
ray[debug]
17+
pyzmq
18+
pyarrow
19+
20+
# for log
21+
matplotlib
22+
plotly
23+
1524
setuptools
1625
wheel

rl_algorithms/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
from .bc.her import LunarLanderContinuousHER, ReacherHER
66
from .bc.sac_agent import BCSACAgent
77
from .bc.sac_learner import BCSACLearner
8+
from .common.distributed.apex import ApeX
89
from .common.networks.backbones import CNN, ResNet
910
from .ddpg.agent import DDPGAgent
1011
from .ddpg.learner import DDPGLearner
1112
from .dqn.agent import DQNAgent
1213
from .dqn.learner import DQNLearner
14+
from .dqn.logger import DQNLogger
1315
from .dqn.losses import C51Loss, DQNLoss, IQNLoss
16+
from .dqn.worker import DQNWorker
1417
from .fd.ddpg_agent import DDPGfDAgent
1518
from .fd.ddpg_learner import DDPGfDLearner
1619
from .fd.dqn_agent import DQfDAgent
@@ -65,4 +68,7 @@
6568
"R2D1IQNLoss",
6669
"R2D1C51Loss",
6770
"R2D1DQNLoss",
71+
"ApeX",
72+
"DQNWorker",
73+
"DQNLogger",
6874
]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Abstract class for distributed architectures.
2+
3+
- Author: Chris Yoon
4+
- Contact: chris.yoon@medipixel.io
5+
"""
6+
7+
from abc import ABC, abstractmethod
8+
9+
10+
class Architecture(ABC):
11+
"""Abstract class for distributed architectures"""
12+
13+
@abstractmethod
14+
def _spawn(self):
15+
pass
16+
17+
@abstractmethod
18+
def train(self):
19+
pass
20+
21+
@abstractmethod
22+
def test(self):
23+
pass

0 commit comments

Comments
 (0)