Skip to content

Commit

Permalink
enhance(client): support gradio components in api decorator (#2991)
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui authored Nov 20, 2023
1 parent 07e6c5b commit 4d2db61
Show file tree
Hide file tree
Showing 17 changed files with 190 additions and 38 deletions.
2 changes: 2 additions & 0 deletions client/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"filelock",
"fastapi",
"orjson", # for web server, e.g. sw job info --web
"starlette",
"protobuf>=3.19.0",
"types-protobuf>=3.19.0",
"lz4>=3.1.10",
Expand All @@ -45,6 +46,7 @@
"urllib3<1.27",
"pydantic<2.0.0", # current broker and fastapi lib code only work with pydantic < 2.0.0
"sortedcontainers",
"importlib_resources",
# workaround: email-validator 2.1.0 has a syntax error in python 3.7, but the email-validator is necessary for fastapi.
"email-validator <= 2.0.0; python_version < '3.8'",
]
Expand Down
118 changes: 97 additions & 21 deletions client/starwhale/api/_impl/service/service.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
from __future__ import annotations

import os
import sys
import typing as t
import functools

import pkg_resources
if sys.version_info >= (3, 9):
from importlib.resources import files
else:
from importlib_resources import files

from fastapi import FastAPI
from pydantic import BaseModel
from pydantic import Field, BaseModel
from starlette.responses import FileResponse
from starlette.staticfiles import StaticFiles

from starwhale.utils import console
from starwhale.base.models.base import SwBaseModel

from .types import ServiceType
from .types import Inputs, Outputs, ServiceType, all_components_are_gradio

STATIC_DIR_DEV = os.getenv("SW_SERVE_STATIC_DIR") or pkg_resources.resource_filename(
"starwhale", "web/ui"
STATIC_DIR_DEV = os.getenv("SW_SERVE_STATIC_DIR") or str(
files("starwhale").joinpath("web/ui")
)


Expand All @@ -26,7 +32,10 @@ class Query(BaseModel):
class Api(SwBaseModel):
func: t.Callable
uri: str
inference_type: ServiceType
inference_type: t.Optional[ServiceType] = None
# do not export inputs and outputs to spec for now
inputs: Inputs = Field(exclude=True)
outputs: Outputs = Field(exclude=True)

@staticmethod
def question_answering(func: t.Callable) -> t.Callable:
Expand All @@ -39,8 +48,15 @@ def view_func(self, ins: t.Any = None) -> t.Callable:
func = self.func
if ins is not None:
func = functools.partial(func, ins)
if self.inference_type is None:
return func
return getattr(self, self.inference_type.value)(func) # type: ignore

def all_gradio_components(self) -> bool:
if self.inference_type is not None:
return False
return all_components_are_gradio(inputs=self.inputs, outputs=self.outputs)


class ServiceSpec(SwBaseModel):
title: t.Optional[str]
Expand All @@ -54,35 +70,67 @@ def __init__(self) -> None:
self.apis: t.Dict[str, Api] = {}
self.api_within_instance_map: t.Dict[str, t.Any] = {}

def api(self, inference_type: ServiceType) -> t.Any:
def api(
self,
inference_type: ServiceType | None = None,
inputs: Inputs = None,
outputs: Outputs = None,
) -> t.Any:
def decorator(func: t.Any) -> t.Any:
self.add_api(func, func.__name__, inference_type=inference_type)
self.add_api(
func,
func.__name__,
inference_type=inference_type,
inputs=inputs,
outputs=outputs,
)
return func

return decorator

def get_spec(self) -> ServiceSpec:
return ServiceSpec(version="0.0.1", apis=list(self.apis.values()))

def add_api(self, func: t.Callable, uri: str, inference_type: ServiceType) -> None:
def add_api(
self,
func: t.Callable,
uri: str,
inference_type: ServiceType | None = None,
inputs: Inputs = None,
outputs: Outputs = None,
) -> None:
console.debug(f"add api {uri}")
if uri in self.apis:
raise ValueError(f"Duplicate api uri: {uri}")

_api = Api(func=func, uri=uri, inference_type=inference_type)
_api = Api(
func=func,
uri=f"{uri}",
inference_type=inference_type,
inputs=inputs,
outputs=outputs,
)
self.apis[uri] = _api

def add_api_instance(self, _api: Api) -> None:
self.apis[_api.uri] = _api

def serve(self, addr: str, port: int, title: t.Optional[str] = None) -> None:
"""
Default serve implementation, users can override this method
:param addr
:param port
:param title webpage title
:return: None
"""
app = FastAPI(title=title or "Starwhale Model Serving")
def serve(
self, addr: str, port: int, title: t.Optional[str] = None
) -> None: # pragma: no cover
title = title or "Starwhale Model Serving"
# check if all the api uses gradio components
# if so, use gradio to serve
# otherwise, use fastapi to serve
if all([_api.all_gradio_components() for _api in self.apis.values()]):
self._serve_gradio(addr, port, title=title)
else:
self._serve_builtin(addr, port, title=title)

def _serve_builtin(
self, addr: str, port: int, title: str
) -> None: # pragma: no cover
app = FastAPI(title=title)

@app.get("/api/spec")
def spec() -> ServiceSpec:
Expand All @@ -105,12 +153,40 @@ def index(opt: t.Any) -> FileResponse:

uvicorn.run(app, host=addr, port=port)

def _serve_gradio(
self, addr: str, port: int, title: str
) -> None: # pragma: no cover
import gradio

def api_to_component(_api: Api) -> gradio.Interface:
return gradio.Interface(
fn=_api.view_func(self.api_within_instance_map.get(_api.uri)),
inputs=_api.inputs,
outputs=_api.outputs,
title=_api.uri,
)

with gradio.blocks.Blocks(title=title) as app:
if len(self.apis) == 1:
# if only one api, use the main page
api_to_component(list(self.apis.values())[0])
else:
# one tab for each api
for _api in self.apis.values():
with gradio.Tab(label=_api.uri):
api_to_component(_api)
app.launch(server_name=addr, server_port=port)


_svc = Service()


def api(inference_type: ServiceType) -> t.Any:
return _svc.api(inference_type=inference_type)
def api(
inputs: Inputs = None,
outputs: Outputs = None,
inference_type: ServiceType | None = None,
) -> t.Any:
return _svc.api(inference_type=inference_type, inputs=inputs, outputs=outputs)


def internal_api_list() -> t.Dict[str, Api]:
Expand Down
31 changes: 31 additions & 0 deletions client/starwhale/api/_impl/service/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum
from typing import Any


class ServiceType(Enum):
Expand All @@ -9,3 +10,33 @@ class ServiceType(Enum):
TEXT_TO_AUDIO = "text_to_audio"
TEXT_TO_VIDEO = "text_to_video"
QUESTION_ANSWERING = "question_answering"


Inputs = Any
Outputs = Any


def all_components_are_gradio(
inputs: Inputs, outputs: Outputs
) -> bool: # pragma: no cover
"""Check if all components are Gradio components."""
if inputs is None and outputs is None:
return False

if not isinstance(inputs, list):
inputs = inputs is not None and [inputs] or []
if not isinstance(outputs, list):
outputs = outputs is not None and [outputs] or []

try:
import gradio
except ImportError:
gradio = None

return all(
[
gradio is not None,
all([isinstance(inp, gradio.components.Component) for inp in inputs]),
all([isinstance(out, gradio.components.Component) for out in outputs]),
]
)
4 changes: 1 addition & 3 deletions client/starwhale/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from urllib.parse import urlparse

import httpx
import pkg_resources
from fastapi import FastAPI, APIRouter
from fastapi.responses import ORJSONResponse
from typing_extensions import Protocol
Expand All @@ -15,8 +14,7 @@

from starwhale.web import user, panel, system, project, data_store
from starwhale.base.uri.instance import Instance

STATIC_DIR_DEV = pkg_resources.resource_filename("starwhale", "web/ui")
from starwhale.api._impl.service.service import STATIC_DIR_DEV


class Component(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion client/tests/data/sdk/service/default_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ def ppl(self, data: bytes, **kw: t.Any) -> t.Any:
def handler_foo(self, data: t.Any) -> t.Any:
return

@service.api(ServiceType.QUESTION_ANSWERING)
@service.api(inference_type=ServiceType.QUESTION_ANSWERING)
def cmp(self, ppl_result: t.Iterator) -> t.Any:
pass
26 changes: 26 additions & 0 deletions client/tests/sdk/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest

from starwhale.api._impl.service.types import all_components_are_gradio


@pytest.mark.skip("enable this test when starwhale support pydantic 2.0+")
def test_all_components_are_gradio():
try:
import gradio
except ImportError:
gradio = None

assert all_components_are_gradio(None, None) is False

assert all_components_are_gradio(gradio.inputs.Textbox(), None) is True
assert all_components_are_gradio([gradio.inputs.Textbox()], None) is True
assert all_components_are_gradio(None, [gradio.outputs.Label()]) is True
assert all_components_are_gradio(None, gradio.outputs.Label()) is True
assert (
all_components_are_gradio(gradio.inputs.Textbox(), gradio.outputs.Label())
is True
)
assert (
all_components_are_gradio([gradio.inputs.Textbox()], [gradio.outputs.Label()])
is True
)
2 changes: 0 additions & 2 deletions example/PennFudanPed/pfp/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import io
import os
import random

import torch
Expand Down Expand Up @@ -102,7 +101,6 @@ def cmp(ppl_result):
@api(
gradio.Image(type="filepath"),
[gradio.Image(type="pil"), gradio.Json()],
examples=[[os.path.join(os.path.dirname(__file__), "../FudanPed00001.png")]],
)
def handler(file: str):
with open(file, "rb") as f:
Expand Down
1 change: 0 additions & 1 deletion example/image-classification/models/cnn/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def _load_model(self, device):
@api(
gradio.Image(type="pil"),
gradio.Label(label="prediction"),
examples=[ROOTDIR / "kitty.jpeg"],
)
def online_eval(self, img: PILImage.Image):
buf = io.BytesIO()
Expand Down
1 change: 0 additions & 1 deletion example/image-classification/models/vit/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def predict_image(data: t.Dict) -> t.Dict:
@api(
gradio.Image(type="filepath"),
gradio.Json(title="Prediction"),
examples=[ROOT_DIR / "kitty.jpeg"],
)
def predict_image_view(file: t.Any) -> t.Any:
with open(file, "rb") as f:
Expand Down
1 change: 0 additions & 1 deletion example/image-segmentation/models/segment-anything/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def generate_mask(img: Image) -> t.List[t.Dict]:
@api(
gradio.Image(type="filepath", label="Input Image"),
[gradio.Image(type="numpy", label="Masked Image"), gradio.JSON(label="Masks Info")],
uri="/mask",
)
def generate_mask_view(file: t.Any) -> t.Any:
with open(file, "rb") as f:
Expand Down
6 changes: 6 additions & 0 deletions example/mnist/mnist/custom_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dill
import numpy as np
import torch
import gradio
from PIL import Image as PILImage
from torchvision import transforms

Expand All @@ -18,6 +19,7 @@
pass_context,
multi_classification,
)
from starwhale.api.service import api
from starwhale.base.uri.resource import Resource, ResourceType

from .model import Net
Expand Down Expand Up @@ -112,6 +114,9 @@ def batch_ppl(
output = self.model(data_tensor)
return output.argmax(1).flatten().tolist(), np.exp(output.tolist()).tolist()

@api(
inputs=gradio.Sketchpad(shape=(28, 28), image_mode="L"), outputs=gradio.Label()
)
def draw(self, data: np.ndarray) -> t.Any:
_image_array = PILImage.fromarray(data.astype("int8"), mode="L")
_image = transforms.Compose(
Expand All @@ -120,6 +125,7 @@ def draw(self, data: np.ndarray) -> t.Any:
output = self.model(torch.stack([_image]).to(self.device))
return {i: p for i, p in enumerate(np.exp(output.tolist()).tolist()[0])}

@api(inputs=gradio.File(), outputs=gradio.Label())
def upload_bin_file(self, file: t.Any) -> t.Any:
with open(file.name, "rb") as f:
data = Image(f.read(), shape=(28, 28, 1))
Expand Down
13 changes: 13 additions & 0 deletions example/mnist/mnist/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import numpy as np
import torch
import gradio
from PIL import Image as PILImage
from torchvision import transforms

from starwhale import Image, PipelineHandler, multi_classification
from starwhale.api.service import api

try:
from .model import Net
Expand Down Expand Up @@ -72,3 +74,14 @@ def _load_model(self, device: torch.device) -> Net:
model.eval()
print("load mnist model, start to inference...")
return model

@api(
inputs=gradio.Sketchpad(shape=(28, 28), image_mode="L"), outputs=gradio.Label()
)
def draw(self, data: np.ndarray) -> t.Any:
_image_array = PILImage.fromarray(data.astype("int8"), mode="L")
_image = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)(_image_array)
output = self.model(torch.stack([_image]).to(self.device))
return {i: p for i, p in enumerate(np.exp(output.tolist()).tolist()[0])}
Loading

0 comments on commit 4d2db61

Please sign in to comment.