Skip to content

Commit

Permalink
support forecasting
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbchao committed Jan 22, 2024
1 parent f601e3b commit 107bc3c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
17 changes: 8 additions & 9 deletions responsibleai/responsibleai/_internal/_served_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import json

import requests

from responsibleai.serialization_utilities import serialize_json_safe
from raiutils.webservice import post_with_retries


class ServedModelWrapper:
Expand Down Expand Up @@ -37,14 +36,14 @@ def forecast(self, X):
# request formatting according to mlflow docs
# https://mlflow.org/docs/latest/cli.html#mlflow-models-serve
# JSON safe serialization takes care of datetime columns
response = requests.post(
url=f"http://localhost:{self.port}/invocations",
headers={"Content-Type": "application/json"},
data=json.dumps(
{"dataframe_split": X.to_dict(orient='split')},
default=serialize_json_safe))
uri = f"http://localhost:{self.port}/invocations"
input_data = json.dumps(
{"dataframe_split": X.to_dict(orient='split')},
default=serialize_json_safe)
headers = {"Content-Type": "application/json"}
try:
response.raise_for_status()
response = post_with_retries(uri, input_data, headers,
max_retries=15, retry_delay=30)
except Exception:
raise RuntimeError(
"Could not retrieve predictions. "
Expand Down
8 changes: 6 additions & 2 deletions responsibleai/responsibleai/rai_insights/rai_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,10 +1285,14 @@ def _get_feature_ranges(
res_object[_UNIQUE_VALUES] = unique_value.tolist()
elif datetime_features is not None and col in datetime_features:
res_object[_RANGE_TYPE] = "datetime"
min_value = test[col].min()
min_value = pd.to_datetime(min_value)
res_object[_MIN_VALUE] = \
test[col].min().strftime(_STRF_TIME_FORMAT)
min_value.strftime(_STRF_TIME_FORMAT)
max_value = test[col].max()
max_value = pd.to_datetime(max_value)
res_object[_MAX_VALUE] = \
test[col].max().strftime(_STRF_TIME_FORMAT)
max_value.strftime(_STRF_TIME_FORMAT)
else:
col_min = test[col].min()
col_max = test[col].max()
Expand Down

0 comments on commit 107bc3c

Please sign in to comment.