Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .[dev,torch,statsmodels]
python -m pip install -e .[dev,torch,statsmodels,xgboost]
- name: Run Tests
run: |
pytest -m 'not rsc_test' --cov --cov-report xml
Expand Down Expand Up @@ -65,8 +65,8 @@ jobs:
run: |
pytest vetiver -m 'rsc_test'

test-no-torch:
name: "Test no-torch"
test-no-extras:
name: "Test no exra ml frameworks"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
Expand Down
14 changes: 14 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ You can use vetiver with:

- `scikit-learn <https://scikit-learn.org/stable/>`_
- `pytorch <https://pytorch.org/>`_
- `statsmodels <https://www.statsmodels.org/>`_
- `xgboost <https://xgboost.readthedocs.io/>`_

You can install the released version of vetiver from `PyPI <https://pypi.org/project/vetiver/>`_:

Expand Down Expand Up @@ -65,6 +67,18 @@ Monitor
~pin_metrics
~plot_metrics

Model Handlers
==================
.. autosummary::
:toctree: reference/
:caption: Monitor

~BaseHandler
~SKLearnHandler
~TorchHandler
~StatsmodelsHandler
~XGBoostHandler

Advanced Usage
==================
.. toctree::
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,6 @@ torch =

statsmodels =
statsmodels

xgboost =
xgboost
1 change: 1 addition & 0 deletions vetiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .handlers.sklearn import SKLearnHandler # noqa
from .handlers.torch import TorchHandler # noqa
from .handlers.statsmodels import StatsmodelsHandler # noqa
from .handlers.xgboost import XGBoostHandler # noqa
from .rsconnect import deploy_rsconnect # noqa
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa

Expand Down
19 changes: 4 additions & 15 deletions vetiver/handlers/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@ class SKLearnHandler(BaseHandler):

model_class = staticmethod(lambda: sklearn.base.BaseEstimator)

def __init__(self, model, ptype_data):
super().__init__(model, ptype_data)

def describe(self):
"""Create description for sklearn model"""
desc = f"Scikit-learn {self.model.__class__} model"
return desc

def construct_meta(
def create_meta(
user: list = None,
version: str = None,
url: str = None,
Expand Down Expand Up @@ -54,17 +51,9 @@ def handler_predict(self, input_data, check_ptype):
Prediction from model
"""

if check_ptype:
if isinstance(input_data, pd.DataFrame):
prediction = self.model.predict(input_data)
else:
prediction = self.model.predict([input_data])

# do not check ptype
else:
if not isinstance(input_data, list):
input_data = [input_data.split(",")] # user delimiter ?

if not check_ptype or isinstance(input_data, pd.DataFrame):
prediction = self.model.predict(input_data)
else:
prediction = self.model.predict([input_data])

return prediction
12 changes: 6 additions & 6 deletions vetiver/handlers/statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def handler_predict(self, input_data, check_ptype):
prediction
Prediction from model
"""
if sm_exists:
if isinstance(input_data, (list, pd.DataFrame)):
prediction = self.model.predict(input_data)
else:
prediction = self.model.predict([input_data])
else:
if not sm_exists:
raise ImportError("Cannot import `statsmodels`")

if isinstance(input_data, (list, pd.DataFrame)):
prediction = self.model.predict(input_data)
else:
prediction = self.model.predict([input_data])

return prediction
21 changes: 8 additions & 13 deletions vetiver/handlers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class TorchHandler(BaseHandler):

model_class = staticmethod(lambda: torch.nn.Module)

def __init__(self, model, ptype_data):
super().__init__(model, ptype_data)

def describe(self):
"""Create description for torch model"""
desc = f"Pytorch model of type {type(self.model)}"
Expand Down Expand Up @@ -58,17 +55,15 @@ def handler_predict(self, input_data, check_ptype):
prediction
Prediction from model
"""
if torch_exists:
if check_ptype:
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
prediction = self.model(torch.from_numpy(input_data))

# do not check ptype
else:
input_data = torch.tensor(input_data)
prediction = self.model(input_data)
if not torch_exists:
raise ImportError("Cannot import `torch`.")
if check_ptype:
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
prediction = self.model(torch.from_numpy(input_data))

# do not check ptype
else:
raise ImportError("Cannot import `torch`.")
input_data = torch.tensor(input_data)
prediction = self.model(input_data)

return prediction
71 changes: 71 additions & 0 deletions vetiver/handlers/xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pandas as pd

from ..meta import _model_meta
from .base import BaseHandler

xgb_exists = True
try:
import xgboost
except ImportError:
xgb_exists = False


class XGBoostHandler(BaseHandler):
"""Handler class for creating VetiverModels with xgboost.

Parameters
----------
model :
a trained and fit xgboost model
"""

model_class = staticmethod(lambda: xgboost.Booster)

def describe(self):
"""Create description for xgboost model"""
desc = f"XGBoost {self.model.__class__} model."
return desc

def create_meta(
user: list = None,
version: str = None,
url: str = None,
required_pkgs: list = [],
):
"""Create metadata for xgboost"""
required_pkgs = required_pkgs + ["xgboost"]
meta = _model_meta(user, version, url, required_pkgs)

return meta

def handler_predict(self, input_data, check_ptype):
"""Generates method for /predict endpoint in VetiverAPI

The `handler_predict` function executes at each API call. Use this
function for calling `predict()` and any other tasks that must be executed
at each API call.

Parameters
----------
input_data:
Test data

Returns
-------
prediction
Prediction from model
"""

if not xgb_exists:
raise ImportError("Cannot import `xgboost`")

if not isinstance(input_data, pd.DataFrame):
try:
input_data = pd.DataFrame(input_data)
except ValueError:
raise (f"Expected a dict or DataFrame, got {type(input_data)}")
input_data = xgboost.DMatrix(input_data)

prediction = self.model.predict(input_data)

return prediction
6 changes: 3 additions & 3 deletions vetiver/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_predict_endpoint_ptype_error():
def test_predict_endpoint_no_ptype():
np.random.seed(500)
client = TestClient(_start_application(save_ptype=False).app)
data = "0,0,0"
data = [{"B": 0, "C": 0, "D": 0}]
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47]}, response.json()
Expand All @@ -58,7 +58,7 @@ def test_predict_endpoint_no_ptype():
def test_predict_endpoint_no_ptype_batch():
np.random.seed(500)
client = TestClient(_start_application(save_ptype=False).app)
data = [["0,0,0"], ["0,0,0"]]
data = [[0, 0, 0], [0, 0, 0]]
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()
Expand All @@ -69,4 +69,4 @@ def test_predict_endpoint_no_ptype_error():
client = TestClient(_start_application(save_ptype=False).app)
data = {"hell0", 9, 32.0}
with pytest.raises(TypeError):
client.post("/predictt", json=data)
client.post("/predict", json=data)
72 changes: 72 additions & 0 deletions vetiver/tests/test_xgboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest

xgb = pytest.importorskip("xgboost", reason="xgboost library not installed")

from vetiver.data import mtcars # noqa
from vetiver.handlers.xgboost import XGBoostHandler # noqa
import numpy as np # noqa
from fastapi.testclient import TestClient # noqa

import vetiver # noqa


@pytest.fixture
def build_xgb():
# read in data
dtrain = xgb.DMatrix(mtcars.drop(columns="mpg"), label=mtcars["mpg"])
# specify parameters via map
param = {
"max_depth": 2,
"eta": 1,
"objective": "reg:squarederror",
"random_state": 0,
}
num_round = 2
fit = xgb.train(param, dtrain, num_round)

return vetiver.VetiverModel(fit, "xgb", mtcars.drop(columns="mpg"))


def test_vetiver_build(build_xgb):
api = vetiver.VetiverAPI(build_xgb)
client = TestClient(api.app)
data = mtcars.head(1).drop(columns="mpg")

response = vetiver.predict(endpoint=client, data=data)

assert response.iloc[0, 0] == 21.064373016357422
assert len(response) == 1


def test_batch(build_xgb):
api = vetiver.VetiverAPI(build_xgb)
client = TestClient(api.app)
data = mtcars.head(3).drop(columns="mpg")

response = vetiver.predict(endpoint=client, data=data)

assert response.iloc[0, 0] == 21.064373016357422
assert len(response) == 3


def test_no_ptype(build_xgb):
api = vetiver.VetiverAPI(build_xgb, check_ptype=False)
client = TestClient(api.app)
data = mtcars.head(1).drop(columns="mpg")

response = vetiver.predict(endpoint=client, data=data)

assert response.iloc[0, 0] == 21.064373016357422
assert len(response) == 1


def test_serialize(build_xgb):
import pins

board = pins.board_temp(allow_pickle_read=True)
vetiver.vetiver_pin_write(board=board, model=build_xgb)
assert isinstance(
board.pin_read("xgb"),
xgb.Booster,
)
board.pin_delete("xgb")