-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tune] Treat checkpoints with nan value as worst #23862
Conversation
cc @XuehaiPan |
It needs similar changes with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @Yard1 and thanks for raising the issue @XuehaiPan!
@@ -5,7 +5,7 @@ | |||
from typing import Any, Callable, Optional | |||
|
|||
from ray.tune.result import NODE_IP | |||
from ray.tune.utils.util import flatten_dict | |||
from ray.tune.utils.util import flatten_dict, is_nan |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use the ml_utils is_nan
directly and remove it from tune.utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to use an alias - we have a precedent for that already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally good, one suggestion
This looks great @Yard1! Can we resolve the conflicts, and then I can merge! |
if self._checkpoint_score_desc: | ||
priority = -priority | ||
return (not is_nan(priority), priority, checkpoint.order) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When priority
is nan
, sorting by tuple key:
(not is_nan(priority), priority, checkpoint.order)
won't give the correct order by checkpoint.order
. Because both nan < nan
and nan > nan
return False
.
if self._checkpoint_score_desc: | |
priority = -priority | |
return (not is_nan(priority), priority, checkpoint.order) | |
if self._checkpoint_score_desc: | |
priority = -priority | |
if is_nan(priority): | |
return (0, checkpoint.order, priority) | |
return (1, priority, checkpoint.order) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it does:
>>> import numpy as np
>>> (False, np.nan, 3) < (False, np.nan, 4)
True
>>> (False, np.nan, 4) < (False, np.nan, 3)
False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it does:
>>> import numpy as np >>> (False, np.nan, 3) < (False, np.nan, 4) True >>> (False, np.nan, 4) < (False, np.nan, 3) False
Seems that tuple.__lt__
skips items when lhs[i] is rhs[i]
.
In [1]: (False, float('nan'), 3) < (False, float('nan'), 4)
Out[1]: False
In [2]: (False, float('nan'), 3) > (False, float('nan'), 4)
Out[2]: False
In [3]: (False, float('nan'), 4) < (False, float('nan'), 3)
Out[3]: False
In [4]: (False, float('nan'), 4) > (False, float('nan'), 3)
Out[4]: False
In [5]: import numpy as np
In [6]: (False, np.nan, 3) < (False, np.nan, 4)
Out[6]: True
In [7]: (False, np.nan, 3) > (False, np.nan, 4)
Out[7]: False
In [8]: import math
In [9]: (False, math.nan, 3) < (False, math.nan, 4)
Out[9]: True
In [10]: (False, math.nan, 3) > (False, math.nan, 4)
Out[10]: False
In [11]: float('nan') is float('nan')
Out[11]: False
In [12]: np.nan is np.nan
Out[12]: True
In [13]: math.nan is math.nan
Out[13]: True
In [14]: float('nan') is math.nan
Out[14]: False
In [15]: math.nan is np.nan
Out[15]: False
In [16]: (False, math.nan, 3) < (False, np.nan, 4)
Out[16]: False
np.nan
is a single variable, but each call of float('nan')
will create a new variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah thanks, you're right. It seems indeed like float('nan') <= float('nan')
is False
, unlike for np.
Fix here: #23909
Following #23862, there was an uncaught bug when comparing nan-priority checkpoints. This is because float("nan") <= float("nan") is always False (unlike e.g. np.nan <= np.nan, which is True). This PR fixes this bug and adds a new test to ensure correct behavior.
Why are these changes needed?
Changes the logic in
CheckpointManager
to consider checkpoints withnan
value of the metric as worst values, meaning they will be deleted first ifkeep_checkpoints_num
is set.Related issue number
Closes #23856
Checks
scripts/format.sh
to lint the changes in this PR.