-
Notifications
You must be signed in to change notification settings - Fork 193
/
Copy pathendpoints.py
162 lines (128 loc) · 5.58 KB
/
endpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from fastapi.requests import Request
from fastapi.responses import Response, HTMLResponse, StreamingResponse
from fastapi.openapi.docs import get_swagger_ui_html
from typing import AsyncIterator, Optional
from ..types import (
MetadataModelResponse,
MetadataServerResponse,
InferenceRequest,
InferenceResponse,
RepositoryIndexRequest,
RepositoryIndexResponse,
)
from ..handlers import DataPlane, ModelRepositoryHandlers
from ..utils import insert_headers, extract_headers
from .responses import ServerSentEvent
from .openapi import get_openapi_schema, get_model_schema_uri, get_model_schema
from .utils import to_status_code
class Endpoints:
"""
Implementation of REST endpoints.
These take care of the REST/HTTP-specific things and then delegate the
business logic to the internal handlers.
"""
def __init__(self, data_plane: DataPlane):
self._data_plane = data_plane
async def live(self) -> Response:
is_live = await self._data_plane.live()
return Response(status_code=to_status_code(is_live))
async def ready(self) -> Response:
is_ready = await self._data_plane.ready()
return Response(status_code=to_status_code(is_ready))
async def openapi(self) -> dict:
return get_openapi_schema()
async def docs(self) -> HTMLResponse:
openapi_url = "/v2/docs/dataplane.json"
title = "MLServer API Docs"
return get_swagger_ui_html(openapi_url=openapi_url, title=title)
async def model_openapi(
self, model_name: str, model_version: Optional[str] = None
) -> dict:
# NOTE: Right now, we use the `model_metadata` method to check that the
# model exists.
# In the future, we will use this metadata to fill in more model
# details in the schema (e.g. expected inputs, etc.).
await self._data_plane.model_metadata(model_name, model_version)
return get_model_schema(model_name, model_version)
async def model_docs(
self, model_name: str, model_version: Optional[str] = None
) -> HTMLResponse:
# NOTE: Right now, we use the `model_metadata` method to check that the
# model exists.
# In the future, we will use this metadata to fill in more model
# details in the schema (e.g. expected inputs, etc.).
await self._data_plane.model_metadata(model_name, model_version)
openapi_url = get_model_schema_uri(model_name, model_version)
title = f"MLServer API Docs - {model_name}"
if model_version:
title = f"{title} ({model_version})"
return get_swagger_ui_html(openapi_url=openapi_url, title=title)
async def model_ready(
self, model_name: str, model_version: Optional[str] = None
) -> Response:
is_ready = await self._data_plane.model_ready(model_name, model_version)
return Response(status_code=to_status_code(is_ready))
async def metadata(self) -> MetadataServerResponse:
return await self._data_plane.metadata()
async def model_metadata(
self, model_name: str, model_version: Optional[str] = None
) -> MetadataModelResponse:
return await self._data_plane.model_metadata(model_name, model_version)
async def infer(
self,
raw_request: Request,
raw_response: Response,
payload: InferenceRequest,
model_name: str,
model_version: Optional[str] = None,
) -> InferenceResponse:
request_headers = dict(raw_request.headers)
insert_headers(payload, request_headers)
inference_response = await self._data_plane.infer(
payload, model_name, model_version
)
response_headers = extract_headers(inference_response)
if response_headers:
raw_response.headers.update(response_headers)
return inference_response
async def infer_stream(
self,
raw_request: Request,
payload: InferenceRequest,
model_name: str,
model_version: Optional[str] = None,
) -> StreamingResponse:
request_headers = dict(raw_request.headers)
insert_headers(payload, request_headers)
async def payloads_generator(
payload: InferenceRequest,
) -> AsyncIterator[InferenceRequest]:
yield payload
payloads = payloads_generator(payload)
infer_stream = self._data_plane.infer_stream(
payloads, model_name, model_version
)
sse_stream = _as_sse(infer_stream)
return StreamingResponse(content=sse_stream, media_type="text/event-stream")
async def _as_sse(
infer_stream: AsyncIterator[InferenceResponse],
) -> AsyncIterator[bytes]:
"""
Helper to convert all the responses coming out of a generator to a
Server-Sent Event object.
"""
async for inference_response in infer_stream:
# TODO: How should we send headers back?
# response_headers = extract_headers(inference_response)
yield ServerSentEvent(inference_response).encode()
class ModelRepositoryEndpoints:
def __init__(self, handlers: ModelRepositoryHandlers):
self._handlers = handlers
async def index(self, payload: RepositoryIndexRequest) -> RepositoryIndexResponse:
return await self._handlers.index(payload)
async def load(self, model_name: str) -> Response:
loaded = await self._handlers.load(name=model_name)
return Response(status_code=to_status_code(loaded))
async def unload(self, model_name: str) -> Response:
unloaded = await self._handlers.unload(name=model_name)
return Response(status_code=to_status_code(unloaded))