Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Treat checkpoints with nan value as worst #23862

Merged
merged 6 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions python/ray/train/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray.train.session import TrainingResult
from ray.train.utils import construct_path
from ray.util import PublicAPI
from ray.util.ml_utils.util import is_nan

if TUNE_INSTALLED:
from ray import tune
Expand Down Expand Up @@ -212,10 +213,13 @@ def write_checkpoint(self, checkpoint: Dict):
)

def priority(checkpoint_score_order, checkpoint_score):
if checkpoint_score_order == MAX:
return checkpoint_score
else:
return -checkpoint_score
# Treat NaN as worst
# The tuple structure is (not is_nan(), metric), which makes
# the nan values to be always considered as the worst
# metrics by the heap
if checkpoint_score_order != MAX:
checkpoint_score = -checkpoint_score
return (not is_nan(checkpoint_score), checkpoint_score)

checkpoint_priority = priority(checkpoint_score_order, checkpoint_score)

Expand Down
9 changes: 5 additions & 4 deletions python/ray/train/tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def test_persisted_checkpoint_strategy(ray_start_2_cpus):
)

def train_func():
train.save_checkpoint(loss=float("nan")) # nan, deleted
train.save_checkpoint(loss=3) # best
train.save_checkpoint(loss=7) # worst, deleted
train.save_checkpoint(loss=5)
Expand All @@ -530,16 +531,16 @@ def train_func():
assert trainer.logdir == Path(logdir).expanduser().resolve()
assert trainer.latest_checkpoint_dir.is_dir()
assert trainer.best_checkpoint_path.is_file()
assert trainer.best_checkpoint_path.name == f"checkpoint_{1:06d}"
assert trainer.best_checkpoint_path.name == f"checkpoint_{2:06d}"
assert trainer.latest_checkpoint["loss"] == 5
assert trainer.best_checkpoint["loss"] == 3

checkpoint_dir = trainer.latest_checkpoint_dir
file_names = [f.name for f in checkpoint_dir.iterdir()]
assert len(file_names) == 2
assert f"checkpoint_{1:06d}" in file_names
assert f"checkpoint_{2:06d}" not in file_names
assert f"checkpoint_{3:06d}" in file_names
assert f"checkpoint_{2:06d}" in file_names
assert f"checkpoint_{3:06d}" not in file_names
assert f"checkpoint_{4:06d}" in file_names

def validate():
checkpoint = train.load_checkpoint()
Expand Down
10 changes: 8 additions & 2 deletions python/ray/tune/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Callable, Optional

from ray.tune.result import NODE_IP
from ray.tune.utils.util import flatten_dict
from ray.tune.utils.util import flatten_dict, is_nan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use the ml_utils is_nan directly and remove it from tune.utils?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to use an alias - we have a precedent for that already.


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -168,6 +168,10 @@ def on_checkpoint(self, checkpoint: _TuneCheckpoint):
self.delete(old_checkpoint)

try:
# NaN metrics are treated as worst checkpoint
# The tuple structure is (not is_nan(), metric), which makes
# the nan values to be always considered as the worst
# metrics by the heap
queue_item = QueueItem(self._priority(checkpoint), checkpoint)
except KeyError:
logger.error(
Expand Down Expand Up @@ -198,7 +202,9 @@ def best_checkpoints(self):
def _priority(self, checkpoint):
result = flatten_dict(checkpoint.result)
priority = result[self._checkpoint_score_attr]
return -priority if self._checkpoint_score_desc else priority
if self._checkpoint_score_desc:
priority = -priority
return (not is_nan(priority), priority, checkpoint.order)
Comment on lines +205 to +207
Copy link
Contributor

@XuehaiPan XuehaiPan Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When priority is nan, sorting by tuple key:

(not is_nan(priority), priority, checkpoint.order)

won't give the correct order by checkpoint.order. Because both nan < nan and nan > nan return False.

Suggested change
if self._checkpoint_score_desc:
priority = -priority
return (not is_nan(priority), priority, checkpoint.order)
if self._checkpoint_score_desc:
priority = -priority
if is_nan(priority):
return (0, checkpoint.order, priority)
return (1, priority, checkpoint.order)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it does:

>>> import numpy as np
>>> (False, np.nan, 3) < (False, np.nan, 4)
True
>>> (False, np.nan, 4) < (False, np.nan, 3)
False

Copy link
Contributor

@XuehaiPan XuehaiPan Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it does:

>>> import numpy as np
>>> (False, np.nan, 3) < (False, np.nan, 4)
True
>>> (False, np.nan, 4) < (False, np.nan, 3)
False

Seems that tuple.__lt__ skips items when lhs[i] is rhs[i].

In [1]: (False, float('nan'), 3) < (False, float('nan'), 4)
Out[1]: False

In [2]: (False, float('nan'), 3) > (False, float('nan'), 4)
Out[2]: False

In [3]: (False, float('nan'), 4) < (False, float('nan'), 3)
Out[3]: False

In [4]: (False, float('nan'), 4) > (False, float('nan'), 3)
Out[4]: False

In [5]: import numpy as np

In [6]: (False, np.nan, 3) < (False, np.nan, 4)
Out[6]: True

In [7]: (False, np.nan, 3) > (False, np.nan, 4)
Out[7]: False

In [8]: import math

In [9]: (False, math.nan, 3) < (False, math.nan, 4)
Out[9]: True

In [10]: (False, math.nan, 3) > (False, math.nan, 4)
Out[10]: False

In [11]: float('nan') is float('nan')
Out[11]: False

In [12]: np.nan is np.nan
Out[12]: True

In [13]: math.nan is math.nan
Out[13]: True

In [14]: float('nan') is math.nan
Out[14]: False

In [15]: math.nan is np.nan
Out[15]: False

In [16]: (False, math.nan, 3) < (False, np.nan, 4)
Out[16]: False

np.nan is a single variable, but each call of float('nan') will create a new variable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks, you're right. It seems indeed like float('nan') <= float('nan') is False, unlike for np.

Fix here: #23909


def __getstate__(self):
state = self.__dict__.copy()
Expand Down
52 changes: 39 additions & 13 deletions python/ray/tune/tests/test_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@

class CheckpointManagerTest(unittest.TestCase):
@staticmethod
def mock_result(i):
return {"i": i, TRAINING_ITERATION: i}
def mock_result(metric, i):
return {"i": metric, TRAINING_ITERATION: i}

def checkpoint_manager(self, keep_checkpoints_num):
return CheckpointManager(keep_checkpoints_num, "i", delete_fn=lambda c: None)

def testNewestCheckpoint(self):
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
memory_checkpoint = _TuneCheckpoint(
_TuneCheckpoint.MEMORY, {0}, self.mock_result(0)
_TuneCheckpoint.MEMORY, {0}, self.mock_result(0, 0)
)
checkpoint_manager.on_checkpoint(memory_checkpoint)
persistent_checkpoint = _TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, {1}, self.mock_result(1)
_TuneCheckpoint.PERSISTENT, {1}, self.mock_result(1, 1)
)
checkpoint_manager.on_checkpoint(persistent_checkpoint)
self.assertEqual(
Expand All @@ -40,7 +40,7 @@ def testOnCheckpointOrdered(self):
keep_checkpoints_num = 2
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i))
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i, i))
for i in range(3)
]

Expand All @@ -66,7 +66,7 @@ def testOnCheckpointUnordered(self):
keep_checkpoints_num = 2
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i))
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, {i}, self.mock_result(i, i))
for i in range(3, -1, -1)
]

Expand All @@ -91,7 +91,7 @@ def testBestCheckpoints(self):
keep_checkpoints_num = 4
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i))
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, i, self.mock_result(i, i))
for i in range(16)
]
random.shuffle(checkpoints)
Expand All @@ -104,6 +104,32 @@ def testBestCheckpoints(self):
for i in range(len(best_checkpoints)):
self.assertEqual(best_checkpoints[i].value, i + 12)

def testBestCheckpointsWithNan(self):
"""
Tests that checkpoints with nan priority are handled correctly.
"""
keep_checkpoints_num = 2
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num)
checkpoints = [
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, None, self.mock_result(float("nan"), i)
)
for i in range(2)
]
checkpoints += [
_TuneCheckpoint(_TuneCheckpoint.PERSISTENT, 3, self.mock_result(0, 3))
]
random.shuffle(checkpoints)

for checkpoint in checkpoints:
checkpoint_manager.on_checkpoint(checkpoint)

best_checkpoints = checkpoint_manager.best_checkpoints()
# best_checkpoints is sorted from worst to best
self.assertEqual(len(best_checkpoints), keep_checkpoints_num)
self.assertEqual(best_checkpoints[0].value, None)
self.assertEqual(best_checkpoints[1].value, 3)

def testOnCheckpointUnavailableAttribute(self):
"""
Tests that an error is logged when the associated result of the
Expand All @@ -122,8 +148,8 @@ def testOnCheckpointUnavailableAttribute(self):

def testOnMemoryCheckpoint(self):
checkpoints = [
_TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0)),
_TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0)),
_TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0, 0)),
_TuneCheckpoint(_TuneCheckpoint.MEMORY, 0, self.mock_result(0, 0)),
]
checkpoint_manager = self.checkpoint_manager(keep_checkpoints_num=1)
checkpoint_manager.on_checkpoint(checkpoints[0])
Expand All @@ -147,16 +173,16 @@ def testSameCheckpoint(self):

checkpoints = [
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, tmpfiles[0], self.mock_result(5)
_TuneCheckpoint.PERSISTENT, tmpfiles[0], self.mock_result(5, 5)
),
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(10)
_TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(10, 10)
),
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, tmpfiles[2], self.mock_result(0)
_TuneCheckpoint.PERSISTENT, tmpfiles[2], self.mock_result(0, 0)
),
_TuneCheckpoint(
_TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(20)
_TuneCheckpoint.PERSISTENT, tmpfiles[1], self.mock_result(20, 20)
),
]
for checkpoint in checkpoints:
Expand Down
12 changes: 4 additions & 8 deletions python/ray/tune/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
unflatten_list_dict,
unflattened_lookup,
)
from ray.util.ml_utils.util import ( # noqa: F401
is_nan,
is_nan_or_inf,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -277,14 +281,6 @@ def date_str():
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")


def is_nan(value):
return np.isnan(value)


def is_nan_or_inf(value):
return is_nan(value) or np.isinf(value)


def _to_pinnable(obj):
"""Converts obj to a form that can be pinned in object store memory.

Expand Down
9 changes: 9 additions & 0 deletions python/ray/util/ml_utils/util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from contextlib import closing
import socket
import numpy as np


def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]


def is_nan(value):
return np.isnan(value)


def is_nan_or_inf(value):
return is_nan(value) or np.isinf(value)