From 016fceda9861cd79e62492a924b110ad3d7113d4 Mon Sep 17 00:00:00 2001 From: Jialei Date: Tue, 3 Jan 2023 10:19:25 +0800 Subject: [PATCH] chore(sdk): hijack submit on premises --- client/starwhale/api/_impl/service.py | 59 +++++++++++++------ client/starwhale/core/model/model.py | 2 +- client/tests/sdk/test_service.py | 11 ++-- .../mlops/domain/job/ModelServingService.java | 4 +- 4 files changed, 51 insertions(+), 25 deletions(-) diff --git a/client/starwhale/api/_impl/service.py b/client/starwhale/api/_impl/service.py index a2df5e63dc..3c6635c5ad 100644 --- a/client/starwhale/api/_impl/service.py +++ b/client/starwhale/api/_impl/service.py @@ -5,6 +5,8 @@ import gradio from gradio.components import Component +from starwhale.utils import in_production + Input = t.Union[Component, t.List[Component]] Output = t.Union[Component, t.List[Component]] @@ -40,6 +42,10 @@ def decorator(func: t.Any) -> t.Any: def add_api( self, input_: Input, output: Output, func: t.Callable, uri: str ) -> None: + if not isinstance(input_, list): + input_ = [input_] + if not isinstance(output, list): + output = [output] _api = Api(input_, output, func, uri) self.apis[uri] = _api @@ -50,35 +56,54 @@ def get_spec(self) -> t.Any: # fast path if not self.apis: return {} - server = self._gen_gradio_server() + # hijack_submit set to True for generating config for console (On-Premises) + server = self._gen_gradio_server(hijack_submit=True) return server.get_config_file() def get_openapi_spec(self) -> t.Any: - server = self._gen_gradio_server() + server = self._gen_gradio_server(hijack_submit=True) return server.app.openapi() - def _gen_gradio_server(self) -> gradio.Blocks: + def _render_api(self, _api: Api, hijack_submit: bool) -> None: + js_func = "x => { wait(); return x; }" if hijack_submit else "" + with gradio.Row(): + with gradio.Column(): + for i in _api.input: + comp = gradio.components.get_component_instance( + i, render=False + ).render() + if isinstance(comp, gradio.components.Changeable): + comp.change( + _api.view_func(self.api_instance), + i, + _api.output, + _js=js_func, + ) + with gradio.Column(): + for i in _api.output: + gradio.components.get_component_instance(i, render=False).render() + + def _gen_gradio_server( + self, hijack_submit: bool, title: t.Optional[str] = None + ) -> gradio.Blocks: apis = self.apis.values() - return gradio.TabbedInterface( - interface_list=[ - gradio.Interface( - fn=api_.view_func(self.api_instance), - inputs=api_.input, - outputs=api_.output, - ) - for api_ in apis - ], - tab_names=[api_.uri for api_ in apis], - ) - - def serve(self, addr: str, port: int) -> None: + with gradio.Blocks() as app: + with gradio.Tabs(): + for _api in apis: + with gradio.TabItem(label=_api.uri): + self._render_api(_api, hijack_submit) + app.title = title or "starwhale" + return app + + def serve(self, addr: str, port: int, title: t.Optional[str] = None) -> None: """ Default serve implementation, users can override this method :param addr :param port + :param title webpage title :return: None """ - server = self._gen_gradio_server() + server = self._gen_gradio_server(hijack_submit=in_production(), title=title) server.launch(server_name=addr, server_port=port) diff --git a/client/starwhale/core/model/model.py b/client/starwhale/core/model/model.py index e2cd1e8039..79df774996 100644 --- a/client/starwhale/core/model/model.py +++ b/client/starwhale/core/model/model.py @@ -721,7 +721,7 @@ def serve( ) -> None: _model_config = cls.load_model_config(workdir / model_yaml) svc = cls._get_service(_model_config.run.handler, workdir) - svc.serve(host, port) + svc.serve(host, port, _model_config.name) class CloudModel(CloudBundleModelMixin, Model): diff --git a/client/tests/sdk/test_service.py b/client/tests/sdk/test_service.py index 9d4dc5bdb3..c119937bb7 100644 --- a/client/tests/sdk/test_service.py +++ b/client/tests/sdk/test_service.py @@ -18,14 +18,13 @@ def test_custom_class(self): assert list(svc.apis.keys()) == ["foo", "bar"] for i in svc.apis.values(): - assert i.input.__class__.__name__ == "CustomInput" - assert i.output.__class__.__name__ == "CustomOutput" + assert i.input.__class__.__name__ == "list" + assert i.input[0].__class__.__name__ == "CustomInput" + assert i.output.__class__.__name__ == "list" + assert i.output[0].__class__.__name__ == "CustomOutput" spec = svc.get_spec() - assert list(filter(bool, [i["api_name"] for i in spec["dependencies"]])) == [ - "predict", - "predict_1", - ] + assert len(spec["dependencies"]) == 2 def test_default_class(self): svc = StandaloneModel._get_service("default_class:MyDefaultClass", self.root) diff --git a/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java b/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java index abab3f0777..6723a4cc16 100644 --- a/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java +++ b/server/controller/src/main/java/ai/starwhale/mlops/domain/job/ModelServingService.java @@ -171,7 +171,9 @@ private void deploy(RuntimeVersionEntity runtime, ModelVersionEntity model, Stri "SW_PYPI_INDEX_URL", runTimeProperties.getPypi().getIndexUrl(), "SW_PYPI_EXTRA_INDEX_URL", runTimeProperties.getPypi().getExtraIndexUrl(), "SW_PYPI_TRUSTED_HOST", runTimeProperties.getPypi().getTrustedHost(), - "SW_MODEL_SERVING_BASE_URI", String.format("/gateway/%s/%d", MODEL_SERVICE_PREFIX, id) + "SW_MODEL_SERVING_BASE_URI", String.format("/gateway/%s/%d", MODEL_SERVICE_PREFIX, id), + // see https://github.com/star-whale/starwhale/blob/c1d85ab98045a95ab3c75a89e7af56a17e966714/client/starwhale/utils/__init__.py#L51 + "SW_PRODUCTION", "1" ); var ss = k8sJobTemplate.renderModelServingOrch(envs, image, name); k8sClient.deployStatefulSet(ss);