-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[tune/rllib] Add checkpoint eraser #4490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9421009
41c84ca
73b28ed
21fa087
017f50c
c8e46c2
8e695f8
0436a94
de0232c
791b288
cd06f37
ec4e942
c6c5681
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -253,6 +253,8 @@ def __init__(self, | |
stopping_criterion=None, | ||
checkpoint_freq=0, | ||
checkpoint_at_end=False, | ||
keep_checkpoints_num=None, | ||
checkpoint_score_attr="", | ||
export_formats=None, | ||
restore_path=None, | ||
upload_dir=None, | ||
|
@@ -288,6 +290,16 @@ def __init__(self, | |
self.last_update_time = -float("inf") | ||
self.checkpoint_freq = checkpoint_freq | ||
self.checkpoint_at_end = checkpoint_at_end | ||
|
||
self.history = [] | ||
self.keep_checkpoints_num = keep_checkpoints_num | ||
self._cmp_greater = not checkpoint_score_attr.startswith("min-") | ||
self.best_checkpoint_attr_value = -float("inf") \ | ||
if self._cmp_greater else float("inf") | ||
# Strip off "min-" from checkpoint attribute | ||
self.checkpoint_score_attr = checkpoint_score_attr \ | ||
if self._cmp_greater else checkpoint_score_attr[4:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could add a comment that this strips off the "min-" part. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added! |
||
|
||
self._checkpoint = Checkpoint( | ||
storage=Checkpoint.DISK, value=restore_path) | ||
self.export_formats = export_formats | ||
|
@@ -299,7 +311,6 @@ def __init__(self, | |
self.trial_id = Trial.generate_id() if trial_id is None else trial_id | ||
self.error_file = None | ||
self.num_failures = 0 | ||
|
||
self.custom_trial_name = None | ||
|
||
# AutoML fields | ||
|
@@ -495,6 +506,28 @@ def update_last_result(self, result, terminate=False): | |
self.last_update_time = time.time() | ||
self.result_logger.on_result(self.last_result) | ||
|
||
def compare_checkpoints(self, attr_mean): | ||
"""Compares two checkpoints based on the attribute attr_mean param. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you change this to match the Google Style for docstrings? https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html |
||
Greater than is used by default. If command-line parameter | ||
checkpoint_score_attr starts with "min-" less than is used. | ||
|
||
Arguments: | ||
attr_mean: mean of attribute value for the current checkpoint | ||
|
||
Returns: | ||
True: when attr_mean is greater than previous checkpoint attr_mean | ||
and greater than function is selected | ||
when attr_mean is less than previous checkpoint attr_mean and | ||
less than function is selected | ||
False: when attr_mean is not in alignment with selected cmp fn | ||
""" | ||
if self._cmp_greater and attr_mean > self.best_checkpoint_attr_value: | ||
return True | ||
elif (not self._cmp_greater | ||
and attr_mean < self.best_checkpoint_attr_value): | ||
return True | ||
return False | ||
|
||
def _get_trainable_cls(self): | ||
return ray.tune.registry._global_registry.get( | ||
ray.tune.registry.TRAINABLE_CLASS, self.trainable_name) | ||
|
Uh oh!
There was an error while loading. Please reload this page.