diff --git a/client/starwhale/api/_impl/service.py b/client/starwhale/api/_impl/service.py index 3bd72cd1cc..3f45084b35 100644 --- a/client/starwhale/api/_impl/service.py +++ b/client/starwhale/api/_impl/service.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import os import typing as t import functools from dataclasses import dataclass @@ -28,10 +31,25 @@ def view_func(self, ins: t.Any = None) -> t.Callable: return func +@dataclass +class Hijack: + """ + Hijack options for online evaluation, useless for local usage. + """ + + # if hijack the submit button js logic + submit: bool = False + # the resource path serving on the server side + # used for example resource render for console + resource_path: t.Optional[str] = None + + class Service: - def __init__(self) -> None: + def __init__(self, hijack: t.Optional[Hijack] = None) -> None: self.apis: t.Dict[str, Api] = {} self.api_instance: t.Any = None + self.example_resources: t.List[str] = [] + self.hijack = hijack # TODO: support function as input and output def api( @@ -70,19 +88,20 @@ def get_spec(self) -> t.Any: # fast path if not self.apis: return {} - # hijack_submit set to True for generating config for console (On-Premises) - server = self._gen_gradio_server(hijack_submit=True) + # hijack set to True for generating config for console (On-Premises) + server = self._gen_gradio_server() return server.get_config_file() def get_openapi_spec(self) -> t.Any: - server = self._gen_gradio_server(hijack_submit=True) + server = self._gen_gradio_server() return server.app.openapi() - def _render_api(self, _api: Api, hijack_submit: bool) -> None: + def _render_api(self, _api: Api) -> None: import gradio + from gradio.components import File, Image, Video, Changeable, IOComponent js_func: t.Optional[str] = None - if hijack_submit: + if self.hijack and self.hijack.submit: js_func = "async(...x) => { typeof wait === 'function' && await wait(); return x; }" with gradio.Row(): with gradio.Column(): @@ -91,25 +110,39 @@ def _render_api(self, _api: Api, hijack_submit: bool) -> None: comp = gradio.components.get_component_instance( i, render=False ).render() - if isinstance(comp, gradio.components.Changeable): + if isinstance(comp, Changeable): comp.change(fn=fn, inputs=i, outputs=_api.output, _js=js_func) - if _api.examples: - gradio.Examples( + # do not serve the useless examples in server instances + # the console will render them even the models are not serving + if _api.examples and not in_production(): + example = gradio.Examples( examples=_api.examples, - inputs=[ - i - for i in _api.input - if isinstance(i, gradio.components.IOComponent) - ], - fn=fn, + inputs=[i for i in _api.input if isinstance(i, IOComponent)], ) + if any( + isinstance(i, (File, Image, Video)) + for i in example.dataset.components + ): + # examples should be a list of file path + # use flatten list + to_copy = [i for j in example.examples for i in j] + self.example_resources.extend(to_copy) + # change example resource path for online evaluation + # e.g. /path/to/example.png -> /workdir/src/.starwhale/examples/example.png + if self.hijack and self.hijack.resource_path: + for i in range(len(example.dataset.samples)): + for j in range(len(example.dataset.samples[i])): + origin = example.dataset.samples[i][j] + if origin in to_copy: + name = os.path.basename(origin) + example.dataset.samples[i][j] = os.path.join( + self.hijack.resource_path, name + ) with gradio.Column(): for i in _api.output: gradio.components.get_component_instance(i, render=False).render() - def _gen_gradio_server( - self, hijack_submit: bool, title: t.Optional[str] = None - ) -> "Blocks": + def _gen_gradio_server(self, title: t.Optional[str] = None) -> Blocks: import gradio apis = self.apis.values() @@ -117,7 +150,7 @@ def _gen_gradio_server( with gradio.Tabs(): for _api in apis: with gradio.TabItem(label=_api.uri): - self._render_api(_api, hijack_submit) + self._render_api(_api) app.title = title or "starwhale" return app @@ -129,7 +162,7 @@ def serve(self, addr: str, port: int, title: t.Optional[str] = None) -> None: :param title webpage title :return: None """ - server = self._gen_gradio_server(hijack_submit=in_production(), title=title) + server = self._gen_gradio_server(title=title) server.launch(server_name=addr, server_port=port) diff --git a/client/starwhale/consts/__init__.py b/client/starwhale/consts/__init__.py index d807db8190..0c88b029fe 100644 --- a/client/starwhale/consts/__init__.py +++ b/client/starwhale/consts/__init__.py @@ -25,6 +25,7 @@ LOCAL_CONFIG_VERSION = "2.0" SW_AUTO_DIRNAME = ".starwhale" +SW_EVALUATION_EXAMPLE_DIR = "examples" # used by the versions before 2.0 # SW_LOCAL_STORAGE = HOMEDIR / ".cache/starwhale" diff --git a/client/starwhale/core/model/model.py b/client/starwhale/core/model/model.py index 432d2ded9e..287cef1868 100644 --- a/client/starwhale/core/model/model.py +++ b/client/starwhale/core/model/model.py @@ -3,6 +3,7 @@ import os import copy import json +import shutil import typing as t import tarfile from abc import ABCMeta @@ -25,10 +26,12 @@ SWMP_SRC_FNAME, DefaultYAMLName, EvalHandlerType, + SW_AUTO_DIRNAME, DEFAULT_PAGE_IDX, DEFAULT_PAGE_SIZE, DEFAULT_COPY_WORKERS, DEFAULT_MANIFEST_NAME, + SW_EVALUATION_EXAMPLE_DIR, DEFAULT_EVALUATION_PIPELINE, DEFAULT_EVALUATION_JOBS_FNAME, DEFAULT_STARWHALE_API_VERSION, @@ -57,6 +60,7 @@ from starwhale.core.eval.store import EvaluationStorage from starwhale.core.model.copy import ModelCopy from starwhale.core.model.store import ModelStorage +from starwhale.api._impl.service import Hijack from starwhale.core.job.scheduler import Scheduler @@ -230,19 +234,40 @@ def remove_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None: self.tag.remove(tags, ignore_errors) def _gen_steps(self, typ: str, ppl: str, workdir: Path) -> None: - d = self.store.src_dir - svc = self._get_service(ppl, workdir) - _f = d / DEFAULT_EVALUATION_SVC_META_FNAME - ensure_file(_f, json.dumps(svc.get_spec(), indent=4)) if typ == EvalHandlerType.DEFAULT: # use default ppl = DEFAULT_EVALUATION_PIPELINE - _f = d / DEFAULT_EVALUATION_JOBS_FNAME + _f = self.store.src_dir / DEFAULT_EVALUATION_JOBS_FNAME logger.debug(f"job ppl path:{_f}, ppl is {ppl}") Parser.generate_job_yaml(ppl, workdir, _f) + def _gen_model_serving(self, ppl: str, workdir: Path) -> None: + rc_dir = ( + f"{self.store.src_dir_name}/{SW_AUTO_DIRNAME}/{SW_EVALUATION_EXAMPLE_DIR}" + ) + # render spec + svc = self._get_service(ppl, workdir, hijack=Hijack(True, rc_dir)) + file = self.store.src_dir / DEFAULT_EVALUATION_SVC_META_FNAME + ensure_file(file, json.dumps(svc.get_spec(), indent=4)) + + if len(svc.example_resources) == 0: + return + + # check duplicate file names, do not support using examples with same name in different dir + names = set([os.path.basename(i) for i in svc.example_resources]) + if len(names) != len(svc.example_resources): + raise NoSupportError("duplicate file names in examples") + + # copy example resources for online evaluation in server instance + dst = self.store.src_dir / SW_AUTO_DIRNAME / SW_EVALUATION_EXAMPLE_DIR + ensure_dir(dst) + for f in svc.example_resources: + shutil.copy2(f, dst) + @staticmethod - def _get_service(module: str, pkg: Path) -> Service: + def _get_service( + module: str, pkg: Path, hijack: t.Optional[Hijack] = None + ) -> Service: module, _, attr = module.partition(":") m = load_module(module, pkg) apis = dict() @@ -284,6 +309,7 @@ def _get_service(module: str, pkg: Path) -> Service: for api in apis.values(): svc.add_api_instance(api) svc.api_instance = ins + svc.hijack = hijack return svc @classmethod @@ -597,6 +623,12 @@ def buildImpl(self, workdir: Path, **kw: t.Any) -> None: # type: ignore[overrid workdir=workdir, ), ), + ( + self._gen_model_serving, + 10, + "generate model serving", + dict(ppl=_model_config.run.handler, workdir=workdir), + ), ( self._make_meta_tar, 20, diff --git a/client/tests/core/test_model.py b/client/tests/core/test_model.py index e44718faa7..996ef16ceb 100644 --- a/client/tests/core/test_model.py +++ b/client/tests/core/test_model.py @@ -79,6 +79,7 @@ def test_build_workflow( svc = MagicMock(spec=Service) svc.get_spec.return_value = {} + svc.example_resources = [] m_get_service.return_value = svc model_uri = URI(self.name, expected_type=URIType.MODEL) @@ -456,6 +457,7 @@ def test_build_with_custom_config_file( svc = MagicMock(spec=Service) svc.get_spec.return_value = {} + svc.example_resources = [] m_get_service.return_value = svc name = "foo" diff --git a/console/src/domain/project/schemas/gradio.ts b/console/src/domain/project/schemas/gradio.ts new file mode 100644 index 0000000000..91adce239a --- /dev/null +++ b/console/src/domain/project/schemas/gradio.ts @@ -0,0 +1,23 @@ +export interface ICompomentProps { + components?: string[] + samples?: string[][] +} + +export interface IComponent { + id: number + type: string + props?: ICompomentProps +} + +export interface IDependency { + targets: number[] + trigger: string + backend_fn: boolean + js: string +} + +export interface IGradioConfig { + version: string + components: IComponent[] + dependencies: IDependency[] +} diff --git a/console/src/pages/Project/OnlineEval.tsx b/console/src/pages/Project/OnlineEval.tsx index ad0324864e..fa9ebf10b2 100644 --- a/console/src/pages/Project/OnlineEval.tsx +++ b/console/src/pages/Project/OnlineEval.tsx @@ -18,11 +18,13 @@ import css from '@/assets/GradioWidget/es/style.css' // eslint-disable-next-line import/extensions import '@/assets/GradioWidget/es/app.es.js' import qs from 'qs' +import { IComponent, IGradioConfig } from '@project/schemas/gradio' declare global { interface Window { // eslint-disable-next-line @typescript-eslint/ban-types wait: Function | null + fetchExample: Function | null gradio_config: any } } @@ -96,7 +98,7 @@ export default function OnlineEval() { if (!resp.data?.baseUri) return // eslint-disable-next-line no-restricted-globals - window.gradio_config.root = `http://${location.host}${resp.data?.baseUri}/run/` + window.gradio_config.root = `${location.protocol}//${location.host}${resp.data?.baseUri}/run/` await new Promise((resolve) => { const check = () => { @@ -120,6 +122,18 @@ export default function OnlineEval() { window.wait = null } }, [formRef, projectId]) + useEffect(() => { + if (window.fetchExample) return undefined + window.fetchExample = async (url: string): Promise => { + const { data } = await axios.get(url, { responseType: 'arraybuffer' }) + const base64 = btoa(new Uint8Array(data).reduce((i, byte) => i + String.fromCharCode(byte), '')) + return `data:;base64,${base64}` + } + + return () => { + window.fetchExample = null + } + }, [formRef, projectId]) useEffect(() => { if (modelInfo.isSuccess || modelVersionInfo.isSuccess) { @@ -129,14 +143,56 @@ export default function OnlineEval() { return } + const api = `/api/v1/project/${project?.name}/model/${modelName}/version/${versionName}/file` fetch( - `/api/v1/project/${project?.name}/model/${modelName}/version/${versionName}/file?${qs.stringify({ + `${api}?${qs.stringify({ Authorization: getToken(), partName: 'svc.json', signature: '', })}` ) .then((res) => res.json()) + .then((conf) => { + // patch examples params + const gradioConfig = conf as IGradioConfig + const datasets: IComponent[] = [] + gradioConfig.components.forEach((compnent) => { + if (compnent.type === 'dataset') { + datasets.push(compnent) + let append = false + for (let i = 0; i < compnent.props?.components?.length; i++) { + const tp = compnent.props?.components[i] + if (tp && ['image', 'video', 'file'].includes(tp)) { + append = true + break + } + } + if (append && compnent.props?.samples) { + compnent.props.samples = compnent.props.samples?.map((parts) => { + return parts.map( + (i) => `${api}?${qs.stringify({ Authorization: getToken(), path: i })}` + ) + }) + } + } + }) + datasets.forEach((ds) => { + gradioConfig.dependencies.forEach((dep) => { + if (!dep.targets.includes(ds.id)) { + return + } + // do not request the builtin backend, the model service is not ready + dep.backend_fn = false + const fn = `async function (...x) { + console.log(x) + return window.fetchExample('${ds.props?.samples?.[0]?.[0]}') + }` + dep.js = fn + }) + }) + + return conf + }) .then((data) => { window.gradio_config = data window.gradio_config.css = css