Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(client): refactor handler args with @starwhale.argument decorator #3100

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 1 addition & 18 deletions client/starwhale/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
from starwhale.api import model, track, evaluation
from starwhale.api.job import (
Job,
Handler,
IntInput,
BoolInput,
ListInput,
FloatInput,
ContextInput,
DatasetInput,
HandlerInput,
)
from starwhale.api.job import Job, Handler
from starwhale.version import STARWHALE_VERSION as __version__
from starwhale.api.metric import multi_classification
from starwhale.api.dataset import Dataset
Expand Down Expand Up @@ -72,13 +62,6 @@
"Text",
"Line",
"Point",
"DatasetInput",
"HandlerInput",
"ListInput",
"BoolInput",
"IntInput",
"FloatInput",
"ContextInput",
"Polygon",
"Audio",
"Video",
Expand Down
12 changes: 10 additions & 2 deletions client/starwhale/api/_impl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,17 @@ def _register_wrapper(func: t.Callable) -> t.Any:
# TODO: dump parser to json file when model building
# TODO: `@handler` decorator function supports @argument decorator
parser = get_parser_from_dataclasses(dataclass_types)
lock = threading.Lock()
parsed_cache: t.Any = None

@wraps(func)
def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any:
dataclass_values = init_dataclasses_values(parser, dataclass_types)
nonlocal parsed_cache, lock
with lock:
if parsed_cache is None:
parsed_cache = init_dataclasses_values(parser, dataclass_types)

dataclass_values = parsed_cache
if inject_name in kw:
raise RuntimeError(
f"{inject_name} has been used as a keyword argument in the decorated function, please use another name by the `inject_name` option."
Expand All @@ -91,7 +98,8 @@ def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any:
def init_dataclasses_values(
parser: click.OptionParser, dataclass_types: t.Any
) -> t.Any:
args_map, _, params = parser.parse_args(ExtraCliArgsRegistry.get())
# forbid to modify the ExtraCliArgsRegistry args values
args_map, _, params = parser.parse_args(ExtraCliArgsRegistry.get().copy())
param_map = {p.name: p for p in params}

ret = []
Expand Down
2 changes: 0 additions & 2 deletions client/starwhale/api/_impl/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def _register_predict(
predict_log_dataset_features=log_dataset_features,
dataset_uris=datasets,
),
built_in=True,
)(func)


Expand Down Expand Up @@ -172,7 +171,6 @@ def _register_evaluate(
extra_kwargs=dict(
predict_auto_log=use_predict_auto_log,
),
built_in=True,
)(func)


Expand Down
1 change: 0 additions & 1 deletion client/starwhale/api/_impl/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def _register_ft(
extra_kwargs=dict(
auto_build_model=auto_build_model,
),
built_in=True,
fine_tune=FineTune(
require_train_datasets=require_train_datasets,
require_validation_datasets=require_validation_datasets,
Expand Down
157 changes: 3 additions & 154 deletions client/starwhale/api/_impl/job/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import typing as t
import inspect
import threading
from abc import ABC, abstractmethod
from pathlib import Path
from functools import partial
from collections import defaultdict
Expand All @@ -18,11 +17,7 @@
from starwhale.utils.error import NoSupportError
from starwhale.base.models.model import StepSpecClient
from starwhale.api._impl.evaluation import PipelineHandler
from starwhale.base.client.models.models import (
FineTune,
RuntimeResource,
ParameterSignature,
)
from starwhale.base.client.models.models import FineTune, RuntimeResource


class Handler:
Expand Down Expand Up @@ -103,7 +98,6 @@ def register(
name: str = "",
expose: int = 0,
require_dataset: bool = False,
built_in: bool = False,
fine_tune: FineTune | None = None,
) -> t.Callable:
"""Register a function as a handler. Enable the function execute by needs handler, run with gpu/cpu/mem resources in server side,
Expand All @@ -122,8 +116,6 @@ def register(
require_dataset: [bool, optional] Whether you need datasets when execute the handler.
Default is False, It means that there is no need to select datasets when executing this handler on the server or cloud instance.
If True, You must select datasets when executing on the server or cloud instance.
built_in: [bool, optional] A special flag to distinguish user defined args in handler function from the StarWhale ones.
This should always be False unless you know what it does.
fine_tune: [FineTune, optional The fine tune config for the handler. Default is None.

Example:
Expand Down Expand Up @@ -166,26 +158,7 @@ def decorator(func: t.Callable) -> t.Callable:

key_name_needs.append(f"{n.__module__}:{n.__qualname__}")

ext_cmd_args = ""
parameters_sig = []
# user defined handlers i.e. not predict/evaluate/fine_tune
if not built_in:
sig = inspect.signature(func)
parameters_sig = [
ParameterSignature(
name=p_name,
required=_p.default is inspect._empty
or (
isinstance(_p.default, HandlerInput) and _p.default.required
),
multiple=isinstance(_p.default, ListInput),
)
for idx, (p_name, _p) in enumerate(sig.parameters.items())
if idx != 0 or not cls_name
]
ext_cmd_args = " ".join(
[f"--{p.name}" for p in parameters_sig if p.required]
)
# TODO: check arguments, then dump for Starwhale Console
_handler = StepSpecClient(
name=key_name,
show_name=name or func_name,
Expand All @@ -199,76 +172,12 @@ def decorator(func: t.Callable) -> t.Callable:
extra_kwargs=extra_kwargs,
expose=expose,
require_dataset=require_dataset,
parameters_sig=parameters_sig,
ext_cmd_args=ext_cmd_args,
fine_tune=fine_tune,
)

cls._register(_handler, func)
setattr(func, DecoratorInjectAttr.Step, True)
import functools

if built_in:
return func
else:

@functools.wraps(func)
def wrapper(*args: t.Any, **kwargs: t.Any) -> None:
if "handler_args" in kwargs:
import click
from click.parser import OptionParser

handler_args: t.List[str] = kwargs.pop("handler_args")

parser = OptionParser()
sig = inspect.signature(func)
for idx, (p_name, _p) in enumerate(sig.parameters.items()):
if (
idx == 0
and args
and callable(getattr(args[0], func.__name__))
): # if the first argument has a function with the same name it is considered as self
continue
required = _p.default is inspect._empty or (
isinstance(_p.default, HandlerInput)
and _p.default.required
)
click.Option(
[f"--{p_name}", f"-{p_name}"],
is_flag=False,
multiple=isinstance(_p.default, ListInput),
required=required,
).add_to_parser(
parser, None # type:ignore
)
hargs, _, _ = parser.parse_args(handler_args)

for idx, (p_name, _p) in enumerate(sig.parameters.items()):
if (
idx == 0
and args
and callable(getattr(args[0], func.__name__))
):
continue
parsed_args = {
p_name: fetch_real_args(
(p_name, _p), hargs.get(p_name, None)
)
}
kwargs.update(
{k: v for k, v in parsed_args.items() if v is not None}
)
func(*args, **kwargs)

def fetch_real_args(
parameter: t.Tuple[str, inspect.Parameter], user_input: t.Any
) -> t.Any:
if isinstance(parameter[1].default, HandlerInput):
return parameter[1].default.parse(user_input)
else:
return user_input

return wrapper
return func

return decorator

Expand Down Expand Up @@ -346,14 +255,12 @@ def _preload_registering_handlers(
Handler.register,
name="predict",
require_dataset=True,
built_in=True,
)

evaluate_register = partial(
Handler.register,
needs=[predict_func],
name="evaluate",
built_in=True,
replicas=1,
)

Expand Down Expand Up @@ -415,61 +322,3 @@ def generate_jobs_yaml(
),
parents=True,
)


class HandlerInput(ABC):
def __init__(self, required: bool = False) -> None:
self.required = required

@abstractmethod
def parse(self, user_input: t.Any) -> t.Any:
raise NotImplementedError


class ListInput(HandlerInput):
def __init__(
self, member_type: t.Any | None = None, required: bool = False
) -> None:
super().__init__(required)
self.member_type = member_type or None

def parse(self, user_input: t.List) -> t.Any:
if not user_input:
return user_input
if isinstance(self.member_type, HandlerInput):
return [self.member_type.parse(item) for item in user_input]
elif inspect.isclass(self.member_type) and issubclass(
self.member_type, HandlerInput
):
return [self.member_type().parse(item) for item in user_input]
else:
return user_input


class DatasetInput(HandlerInput):
def parse(self, user_input: str) -> t.Any:
from starwhale import dataset

return dataset(user_input) if user_input else None


class BoolInput(HandlerInput):
def parse(self, user_input: str) -> t.Any:
return "false" != str(user_input).lower()


class IntInput(HandlerInput):
def parse(self, user_input: str) -> t.Any:
return int(user_input)


class FloatInput(HandlerInput):
def parse(self, user_input: str) -> t.Any:
return float(user_input)


class ContextInput(HandlerInput):
def parse(self, user_input: str) -> t.Any:
from starwhale import Context

return Context.get_runtime_context()
21 changes: 1 addition & 20 deletions client/starwhale/api/job.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,3 @@
from ._impl.job import Job, Handler
from ._impl.job.handler import (
IntInput,
BoolInput,
ListInput,
FloatInput,
ContextInput,
DatasetInput,
HandlerInput,
)

__all__ = [
"Handler",
"Job",
"DatasetInput",
"HandlerInput",
"ListInput",
"BoolInput",
"IntInput",
"FloatInput",
"ContextInput",
]
__all__ = ["Handler", "Job"]
3 changes: 0 additions & 3 deletions client/starwhale/base/scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(
workdir: Path,
dataset_uris: t.List[str],
steps: t.List[Step],
handler_args: t.List[str] | None = None,
run_project: t.Optional[Project] = None,
log_project: t.Optional[Project] = None,
dataset_head: int = 0,
Expand All @@ -36,7 +35,6 @@ def __init__(
self.dataset_uris = dataset_uris
self.workdir = workdir
self.version = version
self.handler_args = handler_args or []
self.dataset_head = dataset_head
self.finetune_val_dataset_uris = finetune_val_dataset_uris
self.model_name = model_name
Expand Down Expand Up @@ -80,7 +78,6 @@ def _schedule_all(self) -> t.List[StepResult]:
dataset_uris=self.dataset_uris,
workdir=self.workdir,
version=self.version,
handler_args=self.handler_args,
dataset_head=self.dataset_head,
finetune_val_dataset_uris=self.finetune_val_dataset_uris,
model_name=self.model_name,
Expand Down
3 changes: 0 additions & 3 deletions client/starwhale/base/scheduler/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def __init__(
workdir: Path,
dataset_uris: t.List[str],
task_num: int = 0,
handler_args: t.List[str] | None = None,
dataset_head: int = 0,
finetune_val_dataset_uris: t.List[str] | None = None,
model_name: str = "",
Expand All @@ -169,7 +168,6 @@ def __init__(
self.dataset_uris = dataset_uris
self.workdir = workdir
self.version = version
self.handler_args = handler_args or []
self.dataset_head = dataset_head
self.finetune_val_dataset_uris = finetune_val_dataset_uris
self.model_name = model_name
Expand Down Expand Up @@ -199,7 +197,6 @@ def execute(self) -> StepResult:
finetune_val_dataset_uris=self.finetune_val_dataset_uris,
model_name=self.model_name,
),
handler_args=self.handler_args,
step=self.step,
workdir=self.workdir,
)
Expand Down
Loading
Loading