-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tested apis; Added benchmark results to README
- Loading branch information
1 parent
22fa794
commit d123d15
Showing
7 changed files
with
180 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import sys | ||
|
||
sys.path.append("..") | ||
|
||
import requests | ||
import threading | ||
|
||
from ml_pool.utils import timer | ||
|
||
|
||
URL = "http://127.0.0.1:8000/iris" | ||
CLIENTS = 20 | ||
REQUESTS_PER_CLIENT = 50 | ||
|
||
|
||
def client(index, features): | ||
for i in range(REQUESTS_PER_CLIENT): | ||
response = requests.post(url=URL, json={"features": features}) | ||
print( | ||
f"Client {index} got {i} / {REQUESTS_PER_CLIENT} " | ||
f"response {response.json()}" | ||
) | ||
|
||
|
||
@timer | ||
def main(): | ||
threads = [ | ||
threading.Thread(target=client, args=(i, [6.2, 2.2, 4.5, 1.5])) | ||
for i in range(CLIENTS) | ||
] | ||
for thread in threads: | ||
thread.start() | ||
|
||
for thread in threads: | ||
thread.join() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import sys | ||
|
||
sys.path.append("..") | ||
|
||
from functools import partial | ||
|
||
from fastapi import FastAPI | ||
import pydantic | ||
import xgboost | ||
import numpy as np | ||
import uvicorn | ||
|
||
from ml_pool import MLPool | ||
from ml_pool.logger import get_logger | ||
|
||
|
||
logger = get_logger("api") | ||
|
||
app = FastAPI() | ||
|
||
|
||
def load_model(model_path: str): | ||
model = xgboost.Booster() | ||
model.load_model(model_path) | ||
return model | ||
|
||
|
||
def score_model(model, features): | ||
# Imitates a heavy model that takes time to score + feature engineering | ||
# could also be unloaded to the worker pool | ||
sum_ = 0 | ||
for i in range(10_000_000): | ||
sum_ += 1 | ||
|
||
features = xgboost.DMatrix([features]) | ||
return np.argmax(model.predict(features)) | ||
|
||
|
||
class Request(pydantic.BaseModel): | ||
features: list[float] | ||
|
||
|
||
class Response(pydantic.BaseModel): | ||
prediction: int | ||
|
||
|
||
@app.get("/") | ||
def health_check(): | ||
return {"Message": "Up and running"} | ||
|
||
|
||
@app.post("/iris") | ||
def score(request: Request) -> Response: | ||
logger.info(f"Got request for features: {request}") | ||
job_id = pool.schedule_model_scoring(features=request.features) | ||
result = pool.get_scoring_result(job_id, wait_if_not_available=True) | ||
return Response(prediction=result) | ||
|
||
|
||
if __name__ == "__main__": | ||
with MLPool( | ||
load_model_func=partial(load_model, "iris_xgb.json"), | ||
score_model_func=score_model, | ||
) as pool: | ||
uvicorn.run(app, workers=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import sys | ||
|
||
sys.path.append("..") | ||
|
||
from fastapi import FastAPI | ||
import pydantic | ||
import xgboost | ||
import numpy as np | ||
import uvicorn | ||
|
||
from ml_pool.logger import get_logger | ||
|
||
|
||
logger = get_logger("api") | ||
|
||
app = FastAPI() | ||
|
||
|
||
def load_model(model_path: str): | ||
model = xgboost.Booster() | ||
model.load_model(model_path) | ||
logger.info("Model loaded") | ||
return model | ||
|
||
|
||
model = load_model("iris_xgb.json") | ||
|
||
|
||
def score_model(model, features): | ||
# Imitates a heavy model that takes time to score + feature engineering | ||
# could also be unloaded to the worker pool | ||
sum_ = 0 | ||
for i in range(10_000_000): | ||
sum_ += 1 | ||
|
||
features = xgboost.DMatrix([features]) | ||
return np.argmax(model.predict(features)) | ||
|
||
|
||
class Request(pydantic.BaseModel): | ||
features: list[float] | ||
|
||
|
||
class Response(pydantic.BaseModel): | ||
prediction: int | ||
|
||
|
||
@app.get("/") | ||
def health_check(): | ||
return {"Message": "Up and running"} | ||
|
||
|
||
@app.post("/iris") | ||
def score(request: Request) -> Response: | ||
logger.info(f"Got request for features: {request}") | ||
result = score_model(model, request.features) | ||
return Response(prediction=result) | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run(app, workers=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters