Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .[dev,torch]
python -m pip install -e .[dev,torch,statsmodels]
- name: Run Tests
run: |
pytest -m 'not rsc_test' --cov --cov-report xml
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,6 @@ dev =

torch =
torch

statsmodels =
statsmodels
1 change: 1 addition & 0 deletions vetiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .handlers.base import BaseHandler, create_handler, InvalidModelError # noqa
from .handlers.sklearn import SKLearnHandler # noqa
from .handlers.torch import TorchHandler # noqa
from .handlers.statsmodels import StatsmodelsHandler # noqa
from .rsconnect import deploy_rsconnect # noqa
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa

Expand Down
12 changes: 6 additions & 6 deletions vetiver/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class InvalidModelError(Exception):

def __init__(
self,
message="The `model` argument must be a scikit-learn or torch model.",
message="The `model` argument must be a supported or custom type.",
):
self.message = message
super().__init__(self.message)
Expand Down Expand Up @@ -47,9 +47,9 @@ def create_handler(model, ptype_data):
"""

raise InvalidModelError(
"Model must be an sklearn or torch model, or a \
custom handler must be used. See the docs for more info on custom handlers. \
https://rstudio.github.io/vetiver-python/advancedusage/custom_handler.html"
"Model must be a supported model type, or a "
"custom handler must be used. See the docs for more info on custom handlers and"
"supported types https://rstudio.github.io/vetiver-python/"
)


Expand Down Expand Up @@ -88,13 +88,13 @@ def create_meta(
url: str = None,
required_pkgs: list = [],
):
"""Create metadata for sklearn model"""
"""Create metadata for a model"""
meta = _model_meta(user, version, url, required_pkgs)

return meta

def construct_ptype(self):
"""Create data prototype for torch model
"""Create data prototype for a model

Parameters
----------
Expand Down
69 changes: 69 additions & 0 deletions vetiver/handlers/statsmodels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pandas as pd

from ..meta import _model_meta
from .base import BaseHandler

sm_exists = True
try:
import statsmodels.api
except ImportError:
sm_exists = False


class StatsmodelsHandler(BaseHandler):
"""Handler class for creating VetiverModels with statsmodels.

Parameters
----------
model : statsmodels
a trained and fit statsmodels model
"""

model_class = staticmethod(lambda: statsmodels.base.wrapper.ResultsWrapper)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for models with a predict method, using this class seems to be sufficient to identify statsmodels models


def __init__(self, model, ptype_data):
super().__init__(model, ptype_data)

def describe(self):
"""Create description for statsmodels model"""
desc = f"Statsmodels {self.model.__class__} model."
return desc

def create_meta(
user: list = None,
version: str = None,
url: str = None,
required_pkgs: list = [],
):
"""Create metadata for statsmodel"""
required_pkgs = required_pkgs + ["statsmodels"]
meta = _model_meta(user, version, url, required_pkgs)

return meta

def handler_predict(self, input_data, check_ptype):
"""Generates method for /predict endpoint in VetiverAPI

The `handler_predict` function executes at each API call. Use this
function for calling `predict()` and any other tasks that must be executed
at each API call.

Parameters
----------
input_data:
Test data

Returns
-------
prediction
Prediction from model
"""
if sm_exists:
if isinstance(input_data, (list, pd.DataFrame)):
prediction = self.model.predict(input_data)
else:
prediction = self.model.predict([input_data])
else:
raise ImportError("Cannot import `statsmodels`")

return prediction
2 changes: 1 addition & 1 deletion vetiver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ async def prediction(

@app.post("/predict")
async def prediction(input_data: Request):

y = await input_data.json()

prediction = self.model.handler_predict(y, check_ptype=self.check_ptype)

return {"prediction": prediction.tolist()}
Expand Down
67 changes: 67 additions & 0 deletions vetiver/tests/test_statsmodels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest

sm = pytest.importorskip("statsmodels.api", reason="statsmodels library not installed")

statsmodels = pytest.importorskip(
"statsmodels", reason="statsmodels library not installed"
)

import numpy as np # noqa
import pandas as pd # noqa
from fastapi.testclient import TestClient # noqa

import vetiver # noqa


@pytest.fixture
def build_sm():

X, y = vetiver.get_mock_data()
glm = sm.GLM(y, X).fit()

v = vetiver.VetiverModel(glm, "glm", X)
return v


def test_vetiver_build(build_sm):
api = vetiver.VetiverAPI(build_sm)
client = TestClient(api.app)
data = [{"B": 0, "C": 0, "D": 0}]

response = vetiver.predict(endpoint=client, data=data)

assert response.iloc[0, 0] == 0.0
assert len(response) == 1


def test_batch(build_sm):
api = vetiver.VetiverAPI(build_sm)
client = TestClient(api.app)
data = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD"))

response = vetiver.predict(endpoint=client, data=data)

assert len(response) == 100


def test_no_ptype(build_sm):
api = vetiver.VetiverAPI(build_sm, check_ptype=False)
client = TestClient(api.app)
data = [0, 0, 0]

response = vetiver.predict(endpoint=client, data=data)

assert response.iloc[0, 0] == 0.0
assert len(response) == 1


def test_serialize(build_sm):
import pins

board = pins.board_temp(allow_pickle_read=True)
vetiver.vetiver_pin_write(board=board, model=build_sm)
assert isinstance(
board.pin_read("glm"),
statsmodels.genmod.generalized_linear_model.GLMResultsWrapper,
)
board.pin_delete("glm")