Skip to content

Commit

Permalink
Merge pull request #525 from prabhatnagarajan/chkpt_freq
Browse files Browse the repository at this point in the history
Adds checkpoint frequencies for serial and batch Agents.
  • Loading branch information
prabhatnagarajan authored Sep 13, 2019
2 parents e521a4e + 649b3c1 commit b14faec
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
11 changes: 8 additions & 3 deletions chainerrl/experiments/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def ask_and_save_agent_replay_buffer(agent, t, outdir, suffix=''):
save_agent_replay_buffer(agent, t, outdir, suffix=suffix)


def train_agent(agent, env, steps, outdir, max_episode_len=None,
step_offset=0, evaluator=None, successful_score=None,
step_hooks=(), logger=None):
def train_agent(agent, env, steps, outdir, checkpoint_freq=None,
max_episode_len=None, step_offset=0, evaluator=None,
successful_score=None, step_hooks=(), logger=None):

logger = logger or logging.getLogger(__name__)

Expand Down Expand Up @@ -80,6 +80,8 @@ def train_agent(agent, env, steps, outdir, max_episode_len=None,
episode_len = 0
obs = env.reset()
r = 0
if checkpoint_freq and t % checkpoint_freq == 0:
save_agent(agent, t, outdir, logger, suffix='_checkpoint')

except (Exception, KeyboardInterrupt):
# Save the current model before being killed
Expand All @@ -97,6 +99,7 @@ def train_agent_with_evaluation(agent,
eval_n_episodes,
eval_interval,
outdir,
checkpoint_freq=None,
train_max_episode_len=None,
step_offset=0,
eval_max_episode_len=None,
Expand All @@ -116,6 +119,7 @@ def train_agent_with_evaluation(agent,
eval_n_episodes (int): Number of episodes at each evaluation phase.
eval_interval (int): Interval of evaluation.
outdir (str): Path to the directory to output data.
checkpoint_freq (int): frequency at which agents are stored.
train_max_episode_len (int): Maximum episode length during training.
step_offset (int): Time step from which training starts.
eval_max_episode_len (int or None): Maximum episode length of
Expand Down Expand Up @@ -155,6 +159,7 @@ def train_agent_with_evaluation(agent,

train_agent(
agent, env, steps, outdir,
checkpoint_freq=checkpoint_freq,
max_episode_len=train_max_episode_len,
step_offset=step_offset,
evaluator=evaluator,
Expand Down
11 changes: 10 additions & 1 deletion chainerrl/experiments/train_agent_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from chainerrl.misc.makedirs import makedirs


def train_agent_batch(agent, env, steps, outdir, log_interval=None,
def train_agent_batch(agent, env, steps, outdir,
checkpoint_freq=None, log_interval=None,
max_episode_len=None, eval_interval=None,
step_offset=0, evaluator=None, successful_score=None,
step_hooks=(), return_window_size=100, logger=None):
Expand All @@ -28,6 +29,7 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None,
steps (int): Number of total time steps for training.
eval_interval (int): Interval of evaluation.
outdir (str): Path to the directory to output things.
checkpoint_freq (int): frequency at which agents are stored.
log_interval (int): Interval of logging.
max_episode_len (int): Maximum episode length.
step_offset (int): Time step from which training starts.
Expand Down Expand Up @@ -92,6 +94,10 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None,

for _ in range(num_envs):
t += 1
if checkpoint_freq and t % checkpoint_freq == 0:
save_agent(agent, t, outdir, logger,
suffix='_checkpoint')

for hook in step_hooks:
hook(env, agent, t)

Expand Down Expand Up @@ -141,6 +147,7 @@ def train_agent_batch_with_evaluation(agent,
eval_n_episodes,
eval_interval,
outdir,
checkpoint_freq=None,
max_episode_len=None,
step_offset=0,
eval_max_episode_len=None,
Expand All @@ -163,6 +170,7 @@ def train_agent_batch_with_evaluation(agent,
eval_interval (int): Interval of evaluation.
outdir (str): Path to the directory to output things.
log_interval (int): Interval of logging.
checkpoint_freq (int): frequency with which to store networks
max_episode_len (int): Maximum episode length.
step_offset (int): Time step from which training starts.
return_window_size (int): Number of training episodes used to estimate
Expand Down Expand Up @@ -204,6 +212,7 @@ def train_agent_batch_with_evaluation(agent,

train_agent_batch(
agent, env, steps, outdir,
checkpoint_freq=checkpoint_freq,
max_episode_len=max_episode_len,
step_offset=step_offset,
eval_interval=eval_interval,
Expand Down
4 changes: 4 additions & 0 deletions examples/atari/train_dqn_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def main():
help='Learning rate.')
parser.add_argument('--prioritized', action='store_true', default=False,
help='Use prioritized experience replay.')
parser.add_argument('--checkpoint-frequency', type=int,
default=None,
help='Frequency at which agents are stored.')
args = parser.parse_args()

import logging
Expand Down Expand Up @@ -234,6 +237,7 @@ def phi(x):
experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=args.steps,
eval_n_steps=None,
checkpoint_freq=args.checkpoint_frequency,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
outdir=args.outdir,
Expand Down
4 changes: 4 additions & 0 deletions examples/atari/train_ppo_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def main():
parser.add_argument('--no-frame-stack', action='store_true', default=False,
help='Disable frame stacking so that the agent can'
' only see the current screen.')
parser.add_argument('--checkpoint-frequency', type=int,
default=None,
help='Frequency at which agents are stored.')
args = parser.parse_args()

import logging
Expand Down Expand Up @@ -233,6 +236,7 @@ def lr_setter(env, agent, value):
steps=args.steps,
eval_n_steps=None,
eval_n_episodes=args.eval_n_runs,
checkpoint_freq=args.checkpoint_frequency,
eval_interval=args.eval_interval,
log_interval=args.log_interval,
save_best_so_far_agent=False,
Expand Down

0 comments on commit b14faec

Please sign in to comment.