Skip to content

Commit

Permalink
refactor(client): use auto generated arg_types
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui committed Dec 19, 2023
1 parent b19f8b2 commit 30cc913
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 30 deletions.
17 changes: 3 additions & 14 deletions client/starwhale/api/_impl/service/types/llm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from __future__ import annotations

import inspect
from typing import Any, Dict, List, Callable, Optional
from typing import Any, List, Callable, Optional

from pydantic import BaseModel

from starwhale.base.client.models.models import (
ComponentValueSpecInt,
ComponentSpecValueType,
ComponentValueSpecFloat,
)

from .types import ServiceType
from .types import ServiceType, generate_type_definition


class MessageItem(BaseModel):
Expand All @@ -31,17 +30,7 @@ class Query(BaseModel):
class LLMChat(ServiceType):
name = "llm_chat"
args = {}

# TODO use pydantic model annotations generated arg_types
arg_types: Dict[str, ComponentSpecValueType] = {
"user_input": ComponentSpecValueType.string,
"history": ComponentSpecValueType.list, # list of Message
"top_k": ComponentSpecValueType.int,
"top_p": ComponentSpecValueType.float,
"temperature": ComponentSpecValueType.float,
"max_new_tokens": ComponentSpecValueType.int,
}

arg_types = generate_type_definition(Query)
Message = MessageItem

def __init__(
Expand Down
21 changes: 6 additions & 15 deletions client/starwhale/api/_impl/service/types/text_to_img.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

import inspect
from typing import Any, Dict, Callable, Optional
from typing import Any, Callable, Optional

from pydantic import BaseModel

from starwhale.base.client.models.models import (
ComponentValueSpecInt,
ComponentSpecValueType,
ComponentValueSpecFloat,
)
from starwhale.api._impl.service.types.types import ServiceType
from starwhale.api._impl.service.types.types import (
ServiceType,
generate_type_definition,
)


class Query(BaseModel):
Expand All @@ -28,18 +30,7 @@ class Query(BaseModel):
class TextToImage(ServiceType):
name = "text_to_image"
args = {}

arg_types: Dict[str, ComponentSpecValueType] = {
"prompt": ComponentSpecValueType.string,
"negative_prompt": ComponentSpecValueType.string,
"sampling_steps": ComponentSpecValueType.int,
"width": ComponentSpecValueType.int,
"height": ComponentSpecValueType.int,
"seed": ComponentSpecValueType.int,
"batch_size": ComponentSpecValueType.int,
"batch_count": ComponentSpecValueType.int,
"guidance_scale": ComponentSpecValueType.float,
}
arg_types = generate_type_definition(Query)

def __init__(
self,
Expand Down
62 changes: 61 additions & 1 deletion client/starwhale/api/_impl/service/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import abc
import inspect
from typing import Any, Dict, List, Union, Callable
from typing import Any, Dict, List, Type, Union, Callable

from pydantic import BaseModel
from typing_extensions import Protocol

from starwhale.utils import console
Expand Down Expand Up @@ -164,3 +165,62 @@ def all_components_are_gradio(
all([isinstance(out, gradio.components.Component) for out in outputs]),
]
)


def generate_type_definition(
model: Type[BaseModel],
) -> Dict[str, ComponentSpecValueType]:
"""
Generate the type definition for a given model.
Args:
model (Type[BaseModel]): The model for which to generate the type definition.
Returns:
Dict[str, ComponentSpecValueType]: The generated type definition.
Raises:
ValueError: If the field type is not supported.
"""
type_definition: Dict[str, ComponentSpecValueType] = {}
# we will get multiple type field_type, e.g. 'str', typing.Optional[str], typing.List[str], <class 'str'>
# if we use from __future__ import annotations, we will get 'str'
# else we will get typing.Optional[str] or <class 'str'>
for field_name, field_type in model.__annotations__.items():
field_type = str(field_type)

while True:
# remove the typing. prefix
if field_type.startswith("typing."):
field_type = field_type[7:]
continue
# remove the typing.Optional[] wrapper
if field_type.startswith("Optional["):
field_type = field_type[9:-1]
continue
# remove the typing[] wrapper
if field_type.startswith("typing["):
field_type = field_type[7:-1]
continue
# remove the <class ''> wrapper
if field_type.startswith("<class '"):
field_type = field_type[8:-2]
continue
break

if field_type == "str":
type_definition[field_name] = ComponentSpecValueType.string
elif field_type == "int":
type_definition[field_name] = ComponentSpecValueType.int
elif field_type == "float":
type_definition[field_name] = ComponentSpecValueType.float
elif field_type == "bool":
type_definition[field_name] = ComponentSpecValueType.bool
elif field_type.startswith("typing.List[") or field_type.startswith("List["):
type_definition[field_name] = ComponentSpecValueType.list
else:
# we do not support union type for now:
raise ValueError(
f"Unsupported field type: {field_type} with name {field_name}"
)
return type_definition
34 changes: 34 additions & 0 deletions client/tests/sdk/test_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import List, Optional

import pytest
from pydantic import BaseModel

from starwhale.api._impl.service.types import all_components_are_gradio
from starwhale.base.client.models.models import (
ComponentValueSpecInt,
ComponentSpecValueType,
)
from starwhale.api._impl.service.types.types import generate_type_definition
from starwhale.api._impl.service.types.text_to_img import TextToImage


Expand Down Expand Up @@ -55,3 +59,33 @@ def test_text_to_image():
assert t.args == {
"sampling_steps": ComponentValueSpecInt(default_val=1),
}


def test_generate_type_definition():
class MyType(BaseModel):
a: int
b: str
c: float
d: bool
e: Optional[int] = None
f: Optional[str] = None
g: Optional[float] = None
h: Optional[bool] = None
i: List[int]
j: List[str]
k: Optional[List[int]] = None

types = generate_type_definition(MyType)
assert types == {
"a": ComponentSpecValueType.int,
"b": ComponentSpecValueType.string,
"c": ComponentSpecValueType.float,
"d": ComponentSpecValueType.bool,
"e": ComponentSpecValueType.int,
"f": ComponentSpecValueType.string,
"g": ComponentSpecValueType.float,
"h": ComponentSpecValueType.bool,
"i": ComponentSpecValueType.list,
"j": ComponentSpecValueType.list,
"k": ComponentSpecValueType.list,
}

0 comments on commit 30cc913

Please sign in to comment.