Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions vetiver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.openapi.utils import get_openapi
from fastapi import testclient
import httpx

import uvicorn
import requests
Expand Down Expand Up @@ -240,7 +241,7 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
"""
if isinstance(endpoint, testclient.TestClient):
requester = endpoint
endpoint = "/predict"
endpoint = requester.app.root_path
else:
requester = requests

Expand All @@ -255,21 +256,21 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
elif isinstance(data, dict):
response = requester.post(endpoint, json=data, **kw)
else:
try:
response = requester.post(endpoint, json=data, **kw)
except TypeError:
response = requester.post(endpoint, json=data, **kw)

try:
response.raise_for_status()
except (requests.exceptions.HTTPError, httpx.HTTPStatusError) as e:
if response.status_code == 422:
raise TypeError(
f"Predict expects a DataFrame or dict. Given type is {type(data)}"
f"Predict expects DataFrame, Series, or dict. Given type is {type(data)}"
)
raise requests.exceptions.HTTPError(
f"Could not obtain data from endpoint with error: {e}"
)

response_df = pd.DataFrame.from_dict(response.json())

if isinstance(response_df.iloc[0, 0], dict):
if "type_error.dict" in response_df.iloc[0, 0].values():
raise TypeError(
f"Predict expects a DataFrame or dict. Given type is {type(data)}"
)

return response_df


Expand Down
105 changes: 45 additions & 60 deletions vetiver/tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import numpy as np
import pandas as pd
from requests.exceptions import HTTPError
from fastapi.testclient import TestClient

from vetiver import mock, VetiverModel, VetiverAPI
from vetiver.server import predict


def test_predict_sklearn_dict_ptype():
@pytest.fixture
def vetiver_model():
np.random.seed(500)
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)
Expand All @@ -19,95 +21,78 @@ def test_predict_sklearn_dict_ptype():
versioned=None,
description="A regression model for testing purposes",
)
app = VetiverAPI(v, check_ptype=True)

return v


@pytest.fixture
def vetiver_client(vetiver_model): # With check_ptype=True
app = VetiverAPI(vetiver_model, check_ptype=True)
app.app.root_path = "/predict"
client = TestClient(app.app)

return client


@pytest.fixture
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels much cleaner! Thank you!

def vetiver_client_check_ptype_false(vetiver_model): # With check_ptype=False
app = VetiverAPI(vetiver_model, check_ptype=False)
app.app.root_path = "/predict"
client = TestClient(app.app)

return client


def test_predict_sklearn_dict_ptype(vetiver_client):
data = {"B": 0, "C": 0, "D": 0}

response = predict(endpoint=client, data=data)
response = predict(endpoint=vetiver_client, data=data)

assert isinstance(response, pd.DataFrame), response
assert response.iloc[0, 0] == 44.47
assert len(response) == 1


def test_predict_sklearn_no_ptype():
np.random.seed(500)
def test_predict_sklearn_no_ptype(vetiver_client_check_ptype_false):
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)
v = VetiverModel(
model=model,
ptype_data=X,
model_name="my_model",
versioned=None,
description="A regression model for testing purposes",
)
app = VetiverAPI(v, check_ptype=False)
client = TestClient(app.app)

response = predict(endpoint=client, data=X)
response = predict(endpoint=vetiver_client_check_ptype_false, data=X)

assert isinstance(response, pd.DataFrame), response
assert response.iloc[0, 0] == 44.47
assert len(response) == 100


def test_predict_sklearn_df_check_ptype():
np.random.seed(500)
def test_predict_sklearn_df_check_ptype(vetiver_client):
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)
v = VetiverModel(
model=model,
ptype_data=X,
model_name="my_model",
versioned=None,
description="A regression model for testing purposes",
)
app = VetiverAPI(v, check_ptype=True)
client = TestClient(app.app)

response = predict(endpoint=client, data=X)
response = predict(endpoint=vetiver_client, data=X)

assert isinstance(response, pd.DataFrame), response
assert response.iloc[0, 0] == 44.47
assert len(response) == 100


def test_predict_sklearn_series_check_ptype():
np.random.seed(500)
X, y = mock.get_mock_data()
def test_predict_sklearn_series_check_ptype(vetiver_client):
ser = pd.Series(data=[0, 0, 0])
model = mock.get_mock_model().fit(X, y)
v = VetiverModel(
model=model,
ptype_data=X,
model_name="my_model",
versioned=None,
description="A regression model for testing purposes",
)
app = VetiverAPI(v, check_ptype=True)
client = TestClient(app.app)

response = predict(endpoint=client, data=ser)
response = predict(endpoint=vetiver_client, data=ser)

assert isinstance(response, pd.DataFrame), response
assert response.iloc[0, 0] == 44.47
assert len(response) == 1


def test_predict_sklearn_type_error():
np.random.seed(500)
@pytest.mark.parametrize("data", [(0, 0), 0, 0.0, "0"])
def test_predict_sklearn_type_error(data, vetiver_client):
msg = f"Predict expects DataFrame, Series, or dict. Given type is {type(data)}"

with pytest.raises(TypeError, match=msg):
predict(endpoint=vetiver_client, data=data)


def test_predict_server_error(vetiver_model):
X, y = mock.get_mock_data()
model = mock.get_mock_model().fit(X, y)
v = VetiverModel(
model=model,
ptype_data=X,
model_name="my_model",
versioned=None,
description="A regression model for testing purposes",
)
app = VetiverAPI(v, check_ptype=True)
app = VetiverAPI(vetiver_model, check_ptype=True)
app.app.root_path = "/i_do_not_exists"
client = TestClient(app.app)
data = (0, 0)

with pytest.raises(TypeError):
predict(endpoint=client, data=data)
with pytest.raises(HTTPError):
predict(endpoint=client, data=X)
42 changes: 27 additions & 15 deletions vetiver/tests/test_statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@pytest.fixture
def build_sm():
def sm_model():

X, y = vetiver.get_mock_data()
glm = sm.GLM(y, X).fit()
Expand All @@ -23,43 +23,55 @@ def build_sm():
return v


def test_vetiver_build(build_sm):
api = vetiver.VetiverAPI(build_sm)
client = TestClient(api.app)
@pytest.fixture
def vetiver_client(sm_model): # With check_ptype=True
app = vetiver.VetiverAPI(sm_model, check_ptype=True)
app.app.root_path = "/predict"
client = TestClient(app.app)

return client


@pytest.fixture
def vetiver_client_check_ptype_false(sm_model): # With check_ptype=True
app = vetiver.VetiverAPI(sm_model, check_ptype=False)
app.app.root_path = "/predict"
client = TestClient(app.app)

return client


def test_vetiver_build(vetiver_client):
data = [{"B": 0, "C": 0, "D": 0}]

response = vetiver.predict(endpoint=client, data=data)
response = vetiver.predict(endpoint=vetiver_client, data=data)

assert response.iloc[0, 0] == 0.0
assert len(response) == 1


def test_batch(build_sm):
api = vetiver.VetiverAPI(build_sm)
client = TestClient(api.app)
def test_batch(vetiver_client):
data = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD"))

response = vetiver.predict(endpoint=client, data=data)
response = vetiver.predict(endpoint=vetiver_client, data=data)

assert len(response) == 100


def test_no_ptype(build_sm):
api = vetiver.VetiverAPI(build_sm, check_ptype=False)
client = TestClient(api.app)
def test_no_ptype(vetiver_client_check_ptype_false):
data = [0, 0, 0]

response = vetiver.predict(endpoint=client, data=data)
response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data)

assert response.iloc[0, 0] == 0.0
assert len(response) == 1


def test_serialize(build_sm):
def test_serialize(sm_model):
import pins

board = pins.board_temp(allow_pickle_read=True)
vetiver.vetiver_pin_write(board=board, model=build_sm)
vetiver.vetiver_pin_write(board=board, model=sm_model)
assert isinstance(
board.pin_read("glm"),
statsmodels.genmod.generalized_linear_model.GLMResultsWrapper,
Expand Down
42 changes: 27 additions & 15 deletions vetiver/tests/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@pytest.fixture
def build_xgb():
def xgb_model():
# read in data
dtrain = xgb.DMatrix(mtcars.drop(columns="mpg"), label=mtcars["mpg"])
# specify parameters via map
Expand All @@ -27,44 +27,56 @@ def build_xgb():
return vetiver.VetiverModel(fit, "xgb", mtcars.drop(columns="mpg"))


def test_vetiver_build(build_xgb):
api = vetiver.VetiverAPI(build_xgb)
client = TestClient(api.app)
@pytest.fixture
def vetiver_client(xgb_model): # With check_ptype=True
app = vetiver.VetiverAPI(xgb_model, check_ptype=True)
app.app.root_path = "/predict"
client = TestClient(app.app)

return client


@pytest.fixture
def vetiver_client_check_ptype_false(xgb_model): # With check_ptype=True
app = vetiver.VetiverAPI(xgb_model, check_ptype=False)
app.app.root_path = "/predict"
client = TestClient(app.app)

return client


def test_vetiver_build(vetiver_client):
data = mtcars.head(1).drop(columns="mpg")

response = vetiver.predict(endpoint=client, data=data)
response = vetiver.predict(endpoint=vetiver_client, data=data)

assert response.iloc[0, 0] == 21.064373016357422
assert len(response) == 1


def test_batch(build_xgb):
api = vetiver.VetiverAPI(build_xgb)
client = TestClient(api.app)
def test_batch(vetiver_client):
data = mtcars.head(3).drop(columns="mpg")

response = vetiver.predict(endpoint=client, data=data)
response = vetiver.predict(endpoint=vetiver_client, data=data)

assert response.iloc[0, 0] == 21.064373016357422
assert len(response) == 3


def test_no_ptype(build_xgb):
api = vetiver.VetiverAPI(build_xgb, check_ptype=False)
client = TestClient(api.app)
def test_no_ptype(vetiver_client_check_ptype_false):
data = mtcars.head(1).drop(columns="mpg")

response = vetiver.predict(endpoint=client, data=data)
response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data)

assert response.iloc[0, 0] == 21.064373016357422
assert len(response) == 1


def test_serialize(build_xgb):
def test_serialize(xgb_model):
import pins

board = pins.board_temp(allow_pickle_read=True)
vetiver.vetiver_pin_write(board=board, model=build_xgb)
vetiver.vetiver_pin_write(board=board, model=xgb_model)
assert isinstance(
board.pin_read("xgb"),
xgb.Booster,
Expand Down