Skip to content

Commit

Permalink
chore(console): online eval in server instance support examples
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui committed Feb 16, 2023
1 parent 28a4741 commit 9443725
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 28 deletions.
73 changes: 53 additions & 20 deletions client/starwhale/api/_impl/service.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

import os
import typing as t
import functools
from dataclasses import dataclass
Expand Down Expand Up @@ -28,10 +31,25 @@ def view_func(self, ins: t.Any = None) -> t.Callable:
return func


@dataclass
class Hijack:
"""
Hijack options for online evaluation, useless for local usage.
"""

# if hijack the submit button js logic
submit: bool = False
# the resource path serving on the server side
# used for example resource render for console
resource_path: t.Optional[str] = None


class Service:
def __init__(self) -> None:
def __init__(self, hijack: t.Optional[Hijack] = None) -> None:
self.apis: t.Dict[str, Api] = {}
self.api_instance: t.Any = None
self.example_resources: t.List[str] = []
self.hijack = hijack

# TODO: support function as input and output
def api(
Expand Down Expand Up @@ -70,19 +88,20 @@ def get_spec(self) -> t.Any:
# fast path
if not self.apis:
return {}
# hijack_submit set to True for generating config for console (On-Premises)
server = self._gen_gradio_server(hijack_submit=True)
# hijack set to True for generating config for console (On-Premises)
server = self._gen_gradio_server()
return server.get_config_file()

def get_openapi_spec(self) -> t.Any:
server = self._gen_gradio_server(hijack_submit=True)
server = self._gen_gradio_server()
return server.app.openapi()

def _render_api(self, _api: Api, hijack_submit: bool) -> None:
def _render_api(self, _api: Api) -> None:
import gradio
from gradio.components import File, Image, Video, Changeable, IOComponent

js_func: t.Optional[str] = None
if hijack_submit:
if self.hijack and self.hijack.submit:
js_func = "async(...x) => { typeof wait === 'function' && await wait(); return x; }"
with gradio.Row():
with gradio.Column():
Expand All @@ -91,33 +110,47 @@ def _render_api(self, _api: Api, hijack_submit: bool) -> None:
comp = gradio.components.get_component_instance(
i, render=False
).render()
if isinstance(comp, gradio.components.Changeable):
if isinstance(comp, Changeable):
comp.change(fn=fn, inputs=i, outputs=_api.output, _js=js_func)
if _api.examples:
gradio.Examples(
# do not serve the useless examples in server instances
# the console will render them even the models are not serving
if _api.examples and not in_production():
example = gradio.Examples(
examples=_api.examples,
inputs=[
i
for i in _api.input
if isinstance(i, gradio.components.IOComponent)
],
fn=fn,
inputs=[i for i in _api.input if isinstance(i, IOComponent)],
)
if any(
isinstance(i, (File, Image, Video))
for i in example.dataset.components
):
# examples should be a list of file path
# use flatten list
to_copy = [i for j in example.examples for i in j]
self.example_resources.extend(to_copy)
# change example resource path for online evaluation
# e.g. /path/to/example.png -> /workdir/src/.starwhale/examples/example.png
if self.hijack and self.hijack.resource_path:
for i in range(len(example.dataset.samples)):
for j in range(len(example.dataset.samples[i])):
origin = example.dataset.samples[i][j]
if origin in to_copy:
name = os.path.basename(origin)
example.dataset.samples[i][j] = os.path.join(
self.hijack.resource_path, name
)
with gradio.Column():
for i in _api.output:
gradio.components.get_component_instance(i, render=False).render()

def _gen_gradio_server(
self, hijack_submit: bool, title: t.Optional[str] = None
) -> "Blocks":
def _gen_gradio_server(self, title: t.Optional[str] = None) -> Blocks:
import gradio

apis = self.apis.values()
with gradio.Blocks() as app:
with gradio.Tabs():
for _api in apis:
with gradio.TabItem(label=_api.uri):
self._render_api(_api, hijack_submit)
self._render_api(_api)
app.title = title or "starwhale"
return app

Expand All @@ -129,7 +162,7 @@ def serve(self, addr: str, port: int, title: t.Optional[str] = None) -> None:
:param title webpage title
:return: None
"""
server = self._gen_gradio_server(hijack_submit=in_production(), title=title)
server = self._gen_gradio_server(title=title)
server.launch(server_name=addr, server_port=port)


Expand Down
1 change: 1 addition & 0 deletions client/starwhale/consts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
LOCAL_CONFIG_VERSION = "2.0"

SW_AUTO_DIRNAME = ".starwhale"
SW_EVALUATION_EXAMPLE_DIR = "examples"

# used by the versions before 2.0
# SW_LOCAL_STORAGE = HOMEDIR / ".cache/starwhale"
Expand Down
44 changes: 38 additions & 6 deletions client/starwhale/core/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import copy
import json
import shutil
import typing as t
import tarfile
from abc import ABCMeta
Expand All @@ -25,10 +26,12 @@
SWMP_SRC_FNAME,
DefaultYAMLName,
EvalHandlerType,
SW_AUTO_DIRNAME,
DEFAULT_PAGE_IDX,
DEFAULT_PAGE_SIZE,
DEFAULT_COPY_WORKERS,
DEFAULT_MANIFEST_NAME,
SW_EVALUATION_EXAMPLE_DIR,
DEFAULT_EVALUATION_PIPELINE,
DEFAULT_EVALUATION_JOBS_FNAME,
DEFAULT_STARWHALE_API_VERSION,
Expand Down Expand Up @@ -57,6 +60,7 @@
from starwhale.core.eval.store import EvaluationStorage
from starwhale.core.model.copy import ModelCopy
from starwhale.core.model.store import ModelStorage
from starwhale.api._impl.service import Hijack
from starwhale.core.job.scheduler import Scheduler


Expand Down Expand Up @@ -230,19 +234,40 @@ def remove_tags(self, tags: t.List[str], ignore_errors: bool = False) -> None:
self.tag.remove(tags, ignore_errors)

def _gen_steps(self, typ: str, ppl: str, workdir: Path) -> None:
d = self.store.src_dir
svc = self._get_service(ppl, workdir)
_f = d / DEFAULT_EVALUATION_SVC_META_FNAME
ensure_file(_f, json.dumps(svc.get_spec(), indent=4))
if typ == EvalHandlerType.DEFAULT:
# use default
ppl = DEFAULT_EVALUATION_PIPELINE
_f = d / DEFAULT_EVALUATION_JOBS_FNAME
_f = self.store.src_dir / DEFAULT_EVALUATION_JOBS_FNAME
logger.debug(f"job ppl path:{_f}, ppl is {ppl}")
Parser.generate_job_yaml(ppl, workdir, _f)

def _gen_model_serving(self, ppl: str, workdir: Path) -> None:
rc_dir = (
f"{self.store.src_dir_name}/{SW_AUTO_DIRNAME}/{SW_EVALUATION_EXAMPLE_DIR}"
)
# render spec
svc = self._get_service(ppl, workdir, hijack=Hijack(True, rc_dir))
file = self.store.src_dir / DEFAULT_EVALUATION_SVC_META_FNAME
ensure_file(file, json.dumps(svc.get_spec(), indent=4))

if len(svc.example_resources) == 0:
return

# check duplicate file names, do not support using examples with same name in different dir
names = set([os.path.basename(i) for i in svc.example_resources])
if len(names) != len(svc.example_resources):
raise NoSupportError("duplicate file names in examples")

# copy example resources for online evaluation in server instance
dst = self.store.src_dir / SW_AUTO_DIRNAME / SW_EVALUATION_EXAMPLE_DIR
ensure_dir(dst)
for f in svc.example_resources:
shutil.copy2(f, dst)

@staticmethod
def _get_service(module: str, pkg: Path) -> Service:
def _get_service(
module: str, pkg: Path, hijack: t.Optional[Hijack] = None
) -> Service:
module, _, attr = module.partition(":")
m = load_module(module, pkg)
apis = dict()
Expand Down Expand Up @@ -284,6 +309,7 @@ def _get_service(module: str, pkg: Path) -> Service:
for api in apis.values():
svc.add_api_instance(api)
svc.api_instance = ins
svc.hijack = hijack
return svc

@classmethod
Expand Down Expand Up @@ -597,6 +623,12 @@ def buildImpl(self, workdir: Path, **kw: t.Any) -> None: # type: ignore[overrid
workdir=workdir,
),
),
(
self._gen_model_serving,
10,
"generate model serving",
dict(ppl=_model_config.run.handler, workdir=workdir),
),
(
self._make_meta_tar,
20,
Expand Down
2 changes: 2 additions & 0 deletions client/tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_build_workflow(

svc = MagicMock(spec=Service)
svc.get_spec.return_value = {}
svc.example_resources = []
m_get_service.return_value = svc

model_uri = URI(self.name, expected_type=URIType.MODEL)
Expand Down Expand Up @@ -456,6 +457,7 @@ def test_build_with_custom_config_file(

svc = MagicMock(spec=Service)
svc.get_spec.return_value = {}
svc.example_resources = []
m_get_service.return_value = svc

name = "foo"
Expand Down
23 changes: 23 additions & 0 deletions console/src/domain/project/schemas/gradio.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
export interface ICompomentProps {
components?: string[]
samples?: string[][]
}

export interface IComponent {
id: number
type: string
props?: ICompomentProps
}

export interface IDependency {
targets: number[]
trigger: string
backend_fn: boolean
js: string
}

export interface IGradioConfig {
version: string
components: IComponent[]
dependencies: IDependency[]
}
60 changes: 58 additions & 2 deletions console/src/pages/Project/OnlineEval.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ import css from '@/assets/GradioWidget/es/style.css'
// eslint-disable-next-line import/extensions
import '@/assets/GradioWidget/es/app.es.js'
import qs from 'qs'
import { IComponent, IGradioConfig } from '@project/schemas/gradio'

declare global {
interface Window {
// eslint-disable-next-line @typescript-eslint/ban-types
wait: Function | null
fetchExample: Function | null
gradio_config: any
}
}
Expand Down Expand Up @@ -96,7 +98,7 @@ export default function OnlineEval() {
if (!resp.data?.baseUri) return

// eslint-disable-next-line no-restricted-globals
window.gradio_config.root = `http://${location.host}${resp.data?.baseUri}/run/`
window.gradio_config.root = `${location.protocol}//${location.host}${resp.data?.baseUri}/run/`

await new Promise((resolve) => {
const check = () => {
Expand All @@ -120,6 +122,18 @@ export default function OnlineEval() {
window.wait = null
}
}, [formRef, projectId])
useEffect(() => {
if (window.fetchExample) return undefined
window.fetchExample = async (url: string): Promise<string> => {
const { data } = await axios.get(url, { responseType: 'arraybuffer' })
const base64 = btoa(new Uint8Array(data).reduce((i, byte) => i + String.fromCharCode(byte), ''))
return `data:;base64,${base64}`
}

return () => {
window.fetchExample = null
}
}, [formRef, projectId])

useEffect(() => {
if (modelInfo.isSuccess || modelVersionInfo.isSuccess) {
Expand All @@ -129,14 +143,56 @@ export default function OnlineEval() {
return
}

const api = `/api/v1/project/${project?.name}/model/${modelName}/version/${versionName}/file`
fetch(
`/api/v1/project/${project?.name}/model/${modelName}/version/${versionName}/file?${qs.stringify({
`${api}?${qs.stringify({
Authorization: getToken(),
partName: 'svc.json',
signature: '',
})}`
)
.then((res) => res.json())
.then((conf) => {
// patch examples params
const gradioConfig = conf as IGradioConfig
const datasets: IComponent[] = []
gradioConfig.components.forEach((compnent) => {
if (compnent.type === 'dataset') {
datasets.push(compnent)
let append = false
for (let i = 0; i < compnent.props?.components?.length; i++) {
const tp = compnent.props?.components[i]
if (tp && ['image', 'video', 'file'].includes(tp)) {
append = true
break
}
}
if (append && compnent.props?.samples) {
compnent.props.samples = compnent.props.samples?.map((parts) => {
return parts.map(
(i) => `${api}?${qs.stringify({ Authorization: getToken(), path: i })}`
)
})
}
}
})
datasets.forEach((ds) => {
gradioConfig.dependencies.forEach((dep) => {
if (!dep.targets.includes(ds.id)) {
return
}
// do not request the builtin backend, the model service is not ready
dep.backend_fn = false
const fn = `async function (...x) {
console.log(x)
return window.fetchExample('${ds.props?.samples?.[0]?.[0]}')
}`
dep.js = fn
})
})

return conf
})
.then((data) => {
window.gradio_config = data
window.gradio_config.css = css
Expand Down

0 comments on commit 9443725

Please sign in to comment.