Skip to content

Commit

Permalink
Replace prioritized_replay_buffer as PrioritizedBufferWrapper (#233)
Browse files Browse the repository at this point in the history
* Add perwrapper

* Add descriptions & Change config parameter

* Delete prioritized_replay_buffer & Add descriptions

* Change minor parameters name & descriptions

* Fix isseus commented
  • Loading branch information
jinPrelude authored Jun 10, 2020
1 parent b8ab7ab commit a302479
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 85 deletions.
53 changes: 53 additions & 0 deletions rl_algorithms/common/abstract/buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
"""Abstract Buffer & BufferWrapper class.
- Author: Euijin Jeong
- Contact: euijin.jeong@medipixel.io
"""

from abc import ABC, abstractmethod
from typing import Any, Tuple

import numpy as np


class BaseBuffer(ABC):
"""Abstract Buffer used for replay buffer."""

@abstractmethod
def add(self, transition: Tuple[Any, ...]) -> Tuple[Any, ...]:
pass

@abstractmethod
def sample(self) -> Tuple[np.ndarray, ...]:
pass

@abstractmethod
def __len__(self) -> int:
pass


class BufferWrapper(BaseBuffer):
"""Abstract BufferWrapper used for buffer wrapper.
Attributes:
buffer (Buffer): Hold replay buffer as am attribute
"""

def __init__(self, base_buffer: BaseBuffer):
"""Initialize a ReplayBuffer object.
Args:
base_buffer (int): ReplayBuffer which should be hold
"""
self.buffer = base_buffer

def add(self, transition: Tuple[Any, ...]) -> Tuple[Any, ...]:
return self.buffer.add(transition)

def sample(self) -> Tuple[np.ndarray, ...]:
return self.buffer.sample()

def __len__(self) -> int:
"""Return the current size of internal memory."""
return len(self.buffer)
33 changes: 16 additions & 17 deletions rl_algorithms/common/buffer/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import numpy as np

from rl_algorithms.common.abstract.buffer import BaseBuffer
from rl_algorithms.common.helper_functions import get_n_step_info


class ReplayBuffer:
class ReplayBuffer(BaseBuffer):
"""Fixed-size buffer to store experience tuples.
Attributes:
Expand All @@ -21,7 +22,7 @@ class ReplayBuffer:
n_step_buffer (deque): recent n transitions
n_step (int): step size for n-step transition
gamma (float): discount factor
buffer_size (int): size of buffers
max_len (int): size of buffers
batch_size (int): batch size for training
demo_size (int): size of demo transitions
length (int): amount of memory filled
Expand All @@ -30,7 +31,7 @@ class ReplayBuffer:

def __init__(
self,
buffer_size: int,
max_len: int,
batch_size: int,
gamma: float = 0.99,
n_step: int = 1,
Expand All @@ -39,15 +40,15 @@ def __init__(
"""Initialize a ReplayBuffer object.
Args:
buffer_size (int): size of replay buffer for experience
max_len (int): size of replay buffer for experience
batch_size (int): size of a batched sampled from replay buffer for training
gamma (float): discount factor
n_step (int): step size for n-step transition
demo (list): transitions of human play
"""
assert 0 < batch_size <= buffer_size
assert 0 < batch_size <= max_len
assert 0.0 <= gamma <= 1.0
assert 1 <= n_step <= buffer_size
assert 1 <= n_step <= max_len

self.obs_buf: np.ndarray = None
self.acts_buf: np.ndarray = None
Expand All @@ -59,7 +60,7 @@ def __init__(
self.n_step = n_step
self.gamma = gamma

self.buffer_size = buffer_size
self.max_len = max_len
self.batch_size = batch_size
self.demo_size = len(demo) if demo else 0
self.demo = demo
Expand All @@ -68,7 +69,7 @@ def __init__(

# demo may have empty tuple list [()]
if self.demo and self.demo[0]:
self.buffer_size += self.demo_size
self.max_len += self.demo_size
self.length += self.demo_size
for idx, d in enumerate(self.demo):
state, action, reward, next_state, done = d
Expand Down Expand Up @@ -112,8 +113,8 @@ def add(
self.done_buf[self.idx] = done

self.idx += 1
self.idx = self.demo_size if self.idx % self.buffer_size == 0 else self.idx
self.length = min(self.length + 1, self.buffer_size)
self.idx = self.demo_size if self.idx % self.max_len == 0 else self.idx
self.length = min(self.length + 1, self.max_len)

# return a single step transition to insert to replay buffer
return self.n_step_buffer[0]
Expand Down Expand Up @@ -143,17 +144,15 @@ def sample(self, indices: List[int] = None) -> Tuple[np.ndarray, ...]:
def _initialize_buffers(self, state: np.ndarray, action: np.ndarray) -> None:
"""Initialze buffers for state, action, resward, next_state, done."""
# In case action of demo is not np.ndarray
self.obs_buf = np.zeros(
[self.buffer_size] + list(state.shape), dtype=state.dtype
)
self.obs_buf = np.zeros([self.max_len] + list(state.shape), dtype=state.dtype)
self.acts_buf = np.zeros(
[self.buffer_size] + list(action.shape), dtype=action.dtype
[self.max_len] + list(action.shape), dtype=action.dtype
)
self.rews_buf = np.zeros([self.buffer_size], dtype=float)
self.rews_buf = np.zeros([self.max_len], dtype=float)
self.next_obs_buf = np.zeros(
[self.buffer_size] + list(state.shape), dtype=state.dtype
[self.max_len] + list(state.shape), dtype=state.dtype
)
self.done_buf = np.zeros([self.buffer_size], dtype=float)
self.done_buf = np.zeros([self.max_len], dtype=float)

def __len__(self) -> int:
"""Return the current size of internal memory."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
# -*- coding: utf-8 -*-
"""Prioritized Replay buffer for algorithms.
"""Wrappers for buffer.
- Author: Kyunghwan Kim
- Contact: kh.kim@medipixel.io
- Author: Kyunghwan Kim & Euijin Jeong
- Contact: kh.kim@medipixel.io & euijin.jeong@medipixel.io
- Paper: https://arxiv.org/pdf/1511.05952.pdf
https://arxiv.org/pdf/1707.08817.pdf
"""

import random
from typing import Any, List, Tuple
from typing import Any, Tuple

import numpy as np
import torch

from rl_algorithms.common.buffer.replay_buffer import ReplayBuffer
from rl_algorithms.common.abstract.buffer import BaseBuffer, BufferWrapper
from rl_algorithms.common.buffer.segment_tree import MinSegmentTree, SumSegmentTree

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class PrioritizedReplayBuffer(ReplayBuffer):
"""Create Prioritized Replay buffer.
class PrioritizedBufferWrapper(BufferWrapper):
"""Prioritized Experience Replay wrapper for Buffer.
Refer to OpenAI baselines github repository:
https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py
Attributes:
buffer (Buffer): Hold replay buffer as an attribute
alpha (float): alpha parameter for prioritized replay buffer
epsilon_d (float): small positive constants to add to the priorities
tree_idx (int): next index of tree
Expand All @@ -35,65 +37,56 @@ class PrioritizedReplayBuffer(ReplayBuffer):
"""

def __init__(
self,
buffer_size: int,
batch_size: int = 32,
gamma: float = 0.99,
n_step: int = 1,
alpha: float = 0.6,
epsilon_d: float = 1.0,
demo: List[Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]] = None,
self, base_buffer: BaseBuffer, alpha: float = 0.6, epsilon_d: float = 1.0
):
"""Initialize.
Args:
buffer_size (int): size of replay buffer for experience
batch_size (int): size of a batched sampled from replay buffer for training
base_buffer (Buffer): ReplayBuffer which should be hold
alpha (float): alpha parameter for prioritized replay buffer
epsilon_d (float): small positive constants to add to the priorities
"""
super(PrioritizedReplayBuffer, self).__init__(
buffer_size, batch_size, gamma, n_step, demo
)
BufferWrapper.__init__(self, base_buffer)
assert alpha >= 0
self.alpha = alpha
self.epsilon_d = epsilon_d
self.tree_idx = 0

# capacity must be positive and a power of 2.
tree_capacity = 1
while tree_capacity < self.buffer_size:
while tree_capacity < self.buffer.max_len:
tree_capacity *= 2

self.sum_tree = SumSegmentTree(tree_capacity)
self.min_tree = MinSegmentTree(tree_capacity)
self._max_priority = 1.0

# for init priority of demo
self.tree_idx = self.demo_size
for i in range(self.demo_size):
self.tree_idx = self.buffer.demo_size
for i in range(self.buffer.demo_size):
self.sum_tree[i] = self._max_priority ** self.alpha
self.min_tree[i] = self._max_priority ** self.alpha

def add(
self, transition: Tuple[np.ndarray, np.ndarray, float, np.ndarray, bool]
) -> Tuple[Any, ...]:
"""Add experience and priority."""
n_step_transition = super().add(transition)
n_step_transition = self.buffer.add(transition)
if n_step_transition:
self.sum_tree[self.tree_idx] = self._max_priority ** self.alpha
self.min_tree[self.tree_idx] = self._max_priority ** self.alpha

self.tree_idx += 1
if self.tree_idx % self.buffer_size == 0:
self.tree_idx = self.demo_size
if self.tree_idx % self.buffer.max_len == 0:
self.tree_idx = self.buffer.demo_size

return n_step_transition

def _sample_proportional(self, batch_size: int) -> list:
"""Sample indices based on proportional."""
indices = []
p_total = self.sum_tree.sum(0, len(self) - 1)
p_total = self.sum_tree.sum(0, len(self.buffer) - 1)
segment = p_total / batch_size

for i in range(batch_size):
Expand All @@ -106,29 +99,29 @@ def _sample_proportional(self, batch_size: int) -> list:

def sample(self, beta: float = 0.4) -> Tuple[torch.Tensor, ...]:
"""Sample a batch of experiences."""
assert len(self) >= self.batch_size
assert len(self.buffer) >= self.buffer.batch_size
assert beta > 0

indices = self._sample_proportional(self.batch_size)
indices = self._sample_proportional(self.buffer.batch_size)

# get max weight
p_min = self.min_tree.min() / self.sum_tree.sum()
max_weight = (p_min * len(self)) ** (-beta)
max_weight = (p_min * len(self.buffer)) ** (-beta)

# calculate weights
weights_, eps_d = [], []
for i in indices:
eps_d.append(self.epsilon_d if i < self.demo_size else 0.0)
eps_d.append(self.epsilon_d if i < self.buffer.demo_size else 0.0)
p_sample = self.sum_tree[i] / self.sum_tree.sum()
weight = (p_sample * len(self)) ** (-beta)
weight = (p_sample * len(self.buffer)) ** (-beta)
weights_.append(weight / max_weight)

weights = np.array(weights_)
eps_d = np.array(eps_d)

weights = weights.reshape(-1, 1)

states, actions, rewards, next_states, dones = super().sample(indices)
states, actions, rewards, next_states, dones = self.buffer.sample(indices)

return states, actions, rewards, next_states, dones, weights, indices, eps_d

Expand All @@ -138,7 +131,7 @@ def update_priorities(self, indices: list, priorities: np.ndarray):

for idx, priority in zip(indices, priorities):
assert priority > 0
assert 0 <= idx < len(self)
assert 0 <= idx < len(self.buffer)

self.sum_tree[idx] = priority ** self.alpha
self.min_tree[idx] = priority ** self.alpha
Expand Down
11 changes: 6 additions & 5 deletions rl_algorithms/dqn/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import wandb

from rl_algorithms.common.abstract.agent import Agent
from rl_algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBuffer
from rl_algorithms.common.buffer.replay_buffer import ReplayBuffer
from rl_algorithms.common.buffer.wrapper import PrioritizedBufferWrapper
from rl_algorithms.common.helper_functions import numpy2floattensor
from rl_algorithms.dqn.learner import DQNLearner
from rl_algorithms.registry import AGENTS
Expand Down Expand Up @@ -105,10 +105,11 @@ def _initialize(self):
"""Initialize non-common things."""
if not self.args.test:
# replay memory for a single step
self.memory = PrioritizedReplayBuffer(
self.hyper_params.buffer_size,
self.hyper_params.batch_size,
alpha=self.hyper_params.per_alpha,
self.memory = ReplayBuffer(
self.hyper_params.buffer_size, self.hyper_params.batch_size,
)
self.memory = PrioritizedBufferWrapper(
self.memory, alpha=self.hyper_params.per_alpha
)

# replay memory for multi-steps
Expand Down
15 changes: 7 additions & 8 deletions rl_algorithms/fd/ddpg_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import numpy as np
import torch

from rl_algorithms.common.buffer.priortized_replay_buffer import PrioritizedReplayBuffer
from rl_algorithms.common.buffer.replay_buffer import ReplayBuffer
from rl_algorithms.common.buffer.wrapper import PrioritizedBufferWrapper
import rl_algorithms.common.helper_functions as common_utils
from rl_algorithms.common.helper_functions import numpy2floattensor
from rl_algorithms.ddpg.agent import DDPGAgent
Expand Down Expand Up @@ -56,20 +56,19 @@ def _initialize(self):

# replay memory for multi-steps
self.memory_n = ReplayBuffer(
buffer_size=self.hyper_params.buffer_size,
max_len=self.hyper_params.buffer_size,
batch_size=self.hyper_params.batch_size,
n_step=self.hyper_params.n_step,
gamma=self.hyper_params.gamma,
demo=demos_n_step,
)

# replay memory for a single step
self.memory = PrioritizedReplayBuffer(
self.hyper_params.buffer_size,
self.hyper_params.batch_size,
demo=demos,
alpha=self.hyper_params.per_alpha,
epsilon_d=self.hyper_params.per_eps_demo,
self.memory = ReplayBuffer(
self.hyper_params.buffer_size, self.hyper_params.batch_size,
)
self.memory = PrioritizedBufferWrapper(
self.memory, alpha=self.hyper_params.per_alpha
)

self.learner = DDPGfDLearner(
Expand Down
Loading

0 comments on commit a302479

Please sign in to comment.