Skip to content

Commit

Permalink
[air] Use checkpoint.as_directory() instead of cleaning up manually (#…
Browse files Browse the repository at this point in the history
…24113)

Follow-up from #23908

Instead of manually deleting checkpoint paths after calling `to_directory()`, we should utilize `Checkpoint.as_directory()` when possible.
  • Loading branch information
krfricke authored Apr 23, 2022
1 parent 3c0a3f4 commit 0360100
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 63 deletions.
52 changes: 17 additions & 35 deletions python/ray/ml/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,30 +228,21 @@ def to_dict(self) -> dict:
return ray.get(self._obj_ref)
elif self._local_path or self._uri:
# Else, checkpoint is either on FS or external storage
cleanup = False

local_path = self._local_path
if not local_path:
# Checkpoint does not exist on local path. Save
# in temporary directory, but clean up later
local_path = self.to_directory()
cleanup = True

checkpoint_data_path = os.path.join(local_path, _DICT_CHECKPOINT_FILE_NAME)
if os.path.exists(checkpoint_data_path):
# If we are restoring a dict checkpoint, load the dict
# from the checkpoint file.
with open(checkpoint_data_path, "rb") as f:
checkpoint_data = pickle.load(f)
else:
data = _pack(local_path)

checkpoint_data = {
_FS_CHECKPOINT_KEY: data,
}

if cleanup:
shutil.rmtree(local_path)
with self.as_directory() as local_path:
checkpoint_data_path = os.path.join(
local_path, _DICT_CHECKPOINT_FILE_NAME
)
if os.path.exists(checkpoint_data_path):
# If we are restoring a dict checkpoint, load the dict
# from the checkpoint file.
with open(checkpoint_data_path, "rb") as f:
checkpoint_data = pickle.load(f)
else:
data = _pack(local_path)

checkpoint_data = {
_FS_CHECKPOINT_KEY: data,
}

return checkpoint_data
else:
Expand Down Expand Up @@ -406,17 +397,8 @@ def to_uri(self, uri: str) -> str:
f"Hint: {fs_hint(uri)}"
)

cleanup = False

local_path = self._local_path
if not local_path:
cleanup = True
local_path = self.to_directory()

upload_to_uri(local_path=local_path, uri=uri)

if cleanup:
shutil.rmtree(local_path)
with self.as_directory() as local_path:
upload_to_uri(local_path=local_path, uri=uri)

return uri

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional, List, Union
import os
import shutil
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -40,15 +39,15 @@ def from_checkpoint(cls, checkpoint: Checkpoint) -> "LightGBMPredictor":
``LightGBMTrainer`` run.
"""
path = checkpoint.to_directory()
bst = lightgbm.Booster(model_file=os.path.join(path, MODEL_KEY))
preprocessor_path = os.path.join(path, PREPROCESSOR_KEY)
if os.path.exists(preprocessor_path):
with open(preprocessor_path, "rb") as f:
preprocessor = cpickle.load(f)
else:
preprocessor = None
shutil.rmtree(path)
with checkpoint.as_directory() as path:
bst = lightgbm.Booster(model_file=os.path.join(path, MODEL_KEY))
preprocessor_path = os.path.join(path, PREPROCESSOR_KEY)
if os.path.exists(preprocessor_path):
with open(preprocessor_path, "rb") as f:
preprocessor = cpickle.load(f)
else:
preprocessor = None

return LightGBMPredictor(model=bst, preprocessor=preprocessor)

def predict(
Expand Down
20 changes: 9 additions & 11 deletions python/ray/ml/predictors/integrations/xgboost/xgboost_predictor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional, List, Union, Dict, Any
import os
import shutil
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -40,16 +39,15 @@ def from_checkpoint(cls, checkpoint: Checkpoint) -> "XGBoostPredictor":
``XGBoostTrainer`` run.
"""
path = checkpoint.to_directory()
bst = xgboost.Booster()
bst.load_model(os.path.join(path, MODEL_KEY))
preprocessor_path = os.path.join(path, PREPROCESSOR_KEY)
if os.path.exists(preprocessor_path):
with open(preprocessor_path, "rb") as f:
preprocessor = cpickle.load(f)
else:
preprocessor = None
shutil.rmtree(path)
with checkpoint.as_directory() as path:
bst = xgboost.Booster()
bst.load_model(os.path.join(path, MODEL_KEY))
preprocessor_path = os.path.join(path, PREPROCESSOR_KEY)
if os.path.exists(preprocessor_path):
with open(preprocessor_path, "rb") as f:
preprocessor = cpickle.load(f)
else:
preprocessor = None
return XGBoostPredictor(model=bst, preprocessor=preprocessor)

def predict(
Expand Down
12 changes: 5 additions & 7 deletions python/ray/ml/tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,11 @@ def testLocalCheckpointSerde(self):
# Local checkpoints are converted to bytes on serialization. Currently
# this is a pickled dict, so we compare with a dict checkpoint.
source_checkpoint = Checkpoint.from_dict({"checkpoint_data": 5})
tmpdir = source_checkpoint.to_directory()
self.addCleanup(shutil.rmtree, tmpdir)

checkpoint = Checkpoint.from_directory(tmpdir)
self._testCheckpointSerde(
checkpoint, *source_checkpoint.get_internal_representation()
)
with source_checkpoint.as_directory() as tmpdir:
checkpoint = Checkpoint.from_directory(tmpdir)
self._testCheckpointSerde(
checkpoint, *source_checkpoint.get_internal_representation()
)

def testBytesCheckpointSerde(self):
# Bytes checkpoints are just dict checkpoints constructed
Expand Down

0 comments on commit 0360100

Please sign in to comment.