From 30cc913562f868842e6cbc3dc2303c5f7b601856 Mon Sep 17 00:00:00 2001 From: Jialei Date: Tue, 19 Dec 2023 14:44:41 +0800 Subject: [PATCH] refactor(client): use auto generated arg_types --- .../starwhale/api/_impl/service/types/llm.py | 17 +---- .../api/_impl/service/types/text_to_img.py | 21 ++----- .../api/_impl/service/types/types.py | 62 ++++++++++++++++++- client/tests/sdk/test_types.py | 34 ++++++++++ 4 files changed, 104 insertions(+), 30 deletions(-) diff --git a/client/starwhale/api/_impl/service/types/llm.py b/client/starwhale/api/_impl/service/types/llm.py index be60036a1a..c6e4e5153f 100644 --- a/client/starwhale/api/_impl/service/types/llm.py +++ b/client/starwhale/api/_impl/service/types/llm.py @@ -1,17 +1,16 @@ from __future__ import annotations import inspect -from typing import Any, Dict, List, Callable, Optional +from typing import Any, List, Callable, Optional from pydantic import BaseModel from starwhale.base.client.models.models import ( ComponentValueSpecInt, - ComponentSpecValueType, ComponentValueSpecFloat, ) -from .types import ServiceType +from .types import ServiceType, generate_type_definition class MessageItem(BaseModel): @@ -31,17 +30,7 @@ class Query(BaseModel): class LLMChat(ServiceType): name = "llm_chat" args = {} - - # TODO use pydantic model annotations generated arg_types - arg_types: Dict[str, ComponentSpecValueType] = { - "user_input": ComponentSpecValueType.string, - "history": ComponentSpecValueType.list, # list of Message - "top_k": ComponentSpecValueType.int, - "top_p": ComponentSpecValueType.float, - "temperature": ComponentSpecValueType.float, - "max_new_tokens": ComponentSpecValueType.int, - } - + arg_types = generate_type_definition(Query) Message = MessageItem def __init__( diff --git a/client/starwhale/api/_impl/service/types/text_to_img.py b/client/starwhale/api/_impl/service/types/text_to_img.py index 4bd1a79d1c..935de69b6a 100644 --- a/client/starwhale/api/_impl/service/types/text_to_img.py +++ b/client/starwhale/api/_impl/service/types/text_to_img.py @@ -1,16 +1,18 @@ from __future__ import annotations import inspect -from typing import Any, Dict, Callable, Optional +from typing import Any, Callable, Optional from pydantic import BaseModel from starwhale.base.client.models.models import ( ComponentValueSpecInt, - ComponentSpecValueType, ComponentValueSpecFloat, ) -from starwhale.api._impl.service.types.types import ServiceType +from starwhale.api._impl.service.types.types import ( + ServiceType, + generate_type_definition, +) class Query(BaseModel): @@ -28,18 +30,7 @@ class Query(BaseModel): class TextToImage(ServiceType): name = "text_to_image" args = {} - - arg_types: Dict[str, ComponentSpecValueType] = { - "prompt": ComponentSpecValueType.string, - "negative_prompt": ComponentSpecValueType.string, - "sampling_steps": ComponentSpecValueType.int, - "width": ComponentSpecValueType.int, - "height": ComponentSpecValueType.int, - "seed": ComponentSpecValueType.int, - "batch_size": ComponentSpecValueType.int, - "batch_count": ComponentSpecValueType.int, - "guidance_scale": ComponentSpecValueType.float, - } + arg_types = generate_type_definition(Query) def __init__( self, diff --git a/client/starwhale/api/_impl/service/types/types.py b/client/starwhale/api/_impl/service/types/types.py index a030a4bf4b..fa6c53a9dc 100644 --- a/client/starwhale/api/_impl/service/types/types.py +++ b/client/starwhale/api/_impl/service/types/types.py @@ -2,8 +2,9 @@ import abc import inspect -from typing import Any, Dict, List, Union, Callable +from typing import Any, Dict, List, Type, Union, Callable +from pydantic import BaseModel from typing_extensions import Protocol from starwhale.utils import console @@ -164,3 +165,62 @@ def all_components_are_gradio( all([isinstance(out, gradio.components.Component) for out in outputs]), ] ) + + +def generate_type_definition( + model: Type[BaseModel], +) -> Dict[str, ComponentSpecValueType]: + """ + Generate the type definition for a given model. + + Args: + model (Type[BaseModel]): The model for which to generate the type definition. + + Returns: + Dict[str, ComponentSpecValueType]: The generated type definition. + + Raises: + ValueError: If the field type is not supported. + """ + type_definition: Dict[str, ComponentSpecValueType] = {} + # we will get multiple type field_type, e.g. 'str', typing.Optional[str], typing.List[str], + # if we use from __future__ import annotations, we will get 'str' + # else we will get typing.Optional[str] or + for field_name, field_type in model.__annotations__.items(): + field_type = str(field_type) + + while True: + # remove the typing. prefix + if field_type.startswith("typing."): + field_type = field_type[7:] + continue + # remove the typing.Optional[] wrapper + if field_type.startswith("Optional["): + field_type = field_type[9:-1] + continue + # remove the typing[] wrapper + if field_type.startswith("typing["): + field_type = field_type[7:-1] + continue + # remove the wrapper + if field_type.startswith("