Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[air] add to_air_checkpoint method for inference only workload. #25444

Merged
merged 14 commits into from
Jun 7, 2022
3 changes: 2 additions & 1 deletion python/ray/ml/predictors/integrations/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ray.ml.predictors.integrations.tensorflow.tensorflow_predictor import (
TensorflowPredictor,
)
from ray.ml.predictors.integrations.tensorflow.utils import to_air_checkpoint

__all__ = ["TensorflowPredictor"]
__all__ = ["TensorflowPredictor", "to_air_checkpoint"]
23 changes: 23 additions & 0 deletions python/ray/ml/predictors/integrations/tensorflow/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional

from tensorflow import keras

from ray.ml.checkpoint import Checkpoint
from ray.ml.constants import MODEL_KEY, PREPROCESSOR_KEY
from ray.ml.preprocessor import Preprocessor


def to_air_checkpoint(
model: keras.Model, preprocessor: Optional[Preprocessor] = None
) -> Checkpoint:
"""Convert a pretrained model to AIR checkpoint for serve or inference.

Args:
model: A pretrained model.
preprocessor: Stateless preprocessor only. The preprocessing logic will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this have to be a stateless preprocessor only? This can be a stateful preprocessor that's already been fit right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rephrased..

be applied to serve/inference.
"""
checkpoint = Checkpoint.from_dict(
{PREPROCESSOR_KEY: preprocessor, MODEL_KEY: model.get_weights()}
)
return checkpoint
13 changes: 12 additions & 1 deletion python/ray/ml/tests/test_tensorflow_predictor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor
import ray
from ray.ml.batch_predictor import BatchPredictor
from ray.ml.predictors.integrations.tensorflow import TensorflowPredictor, to_air_checkpoint
from ray.ml.preprocessor import Preprocessor
from ray.ml.checkpoint import Checkpoint
from ray.ml.constants import PREPROCESSOR_KEY, MODEL_KEY
Expand Down Expand Up @@ -93,6 +95,15 @@ def test_predict_dataframe_with_feature_columns():
assert predictions.to_numpy().flatten().tolist() == [1, 3]


def test_tensorflow_predictor_no_training():
model = build_model()
checkpoint = to_air_checkpoint(model)
batch_predictor = BatchPredictor(checkpoint, TensorflowPredictor, model_definition=build_model)
predict_dataset = ray.data.range(3)
predictions = batch_predictor.predict(predict_dataset)
assert predictions.count() == 3


if __name__ == "__main__":
import pytest
import sys
Expand Down