Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ lint:
test: clean-test
pytest -m 'not rsc_test and not docker'

test-pdb: clean-test
pytest -m 'not rsc_test and not docker' --pdb

test-rsc: clean-test
pytest

Expand Down
1 change: 1 addition & 0 deletions vetiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .handlers.torch import TorchHandler # noqa
from .handlers.statsmodels import StatsmodelsHandler # noqa
from .handlers.xgboost import XGBoostHandler # noqa
from .helpers import api_data_to_frame # noqa
from .rsconnect import deploy_rsconnect # noqa
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
from .model_card import model_card # noqa
Expand Down
3 changes: 1 addition & 2 deletions vetiver/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from vetiver.handlers import base
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this (or immediately after this PR) is the good time to add a module docstring for all the whats and the whys about Handlers. I will do a PR after this one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! That would be a great addition for more context on Handlers.

from functools import singledispatch
from contextlib import suppress

Expand Down Expand Up @@ -145,7 +144,7 @@ def handler_predict(self, input_data, check_prototype):


@create_handler.register
def _(model: base.BaseHandler, prototype_data):
def _(model: BaseHandler, prototype_data):
if model.prototype_data is None and prototype_data is not None:
model.prototype_data = prototype_data

Expand Down
2 changes: 1 addition & 1 deletion vetiver/handlers/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def handler_predict(self, input_data, check_prototype):
else:
prediction = self.model.predict([input_data])

return prediction
return prediction.tolist()
2 changes: 1 addition & 1 deletion vetiver/handlers/statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ def handler_predict(self, input_data, check_prototype):
else:
prediction = self.model.predict([input_data])

return prediction
return prediction.tolist()
5 changes: 2 additions & 3 deletions vetiver/handlers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ def handler_predict(self, input_data, check_prototype):
"""
if not torch_exists:
raise ImportError("Cannot import `torch`.")

if check_prototype:
input_data = np.array(input_data, dtype=np.array(self.prototype_data).dtype)
prediction = self.model(torch.from_numpy(input_data))

# do not check ptype
else:
input_data = torch.tensor(input_data)
prediction = self.model(input_data)

return prediction
return prediction.tolist()
3 changes: 2 additions & 1 deletion vetiver/handlers/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def handler_predict(self, input_data, check_prototype):
input_data = pd.DataFrame(input_data)
except ValueError:
raise (f"Expected a dict or DataFrame, got {type(input_data)}")

input_data = xgboost.DMatrix(input_data)

prediction = self.model.predict(input_data)

return prediction
return prediction.tolist()
51 changes: 51 additions & 0 deletions vetiver/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from functools import singledispatch
import pandas as pd
import pydantic


@singledispatch
def api_data_to_frame(pred_data) -> pd.DataFrame:
"""Convert prototype to dataframe data

Parameters
----------
pred_data : pydantic.BaseModel
User data from given to API endpoint

Returns
-------
pd.DataFrame
BaseModel data translated into DataFrame
"""

raise TypeError("Data should be list, pydantic.BaseModel, pd.DataFrame")


@api_data_to_frame.register(pydantic.BaseModel)
@api_data_to_frame.register(list)
def _(pred_data):

return pd.DataFrame([dict(s) for s in pred_data])


@api_data_to_frame.register(dict)
def _dict(pred_data):
return api_data_to_frame([pred_data])


def response_to_frame(response: dict) -> pd.DataFrame:
"""Convert API JSON response to data frame

Parameters
----------
response : dict
Response from API endpoint

Returns
-------
pd.DataFrame
Response translated into DataFrame
"""
response_df = pd.DataFrame.from_dict(response.json())

return response_df
3 changes: 2 additions & 1 deletion vetiver/rsconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import typing

from rsconnect.actions import deploy_python_fastapi
from rsconnect.api import RSConnectServer as ConnectServer

from .write_fastapi import write_app


def deploy_rsconnect(
connect_server,
connect_server: ConnectServer,
board,
pin_name: str,
version: str = None,
Expand Down
66 changes: 28 additions & 38 deletions vetiver/server.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from typing import Any, Callable, Dict, List, Union
from typing import Callable, List, Union
from urllib.parse import urljoin

import re
import httpx
import pandas as pd
import requests
import uvicorn
from fastapi import FastAPI, Request, testclient
from fastapi.exceptions import RequestValidationError
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.responses import PlainTextResponse
from warnings import warn

from .utils import _jupyter_nb
from .vetiver_model import VetiverModel
from .meta import VetiverMeta
from .helpers import api_data_to_frame, response_to_frame


class VetiverAPI:
Expand Down Expand Up @@ -138,6 +142,10 @@ async def rapidoc():
</html>
"""

@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return PlainTextResponse(str(exc), status_code=422)

return app

def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
Expand Down Expand Up @@ -167,26 +175,26 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
if self.check_prototype is True:

@self.app.post(urljoin("/", endpoint_name), name=endpoint_name)
async def custom_endpoint(
input_data: Union[self.model.prototype, List[self.model.prototype]]
):
async def custom_endpoint(input_data: List[self.model.prototype]):
_to_frame = api_data_to_frame(input_data)
predictions = endpoint_fx(_to_frame, **kw)

if isinstance(input_data, List):
served_data = _batch_data(input_data)
if isinstance(predictions, List):
return {endpoint_name: predictions}
else:
served_data = _prepare_data(input_data)

new = endpoint_fx(served_data, **kw)
return {endpoint_name: new.tolist()}
return predictions

else:

@self.app.post(urljoin("/", endpoint_name))
async def custom_endpoint(input_data: Request):
served_data = await input_data.json()
new = endpoint_fx(served_data, **kw)
predictions = endpoint_fx(served_data, **kw)

return {endpoint_name: new.tolist()}
if isinstance(predictions, List):
return {endpoint_name: predictions}
else:
return predictions

def run(self, port: int = 8000, host: str = "127.0.0.1", **kw):
"""
Expand Down Expand Up @@ -261,46 +269,28 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da
# TO DO: dispatch

if isinstance(data, pd.DataFrame):
data_json = data.to_json(orient="records")
response = requester.post(endpoint, data=data_json, **kw)
response = requester.post(
endpoint, data=data.to_json(orient="records"), **kw
) # TO DO: httpx deprecating data in favor of content for TestClient
elif isinstance(data, pd.Series):
data_dict = data.to_json()
response = requester.post(endpoint, data=data_dict, **kw)
response = requester.post(endpoint, json=[data.to_dict()], **kw)
elif isinstance(data, dict):
response = requester.post(endpoint, json=data, **kw)
response = requester.post(endpoint, json=[data], **kw)
else:
response = requester.post(endpoint, json=data, **kw)

try:
response.raise_for_status()
except (requests.exceptions.HTTPError, httpx.HTTPStatusError) as e:
if response.status_code == 422:
raise TypeError(
f"Predict expects DataFrame, Series, or dict. Given type is {type(data)}"
)
raise TypeError(re.sub(r"\n", ": ", response.text))
raise requests.exceptions.HTTPError(
f"Could not obtain data from endpoint with error: {e}"
)

response_df = pd.DataFrame.from_dict(response.json())

return response_df


def _prepare_data(pred_data: Dict[str, Any]) -> List[Any]:
served_data = []
for key, value in pred_data:
served_data.append(value)
return served_data


def _batch_data(pred_data: List[Any]) -> pd.DataFrame:
columns = pred_data[0].dict().keys()

data = [line.dict() for line in pred_data]
response_frame = response_to_frame(response)

served_data = pd.DataFrame(data, columns=columns)
return served_data
return response_frame


def vetiver_endpoint(url: str = "http://127.0.0.1:8000/predict") -> str:
Expand Down
10 changes: 5 additions & 5 deletions vetiver/tests/test_add_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fastapi.testclient import TestClient

from vetiver import mock, VetiverModel, VetiverAPI
from vetiver.helpers import api_data_to_frame
import vetiver


Expand All @@ -25,12 +26,11 @@ def vetiver_model():


def sum_values(x):
return x.sum()
return x.sum().to_list()


def sum_dict(x):
x = pd.DataFrame(x)
return x.sum()
def sum_values_no_prototype(x):
return api_data_to_frame(x).sum().to_list()


@pytest.fixture
Expand All @@ -49,7 +49,7 @@ def vetiver_client(vetiver_model): # With check_prototype=True
def vetiver_client_check_ptype_false(vetiver_model): # With check_prototype=False

app = VetiverAPI(vetiver_model, check_prototype=False)
app.vetiver_post(sum_dict, "sum")
app.vetiver_post(sum_values_no_prototype, "sum")

app.app.root_path = "/sum"
client = TestClient(app.app)
Expand Down
98 changes: 0 additions & 98 deletions vetiver/tests/test_predict.py

This file was deleted.

Loading