Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds checkpoint frequencies for serial and batch Agents. #525

Merged
merged 10 commits into from
Sep 13, 2019
Prev Previous commit
Next Next commit
Adds checkpointing for batch training
  • Loading branch information
prabhatnagarajan committed Aug 20, 2019
commit 5832fe3b885b78528a83539c4ab2aa1a23bd70a4
11 changes: 10 additions & 1 deletion chainerrl/experiments/train_agent_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from chainerrl.misc.makedirs import makedirs


def train_agent_batch(agent, env, steps, outdir, log_interval=None,
def train_agent_batch(agent, env, steps, outdir,
checkpoint_freq=None, log_interval=None,
max_episode_len=None, eval_interval=None,
step_offset=0, evaluator=None, successful_score=None,
step_hooks=(), return_window_size=100, logger=None):
Expand All @@ -28,6 +29,7 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None,
steps (int): Number of total time steps for training.
eval_interval (int): Interval of evaluation.
outdir (str): Path to the directory to output things.
checkpoint_freq (int): frequency with which to store networks
log_interval (int): Interval of logging.
max_episode_len (int): Maximum episode length.
step_offset (int): Time step from which training starts.
Expand Down Expand Up @@ -92,6 +94,10 @@ def train_agent_batch(agent, env, steps, outdir, log_interval=None,

for _ in range(num_envs):
t += 1
if checkpoint_freq:
if t % checkpoint_freq == 0:
save_agent(agent, t, outdir, logger, suffix='_checkpoint')

for hook in step_hooks:
hook(env, agent, t)

Expand Down Expand Up @@ -141,6 +147,7 @@ def train_agent_batch_with_evaluation(agent,
eval_n_episodes,
eval_interval,
outdir,
checkpoint_freq=None,
max_episode_len=None,
step_offset=0,
eval_max_episode_len=None,
Expand All @@ -163,6 +170,7 @@ def train_agent_batch_with_evaluation(agent,
eval_interval (int): Interval of evaluation.
outdir (str): Path to the directory to output things.
log_interval (int): Interval of logging.
checkpoint_freq (int): frequency with which to store networks
max_episode_len (int): Maximum episode length.
step_offset (int): Time step from which training starts.
return_window_size (int): Number of training episodes used to estimate
Expand Down Expand Up @@ -204,6 +212,7 @@ def train_agent_batch_with_evaluation(agent,

train_agent_batch(
agent, env, steps, outdir,
checkpoint_freq=checkpoint_freq,
max_episode_len=max_episode_len,
step_offset=step_offset,
eval_interval=eval_interval,
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/train_dqn_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def phi(x):
experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=args.steps,
eval_n_steps=None,
checkpoint_freq=args.checkpoint_freq,
checkpoint_freq=args.checkpoint_frequency,
eval_n_episodes=args.eval_n_runs,
eval_interval=args.eval_interval,
outdir=args.outdir,
Expand Down
4 changes: 4 additions & 0 deletions examples/atari/train_ppo_ale.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def main():
parser.add_argument('--no-frame-stack', action='store_true', default=False,
help='Disable frame stacking so that the agent can'
' only see the current screen.')
parser.add_argument('--checkpoint-frequency', type=int,
default=None,
help='Frequency with which to checkpoint networks')
args = parser.parse_args()

import logging
Expand Down Expand Up @@ -233,6 +236,7 @@ def lr_setter(env, agent, value):
steps=args.steps,
eval_n_steps=None,
eval_n_episodes=args.eval_n_runs,
checkpoint_freq=args.checkpoint_frequency,
eval_interval=args.eval_interval,
log_interval=args.log_interval,
save_best_so_far_agent=False,
Expand Down