Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,20 @@ jobs:
- name: Run tests
run: |
pytest vetiver -m 'rsc_test'

test-no-torch:
name: "Test no-torch"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .[dev]

- name: Run tests
run: |
pytest vetiver/tests/test_sklearn.py
52 changes: 28 additions & 24 deletions docs/source/advancedusage/custom_handler.md
Original file line number Diff line number Diff line change
@@ -1,41 +1,45 @@
# Custom Handlers

There are two different ways that vetiver supports flexible handling for models that do not work automatically with the vetiver framework. The first way is with [new model types,](#new-model-type) where there is no current implementation for the type of model you would like to deploy. The second way is when you would like to implement a current handler, but [in a different way](#different-model-implementation). In either case, you *must* make a custom handler from the base `VetiverHandler`. A minimal custom handler could look like the following:
There are two different ways that vetiver supports flexible handling for models that do not work automatically with the vetiver framework. The first way is with new model types where there is no current implementation for the type of model you would like to deploy. The second way is when you would like to implement a current handler, but in a different way. In either case, you should create a custom handler from vetiver's `BaseHandler()`. At a minimum, you must give the type of your model via `model_type` how predictions should be made, via the method `handler_predict()`. Then, initialize your handler with your model, and pass the object into `VetiverModel`.

This example shows a custom handler of `newmodeltype` type.

```python
from vetiver.handlers.base import VetiverHandler
from vetiver.handlers.base import BaseHandler

class SampleCustomHandler(VetiverHandler):
class CustomHandler(BaseHandler):
def __init__(model, ptype_data):
super().__init__(model, ptype_data)

def handler_predict(self, input_data, check_ptype):
model_type = staticmethod(lambda: newmodeltype)

def handler_predict(self, input_data, check_ptype: bool):
"""
handler_predict defines how to make predictions from your model
Generates method for /predict endpoint in VetiverAPI

The `handler_predict` function executes at each API call. Use this
function for calling `predict()` and any other tasks that must be executed at each API call.

Parameters
----------
input_data:
Test data
check_ptype: bool
Whether the ptype should be enforced

Returns
-------
prediction
Prediction from model
"""
# your code here
```
prediction = model.fancy_new_predict(input_data)

## New model type
If your model type is not supported by vetiver, you should create and then register the handler using [single dispatch](https://docs.python.org/3/library/functools.html#functools.singledispatch). Once the new type is registered, you are able to use `VetiverModel()` as normal. Here is a template for such a function:

```python
from vetiver.handlers._interface import create_handler
return prediction

@create_handler.register
def _(model: {_model_type}, ptype_data):
return SampleCustomHandler(model, ptype_data)
new_model = CustomHandler(model, ptype_data)

VetiverModel(your_model, "your_model")
VetiverModel(new_model, "custom_model")
```

If your model is a common type, please consider [submitting a pull request](https://github.com/rstudio/vetiver-python/pulls).

## Different model implementation
If your model's prediction function is different than vetiver's, you should create a custom handler with a `handler_predict` method to make predictions. Then, initialize your handler with your model, and pass the object into `VetiverModel`.

```python
new_model = SampleCustomHandler(your_model, your_ptype_data)

VetiverModel(new_model, "your_model")
```
3 changes: 1 addition & 2 deletions vetiver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from .meta import * # noqa
from .write_docker import write_docker # noqa
from .write_fastapi import write_app # noqa
from .handlers._interface import create_handler, InvalidModelError # noqa
from .handlers.base import VetiverHandler # noqa
from .handlers.base import BaseHandler, create_handler, InvalidModelError # noqa
from .handlers.sklearn import SKLearnHandler # noqa
from .handlers.torch import TorchHandler # noqa
from .rsconnect import deploy_rsconnect # noqa
Expand Down
99 changes: 0 additions & 99 deletions vetiver/handlers/_interface.py

This file was deleted.

74 changes: 72 additions & 2 deletions vetiver/handlers/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,62 @@
from abc import ABCMeta
from vetiver.handlers import base
from functools import singledispatch
from contextlib import suppress

from ..ptype import vetiver_create_ptype
from ..meta import _model_meta


class VetiverHandler(metaclass=ABCMeta):
class InvalidModelError(Exception):
"""
Throw an error if `model` is not registered.
"""

def __init__(
self,
message="The `model` argument must be a scikit-learn or torch model.",
):
self.message = message
super().__init__(self.message)


@singledispatch
def create_handler(model, ptype_data):
"""check for model type to handle prediction

Parameters
----------
model: object
Description of parameter `x`.
ptype_data : object
An object with information (data) whose layout is to be determined.

Returns
-------
handler
Handler class for specified model type


Examples
--------
>>> import vetiver
>>> X, y = vetiver.mock.get_mock_data()
>>> model = vetiver.mock.get_mock_model()
>>> handler = vetiver.create_handler(model, X)
>>> handler.describe()
"Scikit-learn <class 'sklearn.dummy.DummyRegressor'> model"
"""

raise InvalidModelError(
"Model must be an sklearn or torch model, or a \
custom handler must be used. See the docs for more info on custom handlers. \
https://rstudio.github.io/vetiver-python/advancedusage/custom_handler.html"
)


# BaseHandler uses create_handler to register subclasses based on model_class


class BaseHandler:
"""Base handler class for creating VetiverModel of different type.

Parameters
Expand All @@ -15,6 +67,12 @@ class VetiverHandler(metaclass=ABCMeta):
An object with information (data) whose layout is to be determined.
"""

@classmethod
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
with suppress(AttributeError, NameError):
create_handler.register(cls.model_class(), cls)

def __init__(self, model, ptype_data):
self.model = model
self.ptype_data = ptype_data
Expand Down Expand Up @@ -79,3 +137,15 @@ def handler_predict(self, input_data, check_ptype):
Prediction from model
"""
...


# BaseHandler for subclassing, Handler for new model types
Handler = BaseHandler


@create_handler.register
def _(model: base.BaseHandler, ptype_data):
if model.ptype_data is None and ptype_data is not None:
model.ptype_data = ptype_data

return model
8 changes: 4 additions & 4 deletions vetiver/handlers/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import sklearn

from ..meta import _model_meta
from .base import VetiverHandler
from .base import BaseHandler


class SKLearnHandler(VetiverHandler):
class SKLearnHandler(BaseHandler):
"""Handler class for creating VetiverModels with sklearn.

Parameters
Expand All @@ -14,7 +14,7 @@ class SKLearnHandler(VetiverHandler):
a trained sklearn model
"""

base_class = sklearn.base.BaseEstimator
model_class = staticmethod(lambda: sklearn.base.BaseEstimator)

def __init__(self, model, ptype_data):
super().__init__(model, ptype_data)
Expand Down Expand Up @@ -54,7 +54,7 @@ def handler_predict(self, input_data, check_ptype):
Prediction from model
"""

if check_ptype == True:
if check_ptype:
if isinstance(input_data, pd.DataFrame):
prediction = self.model.predict(input_data)
else:
Expand Down
8 changes: 4 additions & 4 deletions vetiver/handlers/torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from ..meta import _model_meta
from .base import VetiverHandler
from .base import BaseHandler

torch_exists = True
try:
Expand All @@ -10,7 +10,7 @@
torch_exists = False


class TorchHandler(VetiverHandler):
class TorchHandler(BaseHandler):
"""Handler class for creating VetiverModels with torch.

Parameters
Expand All @@ -19,7 +19,7 @@ class TorchHandler(VetiverHandler):
a trained torch model
"""

base_class = torch.nn.Module
model_class = staticmethod(lambda: torch.nn.Module)

def __init__(self, model, ptype_data):
super().__init__(model, ptype_data)
Expand Down Expand Up @@ -59,7 +59,7 @@ def handler_predict(self, input_data, check_ptype):
Prediction from model
"""
if torch_exists:
if check_ptype == True:
if check_ptype:
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
prediction = self.model(torch.from_numpy(input_data))

Expand Down
6 changes: 4 additions & 2 deletions vetiver/tests/test_custom_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import pydantic
import pandas as pd

from vetiver import mock, VetiverModel, VetiverHandler
from vetiver import mock, VetiverModel, BaseHandler


class CustomHandler(VetiverHandler):
class CustomHandler(BaseHandler):
def __init__(self, model, ptype_data):
super().__init__(model, ptype_data)

model_type = staticmethod(lambda: sklearn.dummy.DummyRegressor)

def handler_predict(self, input_data, check_ptype):
if check_ptype is True:
if isinstance(input_data, pd.DataFrame):
Expand Down
Loading