forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[air] add to_air_checkpoint method for inference only workload. (ray-…
…project#25444) Follow up on our last discussion for supporting piecemeal fashion air users. Only did for tensorflow for now, want to collect some feedback on API naming, package structure etc and I will add others.
- Loading branch information
1 parent
3257994
commit 76b34d4
Showing
19 changed files
with
270 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# flake8: noqa | ||
|
||
# __use_pretrained_model_start__ | ||
import ray | ||
import tensorflow as tf | ||
from ray.air.batch_predictor import BatchPredictor | ||
from ray.air.predictors.integrations.tensorflow import ( | ||
to_air_checkpoint, | ||
TensorflowPredictor, | ||
) | ||
|
||
|
||
# to simulate having a pretrained model. | ||
def build_model() -> tf.keras.Model: | ||
model = tf.keras.Sequential( | ||
[ | ||
tf.keras.layers.InputLayer(input_shape=(1,)), | ||
tf.keras.layers.Dense(1), | ||
] | ||
) | ||
return model | ||
|
||
|
||
model = build_model() | ||
checkpoint = to_air_checkpoint(model) | ||
batch_predictor = BatchPredictor( | ||
checkpoint, TensorflowPredictor, model_definition=build_model | ||
) | ||
predict_dataset = ray.data.range(3) | ||
predictions = batch_predictor.predict(predict_dataset) | ||
|
||
# __use_pretrained_model_end__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
.. _use-pretrained-model: | ||
|
||
Use a pretrained model for batch or online inference | ||
===================================================== | ||
|
||
Ray Air moves end to end machine learning workloads seamlessly through the construct of ``Checkpoint``. ``Checkpoint`` | ||
is the output of training and tuning as well as the input to downstream inference tasks. | ||
|
||
Having said that, it is entirely possible and supported to use Ray Air in a piecemeal fashion. | ||
|
||
Say you already have a model trained elsewhere, you can use Ray Air for downstream tasks such as batch and | ||
online inference. To do that, you would need to convert the pretrained model together with any preprocessing | ||
steps into ``Checkpoint``. | ||
|
||
To facilitate this, we have prepared framework specific ``to_air_checkpoint`` helper function. | ||
|
||
Examples: | ||
|
||
.. literalinclude:: doc_code/use_pretrained_model.py | ||
:language: python | ||
:start-after: __use_pretrained_model_start__ | ||
:end-before: __use_pretrained_model_end__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from ray.air.predictors.integrations.lightgbm.lightgbm_predictor import ( | ||
LightGBMPredictor, | ||
) | ||
from ray.air.predictors.integrations.lightgbm.utils import to_air_checkpoint | ||
|
||
__all__ = ["LightGBMPredictor"] | ||
__all__ = ["LightGBMPredictor", "to_air_checkpoint"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Optional | ||
|
||
import os | ||
import lightgbm | ||
|
||
from ray.air.checkpoint import Checkpoint | ||
from ray.air.constants import MODEL_KEY | ||
from ray.air.preprocessor import Preprocessor | ||
from ray.air._internal.checkpointing import ( | ||
save_preprocessor_to_dir, | ||
) | ||
|
||
|
||
def to_air_checkpoint( | ||
path: str, | ||
booster: lightgbm.Booster, | ||
preprocessor: Optional[Preprocessor] = None, | ||
) -> Checkpoint: | ||
"""Convert a pretrained model to AIR checkpoint for serve or inference. | ||
Args: | ||
path: The directory path where model and preprocessor steps are stored to. | ||
booster: A pretrained lightgbm model. | ||
preprocessor: A fitted preprocessor. The preprocessing logic will | ||
be applied to serve/inference. | ||
Returns: | ||
A Ray Air checkpoint. | ||
""" | ||
booster.save_model(os.path.join(path, MODEL_KEY)) | ||
|
||
if preprocessor: | ||
save_preprocessor_to_dir(preprocessor, path) | ||
|
||
checkpoint = Checkpoint.from_directory(path) | ||
|
||
return checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from ray.air.predictors.integrations.sklearn.sklearn_predictor import ( | ||
SklearnPredictor, | ||
) | ||
from ray.air.predictors.integrations.sklearn.utils import to_air_checkpoint | ||
|
||
__all__ = ["SklearnPredictor"] | ||
__all__ = ["SklearnPredictor", "to_air_checkpoint"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import Optional | ||
|
||
import os | ||
from sklearn.base import BaseEstimator | ||
|
||
from ray.air.checkpoint import Checkpoint | ||
from ray.air.constants import MODEL_KEY | ||
from ray.air.preprocessor import Preprocessor | ||
from ray.air._internal.checkpointing import ( | ||
save_preprocessor_to_dir, | ||
) | ||
import ray.cloudpickle as cpickle | ||
|
||
|
||
def to_air_checkpoint( | ||
path: str, | ||
estimator: BaseEstimator, | ||
preprocessor: Optional[Preprocessor] = None, | ||
) -> Checkpoint: | ||
"""Convert a pretrained model to AIR checkpoint for serve or inference. | ||
Args: | ||
path: The directory path where model and preprocessor steps are stored to. | ||
estimator: A pretrained model. | ||
preprocessor: A fitted preprocessor. The preprocessing logic will | ||
be applied to serve/inference. | ||
Returns: | ||
A Ray Air checkpoint. | ||
""" | ||
with open(os.path.join(path, MODEL_KEY), "wb") as f: | ||
cpickle.dump(estimator, f) | ||
|
||
if preprocessor: | ||
save_preprocessor_to_dir(preprocessor, path) | ||
|
||
checkpoint = Checkpoint.from_directory(path) | ||
|
||
return checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from ray.air.predictors.integrations.tensorflow.tensorflow_predictor import ( | ||
TensorflowPredictor, | ||
) | ||
from ray.air.predictors.integrations.tensorflow.utils import to_air_checkpoint | ||
|
||
__all__ = ["TensorflowPredictor"] | ||
__all__ = ["TensorflowPredictor", "to_air_checkpoint"] |
25 changes: 25 additions & 0 deletions
25
python/ray/air/predictors/integrations/tensorflow/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import Optional | ||
|
||
from tensorflow import keras | ||
|
||
from ray.air.checkpoint import Checkpoint | ||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY | ||
from ray.air.preprocessor import Preprocessor | ||
|
||
|
||
def to_air_checkpoint( | ||
model: keras.Model, preprocessor: Optional[Preprocessor] = None | ||
) -> Checkpoint: | ||
"""Convert a pretrained model to AIR checkpoint for serve or inference. | ||
Args: | ||
model: A pretrained model. | ||
preprocessor: A fitted preprocessor. The preprocessing logic will | ||
be applied to serve/inference. | ||
Returns: | ||
A Ray Air checkpoint. | ||
""" | ||
checkpoint = Checkpoint.from_dict( | ||
{PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model.get_weights()} | ||
) | ||
return checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from ray.air.predictors.integrations.torch.torch_predictor import TorchPredictor | ||
from ray.air.predictors.integrations.torch.utils import to_air_checkpoint | ||
|
||
__all__ = ["TorchPredictor"] | ||
__all__ = ["TorchPredictor", "to_air_checkpoint"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from ray.air.checkpoint import Checkpoint | ||
from ray.air.constants import MODEL_KEY, PREPROCESSOR_KEY | ||
from ray.air.preprocessor import Preprocessor | ||
|
||
|
||
def to_air_checkpoint( | ||
model: torch.nn.Module, preprocessor: Optional[Preprocessor] = None | ||
) -> Checkpoint: | ||
"""Convert a pretrained model to AIR checkpoint for serve or inference. | ||
Args: | ||
model: A pretrained model. | ||
preprocessor: A fitted preprocessor. The preprocessing logic will | ||
be applied to serve/inference. | ||
Returns: | ||
A Ray Air checkpoint. | ||
""" | ||
checkpoint = Checkpoint.from_dict( | ||
{PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model} | ||
) | ||
return checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from ray.air.predictors.integrations.xgboost.xgboost_predictor import XGBoostPredictor | ||
from ray.air.predictors.integrations.xgboost.utils import to_air_checkpoint | ||
|
||
__all__ = ["XGBoostPredictor"] | ||
__all__ = ["XGBoostPredictor", "to_air_checkpoint"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Optional | ||
|
||
import os | ||
import xgboost | ||
|
||
from ray.air.checkpoint import Checkpoint | ||
from ray.air.constants import MODEL_KEY | ||
from ray.air.preprocessor import Preprocessor | ||
from ray.air._internal.checkpointing import ( | ||
save_preprocessor_to_dir, | ||
) | ||
|
||
|
||
def to_air_checkpoint( | ||
path: str, | ||
booster: xgboost.Booster, | ||
preprocessor: Optional[Preprocessor] = None, | ||
) -> Checkpoint: | ||
"""Convert a pretrained model to AIR checkpoint for serve or inference. | ||
Args: | ||
path: The directory path where model and preprocessor steps are stored to. | ||
booster: A pretrained xgboost model. | ||
preprocessor: A fitted preprocessor. The preprocessing logic will | ||
be applied to serve/inference. | ||
Returns: | ||
A Ray Air checkpoint. | ||
""" | ||
booster.save_model(os.path.join(path, MODEL_KEY)) | ||
|
||
if preprocessor: | ||
save_preprocessor_to_dir(preprocessor, path) | ||
|
||
checkpoint = Checkpoint.from_directory(path) | ||
|
||
return checkpoint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.