Skip to content

Commit

Permalink
enhance(client): embed model serving spec into job spec (#3032)
Browse files Browse the repository at this point in the history
  • Loading branch information
jialeicui authored Nov 27, 2023
1 parent 39bc0bc commit 94fe64d
Show file tree
Hide file tree
Showing 13 changed files with 434 additions and 275 deletions.
2 changes: 1 addition & 1 deletion client/starwhale/api/_impl/job/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def generate_jobs_yaml(
yaml_path,
yaml.safe_dump(
{
name: [h.dict() for h in handlers]
name: [h.to_dict() for h in handlers]
for name, handlers in expanded_handlers.items()
},
default_flow_style=False,
Expand Down
17 changes: 2 additions & 15 deletions client/starwhale/api/_impl/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typing as t
import functools

from .types.types import ComponentSpec
from starwhale.base.client.models.models import ApiSpec, ServiceSpec

if sys.version_info >= (3, 9):
from importlib.resources import files
Expand All @@ -27,19 +27,6 @@
)


class ApiSpec(SwBaseModel):
uri: str
inference_type: str
components_hint: t.List[ComponentSpec] = Field(default_factory=list)


class ServiceSpec(SwBaseModel):
title: t.Optional[str] = None
description: t.Optional[str] = None
version: str
apis: t.List[ApiSpec]


class Api(SwBaseModel):
func: t.Callable
uri: str
Expand Down Expand Up @@ -68,7 +55,7 @@ def to_spec(self) -> ApiSpec | None:
return ApiSpec(
uri=self.uri,
inference_type=self.inference_type.name,
components_hint=self.inference_type.components_spec(),
components=self.inference_type.components_spec(),
)


Expand Down
20 changes: 11 additions & 9 deletions client/starwhale/api/_impl/service/types/llm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import inspect
from typing import Any, Set, List, Callable, Optional
from typing import Any, Set, Dict, List, Callable, Optional

from pydantic import BaseModel
from pydantic.dataclasses import dataclass

from starwhale.base.client.models.models import ComponentSpecValueType

from .types import ServiceType, ComponentSpec


Expand All @@ -28,13 +30,13 @@ class LLMChat(ServiceType):
name = "llm_chat"

# TODO use pydantic model annotations generated arg_types
arg_types = {
"user_input": str,
"history": list, # list of Message
"top_k": int,
"top_p": float,
"temperature": float,
"max_new_tokens": int,
arg_types: Dict[str, ComponentSpecValueType] = {
"user_input": ComponentSpecValueType.string,
"history": ComponentSpecValueType.list, # list of Message
"top_k": ComponentSpecValueType.int,
"top_p": ComponentSpecValueType.float,
"temperature": ComponentSpecValueType.float,
"max_new_tokens": ComponentSpecValueType.int,
}

def __init__(self, args: Set | None = None) -> None:
Expand All @@ -50,7 +52,7 @@ def __init__(self, args: Set | None = None) -> None:

def components_spec(self) -> List[ComponentSpec]:
return [
ComponentSpec(name=arg, type=self.arg_types[arg].__name__)
ComponentSpec(name=arg, component_spec_value_type=self.arg_types[arg])
for arg in self.args
]

Expand Down
12 changes: 2 additions & 10 deletions client/starwhale/api/_impl/service/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,18 @@
from typing import Any, Dict, List, Callable

from starwhale.utils import console
from starwhale.base.models.base import SwBaseModel
from starwhale.base.client.models.models import ComponentSpec, ComponentSpecValueType

Inputs = Any
Outputs = Any


class ComponentSpec(SwBaseModel):
name: str
type: str

def __hash__(self) -> int:
return hash((self.name, self.type))


class ServiceType(abc.ABC):
"""Protocol for service types."""

@property
@abc.abstractmethod
def arg_types(self) -> Dict[str, Any]:
def arg_types(self) -> Dict[str, ComponentSpecValueType]:
...

@property
Expand Down
Loading

0 comments on commit 94fe64d

Please sign in to comment.