Skip to content

Commit

Permalink
chore(sdk): hijack submit on premises
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui committed Jan 3, 2023
1 parent cf594ad commit 016fced
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 25 deletions.
59 changes: 42 additions & 17 deletions client/starwhale/api/_impl/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import gradio
from gradio.components import Component

from starwhale.utils import in_production

Input = t.Union[Component, t.List[Component]]
Output = t.Union[Component, t.List[Component]]

Expand Down Expand Up @@ -40,6 +42,10 @@ def decorator(func: t.Any) -> t.Any:
def add_api(
self, input_: Input, output: Output, func: t.Callable, uri: str
) -> None:
if not isinstance(input_, list):
input_ = [input_]
if not isinstance(output, list):
output = [output]
_api = Api(input_, output, func, uri)
self.apis[uri] = _api

Expand All @@ -50,35 +56,54 @@ def get_spec(self) -> t.Any:
# fast path
if not self.apis:
return {}
server = self._gen_gradio_server()
# hijack_submit set to True for generating config for console (On-Premises)
server = self._gen_gradio_server(hijack_submit=True)
return server.get_config_file()

def get_openapi_spec(self) -> t.Any:
server = self._gen_gradio_server()
server = self._gen_gradio_server(hijack_submit=True)
return server.app.openapi()

def _gen_gradio_server(self) -> gradio.Blocks:
def _render_api(self, _api: Api, hijack_submit: bool) -> None:
js_func = "x => { wait(); return x; }" if hijack_submit else ""
with gradio.Row():
with gradio.Column():
for i in _api.input:
comp = gradio.components.get_component_instance(
i, render=False
).render()
if isinstance(comp, gradio.components.Changeable):
comp.change(
_api.view_func(self.api_instance),
i,
_api.output,
_js=js_func,
)
with gradio.Column():
for i in _api.output:
gradio.components.get_component_instance(i, render=False).render()

def _gen_gradio_server(
self, hijack_submit: bool, title: t.Optional[str] = None
) -> gradio.Blocks:
apis = self.apis.values()
return gradio.TabbedInterface(
interface_list=[
gradio.Interface(
fn=api_.view_func(self.api_instance),
inputs=api_.input,
outputs=api_.output,
)
for api_ in apis
],
tab_names=[api_.uri for api_ in apis],
)

def serve(self, addr: str, port: int) -> None:
with gradio.Blocks() as app:
with gradio.Tabs():
for _api in apis:
with gradio.TabItem(label=_api.uri):
self._render_api(_api, hijack_submit)
app.title = title or "starwhale"
return app

def serve(self, addr: str, port: int, title: t.Optional[str] = None) -> None:
"""
Default serve implementation, users can override this method
:param addr
:param port
:param title webpage title
:return: None
"""
server = self._gen_gradio_server()
server = self._gen_gradio_server(hijack_submit=in_production(), title=title)
server.launch(server_name=addr, server_port=port)


Expand Down
2 changes: 1 addition & 1 deletion client/starwhale/core/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@ def serve(
) -> None:
_model_config = cls.load_model_config(workdir / model_yaml)
svc = cls._get_service(_model_config.run.handler, workdir)
svc.serve(host, port)
svc.serve(host, port, _model_config.name)


class CloudModel(CloudBundleModelMixin, Model):
Expand Down
11 changes: 5 additions & 6 deletions client/tests/sdk/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ def test_custom_class(self):
assert list(svc.apis.keys()) == ["foo", "bar"]

for i in svc.apis.values():
assert i.input.__class__.__name__ == "CustomInput"
assert i.output.__class__.__name__ == "CustomOutput"
assert i.input.__class__.__name__ == "list"
assert i.input[0].__class__.__name__ == "CustomInput"
assert i.output.__class__.__name__ == "list"
assert i.output[0].__class__.__name__ == "CustomOutput"

spec = svc.get_spec()
assert list(filter(bool, [i["api_name"] for i in spec["dependencies"]])) == [
"predict",
"predict_1",
]
assert len(spec["dependencies"]) == 2

def test_default_class(self):
svc = StandaloneModel._get_service("default_class:MyDefaultClass", self.root)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ private void deploy(RuntimeVersionEntity runtime, ModelVersionEntity model, Stri
"SW_PYPI_INDEX_URL", runTimeProperties.getPypi().getIndexUrl(),
"SW_PYPI_EXTRA_INDEX_URL", runTimeProperties.getPypi().getExtraIndexUrl(),
"SW_PYPI_TRUSTED_HOST", runTimeProperties.getPypi().getTrustedHost(),
"SW_MODEL_SERVING_BASE_URI", String.format("/gateway/%s/%d", MODEL_SERVICE_PREFIX, id)
"SW_MODEL_SERVING_BASE_URI", String.format("/gateway/%s/%d", MODEL_SERVICE_PREFIX, id),
// see https://github.com/star-whale/starwhale/blob/c1d85ab98045a95ab3c75a89e7af56a17e966714/client/starwhale/utils/__init__.py#L51
"SW_PRODUCTION", "1"
);
var ss = k8sJobTemplate.renderModelServingOrch(envs, image, name);
k8sClient.deployStatefulSet(ss);
Expand Down

0 comments on commit 016fced

Please sign in to comment.