From 9e897adfe93600c1db85ce1a7e064064b025c2c3 Mon Sep 17 00:00:00 2001 From: Chris Yoon <33583101+cyoon1729@users.noreply.github.com> Date: Tue, 23 Jun 2020 17:27:59 +0900 Subject: [PATCH] Incorporate distributed RL framework, Ape-X and Ape-X DQN (#246) * Take context as init_communication input; all processes share the same context. * implement abstract classes for distributed and ApeX Learner wrapper * Implement params2numpy method that loads torch state_dict as array of np.ndarray. * add __init__ * implement worker as abstract class, not wrapper base class * Change apex_learner file name to learner. * Implement Ape-X worker and learner base classes * implement Ape-X DQN worker * Create base class for distributed architectures * Implement and test Ape-X DQN working on Pong * Accept current change (master) for PongNoFrameskip-v4 dqn config * Make env_info more explicit in run_pong script (accept incoming change) * Make learner return cpu state_dict (accept incoming change) * Fix minor errors * Implement ApeXWorker as a wrapper ApeXWorkerWrapper Implement Logger and test wandb functionality Add worker and logger render in argparse Implement load_param() method in logger and worker * Move num_workers to hyperparams, and add logger_interval to hyperparams. * Implement safe exit condition for all ray actors. * Change _init_communication -> init_communication and call outside of __init__ for all ApeX actors Implement test() in distributed architectures (load from checkpoint and run logger test()) * * Add documentation * Move collect_data from worker class to ApeX Wrapper * Change hyperparameters around * Add worker-verbose as argparse flag * * Move num_worker to hyper_param cfg * * Add author * Add separate integration test for ApeX * Add integration test flag to pong * argparse integration test flag store_false->store_true * Change default config to dqn. * * Log worker scores per update step on Wandb. * Modify integration test * Modify apex buffer config for integration test * Change distributed directory structure * Add documentation * Modify readme.md * Modify readme.md * Add Ape-X to README. * Add description about args flags for distributed training. Co-authored-by: khkim Co-authored-by: Kyunghwan Kim --- LICENSE.md | 16 +- README.md | 19 +- configs/pong_no_frameskip_v4/apex_dqn.py | 77 ++++++ requirements.txt | 11 +- rl_algorithms/__init__.py | 6 + rl_algorithms/common/abstract/architecture.py | 23 ++ .../common/abstract/distributed_logger.py | 239 ++++++++++++++++++ rl_algorithms/common/abstract/learner.py | 57 ++++- rl_algorithms/common/abstract/worker.py | 162 ++++++++++++ rl_algorithms/common/distributed/__init__.py | 0 rl_algorithms/common/distributed/apex.py | 150 +++++++++++ rl_algorithms/common/distributed/buffer.py | 101 ++++++++ rl_algorithms/common/distributed/learner.py | 113 +++++++++ rl_algorithms/common/distributed/worker.py | 147 +++++++++++ rl_algorithms/common/helper_functions.py | 37 ++- rl_algorithms/dqn/agent.py | 3 +- rl_algorithms/dqn/learner.py | 16 +- rl_algorithms/dqn/logger.py | 74 ++++++ rl_algorithms/dqn/worker.py | 117 +++++++++ rl_algorithms/registry.py | 15 +- rl_algorithms/utils/__init__.py | 4 +- rl_algorithms/utils/registry.py | 36 +++ run_lunarlander_continuous_v2.py | 15 +- run_lunarlander_v2.py | 13 +- run_pong_no_frameskip_v4.py | 48 +++- run_reacher_v2.py | 13 +- tests/integration/test_run_agent.py | 8 +- tests/integration/test_run_apex.py | 67 +++++ 28 files changed, 1538 insertions(+), 49 deletions(-) create mode 100644 configs/pong_no_frameskip_v4/apex_dqn.py create mode 100644 rl_algorithms/common/abstract/architecture.py create mode 100644 rl_algorithms/common/abstract/distributed_logger.py create mode 100644 rl_algorithms/common/abstract/worker.py create mode 100644 rl_algorithms/common/distributed/__init__.py create mode 100644 rl_algorithms/common/distributed/apex.py create mode 100644 rl_algorithms/common/distributed/buffer.py create mode 100644 rl_algorithms/common/distributed/learner.py create mode 100644 rl_algorithms/common/distributed/worker.py create mode 100644 rl_algorithms/dqn/logger.py create mode 100644 rl_algorithms/dqn/worker.py create mode 100644 tests/integration/test_run_apex.py diff --git a/LICENSE.md b/LICENSE.md index d343fe5f..7694fad1 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,5 +1,4 @@ -# Our repository -MIT License +The MIT License (MIT) Copyright (c) 2019 Medipixel @@ -20,16 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -# Mujoco models -This work is derived from [MuJuCo models](http://www.mujoco.org/forum/index.php?resources/) used under the following license: -``` -This file is part of MuJoCo. -Copyright 2009-2015 Roboti LLC. -Mujoco :: Advanced physics simulation engine -Source : www.roboti.us -Version : 1.31 -Released : 23Apr16 -Author :: Vikash Kumar -Contacts : kumar@roboti.us -``` diff --git a/README.md b/README.md index cf5081cd..f87bc36b 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@

-[![CircleCI](https://circleci.com/gh/circleci/circleci-docs.svg?style=shield)](https://circleci.com/gh/medipixel) [![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/medipixel/rl_algorithms.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/medipixel/rl_algorithms/context:python) -[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) +[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![All Contributors](https://img.shields.io/badge/all_contributors-7-orange.svg?style=flat-square)](#contributors-) @@ -63,8 +63,8 @@ This project follows the [all-contributors](https://github.com/all-contributors/ 7. [Rainbow DQN](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/dqn) 8. [Rainbow IQN (without DuelingNet)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/dqn) - DuelingNet [degrades performance](https://github.com/medipixel/rl_algorithms/pull/137) 9. Rainbow IQN (with [ResNet](https://github.com/medipixel/rl_algorithms/blob/master/rl_algorithms/common/networks/backbones/resnet.py)) -10. [Recurrent Replay DQN (R2D1)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/recurrent/dqn_agent.py) - +10. [Recurrent Replay DQN (R2D1)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/recurrent) +11. [Distributed Pioritized Experience Replay (Ape-X)](https://github.com/medipixel/rl_algorithms/tree/master/rl_algorithms/common/distributed) ## Performance @@ -205,6 +205,16 @@ python -h - Start rendering after the number of episodes. - `--load-from ` - Load the saved models and optimizers at the beginning. + +#### Arguments for distributed training in run-files +- `--max-episode-steps ` + - Set maximum update step for learner as a stopping criterion for training loop. If the number is less than or equal to 0, it uses the default maximum step number of the environment. +- `--off-worker-render` + - Turn off rendering of individual workers. +- `--off-logger-render` + - Turn off rendering of logger tests. +- `--worker-verbose` + - Turn on printing episode run info for individual workers #### Show feature map with Grad-CAM @@ -252,3 +262,4 @@ This won't be frequently updated. 17. [Ramprasaath R. Selvaraju et al., "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." arXiv preprint arXiv:1610.02391, 2016.](https://arxiv.org/pdf/1610.02391.pdf) 18. [Kaiming He et al., "Deep Residual Learning for Image Recognition." arXiv preprint arXiv:1512.03385, 2015.](https://arxiv.org/pdf/1512.03385) 19. [Steven Kapturowski et al., "Recurrent Experience Replay in Distributed Reinforcement Learning." in International Conference on Learning Representations https://openreview.net/forum?id=r1lyTjAqYX, 2019.](https://openreview.net/forum?id=r1lyTjAqYX) +20. [Horgan et al., "Distributed Prioritized Experience Replay." in International Conference on Learning Representations, 2018](https://arxiv.org/pdf/1803.00933.pdf) \ No newline at end of file diff --git a/configs/pong_no_frameskip_v4/apex_dqn.py b/configs/pong_no_frameskip_v4/apex_dqn.py new file mode 100644 index 00000000..d04d6d2c --- /dev/null +++ b/configs/pong_no_frameskip_v4/apex_dqn.py @@ -0,0 +1,77 @@ +"""Config for ApeX-DQN on Pong-No_FrameSkip-v4. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" + +from rl_algorithms.common.helper_functions import identity + +agent = dict( + type="ApeX", + hyper_params=dict( + gamma=0.99, + tau=5e-3, + buffer_size=int(2.5e5), # openai baselines: int(1e4) + batch_size=512, # openai baselines: 32 + update_starts_from=int(1e5), # openai baselines: int(1e4) + multiple_update=1, # multiple learning updates + train_freq=1, # in openai baselines, train_freq = 4 + gradient_clip=10.0, # dueling: 10.0 + n_step=5, + w_n_step=1.0, + w_q_reg=0.0, + per_alpha=0.6, # openai baselines: 0.6 + per_beta=0.4, + per_eps=1e-6, + loss_type=dict(type="DQNLoss"), + # Epsilon Greedy + max_epsilon=1.0, + min_epsilon=0.1, # openai baselines: 0.01 + epsilon_decay=1e-6, # openai baselines: 1e-7 / 1e-1 + # grad_cam + grad_cam_layer_list=[ + "backbone.cnn.cnn_0.cnn", + "backbone.cnn.cnn_1.cnn", + "backbone.cnn.cnn_2.cnn", + ], + num_workers=4, + local_buffer_max_size=1000, + worker_update_interval=50, + logger_interval=2000, + ), + learner_cfg=dict( + type="DQNLearner", + device="cuda", + backbone=dict( + type="CNN", + configs=dict( + input_sizes=[4, 32, 64], + output_sizes=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + paddings=[1, 0, 0], + ), + ), + head=dict( + type="DuelingMLP", + configs=dict( + use_noisy_net=False, hidden_sizes=[512], output_activation=identity + ), + ), + optim_cfg=dict( + lr_dqn=0.0003, # dueling: 6.25e-5, openai baselines: 1e-4 + weight_decay=0.0, # this makes saturation in cnn weights + adam_eps=1e-8, # rainbow: 1.5e-4, openai baselines: 1e-8 + ), + ), + worker_cfg=dict(type="DQNWorker", device="cpu",), + logger_cfg=dict(type="DQNLogger",), + comm_cfg=dict( + learner_buffer_port=6554, + learner_worker_port=6555, + worker_buffer_port=6556, + learner_logger_port=6557, + send_batch_port=6558, + priorities_port=6559, + ), +) diff --git a/requirements.txt b/requirements.txt index 58d1b5c1..8626fd81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,17 @@ cloudpickle opencv-python wandb addict - # mujoco +# for distributed learning +ray +ray[debug] +pyzmq +pyarrow + +# for log +matplotlib +plotly + setuptools wheel diff --git a/rl_algorithms/__init__.py b/rl_algorithms/__init__.py index c85fff7f..6281e51b 100644 --- a/rl_algorithms/__init__.py +++ b/rl_algorithms/__init__.py @@ -5,12 +5,15 @@ from .bc.her import LunarLanderContinuousHER, ReacherHER from .bc.sac_agent import BCSACAgent from .bc.sac_learner import BCSACLearner +from .common.distributed.apex import ApeX from .common.networks.backbones import CNN, ResNet from .ddpg.agent import DDPGAgent from .ddpg.learner import DDPGLearner from .dqn.agent import DQNAgent from .dqn.learner import DQNLearner +from .dqn.logger import DQNLogger from .dqn.losses import C51Loss, DQNLoss, IQNLoss +from .dqn.worker import DQNWorker from .fd.ddpg_agent import DDPGfDAgent from .fd.ddpg_learner import DDPGfDLearner from .fd.dqn_agent import DQfDAgent @@ -65,4 +68,7 @@ "R2D1IQNLoss", "R2D1C51Loss", "R2D1DQNLoss", + "ApeX", + "DQNWorker", + "DQNLogger", ] diff --git a/rl_algorithms/common/abstract/architecture.py b/rl_algorithms/common/abstract/architecture.py new file mode 100644 index 00000000..0a70f7f9 --- /dev/null +++ b/rl_algorithms/common/abstract/architecture.py @@ -0,0 +1,23 @@ +"""Abstract class for distributed architectures. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" + +from abc import ABC, abstractmethod + + +class Architecture(ABC): + """Abstract class for distributed architectures""" + + @abstractmethod + def _spawn(self): + pass + + @abstractmethod + def train(self): + pass + + @abstractmethod + def test(self): + pass diff --git a/rl_algorithms/common/abstract/distributed_logger.py b/rl_algorithms/common/abstract/distributed_logger.py new file mode 100644 index 00000000..c5de04c4 --- /dev/null +++ b/rl_algorithms/common/abstract/distributed_logger.py @@ -0,0 +1,239 @@ +"""Base class for loggers use in distributed training. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" + +from abc import ABC, abstractmethod +import argparse +from collections import deque +import os +import shutil +from typing import List + +import gym +import numpy as np +import plotly.graph_objects as go +import pyarrow as pa +import torch +import wandb +import zmq + +from rl_algorithms.common.env.atari_wrappers import atari_env_generator +import rl_algorithms.common.env.utils as env_utils +from rl_algorithms.common.networks.brain import Brain +from rl_algorithms.utils.config import ConfigDict + + +class DistributedLogger(ABC): + """Base class for loggers use in distributed training. + + Attributes: + args (argparse.Namespace): arguments including hyperparameters and training settings + env_info (ConfigDict): information about environment + log_cfg (ConfigDict): configuration for saving log and checkpoint + comm_config (ConfigDict): configs for communication + backbone (ConfigDict): backbone configs for building network + head (ConfigDict): head configs for building network + brain (Brain): logger brain for evaluation + update_step (int): tracker for learner update step + device (torch.device): device, cpu by default + log_info_queue (deque): queue for storing log info received from learner + env (gym.Env): gym environment for running test + + """ + + def __init__( + self, + args: argparse.Namespace, + env_info: ConfigDict, + log_cfg: ConfigDict, + comm_cfg: ConfigDict, + backbone: ConfigDict, + head: ConfigDict, + ): + self.args = args + self.env_info = env_info + self.log_cfg = log_cfg + self.comm_cfg = comm_cfg + self.device = torch.device("cpu") # Logger only runs on cpu + self.brain = Brain(backbone, head).to(self.device) + + self.update_step = 0 + self.log_info_queue = deque(maxlen=100) + + self._init_env() + + # pylint: disable=attribute-defined-outside-init + def _init_env(self): + """Initialize gym environment.""" + if self.env_info.is_atari: + self.env = atari_env_generator( + self.env_info.name, self.args.max_episode_steps + ) + else: + self.env = gym.make(self.env_info.name) + env_utils.set_env(self.env, self.args) + + @abstractmethod + def load_params(self, path: str): + if not os.path.exists(path): + raise Exception( + f"[ERROR] the input path does not exist. Wrong path: {path}" + ) + + # pylint: disable=attribute-defined-outside-init + def init_communication(self): + """Initialize inter-process communication sockets.""" + ctx = zmq.Context() + self.pull_socket = ctx.socket(zmq.PULL) + self.pull_socket.bind(f"tcp://127.0.0.1:{self.comm_cfg.learner_logger_port}") + + @abstractmethod + def select_action(self, state: np.ndarray): + pass + + @abstractmethod + def write_log(self, log_value: dict): + pass + + # pylint: disable=no-self-use + @staticmethod + def _preprocess_state(state: np.ndarray, device: torch.device) -> torch.Tensor: + state = torch.FloatTensor(state).to(device) + return state + + def set_wandb(self): + """Set configuration for wandb logging.""" + wandb.init( + project=self.env_info.name, + name=f"{self.log_cfg.agent}/{self.log_cfg.curr_time}", + ) + wandb.config.update(vars(self.args)) + shutil.copy(self.args.cfg_path, os.path.join(wandb.run.dir, "config.py")) + + def recv_log_info(self): + """Receive info from learner.""" + received = False + try: + log_info_id = self.pull_socket.recv(zmq.DONTWAIT) + received = True + except zmq.Again: + pass + + if received: + self.log_info_queue.append(log_info_id) + + def run(self): + """Run main logging loop; continuously receive data and log.""" + if self.args.log: + self.set_wandb() + + while self.update_step < self.args.max_update_step: + self.recv_log_info() + if self.log_info_queue: # if non-empty + log_info_id = self.log_info_queue.pop() + log_info = pa.deserialize(log_info_id) + state_dict = log_info["state_dict"] + log_value = log_info["log_value"] + self.update_step = log_value["update_step"] + + self.synchronize(state_dict) + avg_score = self.test(self.update_step) + log_value["avg_score"] = avg_score + self.write_log(log_value) + + def write_worker_log(self, worker_logs: List[dict]): + """Log the mean scores of each episode per update step to wandb.""" + # NOTE: Worker plots are passed onto wandb.log as matplotlib.pyplot + # since wandb doesn't support logging multiple lines to single plot + if self.args.log: + self.set_wandb() + # Plot individual workers + fig = go.Figure() + worker_id = 0 + for worker_log in worker_logs: + fig.add_trace( + go.Scatter( + x=list(worker_log.keys()), + y=list(worker_log.values()), + mode="lines", + name=f"Worker {worker_id}", + line=dict(width=2), + ) + ) + worker_id = worker_id + 1 + + # Plot mean scores + steps = worker_logs[0].keys() + mean_scores = [] + for step in steps: + each_scores = [worker_log[step] for worker_log in worker_logs] + mean_scores.append(np.mean(each_scores)) + + fig.add_trace( + go.Scatter( + x=list(worker_logs[0].keys()), + y=mean_scores, + mode="lines+markers", + name="Mean scores", + line=dict(width=5), + ) + ) + + # Write to wandb + wandb.log({"Worker scores": fig}) + + def test(self, update_step: int, interim_test: bool = True): + """Test the agent.""" + avg_score = self._test(update_step, interim_test) + + # termination + self.env.close() + return avg_score + + def _test(self, update_step: int, interim_test: bool) -> float: + """Common test routine.""" + if interim_test: + test_num = self.args.interim_test_num + else: + test_num = self.args.episode_num + + scores = [] + for i_episode in range(test_num): + state = self.env.reset() + done = False + score = 0 + step = 0 + + while not done: + if self.args.logger_render: + self.env.render() + + action = self.select_action(state) + next_state, reward, done, _ = self.env.step(action) + + state = next_state + score += reward + step += 1 + + scores.append(score) + + if interim_test: + print( + "[INFO] update step: %d\ttest %d\tstep: %d\ttotal score: %d" + % (update_step, i_episode, step, score) + ) + else: + print( + "[INFO] test %d\tstep: %d\ttotal score: %d" + % (i_episode, step, score) + ) + + return np.mean(scores) + + def synchronize(self, new_params: List[np.ndarray]): + """Copy parameters from numpy arrays""" + for param, new_param in zip(self.brain.parameters(), new_params): + new_param = torch.FloatTensor(new_param).to(self.device) + param.data.copy_(new_param) diff --git a/rl_algorithms/common/abstract/learner.py b/rl_algorithms/common/abstract/learner.py index 383a5ee3..fe3aebdc 100644 --- a/rl_algorithms/common/abstract/learner.py +++ b/rl_algorithms/common/abstract/learner.py @@ -1,3 +1,9 @@ +"""Base Learner & LearnerWrapper class. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" + from abc import ABC, abstractmethod import argparse from collections import OrderedDict @@ -54,18 +60,17 @@ def __init__( env_info: ConfigDict, hyper_params: ConfigDict, log_cfg: ConfigDict, - device: torch.device, + device: str, ): """Initialize.""" self.args = args self.env_info = env_info self.hyper_params = hyper_params - self.device = device + self.device = torch.device(device) if not self.args.test: self.ckpt_path = ( - "./checkpoint/" - f"{env_info.env_name}/{log_cfg.agent}/{log_cfg.curr_time}/" + f"./checkpoint/{env_info.name}/{log_cfg.agent}/{log_cfg.curr_time}/" ) os.makedirs(self.ckpt_path, exist_ok=True) @@ -117,7 +122,7 @@ def get_policy(self) -> nn.Module: class LearnerWrapper(BaseLearner): - """Base class for all learner wrappers""" + """Base class for all learner wrappers.""" def __init__(self, learner: BaseLearner): """Initialize.""" @@ -134,3 +139,45 @@ def load_params(self, path: str): def get_state_dict(self) -> Union[OrderedDict, Tuple[OrderedDict]]: return self.learner.get_state_dict() + + +class DistributedLearnerWrapper(LearnerWrapper): + """Base wrapper class for distributed learners. + + Attributes: + learner (Learner): learner + comm_config (ConfigDict): configs for communication + + """ + + def __init__(self, learner: Learner, comm_cfg: ConfigDict): + LearnerWrapper.__init__(self, learner) + self.comm_cfg = comm_cfg + + @abstractmethod + def init_communication(self): + pass + + def update_model(self, experience: Union[TensorTuple, Tuple[TensorTuple]]) -> tuple: + """Run one step of learner model update.""" + return self.learner.update_model(experience) + + def save_params(self, n_update_step: int): + """Save learner params at defined directory.""" + self.learner.save_params(n_update_step) + + def load_params(self, path: str): + """Load params at start.""" + self.learner.load_params(path) + + def get_policy(self): + """Return model (policy) used for action selection, used only in grad cam.""" + return self.learner.get_policy() + + def get_state_dict(self): + """Return state dicts.""" + return self.learner.get_state_dict() + + @abstractmethod + def run(self): + pass diff --git a/rl_algorithms/common/abstract/worker.py b/rl_algorithms/common/abstract/worker.py new file mode 100644 index 00000000..b9944cc0 --- /dev/null +++ b/rl_algorithms/common/abstract/worker.py @@ -0,0 +1,162 @@ +"""Worker classes for distributed training. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" + +from abc import ABC, abstractmethod +import argparse +import os +import random +from typing import Deque, Dict, List, Tuple + +import gym +import numpy as np +import torch + +from rl_algorithms.common.env.atari_wrappers import atari_env_generator +import rl_algorithms.common.env.utils as env_utils +from rl_algorithms.common.helper_functions import set_random_seed +from rl_algorithms.utils.config import ConfigDict + + +class BaseWorker(ABC): + """Base class for Worker classes.""" + + @abstractmethod + def select_action(self, state: np.ndarray) -> np.ndarray: + pass + + @abstractmethod + def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, dict]: + pass + + @abstractmethod + def synchronize(self, new_params: list): + pass + + # pylint: disable=no-self-use + def _synchronize(self, network, new_params: List[np.ndarray]): + """Copy parameters from numpy arrays.""" + for param, new_param in zip(network.parameters(), new_params): + new_param = torch.FloatTensor(new_param).to(self.device) + param.data.copy_(new_param) + + +class Worker(BaseWorker): + """Base class for all functioning RL workers. + + Attributes: + rank (int): rank (ID) of worker + args (argparse.Namespace): args from run script + env_info (ConfigDict): information about environment + hyper_params (ConfigDict): algorithm hyperparameters + device (torch.Device): device on which worker process runs + env (gym.ENV): gym environment + """ + + def __init__( + self, + rank: int, + args: argparse.Namespace, + env_info: ConfigDict, + hyper_params: ConfigDict, + device: str, + ): + """Initialize.""" + self.rank = rank + self.args = args + self.env_info = env_info + self.hyper_params = hyper_params + self.device = torch.device(device) + + self._init_env() + + # pylint: disable=attribute-defined-outside-init, no-self-use + def _init_env(self): + """Intialize worker local environment.""" + if self.env_info.is_atari: + self.env = atari_env_generator( + self.env_info.name, self.args.max_episode_steps, frame_stack=True + ) + else: + self.env = gym.make(self.env_info.name) + env_utils.set_env(self.env, self.args) + + random.seed(self.rank) + env_seed = random.randint(0, 999) + set_random_seed(env_seed, self.env) + + @abstractmethod + def load_params(self, path: str): + if not os.path.exists(path): + raise Exception( + f"[ERROR] the input path does not exist. Wrong path: {path}" + ) + + @abstractmethod + def select_action(self, state: np.ndarray) -> np.ndarray: + pass + + @abstractmethod + def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, dict]: + pass + + # NOTE: No need to explicitly implement for non-PER/non-Ape-X workers + @abstractmethod + def compute_priorities(self, experience: Dict[str, np.ndarray]): + pass + + @abstractmethod + def synchronize(self, new_params: list): + pass + + @staticmethod + def _preprocess_state(state: np.ndarray, device: torch.device) -> torch.Tensor: + """Preprocess state so that actor selects an action.""" + state = torch.FloatTensor(state).to(device) + return state + + +class DistributedWorkerWrapper(BaseWorker): + """Base wrapper class for distributed worker wrappers.""" + + def __init__(self, worker: Worker, args: argparse.Namespace, comm_cfg: ConfigDict): + self.worker = worker + self.args = args + self.comm_cfg = comm_cfg + + @abstractmethod + def init_communication(self): + pass + + def select_action(self, state: np.ndarray) -> np.ndarray: + """Select an action from the input space.""" + return self.worker.select_action(state) + + def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, dict]: + """Take an action and return the response of the env.""" + return self.worker.step(action) + + def synchronize(self, new_params: list): + """Synchronize worker brain with learner brain.""" + self.worker.synchronize(new_params) + + @abstractmethod + def collect_data(self) -> Dict[str, np.ndarray]: + pass + + @abstractmethod + def run(self): + pass + + def preprocess_nstep(self, nstepqueue: Deque) -> Tuple[np.ndarray, ...]: + """Return n-step transition with discounted reward.""" + discounted_reward = 0 + _, _, _, last_state, done = nstepqueue[-1] + for transition in list(reversed(nstepqueue)): + state, action, reward, _, _ = transition + discounted_reward = reward + self.hyper_params.gamma * discounted_reward + nstep_data = (state, action, discounted_reward, last_state, done) + + return nstep_data diff --git a/rl_algorithms/common/distributed/__init__.py b/rl_algorithms/common/distributed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rl_algorithms/common/distributed/apex.py b/rl_algorithms/common/distributed/apex.py new file mode 100644 index 00000000..e7ed5543 --- /dev/null +++ b/rl_algorithms/common/distributed/apex.py @@ -0,0 +1,150 @@ +"""General Ape-X architecture for distributed training. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +- Paper: https://arxiv.org/pdf/1803.00933.pdf +- Reference: https://github.com/haje01/distper +""" + +import gym +import ray + +from rl_algorithms.common.abstract.architecture import Architecture +from rl_algorithms.common.buffer.replay_buffer import ReplayBuffer +from rl_algorithms.common.buffer.wrapper import PrioritizedBufferWrapper +from rl_algorithms.common.distributed.buffer import ApeXBufferWrapper +from rl_algorithms.common.distributed.learner import ApeXLearnerWrapper +from rl_algorithms.common.distributed.worker import ApeXWorkerWrapper +from rl_algorithms.registry import AGENTS, build_learner, build_logger, build_worker +from rl_algorithms.utils.config import ConfigDict + + +@AGENTS.register_module +class ApeX(Architecture): + """General Ape-X architecture for distributed training. + + Attributes: + rank (int): rank (ID) of worker + args (argparse.Namespace): args from run script + env_info (ConfigDict): information about environment + hyper_params (ConfigDict): algorithm hyperparameters + learner_cfg (ConfigDict): configs for learner class + worker_cfg (ConfigDict): configs for worker class + logger_cfg (ConfigDict): configs for logger class + comm_cfg (ConfigDict): configs for inter-process communication + log_cfg (ConfigDict): configs for logging, passed on to logger_cfg + learner (Learner): distributed learner class + workers (list): List of distributed worker class + global buffer (ReplayBuffer): centralized buffer wrapped with PER and ApeX + logger (DistributedLogger): logger class + processes (list): List of all processes + + """ + + def __init__( + self, + args: ConfigDict, + env: gym.Env, + env_info: ConfigDict, + hyper_params: ConfigDict, + learner_cfg: ConfigDict, + worker_cfg: ConfigDict, + logger_cfg: ConfigDict, + comm_cfg: ConfigDict, + log_cfg: ConfigDict, + ): + self.args = args + self.env = env + self.env_info = env_info + self.hyper_params = hyper_params + self.learner_cfg = learner_cfg + self.worker_cfg = worker_cfg + self.logger_cfg = logger_cfg + self.comm_cfg = comm_cfg + self.log_cfg = log_cfg + + self._organize_configs() + + # pylint: disable=attribute-defined-outside-init + def _organize_configs(self): + """Organize configs for initializing components from registry.""" + # organize learner configs + self.learner_cfg.args = self.args + self.learner_cfg.env_info = self.env_info + self.learner_cfg.hyper_params = self.hyper_params + self.learner_cfg.log_cfg = self.log_cfg + self.learner_cfg.head.configs.state_size = self.env_info.observation_space.shape + self.learner_cfg.head.configs.output_size = self.env_info.action_space.n + + # organize worker configs + self.worker_cfg.env_info = self.env_info + self.worker_cfg.hyper_params = self.hyper_params + self.worker_cfg.backbone = self.learner_cfg.backbone + self.worker_cfg.head = self.learner_cfg.head + + # organize logger configs + self.logger_cfg.args = self.args + self.logger_cfg.env_info = self.env_info + self.logger_cfg.log_cfg = self.log_cfg + self.logger_cfg.comm_cfg = self.comm_cfg + self.logger_cfg.backbone = self.learner_cfg.backbone + self.logger_cfg.head = self.learner_cfg.head + + def _spawn(self): + """Intialize distributed worker, learner and centralized replay buffer.""" + replay_buffer = ReplayBuffer( + self.hyper_params.buffer_size, self.hyper_params.batch_size, + ) + per_buffer = PrioritizedBufferWrapper( + replay_buffer, alpha=self.hyper_params.per_alpha + ) + self.global_buffer = ApeXBufferWrapper.remote( + per_buffer, self.args, self.hyper_params, self.comm_cfg + ) + + learner = build_learner(self.learner_cfg) + self.learner = ApeXLearnerWrapper.remote(learner, self.comm_cfg) + + state_dict = learner.get_state_dict() + worker_build_args = dict(args=self.args, state_dict=state_dict) + + self.workers = [] + self.num_workers = self.hyper_params.num_workers + for rank in range(self.num_workers): + worker_build_args["rank"] = rank + worker = build_worker(self.worker_cfg, build_args=worker_build_args) + apex_worker = ApeXWorkerWrapper.remote(worker, self.args, self.comm_cfg) + self.workers.append(apex_worker) + + self.logger = build_logger(self.logger_cfg) + + self.processes = self.workers + [self.learner, self.global_buffer, self.logger] + + def train(self): + """Spawn processes and run training loop.""" + print("Spawning and initializing communication...") + # Spawn processes: + self._spawn() + # Initialize communication + for proc in self.processes: + proc.init_communication.remote() + + # Run main training loop + print("Running main training loop...") + run_procs = [proc.run.remote() for proc in self.processes] + futures = ray.get(run_procs) + + # Retreive workers' data and write to wandb + # NOTE: Logger logs the mean scores of each episode per update step + if self.args.log: + worker_logs = [f for f in futures if f is not None] + self.logger.write_worker_log.remote(worker_logs) + print("Exiting training...") + + def test(self): + """Load model from checkpoint and run logger for testing.""" + # NOTE: You could also load the Ape-X trained model on the single agent DQN + self.logger = build_logger(self.logger_cfg) + self.logger.load_params.remote(self.args.load_from) + ray.get([self.logger.test.remote(update_step=0, interim_test=False)]) + print("Exiting testing...") diff --git a/rl_algorithms/common/distributed/buffer.py b/rl_algorithms/common/distributed/buffer.py new file mode 100644 index 00000000..797fd279 --- /dev/null +++ b/rl_algorithms/common/distributed/buffer.py @@ -0,0 +1,101 @@ +"""Wrapper for Ape-X global buffer. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" + +import argparse + +import pyarrow as pa +import ray +import zmq + +from rl_algorithms.common.abstract.buffer import BufferWrapper +from rl_algorithms.common.buffer.wrapper import PrioritizedBufferWrapper +from rl_algorithms.utils.config import ConfigDict + + +@ray.remote +class ApeXBufferWrapper(BufferWrapper): + """Wrapper for Ape-X global buffer. + + Attributes: + per_buffer (ReplayBuffer): replay buffer wrappped in PER wrapper + args (arpgarse.Namespace): args from run script + hyper_params (ConfigDict): algorithm hyperparameters + comm_config (ConfigDict): configs for communication + + """ + + def __init__( + self, + per_buffer: PrioritizedBufferWrapper, + args: argparse.Namespace, + hyper_params: ConfigDict, + comm_cfg: ConfigDict, + ): + BufferWrapper.__init__(self, per_buffer) + self.args = args + self.hyper_params = hyper_params + self.comm_cfg = comm_cfg + self.per_beta = hyper_params.per_beta + self.num_sent = 0 + + # pylint: disable=attribute-defined-outside-init + def init_communication(self): + """Initialize sockets for communication.""" + ctx = zmq.Context() + self.req_socket = ctx.socket(zmq.REQ) + self.req_socket.connect(f"tcp://127.0.0.1:{self.comm_cfg.learner_buffer_port}") + + self.pull_socket = ctx.socket(zmq.PULL) + self.pull_socket.bind(f"tcp://127.0.0.1:{self.comm_cfg.worker_buffer_port}") + + def recv_worker_data(self): + """Receive replay data from worker and incorporate to buffer.""" + received = False + try: + new_replay_data_id = self.pull_socket.recv(zmq.DONTWAIT) + received = True + except zmq.Again: + pass + + if received: + new_replay_data = pa.deserialize(new_replay_data_id) + experience, priorities = new_replay_data + for idx in range(len(experience["states"])): + transition = ( + experience["states"][idx], + experience["actions"][idx], + experience["rewards"][idx], + experience["next_states"][idx], + experience["dones"][idx], + ) + self.buffer.add(transition) + self.buffer.update_priorities([len(self.buffer) - 1], priorities[idx]) + + def send_batch_to_learner(self): + """Send batch to learner and receive priorities.""" + # Send batch and request priorities (blocking recv) + batch = self.buffer.sample(self.per_beta) + batch_id = pa.serialize(batch).to_buffer() + self.req_socket.send(batch_id) + self.num_sent = self.num_sent + 1 + + # Receive priorities + new_priors_id = self.req_socket.recv() + idxes, new_priorities = pa.deserialize(new_priors_id) + self.buffer.update_priorities(idxes, new_priorities) + + def update_priority_beta(self): + """Update important sampling ratio for prioritized buffer.""" + fraction = min(float(self.num_sent) / self.args.max_update_step, 1.0) + self.per_beta = self.per_beta + fraction * (1.0 - self.per_beta) + + def run(self): + """Run main buffer loop to communicate data.""" + while self.num_sent < self.args.max_update_step: + self.recv_worker_data() + if len(self.buffer) >= self.hyper_params.update_starts_from: + self.send_batch_to_learner() + self.update_priority_beta() diff --git a/rl_algorithms/common/distributed/learner.py b/rl_algorithms/common/distributed/learner.py new file mode 100644 index 00000000..211b31eb --- /dev/null +++ b/rl_algorithms/common/distributed/learner.py @@ -0,0 +1,113 @@ +"""Learner Wrapper to enable Ape-X distributed training. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" + +from typing import List + +import numpy as np +import pyarrow as pa +import ray +import zmq + +from rl_algorithms.common.abstract.learner import DistributedLearnerWrapper, Learner +from rl_algorithms.common.helper_functions import numpy2floattensor, state_dict2numpy +from rl_algorithms.utils.config import ConfigDict + + +@ray.remote(num_gpus=1) +class ApeXLearnerWrapper(DistributedLearnerWrapper): + """Learner Wrapper to enable Ape-X distributed training. + + Attributes: + learner (Learner): learner + comm_config (ConfigDict): configs for communication + update_step (int): counts update steps + pub_socket (zmq.Socket): publisher socket for broadcasting params + rep_socket (zmq.Socket): reply socket for receiving replay data & sending new priorities + update_step (int): number of update steps + max_update_step (int): maximum update steps per run + worker_update_interval (int): num update steps between worker synchronization + logger_interval (int): num update steps between logging + + """ + + def __init__(self, learner: Learner, comm_cfg: ConfigDict): + """Initialize.""" + DistributedLearnerWrapper.__init__(self, learner, comm_cfg) + self.update_step = 0 + self.max_update_step = self.learner.args.max_update_step + self.worker_update_interval = self.learner.hyper_params.worker_update_interval + self.logger_interval = self.learner.hyper_params.logger_interval + + # NOTE: disable because learner uses preprocessed n_step experience + self.learner.use_n_step = False + + # pylint: disable=attribute-defined-outside-init + def init_communication(self): + """Initialize sockets for communication.""" + ctx = zmq.Context() + # Socket to send updated network parameters to worker + self.pub_socket = ctx.socket(zmq.PUB) + self.pub_socket.bind(f"tcp://127.0.0.1:{self.comm_cfg.learner_worker_port}") + + # Socket to receive replay data and send new priorities to buffer + self.rep_socket = ctx.socket(zmq.REP) + self.rep_socket.bind(f"tcp://127.0.0.1:{self.comm_cfg.learner_buffer_port}") + + # Socket to send logging data to logger + self.push_socket = ctx.socket(zmq.PUSH) + self.push_socket.connect(f"tcp://127.0.0.1:{self.comm_cfg.learner_logger_port}") + + def recv_replay_data(self): + """Receive replay data from gloal buffer.""" + replay_data_id = self.rep_socket.recv() + replay_data = pa.deserialize(replay_data_id) + return replay_data + + def send_new_priorities(self, indices: np.ndarray, priorities: np.ndarray): + """Send new priority values and corresponding indices to buffer.""" + new_priors = [indices, priorities] + new_priors_id = pa.serialize(new_priors).to_buffer() + self.rep_socket.send(new_priors_id) + + def publish_params(self, update_step: int, np_state_dict: List[np.ndarray]): + """Broadcast updated params to all workers.""" + param_info = [update_step, np_state_dict] + new_params_id = pa.serialize(param_info).to_buffer() + self.pub_socket.send(new_params_id) + + def send_info_to_logger( + self, np_state_dict: List[np.ndarray], step_info: list, + ): + """Send new params and log info to logger.""" + log_value = dict(update_step=self.update_step, step_info=step_info) + log_info = dict(log_value=log_value, state_dict=np_state_dict) + log_info_id = pa.serialize(log_info).to_buffer() + self.push_socket.send(log_info_id) + + def run(self): + """Run main training loop.""" + self.telapsed = 0 + while self.update_step < self.max_update_step: + replay_data = self.recv_replay_data() + if replay_data is not None: + replay_data = numpy2floattensor(replay_data[:6]) + replay_data[6:] + info = self.update_model(replay_data) + indices, new_priorities = info[-2:] + step_info = info[:-2] + self.update_step = self.update_step + 1 + + self.send_new_priorities(indices, new_priorities) + + if self.update_step % self.worker_update_interval == 0: + state_dict = self.get_state_dict() + np_state_dict = state_dict2numpy(state_dict) + self.publish_params(self.update_step, np_state_dict) + + if self.update_step % self.logger_interval == 0: + state_dict = self.get_state_dict() + np_state_dict = state_dict2numpy(state_dict) + self.send_info_to_logger(np_state_dict, step_info) + self.learner.save_params(self.update_step) diff --git a/rl_algorithms/common/distributed/worker.py b/rl_algorithms/common/distributed/worker.py new file mode 100644 index 00000000..e8a78236 --- /dev/null +++ b/rl_algorithms/common/distributed/worker.py @@ -0,0 +1,147 @@ +"""Wrapper class for ApeX based distributed workers. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" + +import argparse +from collections import deque +from typing import Dict + +import numpy as np +import pyarrow as pa +import ray +import zmq + +from rl_algorithms.common.abstract.worker import DistributedWorkerWrapper, Worker +from rl_algorithms.utils.config import ConfigDict + + +@ray.remote(num_cpus=1) +class ApeXWorkerWrapper(DistributedWorkerWrapper): + """Wrapper class for ApeX based distributed workers. + + Attributes: + hyper_params (ConfigDict): worker hyper_params + update_step (int): tracker for learner update step + use_n_step (int): indication for using n-step transitions + sub_socket (zmq.Context): subscriber socket for receiving params from learner + push_socket (zmq.Context): push socket for sending experience to global buffer + + """ + + def __init__(self, worker: Worker, args: argparse.Namespace, comm_cfg: ConfigDict): + DistributedWorkerWrapper.__init__(self, worker, args, comm_cfg) + self.update_step = 0 + self.hyper_params = self.worker.hyper_params + self.use_n_step = self.hyper_params.n_step > 1 + self.scores = dict() + + self.worker._init_env() + + # pylint: disable=attribute-defined-outside-init + def init_communication(self): + """Initialize sockets connecting worker-learner, worker-buffer.""" + # for receiving params from learner + ctx = zmq.Context() + self.sub_socket = ctx.socket(zmq.SUB) + self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, "") + self.sub_socket.setsockopt(zmq.CONFLATE, 1) + self.sub_socket.connect(f"tcp://127.0.0.1:{self.comm_cfg.learner_worker_port}") + + # for sending replay data to buffer + self.push_socket = ctx.socket(zmq.PUSH) + self.push_socket.connect(f"tcp://127.0.0.1:{self.comm_cfg.worker_buffer_port}") + + def send_data_to_buffer(self, replay_data): + """Send replay data to global buffer.""" + replay_data_id = pa.serialize(replay_data).to_buffer() + self.push_socket.send(replay_data_id) + + def recv_params_from_learner(self): + """Get new params and sync. return True if success, False otherwise.""" + received = False + try: + new_params_id = self.sub_socket.recv(zmq.DONTWAIT) + received = True + except zmq.Again: + # Although learner doesn't send params, don't wait + pass + + if received: + new_param_info = pa.deserialize(new_params_id) + update_step, new_params = new_param_info + self.update_step = update_step + self.worker.synchronize(new_params) + + # Add new entry for scores dict + self.scores[self.update_step] = [] + + def compute_priorities(self, experience: Dict[str, np.ndarray]): + """Compute priority values (TD error) of collected experience.""" + return self.worker.compute_priorities(experience) + + def collect_data(self) -> dict: + """Fill and return local buffer.""" + local_memory = dict(states=[], actions=[], rewards=[], next_states=[], dones=[]) + local_memory_keys = local_memory.keys() + if self.use_n_step: + nstep_queue = deque(maxlen=self.hyper_params.n_step) + + while len(local_memory["states"]) < self.hyper_params.local_buffer_max_size: + state = self.worker.env.reset() + done = False + score = 0 + num_steps = 0 + while not done: + if self.args.worker_render: + self.worker.env.render() + num_steps += 1 + action = self.select_action(state) + next_state, reward, done, _ = self.step(action) + transition = (state, action, reward, next_state, int(done)) + if self.use_n_step: + nstep_queue.append(transition) + if self.hyper_params.n_step == len(nstep_queue): + nstep_exp = self.preprocess_nstep(nstep_queue) + for entry, keys in zip(nstep_exp, local_memory_keys): + local_memory[keys].append(entry) + else: + for entry, keys in zip(transition, local_memory_keys): + local_memory[keys].append(entry) + + state = next_state + score += reward + + self.recv_params_from_learner() + + self.scores[self.update_step].append(score) + + if self.args.worker_verbose: + print( + "[TRAIN] [Worker %d] Score: %d, Epsilon: %.5f " + % (self.worker.rank, score, self.worker.epsilon) + ) + + for key in local_memory_keys: + local_memory[key] = np.array(local_memory[key]) + + return local_memory + + def run(self) -> Dict[int, float]: + """Run main worker loop.""" + self.scores[0] = [] + while self.update_step < self.args.max_update_step: + experience = self.collect_data() + priority_values = self.compute_priorities(experience) + worker_data = [experience, priority_values] + self.send_data_to_buffer(worker_data) + + mean_scores_per_ep_step = self.compute_mean_scores(self.scores) + return mean_scores_per_ep_step + + @staticmethod + def compute_mean_scores(scores: Dict[int, list]): + for step in scores.keys(): + scores[step] = np.mean(scores[step]) + return scores diff --git a/rl_algorithms/common/helper_functions.py b/rl_algorithms/common/helper_functions.py index af5dfe04..bc2e0729 100644 --- a/rl_algorithms/common/helper_functions.py +++ b/rl_algorithms/common/helper_functions.py @@ -4,7 +4,6 @@ - Author: Curt Park - Contact: curt.park@medipixel.io """ - from collections import deque import random from typing import Deque, List, Tuple @@ -14,6 +13,8 @@ import torch import torch.nn as nn +from rl_algorithms.utils.config import ConfigDict + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -94,8 +95,36 @@ def numpy2floattensor(arrays: Tuple[np.ndarray]) -> Tuple[np.ndarray]: """Convert numpy arrays to torch float tensor.""" tensors = [] for array in arrays: - tensor = torch.FloatTensor(array).to(device) - if torch.cuda.is_available(): - tensor = tensor.cuda(non_blocking=True) + tensor = torch.FloatTensor(array).to(device, non_blocking=True) tensors.append(tensor) return tuple(tensors) + + +def state_dict2numpy(state_dict) -> List[np.ndarray]: + """Convert Pytorch state dict to list of numpy arrays.""" + params = [] + for param in list(state_dict): + params.append(state_dict[param].numpy()) + return params + + +def set_cfg_for_intergration_test(cfg: ConfigDict) -> ConfigDict: + """Set specific values in config for intergration test.""" + if "batch_size" in cfg.agent.hyper_params: + cfg.agent.hyper_params.batch_size = 10 + if "update_starts_from" in cfg.agent.hyper_params: + cfg.agent.hyper_params.update_starts_from = 50 + if "initial_random_action" in cfg.agent.hyper_params: + cfg.agent.hyper_params.initial_random_action = 10 + if cfg.agent.type == "ApeX": + cfg.agent.hyper_params.num_workers = 1 + cfg.agent.hyper_params.worker_update_interval = 1 + cfg.agent.hyper_params.logger_interval = 1 + cfg.agent.hyper_params.buffer_size = 50 + if cfg.agent.type == "PPOAgent": + cfg.agent.hyper_params.epoch = 1 + cfg.agent.hyper_params.n_workers = 1 + cfg.agent.hyper_params.rollout_len = 10 + if "fD" in cfg.agent.type: + cfg.agent.hyper_params.pretrain_step = 1 + return cfg diff --git a/rl_algorithms/dqn/agent.py b/rl_algorithms/dqn/agent.py index 06930336..7d4d1a4b 100644 --- a/rl_algorithms/dqn/agent.py +++ b/rl_algorithms/dqn/agent.py @@ -200,8 +200,10 @@ def pretrain(self): pass def sample_experience(self) -> Tuple[torch.Tensor, ...]: + """Sample experience from replay buffer.""" experiences_1 = self.memory.sample(self.per_beta) experiences_1 = numpy2floattensor(experiences_1[:6]) + experiences_1[6:] + if self.use_n_step: indices = experiences_1[-2] experiences_n = self.memory_n.sample(indices) @@ -231,7 +233,6 @@ def train(self): while not done: if self.args.render and self.i_episode >= self.args.render_after: self.env.render() - action = self.select_action(state) next_state, reward, done, _ = self.step(action) self.total_step += 1 diff --git a/rl_algorithms/dqn/learner.py b/rl_algorithms/dqn/learner.py index e2528ce7..388b9f3c 100644 --- a/rl_algorithms/dqn/learner.py +++ b/rl_algorithms/dqn/learner.py @@ -1,5 +1,11 @@ +"""Learner for DQN Agent. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" import argparse from collections import OrderedDict +from copy import deepcopy from typing import Tuple, Union import numpy as np @@ -38,10 +44,9 @@ def __init__( backbone: ConfigDict, head: ConfigDict, optim_cfg: ConfigDict, - device: torch.device, + device: str, ): Learner.__init__(self, args, env_info, hyper_params, log_cfg, device) - self.backbone_cfg = backbone self.head_cfg = head self.head_cfg.configs.state_size = self.env_info.observation_space.shape @@ -75,7 +80,7 @@ def _init_network(self): def update_model( self, experience: Union[TensorTuple, Tuple[TensorTuple]] ) -> Tuple[torch.Tensor, torch.Tensor, list, np.ndarray]: # type: ignore - """Update dqn and dqn target""" + """Update dqn and dqn target.""" if self.use_n_step: experience_1, experience_n = experience @@ -157,8 +162,9 @@ def load_params(self, path: str): def get_state_dict(self) -> OrderedDict: """Return state dicts, mainly for distributed worker.""" - return self.dqn.state_dict() + dqn = deepcopy(self.dqn) + return dqn.cpu().state_dict() def get_policy(self) -> nn.Module: - """Return model (policy) used for action selection.""" + """Return model (policy) used for action selection, used only in grad cam.""" return self.dqn diff --git a/rl_algorithms/dqn/logger.py b/rl_algorithms/dqn/logger.py new file mode 100644 index 00000000..07d84713 --- /dev/null +++ b/rl_algorithms/dqn/logger.py @@ -0,0 +1,74 @@ +"""DQN Logger for distributed training. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" + +import argparse + +import numpy as np +import torch +import wandb + +from rl_algorithms.common.abstract.distributed_logger import DistributedLogger +from rl_algorithms.registry import LOGGERS +from rl_algorithms.utils.config import ConfigDict + + +@LOGGERS.register_module +class DQNLogger(DistributedLogger): + """DQN Logger for distributed training.""" + + def __init__( + self, + args: argparse.Namespace, + env_info: ConfigDict, + log_cfg: ConfigDict, + comm_cfg: ConfigDict, + backbone: ConfigDict, + head: ConfigDict, + ): + DistributedLogger.__init__( + self, args, env_info, log_cfg, comm_cfg, backbone, head + ) + + def load_params(self, path: str): + """Load model and optimizer parameters.""" + # Logger only runs on cpu + DistributedLogger.load_params(self, path) + + params = torch.load(path, map_location="cpu") + self.brain.load_state_dict(params["dqn_state_dict"]) + print("[INFO] loaded the model and optimizer from", path) + + def select_action(self, state: np.ndarray): + """Select action to be executed at given state.""" + with torch.no_grad(): + state = self._preprocess_state(state, self.device) + selected_action = self.brain(state).argmax() + selected_action = selected_action.cpu().numpy() + + return selected_action + + def write_log(self, log_value: dict): + """Write log about loss and score.""" + print( + "[INFO] update_step %d, average score: %f, " + "loss: %f, avg q-value: %f" + % ( + log_value["update_step"], + log_value["avg_score"], + log_value["step_info"][0], + log_value["step_info"][1], + ) + ) + + if self.args.log: + wandb.log( + { + "avg_test_score": log_value["avg_score"], + "loss": log_value["step_info"][0], + "avg_q_value": log_value["step_info"][1], + }, + step=log_value["update_step"], + ) diff --git a/rl_algorithms/dqn/worker.py b/rl_algorithms/dqn/worker.py new file mode 100644 index 00000000..d9d1664f --- /dev/null +++ b/rl_algorithms/dqn/worker.py @@ -0,0 +1,117 @@ +"""DQN worker for distributed training. + +- Author: Chris Yoon +- Contact: chris.yoon@medipixel.io +""" + +import argparse +from collections import OrderedDict +from typing import Dict, List, Tuple + +import numpy as np +import torch + +from rl_algorithms.common.abstract.worker import Worker +from rl_algorithms.common.networks.brain import Brain +from rl_algorithms.registry import WORKERS, build_loss +from rl_algorithms.utils.config import ConfigDict + + +@WORKERS.register_module +class DQNWorker(Worker): + """DQN worker for distributed training. + + Attributes: + backbone (ConfigDict): backbone configs for building network + head (ConfigDict): head configs for building network + state_dict (ConfigDict): initial network state dict received form learner + device (str): literal to indicate cpu/cuda use + + """ + + def __init__( + self, + rank: int, + args: argparse.Namespace, + env_info: ConfigDict, + hyper_params: ConfigDict, + backbone: ConfigDict, + head: ConfigDict, + state_dict: OrderedDict, + device: str, + ): + Worker.__init__(self, rank, args, env_info, hyper_params, device) + self.loss_fn = build_loss(self.hyper_params.loss_type) + self.backbone_cfg = backbone + self.head_cfg = head + self.head_cfg.configs.state_size = self.env_info.observation_space.shape + self.head_cfg.configs.output_size = self.env_info.action_space.n + + self.use_n_step = self.hyper_params.n_step > 1 + + self.max_epsilon = self.hyper_params.max_epsilon + self.min_epsilon = self.hyper_params.min_epsilon + self.epsilon = self.hyper_params.max_epsilon + + self._init_networks(state_dict) + + # pylint: disable=attribute-defined-outside-init + def _init_networks(self, state_dict: OrderedDict): + """Initialize DQN policy with learner state dict.""" + self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device) + self.dqn.load_state_dict(state_dict) + + def load_params(self, path: str): + """Load model and optimizer parameters.""" + Worker.load_params(self, path) + + params = torch.load(path) + self.dqn.load_state_dict(params["dqn_state_dict"]) + print("[INFO] loaded the model and optimizer from", path) + + def select_action(self, state: np.ndarray) -> np.ndarray: + """Select an action from the input space.""" + # epsilon greedy policy + # pylint: disable=comparison-with-callable + if self.epsilon > np.random.random(): + selected_action = np.array(self.env.action_space.sample()) + else: + with torch.no_grad(): + state = self._preprocess_state(state, self.device) + selected_action = self.dqn(state).argmax() + selected_action = selected_action.cpu().numpy() + + # Decay epsilon + self.epsilon = max( + self.epsilon + - (self.max_epsilon - self.min_epsilon) * self.hyper_params.epsilon_decay, + self.min_epsilon, + ) + + return selected_action + + def step(self, action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, dict]: + """Take an action and return the response of the env.""" + next_state, reward, done, info = self.env.step(action) + return next_state, reward, done, info + + def compute_priorities(self, memory: Dict[str, np.ndarray]) -> np.ndarray: + """Compute initial priority values of experiences in local memory.""" + states = torch.FloatTensor(memory["states"]).to(self.device) + actions = torch.LongTensor(memory["actions"]).to(self.device) + rewards = torch.FloatTensor(memory["rewards"].reshape(-1, 1)).to(self.device) + next_states = torch.FloatTensor(memory["next_states"]).to(self.device) + dones = torch.FloatTensor(memory["dones"].reshape(-1, 1)).to(self.device) + memory_tensors = (states, actions, rewards, next_states, dones) + + dq_loss_element_wise, _ = self.loss_fn( + self.dqn, self.dqn, memory_tensors, self.hyper_params.gamma, self.head_cfg + ) + loss_for_prior = dq_loss_element_wise.detach().cpu().numpy() + new_priorities = loss_for_prior + self.hyper_params.per_eps + + return new_priorities + + def synchronize(self, new_params: List[np.ndarray]): + """Synchronize worker dqn with learner dqn.""" + self._synchronize(self.dqn, new_params) diff --git a/rl_algorithms/registry.py b/rl_algorithms/registry.py index 752daddb..cadc1bab 100644 --- a/rl_algorithms/registry.py +++ b/rl_algorithms/registry.py @@ -1,4 +1,4 @@ -from rl_algorithms.utils import Registry, build_from_cfg +from rl_algorithms.utils import Registry, build_from_cfg, build_ray_obj_from_cfg from rl_algorithms.utils.config import ConfigDict AGENTS = Registry("agents") @@ -7,6 +7,8 @@ HEADS = Registry("heads") LOSSES = Registry("losses") HERS = Registry("hers") +WORKERS = Registry("workers") +LOGGERS = Registry("loggers") def build_agent(cfg: ConfigDict, build_args: dict = None): @@ -37,3 +39,14 @@ def build_loss(cfg: ConfigDict, build_args: dict = None): def build_her(cfg: ConfigDict, build_args: dict = None): """Build her using config and additional arguments.""" return build_from_cfg(cfg, HERS, build_args) + + +def build_worker(cfg: ConfigDict, build_args: dict = None): + """Build ray worker using config and additional arguments.""" + # return build_ray_obj_from_cfg(cfg, WORKERS, build_args) + return build_from_cfg(cfg, WORKERS, build_args) + + +def build_logger(cfg: ConfigDict, build_args: dict = None): + """Build ray worker using config and additional arguments.""" + return build_ray_obj_from_cfg(cfg, LOGGERS, build_args) diff --git a/rl_algorithms/utils/__init__.py b/rl_algorithms/utils/__init__.py index 2d84667e..70a53b80 100644 --- a/rl_algorithms/utils/__init__.py +++ b/rl_algorithms/utils/__init__.py @@ -1,4 +1,4 @@ from .config import Config -from .registry import Registry, build_from_cfg +from .registry import Registry, build_from_cfg, build_ray_obj_from_cfg -__all__ = ["Registry", "build_from_cfg", "Config"] +__all__ = ["Registry", "build_from_cfg", "build_ray_obj_from_cfg", "Config"] diff --git a/rl_algorithms/utils/registry.py b/rl_algorithms/utils/registry.py index 61600017..036bd8a9 100644 --- a/rl_algorithms/utils/registry.py +++ b/rl_algorithms/utils/registry.py @@ -1,5 +1,7 @@ import inspect +import ray + from rl_algorithms.utils.config import ConfigDict @@ -76,3 +78,37 @@ def build_from_cfg(cfg: ConfigDict, registry: Registry, default_args: dict = Non for name, value in default_args.items(): args.setdefault(name, value) return obj_cls(**args) + + +def build_ray_obj_from_cfg( + cfg: ConfigDict, registry: Registry, default_args: dict = None +): + """Build a module from config dict. + Args: + cfg (:obj: `ConfigDict`): Config dict. It should at least contain the key "type". + registry (:obj:`Registry`): The registry to search the type from. + default_args (dict, optional): Default initialization arguments. + Returns: + obj: The constructed object. + """ + assert isinstance(cfg, dict) and "type" in cfg + assert isinstance(default_args, dict) or default_args is None + args = cfg.copy() + obj_type = args.pop("type") + if isinstance(obj_type, str): + obj_cls = registry.get(obj_type) + if obj_cls is None: + raise KeyError( + "{} is not in the {} registry".format(obj_type, registry.name) + ) + elif inspect.isclass(obj_type): + obj_cls = obj_type + else: + raise TypeError( + "type must be a str or valid type, but got {}".format(type(obj_type)) + ) + + if default_args is not None: + for name, value in default_args.items(): + args.setdefault(name, value) + return ray.remote(num_cpus=1)(obj_cls).remote(**args) diff --git a/run_lunarlander_continuous_v2.py b/run_lunarlander_continuous_v2.py index a3b0c5c4..364eb075 100644 --- a/run_lunarlander_continuous_v2.py +++ b/run_lunarlander_continuous_v2.py @@ -25,7 +25,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--cfg-path", type=str, - default="./configs/lunarlander_continuous_v2/per_ddpg.py", + default="./configs/lunarlander_continuous_v2/ddpg.py", help="config path", ) parser.add_argument( @@ -70,6 +70,12 @@ def parse_args() -> argparse.Namespace: default="data/lunarlander_continuous_demo.pkl", help="demonstration path for learning from demo", ) + parser.add_argument( + "--integration-test", + dest="integration_test", + action="store_true", + help="indicate integration test", + ) return parser.parse_args() @@ -91,8 +97,13 @@ def main(): curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S") cfg = Config.fromfile(args.cfg_path) + + # If running integration test, simplify experiment + if args.integration_test: + cfg = common_utils.set_cfg_for_intergration_test(cfg) + cfg.agent.env_info = dict( - env_name=env_name, + name=env_name, observation_space=env.observation_space, action_space=env.action_space, is_discrete=False, diff --git a/run_lunarlander_v2.py b/run_lunarlander_v2.py index 4e65b7c3..41930569 100644 --- a/run_lunarlander_v2.py +++ b/run_lunarlander_v2.py @@ -70,6 +70,12 @@ def parse_args() -> argparse.Namespace: default="data/lunarlander_discrete_demo.pkl", help="demonstration path for learning from demo", ) + parser.add_argument( + "--integration-test", + dest="integration_test", + action="store_true", + help="indicate integration test", + ) return parser.parse_args() @@ -91,8 +97,13 @@ def main(): curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S") cfg = Config.fromfile(args.cfg_path) + + # If running integration test, simplify experiment + if args.integration_test: + cfg = common_utils.set_cfg_for_intergration_test(cfg) + cfg.agent.env_info = dict( - env_name=env_name, + name=env_name, observation_space=env.observation_space, action_space=env.action_space, is_discrete=True, diff --git a/run_pong_no_frameskip_v4.py b/run_pong_no_frameskip_v4.py index 9e12df9c..19153f4a 100644 --- a/run_pong_no_frameskip_v4.py +++ b/run_pong_no_frameskip_v4.py @@ -8,6 +8,8 @@ import argparse import datetime +import ray + from rl_algorithms import build_agent from rl_algorithms.common.env.atari_wrappers import atari_env_generator import rl_algorithms.common.helper_functions as common_utils @@ -18,7 +20,7 @@ def parse_args() -> argparse.Namespace: # configurations parser = argparse.ArgumentParser(description="Pytorch RL algorithms") parser.add_argument( - "--seed", type=int, default=777, help="random seed for reproducibility" + "--seed", type=int, default=161, help="random seed for reproducibility" ) parser.add_argument( "--cfg-path", @@ -44,12 +46,30 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--off-render", dest="render", action="store_false", help="turn off rendering" ) + parser.add_argument( + "--off-worker-render", + dest="worker_render", + action="store_false", + help="turn off worker rendering", + ) + parser.add_argument( + "--off-logger-render", + dest="logger_render", + action="store_false", + help="turn off logger rendering", + ) parser.add_argument( "--render-after", type=int, default=0, help="start rendering after the input number of episode", ) + parser.add_argument( + "--worker-verbose", + dest="worker_verbose", + action="store_true", + help="turn on worker print statements", + ) parser.add_argument( "--log", dest="log", action="store_true", help="turn on logging" ) @@ -57,11 +77,20 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--episode-num", type=int, default=500, help="total episode num" ) + parser.add_argument( + "--max-update-step", type=int, default=100000, help="max update step" + ) parser.add_argument( "--max-episode-steps", type=int, default=None, help="max episode step" ) parser.add_argument( - "--interim-test-num", type=int, default=10, help="interim test number" + "--interim-test-num", type=int, default=5, help="interim test number" + ) + parser.add_argument( + "--integration-test", + dest="integration_test", + action="store_true", + help="indicate integration test", ) parser.add_argument( "--off-framestack", @@ -90,16 +119,27 @@ def main(): curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S") cfg = Config.fromfile(args.cfg_path) + + # If running integration test, simplify experiment + if args.integration_test: + cfg = common_utils.set_cfg_for_intergration_test(cfg) + cfg.agent.env_info = dict( - env_name=env_name, + name="PongNoFrameskip-v4", + is_atari=True, + is_discrete=True, observation_space=env.observation_space, action_space=env.action_space, - is_discrete=True, ) cfg.agent.log_cfg = dict(agent=cfg.agent.type, curr_time=curr_time) build_args = dict(args=args, env=env) + agent = build_agent(cfg.agent, build_args) + # Initialize ray if using ApeX + if cfg.agent.type == "ApeX": + ray.init() + if not args.test: agent.train() elif args.test and args.grad_cam: diff --git a/run_reacher_v2.py b/run_reacher_v2.py index b050c39b..46416938 100644 --- a/run_reacher_v2.py +++ b/run_reacher_v2.py @@ -71,6 +71,12 @@ def parse_args() -> argparse.Namespace: default="data/reacher_demo.pkl", help="demonstration path for learning from demo", ) + parser.add_argument( + "--integration-test", + dest="integration_test", + action="store_true", + help="indicate integration test", + ) return parser.parse_args() @@ -92,8 +98,13 @@ def main(): curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S") cfg = Config.fromfile(args.cfg_path) + + # If running integration test, simplify experiment + if args.integration_test: + cfg = common_utils.set_cfg_for_intergration_test(cfg) + cfg.agent.env_info = dict( - env_name=env_name, + name=env_name, observation_space=env.observation_space, action_space=env.action_space, is_discrete=False, diff --git a/tests/integration/test_run_agent.py b/tests/integration/test_run_agent.py index 38576b30..52109c09 100644 --- a/tests/integration/test_run_agent.py +++ b/tests/integration/test_run_agent.py @@ -17,12 +17,13 @@ def check_run_env(config_root: str, run_file: str): configs = os.listdir(config_root) for cfg in configs: # except such as __init__, __pycache__ - if "__" in cfg: + if "__" in cfg or "apex" in cfg: continue cmd = ( - f"python {run_file} --cfg-path {config_root}{cfg} " - + f"--off-render --episode-num 1 --max-episode-step 1 --seed 12345" + f"python {run_file} --cfg-path {config_root}{cfg} --integration-test " + + f"--off-render --episode-num 1 --max-episode-step 1 --seed 12345 " + + f"--interim-test-num 1" ) p = subprocess.Popen( @@ -33,6 +34,7 @@ def check_run_env(config_root: str, run_file: str): shell=True, ) output, _ = p.communicate() + print(str(output)) assert p.returncode == 0 # Find saved checkpoint path diff --git a/tests/integration/test_run_apex.py b/tests/integration/test_run_apex.py new file mode 100644 index 00000000..0e6bd442 --- /dev/null +++ b/tests/integration/test_run_apex.py @@ -0,0 +1,67 @@ +"""Test only one step of run file for training.""" + +import json +import os +import os.path as osp +import re +import shutil +import subprocess + +import ray + +from rl_algorithms.utils.config import Config + + +def check_run_apex(config_root: str, run_file: str): + """Test that 1 episode of run file works well.""" + test_dir = osp.dirname(osp.abspath(__file__)) + pkg_root_dir = osp.dirname(osp.dirname(test_dir)) + os.chdir(pkg_root_dir) + + # loop of configs + configs = os.listdir(config_root) + for cfg in configs: + # except such as __init__, __pycache__ + if "__" in cfg or "apex" not in cfg: + continue + + cmd = ( + f"python {run_file} --cfg-path {config_root}{cfg} --integration-test " + + f"--off-worker-render --off-logger-render --max-update-step 1 --seed 12345 " + + f"--interim-test-num 1" + ) + + p = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + shell=True, + ) + output, _ = p.communicate() + print(str(output)) + assert p.returncode == 0 + + # Find saved checkpoint path + pattern = r"./checkpoint/.+/" + save_path = re.findall(pattern, str(output))[0] + print(save_path) + + check_save_path(save_path) + + +def check_save_path(save_path: str): + """Check checkpoint that tested run file makes and remove the checkpoint.""" + assert os.path.exists(save_path) + + # Remove checkpoint dir + shutil.rmtree(save_path) + + +def test_run_pong_no_frame_skip(): + """Test all agents that train PongNoFrameskip-v4 env.""" + check_run_apex("configs/pong_no_frameskip_v4/", "run_pong_no_frameskip_v4.py") + + +if __name__ == "__main__": + test_run_pong_no_frame_skip()