Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions chainerrl/agents/pcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class PCL(agent.AttributeSavingMixin, agent.AsyncAgent):
- action distributions (Distribution)
- state values (chainer.Variable)
optimizer (chainer.Optimizer): optimizer used to train the model
t_max (int or None): The model is updated after every t_max local
steps. If set None, the model is updated after every episode.
update_interval (int): The model is updated after every update_interval
local steps
gamma (float): Discount factor [0,1]
tau (float): Weight coefficient for the entropy regularizaiton term.
phi (callable): Feature extractor function
Expand All @@ -64,6 +64,10 @@ class PCL(agent.AttributeSavingMixin, agent.AsyncAgent):
(batchsize x t_max).
disable_online_update (bool): If set true, disable online on-policy
update and rely only on experience replay.
t_max (int or None): Maximum length of trajectories sampled from the
replay buffer. If set to None, there is not limit on it,
complete trajectories / episodes will be sampled. Refer to the
behavior of AbstractEpisodicReplayBuffer for more details.
n_times_replay (int): Number of times experience replay is repeated per
one time of online update.
replay_start_size (int): Experience replay is disabled if the number of
Expand Down Expand Up @@ -95,7 +99,7 @@ class PCL(agent.AttributeSavingMixin, agent.AsyncAgent):

def __init__(self, model, optimizer,
replay_buffer=None,
t_max=None,
update_interval=1,
gamma=0.99,
tau=1e-2,
phi=lambda x: x,
Expand All @@ -104,6 +108,7 @@ def __init__(self, model, optimizer,
rollout_len=10,
batchsize=1,
disable_online_update=False,
t_max=None,
n_times_replay=1,
replay_start_size=10 ** 2,
normalize_loss_by_steps=True,
Expand Down Expand Up @@ -131,7 +136,7 @@ def __init__(self, model, optimizer,
self.optimizer = optimizer

self.replay_buffer = replay_buffer
self.t_max = t_max
self.update_interval = update_interval
self.gamma = gamma
self.tau = tau
self.phi = phi
Expand All @@ -148,6 +153,7 @@ def __init__(self, model, optimizer,
self.normalize_loss_by_steps = normalize_loss_by_steps
self.act_deterministically = act_deterministically
self.disable_online_update = disable_online_update
self.t_max = t_max
self.n_times_replay = n_times_replay
self.replay_start_size = replay_start_size
self.average_loss_decay = average_loss_decay
Expand Down Expand Up @@ -385,7 +391,7 @@ def act_and_train(self, obs, reward):
if self.last_state is not None:
self.past_rewards[self.t - 1] = reward

if self.t - self.t_start == self.t_max:
if self.t - self.t_start == self.update_interval:
self.update_on_policy(statevar)
if len(self.online_batch_losses) == 0:
for _ in range(self.n_times_replay):
Expand Down
4 changes: 3 additions & 1 deletion examples/gym/train_pcl_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def main():
parser.add_argument('--n-hidden-layers', type=int, default=2)
parser.add_argument('--n-times-replay', type=int, default=1)
parser.add_argument('--replay-start-size', type=int, default=10000)
parser.add_argument('--update-interval', type=int, default=1)
parser.add_argument('--t-max', type=int, default=None)
parser.add_argument('--tau', type=float, default=1e-2)
parser.add_argument('--profile', action='store_true')
Expand Down Expand Up @@ -156,10 +157,11 @@ def make_env(process_idx, test):

agent = chainerrl.agents.PCL(
model, opt, replay_buffer=replay_buffer,
t_max=args.t_max, gamma=0.99,
update_interval=args.update_interval, gamma=0.99,
tau=args.tau,
phi=lambda x: x.astype(np.float32, copy=False),
rollout_len=args.rollout_len,
t_max=args.t_max,
n_times_replay=args.n_times_replay,
replay_start_size=args.replay_start_size,
batchsize=args.batchsize,
Expand Down
21 changes: 13 additions & 8 deletions tests/agents_tests/test_pcl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

@testing.parameterize(*(
testing.product({
't_max': [1],
'update_interval': [1],
't_max': [10],
'use_lstm': [False],
'episodic': [True], # PCL doesn't work well with continuing envs
'disable_online_update': [True, False],
Expand All @@ -36,6 +37,7 @@
'batchsize': [1, 5],
}) +
testing.product({
'update_interval': [1],
't_max': [None],
'use_lstm': [True, False],
'episodic': [True],
Expand All @@ -53,25 +55,27 @@ def setUp(self):

@testing.attr.slow
def test_abc_discrete(self):
self._test_abc(self.t_max, self.use_lstm, episodic=self.episodic)
self._test_abc(self.t_max, self.update_interval, self.use_lstm,
episodic=self.episodic)

def test_abc_discrete_fast(self):
self._test_abc(self.t_max, self.use_lstm, episodic=self.episodic,
steps=10, require_success=False)
self._test_abc(self.t_max, self.update_interval, self.use_lstm,
episodic=self.episodic, steps=10,
require_success=False)

@testing.attr.slow
def test_abc_gaussian(self):
self._test_abc(self.t_max, self.use_lstm,
self._test_abc(self.t_max, self.update_interval, self.use_lstm,
discrete=False, episodic=self.episodic,
steps=100000)

def test_abc_gaussian_fast(self):
self._test_abc(self.t_max, self.use_lstm,
self._test_abc(self.t_max, self.update_interval, self.use_lstm,
discrete=False, episodic=self.episodic,
steps=10, require_success=False)

def _test_abc(self, t_max, use_lstm, discrete=True, episodic=True,
steps=100000, require_success=True):
def _test_abc(self, t_max, update_interval, use_lstm, discrete=True,
episodic=True, steps=100000, require_success=True):

nproc = 8

Expand Down Expand Up @@ -182,6 +186,7 @@ def phi(x):
agent = pcl.PCL(model, opt,
replay_buffer=replay_buffer,
t_max=t_max,
update_interval=update_interval,
gamma=gamma,
tau=tau,
phi=phi,
Expand Down