Skip to content

[tune/rllib] Add checkpoint eraser #4490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 7, 2019
Merged
2 changes: 2 additions & 0 deletions python/ray/rllib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def run(args, parser):
args.experiment_name: { # i.e. log to ~/ray_results/default
"run": args.run,
"checkpoint_freq": args.checkpoint_freq,
"keep_checkpoints_num": args.keep_checkpoints_num,
"checkpoint_score_attr": args.checkpoint_score_attr,
"local_dir": args.local_dir,
"resources_per_trial": (
args.resources_per_trial and
Expand Down
18 changes: 18 additions & 0 deletions python/ray/tune/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ def make_parser(parser_creator=None, **kwargs):
action="store_true",
help="Whether to checkpoint at the end of the experiment. "
"Default is False.")
parser.add_argument(
"--keep-checkpoints-num",
default=None,
type=int,
help="Number of last checkpoints to keep. Others get "
"deleted. Default (None) keeps all checkpoints.")
parser.add_argument(
"--checkpoint-score-attr",
default="training_iteration",
type=str,
help="Specifies by which attribute to rank the best checkpoint. "
"Default is increasing order. If attribute starts with min- it "
"will rank attribute in decreasing order. Example: "
"min-validation_loss")
parser.add_argument(
"--export-formats",
default=None,
Expand Down Expand Up @@ -143,6 +157,8 @@ def to_argv(config):
for k, v in config.items():
if "-" in k:
raise ValueError("Use '_' instead of '-' in `{}`".format(k))
if v is None:
continue
if not isinstance(v, bool) or v: # for argparse flags
argv.append("--{}".format(k.replace("_", "-")))
if isinstance(v, string_types):
Expand Down Expand Up @@ -188,6 +204,8 @@ def create_trial_from_spec(spec, output_path, parser, **trial_kwargs):
stopping_criterion=spec.get("stop", {}),
checkpoint_freq=args.checkpoint_freq,
checkpoint_at_end=args.checkpoint_at_end,
keep_checkpoints_num=args.keep_checkpoints_num,
checkpoint_score_attr=args.checkpoint_score_attr,
export_formats=spec.get("export_formats", []),
# str(None) doesn't create None
restore_path=spec.get("restore"),
Expand Down
4 changes: 4 additions & 0 deletions python/ray/tune/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def __init__(self,
sync_function=None,
checkpoint_freq=0,
checkpoint_at_end=False,
keep_checkpoints_num=None,
checkpoint_score_attr=None,
export_formats=None,
max_failures=3,
restore=None,
Expand Down Expand Up @@ -102,6 +104,8 @@ def __init__(self,
"sync_function": sync_function,
"checkpoint_freq": checkpoint_freq,
"checkpoint_at_end": checkpoint_at_end,
"keep_checkpoints_num": keep_checkpoints_num,
"checkpoint_score_attr": checkpoint_score_attr,
"export_formats": export_formats or [],
"max_failures": max_failures,
"restore": restore
Expand Down
39 changes: 37 additions & 2 deletions python/ray/tune/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import print_function

import logging
import math
import os
import random
import time
Expand Down Expand Up @@ -469,10 +470,44 @@ def save(self, trial, storage=Checkpoint.DISK):
if storage == Checkpoint.MEMORY:
trial._checkpoint.value = trial.runner.save_to_object.remote()
else:
with warn_if_slow("save_to_disk"):
trial._checkpoint.value = ray.get(trial.runner.save.remote())
# Keeps only highest performing checkpoints if enabled
if trial.keep_checkpoints_num:
try:
last_attr_val = trial.last_result[
trial.checkpoint_score_attr]
if (trial.compare_checkpoints(last_attr_val)
and not math.isnan(last_attr_val)):
trial.best_checkpoint_attr_value = last_attr_val
self._checkpoint_and_erase(trial)
except KeyError:
logger.warning(
"Result dict has no key: {}. keep"
"_checkpoints_num flag will not work".format(
trial.checkpoint_score_attr))
else:
with warn_if_slow("save_to_disk"):
trial._checkpoint.value = ray.get(
trial.runner.save.remote())

return trial._checkpoint.value

def _checkpoint_and_erase(self, trial):
"""Checkpoints the model and erases old checkpoints
if needed.
Parameters
----------
trial : trial to save
"""

with warn_if_slow("save_to_disk"):
trial._checkpoint.value = ray.get(trial.runner.save.remote())

if len(trial.history) >= trial.keep_checkpoints_num:
ray.get(trial.runner.delete_checkpoint.remote(trial.history[-1]))
trial.history.pop()

trial.history.insert(0, trial._checkpoint.value)

def restore(self, trial, checkpoint=None):
"""Restores training state from a given model checkpoint.

Expand Down
11 changes: 11 additions & 0 deletions python/ray/tune/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,17 @@ def train(self):

return result

def delete_checkpoint(self, checkpoint_dir):
"""Removes subdirectory within checkpoint_folder
Parameters
----------
checkpoint_dir : path to checkpoint
"""
if os.path.isfile(checkpoint_dir):
shutil.rmtree(os.path.dirname(checkpoint_dir))
else:
shutil.rmtree(checkpoint_dir)

def save(self, checkpoint_dir=None):
"""Saves the current model state to a checkpoint.

Expand Down
35 changes: 34 additions & 1 deletion python/ray/tune/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def __init__(self,
stopping_criterion=None,
checkpoint_freq=0,
checkpoint_at_end=False,
keep_checkpoints_num=None,
checkpoint_score_attr="",
export_formats=None,
restore_path=None,
upload_dir=None,
Expand Down Expand Up @@ -288,6 +290,16 @@ def __init__(self,
self.last_update_time = -float("inf")
self.checkpoint_freq = checkpoint_freq
self.checkpoint_at_end = checkpoint_at_end

self.history = []
self.keep_checkpoints_num = keep_checkpoints_num
self._cmp_greater = not checkpoint_score_attr.startswith("min-")
self.best_checkpoint_attr_value = -float("inf") \
if self._cmp_greater else float("inf")
# Strip off "min-" from checkpoint attribute
self.checkpoint_score_attr = checkpoint_score_attr \
if self._cmp_greater else checkpoint_score_attr[4:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could add a comment that this strips off the "min-" part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!


self._checkpoint = Checkpoint(
storage=Checkpoint.DISK, value=restore_path)
self.export_formats = export_formats
Expand All @@ -299,7 +311,6 @@ def __init__(self,
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
self.error_file = None
self.num_failures = 0

self.custom_trial_name = None

# AutoML fields
Expand Down Expand Up @@ -495,6 +506,28 @@ def update_last_result(self, result, terminate=False):
self.last_update_time = time.time()
self.result_logger.on_result(self.last_result)

def compare_checkpoints(self, attr_mean):
"""Compares two checkpoints based on the attribute attr_mean param.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this to match the Google Style for docstrings? https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html

Greater than is used by default. If command-line parameter
checkpoint_score_attr starts with "min-" less than is used.

Arguments:
attr_mean: mean of attribute value for the current checkpoint

Returns:
True: when attr_mean is greater than previous checkpoint attr_mean
and greater than function is selected
when attr_mean is less than previous checkpoint attr_mean and
less than function is selected
False: when attr_mean is not in alignment with selected cmp fn
"""
if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value:
return True
elif (not self._cmp_greater
and attr_mean < self.best_checkpoint_attr_value):
return True
return False

def _get_trainable_cls(self):
return ray.tune.registry._global_registry.get(
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name)
Expand Down