Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ repos:
# line too long and line before binary operator (black is ok with these)
types:
- python
args:
- "--max-line-length=90"
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
Expand Down
29 changes: 18 additions & 11 deletions vetiver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import uvicorn
import requests
import pandas as pd
import numpy as np
from typing import Callable, Union, List

from . import __version__
Expand Down Expand Up @@ -65,7 +64,7 @@ async def ping():

if self.check_ptype is True:

@app.post("/predict/")
@app.post("/predict")
async def prediction(
input_data: Union[self.model.ptype, List[self.model.ptype]]
):
Expand All @@ -82,7 +81,7 @@ async def prediction(

elif self.check_ptype is False:

@app.post("/predict/")
@app.post("/predict")
async def prediction(input_data: Request):

y = await input_data.json()
Expand All @@ -99,17 +98,25 @@ async def rapidoc():
<!doctype html>
<html>
<head>
<meta name="viewport" content="width=device-width,minimum-scale=1,initial-scale=1,user-scalable=yes">
<meta name="viewport"
content="width=device-width,minimum-scale=1,initial-scale=1,user-scalable=yes">
<title>RapiDoc</title>
<script type="module" src="https://unpkg.com/rapidoc@9.1.3/dist/rapidoc-min.js"></script>
<script type="module"
src="https://unpkg.com/rapidoc@9.3.3/dist/rapidoc-min.js"></script>
</script></head>
<body>
<rapi-doc spec-url="{self.app.openapi_url[1:]}"
id="thedoc" render-style="read" schema-style="tree"
show-components="true" show-info="true" show-header="true"
id="thedoc"
render-style="read"
schema-style="tree"
show-components="true"
show-info="true"
show-header="true"
allow-search="true"
show-side-nav="false"
allow-authentication="false" update-route="false" match-type="regex"
allow-authentication="false"
update-route="false"
match-type="regex"
theme="light"
header-color="#F2C6AC"
primary-color = "#8C2D2D">
Expand Down Expand Up @@ -143,15 +150,15 @@ def vetiver_post(
"""
if self.check_ptype is True:

@self.app.post("/" + endpoint_name + "/")
@self.app.post("/" + endpoint_name)
async def custom_endpoint(input_data: self.model.ptype):
y = _prepare_data(input_data)
new = endpoint_fx(pd.Series(y))
return {endpoint_name: new.tolist()}

else:

@self.app.post("/" + endpoint_name + "/")
@self.app.post("/" + endpoint_name)
async def custom_endpoint(input_data: Request):
y = await input_data.json()
new = endpoint_fx(pd.Series(y))
Expand Down Expand Up @@ -204,7 +211,7 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
"""
if isinstance(endpoint, testclient.TestClient):
requester = endpoint
endpoint = "/predict/"
endpoint = "/predict"
else:
requester = requests

Expand Down
4 changes: 2 additions & 2 deletions vetiver/tests/test_add_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_endpoint_adds_ptype():

client = TestClient(app)
data = {"B": 0, "C": 0, "D": 0}
response = client.post("/sum/", json=data)
response = client.post("/sum", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"sum": 0}, response.json()

Expand All @@ -38,6 +38,6 @@ def test_endpoint_adds_no_ptype():

client = TestClient(app)
data = [0, 0, 0]
response = client.post("/sum/", json=data)
response = client.post("/sum", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"sum": 0}, response.json()
10 changes: 5 additions & 5 deletions vetiver/tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_torch_predict_ptype():

client = TestClient(v_api.app)
data = {"0": 3.3}
response = client.post("/predict/", json=data)
response = client.post("/predict", json=data)

assert response.status_code == 200, response.text
assert response.json() == {"prediction": [-4.060722351074219]}, response.text
Expand All @@ -77,7 +77,7 @@ def test_torch_predict_ptype_batch():

client = TestClient(v_api.app)
data = [{"0": 3.3}, {"0": 3.3}]
response = client.post("/predict/", json=data)
response = client.post("/predict", json=data)

assert response.status_code == 200, response.text
assert response.json() == {
Expand All @@ -93,7 +93,7 @@ def test_torch_predict_ptype_error():

client = TestClient(v_api.app)
data = {"0": "bad"}
response = client.post("/predict/", json=data)
response = client.post("/predict", json=data)

assert response.status_code == 422, response.text # value is not a valid float

Expand All @@ -106,7 +106,7 @@ def test_torch_predict_no_ptype_batch():

client = TestClient(v_api.app)
data = [[3.3], [3.3]]
response = client.post("/predict/", json=data)
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {
"prediction": [[-4.060722351074219], [-4.060722351074219]]
Expand All @@ -121,6 +121,6 @@ def test_torch_predict_no_ptype():

client = TestClient(v_api.app)
data = [[3.3]]
response = client.post("/predict/", json=data)
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [[-4.060722351074219]]}, response.text
12 changes: 6 additions & 6 deletions vetiver/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_predict_endpoint_ptype():
np.random.seed(500)
client = TestClient(_start_application().app)
data = {"B": 0, "C": 0, "D": 0}
response = client.post("/predict/", json=data)
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47]}, response.json()

Expand All @@ -33,7 +33,7 @@ def test_predict_endpoint_ptype_batch():
np.random.seed(500)
client = TestClient(_start_application().app)
data = [{"B": 0, "C": 0, "D": 0}, {"B": 0, "C": 0, "D": 0}]
response = client.post("/predict/", json=data)
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()

Expand All @@ -42,15 +42,15 @@ def test_predict_endpoint_ptype_error():
np.random.seed(500)
client = TestClient(_start_application().app)
data = {"B": 0, "C": "a", "D": 0}
response = client.post("/predict/", json=data)
response = client.post("/predict", json=data)
assert response.status_code == 422, response.text # value is not a valid integer


def test_predict_endpoint_no_ptype():
np.random.seed(500)
client = TestClient(_start_application(save_ptype=False).app)
data = "0,0,0"
response = client.post("/predict/", json=data)
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47]}, response.json()

Expand All @@ -59,7 +59,7 @@ def test_predict_endpoint_no_ptype_batch():
np.random.seed(500)
client = TestClient(_start_application(save_ptype=False).app)
data = [["0,0,0"], ["0,0,0"]]
response = client.post("/predict/", json=data)
response = client.post("/predict", json=data)
assert response.status_code == 200, response.text
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()

Expand All @@ -69,4 +69,4 @@ def test_predict_endpoint_no_ptype_error():
client = TestClient(_start_application(save_ptype=False).app)
data = {"hell0", 9, 32.0}
with pytest.raises(TypeError):
client.post("/predict/", json=data)
client.post("/predictt", json=data)