Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vetiver/handlers/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class SKLearnHandler(BaseHandler):
pip_name = "scikit-learn"

def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
"""
Generates method for /predict endpoint in VetiverAPI

The `handler_predict` function executes at each API call. Use this
function for calling `predict()` and any other tasks that must be executed
Expand All @@ -30,7 +31,7 @@ def handler_predict(self, input_data, check_prototype):

Returns
-------
prediction
prediction:
Prediction from model
"""

Expand Down
7 changes: 5 additions & 2 deletions vetiver/handlers/spacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def construct_prototype(self):
return prototype

def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
"""
Generates method for /predict endpoint in VetiverAPI

The `handler_predict` function executes at each API call. Use this
function for calling `predict()` and any other tasks that must be executed
Expand All @@ -63,7 +64,9 @@ def handler_predict(self, input_data, check_prototype):
Parameters
----------
input_data:
Test data
Test data. The SpacyHandler expects an input of a 1 column DataFrame with
the same column names as the prototype data, or column name "text" if no
prototype was given.

Returns
-------
Expand Down
3 changes: 2 additions & 1 deletion vetiver/handlers/statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class StatsmodelsHandler(BaseHandler):
pip_name = "statsmodels"

def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
"""
Generates method for /predict endpoint in VetiverAPI

The `handler_predict` function executes at each API call. Use this
function for calling `predict()` and any other tasks that must be executed
Expand Down
3 changes: 2 additions & 1 deletion vetiver/handlers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class TorchHandler(BaseHandler):
pip_name = "torch"

def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
"""
Generates method for /predict endpoint in VetiverAPI

The `handler_predict` function executes at each API call. Use this
function for calling `predict()` and any other tasks that must be executed
Expand Down
3 changes: 2 additions & 1 deletion vetiver/handlers/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class XGBoostHandler(BaseHandler):
pip_name = "xgboost"

def handler_predict(self, input_data, check_prototype):
"""Generates method for /predict endpoint in VetiverAPI
"""
Generates method for /predict endpoint in VetiverAPI

The `handler_predict` function executes at each API call. Use this
function for calling `predict()` and any other tasks that must be executed
Expand Down
15 changes: 13 additions & 2 deletions vetiver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.responses import PlainTextResponse
from textwrap import dedent
from warnings import warn

from .utils import _jupyter_nb
Expand Down Expand Up @@ -105,10 +106,12 @@ def pin_url():

@app.get("/ping", include_in_schema=True)
async def ping():
"""Ping endpoint for health check"""
return {"ping": "pong"}

@app.get("/metadata")
async def get_metadata():
"""Get metadata from model"""
return self.model.metadata.to_dict()

self.vetiver_post(
Expand Down Expand Up @@ -183,13 +186,21 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
if not endpoint_name:
endpoint_name = endpoint_fx.__name__

if endpoint_fx.__doc__ is not None:
api_desc = dedent(endpoint_fx.__doc__)
else:
api_desc = None

if self.check_prototype is True:

@self.app.post(urljoin("/", endpoint_name), name=endpoint_name)
@self.app.post(
urljoin("/", endpoint_name),
name=endpoint_name,
description=api_desc,
)
async def custom_endpoint(input_data: List[self.model.prototype]):
_to_frame = api_data_to_frame(input_data)
predictions = endpoint_fx(_to_frame, **kw)

if isinstance(predictions, List):
return {endpoint_name: predictions}
else:
Expand Down