From 08cd5a8728d40537720a501c5f9cc8f47fcbeedc Mon Sep 17 00:00:00 2001 From: James Chao Date: Mon, 22 Jan 2024 17:03:57 -0500 Subject: [PATCH] support forecasting --- erroranalysis/requirements-dev.txt | 1 + .../_internal/_served_model_wrapper.py | 17 ++++++------- .../rai_insights/rai_insights.py | 8 ++++-- .../tests/rai_insights/test_served_model.py | 25 +------------------ 4 files changed, 16 insertions(+), 35 deletions(-) diff --git a/erroranalysis/requirements-dev.txt b/erroranalysis/requirements-dev.txt index c30e933b74..1d648cdb29 100644 --- a/erroranalysis/requirements-dev.txt +++ b/erroranalysis/requirements-dev.txt @@ -4,4 +4,5 @@ pytest-mock==3.6.1 requirements-parser==0.2.0 rai_test_utils[object_detection] +scikit-learn<=1.3.2 interpret-core[required]<=0.3.2 \ No newline at end of file diff --git a/responsibleai/responsibleai/_internal/_served_model_wrapper.py b/responsibleai/responsibleai/_internal/_served_model_wrapper.py index bb6006bab4..79d83bd221 100644 --- a/responsibleai/responsibleai/_internal/_served_model_wrapper.py +++ b/responsibleai/responsibleai/_internal/_served_model_wrapper.py @@ -3,8 +3,7 @@ import json -import requests - +from raiutils.webservice import post_with_retries from responsibleai.serialization_utilities import serialize_json_safe @@ -37,14 +36,14 @@ def forecast(self, X): # request formatting according to mlflow docs # https://mlflow.org/docs/latest/cli.html#mlflow-models-serve # JSON safe serialization takes care of datetime columns - response = requests.post( - url=f"http://localhost:{self.port}/invocations", - headers={"Content-Type": "application/json"}, - data=json.dumps( - {"dataframe_split": X.to_dict(orient='split')}, - default=serialize_json_safe)) + uri = f"http://localhost:{self.port}/invocations" + input_data = json.dumps( + {"dataframe_split": X.to_dict(orient='split')}, + default=serialize_json_safe) + headers = {"Content-Type": "application/json"} try: - response.raise_for_status() + response = post_with_retries(uri, input_data, headers, + max_retries=15, retry_delay=30) except Exception: raise RuntimeError( "Could not retrieve predictions. " diff --git a/responsibleai/responsibleai/rai_insights/rai_insights.py b/responsibleai/responsibleai/rai_insights/rai_insights.py index 6199331b09..6d67007c38 100644 --- a/responsibleai/responsibleai/rai_insights/rai_insights.py +++ b/responsibleai/responsibleai/rai_insights/rai_insights.py @@ -1285,10 +1285,14 @@ def _get_feature_ranges( res_object[_UNIQUE_VALUES] = unique_value.tolist() elif datetime_features is not None and col in datetime_features: res_object[_RANGE_TYPE] = "datetime" + min_value = test[col].min() + min_value = pd.to_datetime(min_value) res_object[_MIN_VALUE] = \ - test[col].min().strftime(_STRF_TIME_FORMAT) + min_value.strftime(_STRF_TIME_FORMAT) + max_value = test[col].max() + max_value = pd.to_datetime(max_value) res_object[_MAX_VALUE] = \ - test[col].max().strftime(_STRF_TIME_FORMAT) + max_value.strftime(_STRF_TIME_FORMAT) else: col_min = test[col].min() col_max = test[col].max() diff --git a/responsibleai/tests/rai_insights/test_served_model.py b/responsibleai/tests/rai_insights/test_served_model.py index d70c57cf4a..7149afc747 100644 --- a/responsibleai/tests/rai_insights/test_served_model.py +++ b/responsibleai/tests/rai_insights/test_served_model.py @@ -6,7 +6,6 @@ from unittest import mock import pytest -import requests from tests.common_utils import (RandomForecastingModel, create_tiny_forecasting_dataset) @@ -41,7 +40,7 @@ def rai_forecasting_insights_for_served_model(): @mock.patch("requests.post") -@mock.patch.dict("os.environ", {"RAI_MODEL_SERVING_PORT": "5123"}) +@mock.patch.dict("os.environ", {"RAI_MODEL_SERVING_PORT": "5432"}) def test_served_model( mock_post, rai_forecasting_insights_for_served_model): @@ -58,25 +57,3 @@ def test_served_model( forecasts = rai_insights.model.forecast(X_test) assert len(forecasts) == len(X_test) assert mock_post.call_count == 1 - - -@mock.patch("requests.post") -@mock.patch.dict("os.environ", {"RAI_MODEL_SERVING_PORT": "5123"}) -def test_served_model_failed( - mock_post, - rai_forecasting_insights_for_served_model): - _, X_test, _, _ = create_tiny_forecasting_dataset() - - response = requests.Response() - response.status_code = 400 - response._content = b"Could not connect to host since it actively " \ - b"refuses the connection." - mock_post.return_value = response - - rai_insights = RAIInsights.load(RAI_INSIGHTS_DIR_NAME) - with pytest.raises( - Exception, - match="Could not retrieve predictions. " - "Model server returned status code 400 " - f"and the following response: {response.content}"): - rai_insights.model.forecast(X_test)