From 5b8ac743c10ea65ebc650ffae0cc8d8f870aae2d Mon Sep 17 00:00:00 2001 From: tianwei Date: Wed, 27 Dec 2023 19:23:45 +0800 Subject: [PATCH] refactor handler args with @starwhale.argument decorator --- client/starwhale/__init__.py | 19 +- client/starwhale/api/_impl/argument.py | 12 +- .../api/_impl/evaluation/__init__.py | 2 - client/starwhale/api/_impl/experiment.py | 1 - client/starwhale/api/_impl/job/handler.py | 157 +--------- client/starwhale/api/job.py | 21 +- client/starwhale/base/scheduler/__init__.py | 3 - client/starwhale/base/scheduler/step.py | 3 - client/starwhale/base/scheduler/task.py | 17 +- client/starwhale/core/model/cli.py | 1 - client/starwhale/core/model/model.py | 2 - client/starwhale/core/model/view.py | 2 - client/tests/sdk/test_job_handler.py | 278 +----------------- client/tests/sdk/test_model.py | 2 - scripts/client_test/cli_test.py | 38 +-- scripts/example/src/evaluator.py | 97 +++--- 16 files changed, 83 insertions(+), 572 deletions(-) diff --git a/client/starwhale/__init__.py b/client/starwhale/__init__.py index b20f6e07dd..89f532a374 100644 --- a/client/starwhale/__init__.py +++ b/client/starwhale/__init__.py @@ -1,15 +1,5 @@ from starwhale.api import model, track, evaluation -from starwhale.api.job import ( - Job, - Handler, - IntInput, - BoolInput, - ListInput, - FloatInput, - ContextInput, - DatasetInput, - HandlerInput, -) +from starwhale.api.job import Job, Handler from starwhale.version import STARWHALE_VERSION as __version__ from starwhale.api.metric import multi_classification from starwhale.api.dataset import Dataset @@ -72,13 +62,6 @@ "Text", "Line", "Point", - "DatasetInput", - "HandlerInput", - "ListInput", - "BoolInput", - "IntInput", - "FloatInput", - "ContextInput", "Polygon", "Audio", "Video", diff --git a/client/starwhale/api/_impl/argument.py b/client/starwhale/api/_impl/argument.py index 77ee36b625..293be90a5d 100644 --- a/client/starwhale/api/_impl/argument.py +++ b/client/starwhale/api/_impl/argument.py @@ -72,10 +72,17 @@ def _register_wrapper(func: t.Callable) -> t.Any: # TODO: dump parser to json file when model building # TODO: `@handler` decorator function supports @argument decorator parser = get_parser_from_dataclasses(dataclass_types) + lock = threading.Lock() + parsed_cache: t.Any = None @wraps(func) def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any: - dataclass_values = init_dataclasses_values(parser, dataclass_types) + nonlocal parsed_cache, lock + with lock: + if parsed_cache is None: + parsed_cache = init_dataclasses_values(parser, dataclass_types) + + dataclass_values = parsed_cache if inject_name in kw: raise RuntimeError( f"{inject_name} has been used as a keyword argument in the decorated function, please use another name by the `inject_name` option." @@ -91,7 +98,8 @@ def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any: def init_dataclasses_values( parser: click.OptionParser, dataclass_types: t.Any ) -> t.Any: - args_map, _, params = parser.parse_args(ExtraCliArgsRegistry.get()) + # forbid to modify the ExtraCliArgsRegistry args values + args_map, _, params = parser.parse_args(ExtraCliArgsRegistry.get().copy()) param_map = {p.name: p for p in params} ret = [] diff --git a/client/starwhale/api/_impl/evaluation/__init__.py b/client/starwhale/api/_impl/evaluation/__init__.py index 552f26a51c..52e1b6e32c 100644 --- a/client/starwhale/api/_impl/evaluation/__init__.py +++ b/client/starwhale/api/_impl/evaluation/__init__.py @@ -108,7 +108,6 @@ def _register_predict( predict_log_dataset_features=log_dataset_features, dataset_uris=datasets, ), - built_in=True, )(func) @@ -172,7 +171,6 @@ def _register_evaluate( extra_kwargs=dict( predict_auto_log=use_predict_auto_log, ), - built_in=True, )(func) diff --git a/client/starwhale/api/_impl/experiment.py b/client/starwhale/api/_impl/experiment.py index 8efba4b668..f5d7fd8c7c 100644 --- a/client/starwhale/api/_impl/experiment.py +++ b/client/starwhale/api/_impl/experiment.py @@ -138,7 +138,6 @@ def _register_ft( extra_kwargs=dict( auto_build_model=auto_build_model, ), - built_in=True, fine_tune=FineTune( require_train_datasets=require_train_datasets, require_validation_datasets=require_validation_datasets, diff --git a/client/starwhale/api/_impl/job/handler.py b/client/starwhale/api/_impl/job/handler.py index 1f10d187b5..46e44f0e36 100644 --- a/client/starwhale/api/_impl/job/handler.py +++ b/client/starwhale/api/_impl/job/handler.py @@ -4,7 +4,6 @@ import typing as t import inspect import threading -from abc import ABC, abstractmethod from pathlib import Path from functools import partial from collections import defaultdict @@ -18,11 +17,7 @@ from starwhale.utils.error import NoSupportError from starwhale.base.models.model import StepSpecClient from starwhale.api._impl.evaluation import PipelineHandler -from starwhale.base.client.models.models import ( - FineTune, - RuntimeResource, - ParameterSignature, -) +from starwhale.base.client.models.models import FineTune, RuntimeResource class Handler: @@ -103,7 +98,6 @@ def register( name: str = "", expose: int = 0, require_dataset: bool = False, - built_in: bool = False, fine_tune: FineTune | None = None, ) -> t.Callable: """Register a function as a handler. Enable the function execute by needs handler, run with gpu/cpu/mem resources in server side, @@ -122,8 +116,6 @@ def register( require_dataset: [bool, optional] Whether you need datasets when execute the handler. Default is False, It means that there is no need to select datasets when executing this handler on the server or cloud instance. If True, You must select datasets when executing on the server or cloud instance. - built_in: [bool, optional] A special flag to distinguish user defined args in handler function from the StarWhale ones. - This should always be False unless you know what it does. fine_tune: [FineTune, optional The fine tune config for the handler. Default is None. Example: @@ -166,26 +158,7 @@ def decorator(func: t.Callable) -> t.Callable: key_name_needs.append(f"{n.__module__}:{n.__qualname__}") - ext_cmd_args = "" - parameters_sig = [] - # user defined handlers i.e. not predict/evaluate/fine_tune - if not built_in: - sig = inspect.signature(func) - parameters_sig = [ - ParameterSignature( - name=p_name, - required=_p.default is inspect._empty - or ( - isinstance(_p.default, HandlerInput) and _p.default.required - ), - multiple=isinstance(_p.default, ListInput), - ) - for idx, (p_name, _p) in enumerate(sig.parameters.items()) - if idx != 0 or not cls_name - ] - ext_cmd_args = " ".join( - [f"--{p.name}" for p in parameters_sig if p.required] - ) + # TODO: check arguments, then dump for Starwhale Console _handler = StepSpecClient( name=key_name, show_name=name or func_name, @@ -199,76 +172,12 @@ def decorator(func: t.Callable) -> t.Callable: extra_kwargs=extra_kwargs, expose=expose, require_dataset=require_dataset, - parameters_sig=parameters_sig, - ext_cmd_args=ext_cmd_args, fine_tune=fine_tune, ) cls._register(_handler, func) setattr(func, DecoratorInjectAttr.Step, True) - import functools - - if built_in: - return func - else: - - @functools.wraps(func) - def wrapper(*args: t.Any, **kwargs: t.Any) -> None: - if "handler_args" in kwargs: - import click - from click.parser import OptionParser - - handler_args: t.List[str] = kwargs.pop("handler_args") - - parser = OptionParser() - sig = inspect.signature(func) - for idx, (p_name, _p) in enumerate(sig.parameters.items()): - if ( - idx == 0 - and args - and callable(getattr(args[0], func.__name__)) - ): # if the first argument has a function with the same name it is considered as self - continue - required = _p.default is inspect._empty or ( - isinstance(_p.default, HandlerInput) - and _p.default.required - ) - click.Option( - [f"--{p_name}", f"-{p_name}"], - is_flag=False, - multiple=isinstance(_p.default, ListInput), - required=required, - ).add_to_parser( - parser, None # type:ignore - ) - hargs, _, _ = parser.parse_args(handler_args) - - for idx, (p_name, _p) in enumerate(sig.parameters.items()): - if ( - idx == 0 - and args - and callable(getattr(args[0], func.__name__)) - ): - continue - parsed_args = { - p_name: fetch_real_args( - (p_name, _p), hargs.get(p_name, None) - ) - } - kwargs.update( - {k: v for k, v in parsed_args.items() if v is not None} - ) - func(*args, **kwargs) - - def fetch_real_args( - parameter: t.Tuple[str, inspect.Parameter], user_input: t.Any - ) -> t.Any: - if isinstance(parameter[1].default, HandlerInput): - return parameter[1].default.parse(user_input) - else: - return user_input - - return wrapper + return func return decorator @@ -346,14 +255,12 @@ def _preload_registering_handlers( Handler.register, name="predict", require_dataset=True, - built_in=True, ) evaluate_register = partial( Handler.register, needs=[predict_func], name="evaluate", - built_in=True, replicas=1, ) @@ -415,61 +322,3 @@ def generate_jobs_yaml( ), parents=True, ) - - -class HandlerInput(ABC): - def __init__(self, required: bool = False) -> None: - self.required = required - - @abstractmethod - def parse(self, user_input: t.Any) -> t.Any: - raise NotImplementedError - - -class ListInput(HandlerInput): - def __init__( - self, member_type: t.Any | None = None, required: bool = False - ) -> None: - super().__init__(required) - self.member_type = member_type or None - - def parse(self, user_input: t.List) -> t.Any: - if not user_input: - return user_input - if isinstance(self.member_type, HandlerInput): - return [self.member_type.parse(item) for item in user_input] - elif inspect.isclass(self.member_type) and issubclass( - self.member_type, HandlerInput - ): - return [self.member_type().parse(item) for item in user_input] - else: - return user_input - - -class DatasetInput(HandlerInput): - def parse(self, user_input: str) -> t.Any: - from starwhale import dataset - - return dataset(user_input) if user_input else None - - -class BoolInput(HandlerInput): - def parse(self, user_input: str) -> t.Any: - return "false" != str(user_input).lower() - - -class IntInput(HandlerInput): - def parse(self, user_input: str) -> t.Any: - return int(user_input) - - -class FloatInput(HandlerInput): - def parse(self, user_input: str) -> t.Any: - return float(user_input) - - -class ContextInput(HandlerInput): - def parse(self, user_input: str) -> t.Any: - from starwhale import Context - - return Context.get_runtime_context() diff --git a/client/starwhale/api/job.py b/client/starwhale/api/job.py index 622446e547..cad41c3f64 100644 --- a/client/starwhale/api/job.py +++ b/client/starwhale/api/job.py @@ -1,22 +1,3 @@ from ._impl.job import Job, Handler -from ._impl.job.handler import ( - IntInput, - BoolInput, - ListInput, - FloatInput, - ContextInput, - DatasetInput, - HandlerInput, -) -__all__ = [ - "Handler", - "Job", - "DatasetInput", - "HandlerInput", - "ListInput", - "BoolInput", - "IntInput", - "FloatInput", - "ContextInput", -] +__all__ = ["Handler", "Job"] diff --git a/client/starwhale/base/scheduler/__init__.py b/client/starwhale/base/scheduler/__init__.py index 92e2f9fb97..792be108db 100644 --- a/client/starwhale/base/scheduler/__init__.py +++ b/client/starwhale/base/scheduler/__init__.py @@ -22,7 +22,6 @@ def __init__( workdir: Path, dataset_uris: t.List[str], steps: t.List[Step], - handler_args: t.List[str] | None = None, run_project: t.Optional[Project] = None, log_project: t.Optional[Project] = None, dataset_head: int = 0, @@ -36,7 +35,6 @@ def __init__( self.dataset_uris = dataset_uris self.workdir = workdir self.version = version - self.handler_args = handler_args or [] self.dataset_head = dataset_head self.finetune_val_dataset_uris = finetune_val_dataset_uris self.model_name = model_name @@ -80,7 +78,6 @@ def _schedule_all(self) -> t.List[StepResult]: dataset_uris=self.dataset_uris, workdir=self.workdir, version=self.version, - handler_args=self.handler_args, dataset_head=self.dataset_head, finetune_val_dataset_uris=self.finetune_val_dataset_uris, model_name=self.model_name, diff --git a/client/starwhale/base/scheduler/step.py b/client/starwhale/base/scheduler/step.py index d05c91298f..1cb91de635 100644 --- a/client/starwhale/base/scheduler/step.py +++ b/client/starwhale/base/scheduler/step.py @@ -157,7 +157,6 @@ def __init__( workdir: Path, dataset_uris: t.List[str], task_num: int = 0, - handler_args: t.List[str] | None = None, dataset_head: int = 0, finetune_val_dataset_uris: t.List[str] | None = None, model_name: str = "", @@ -169,7 +168,6 @@ def __init__( self.dataset_uris = dataset_uris self.workdir = workdir self.version = version - self.handler_args = handler_args or [] self.dataset_head = dataset_head self.finetune_val_dataset_uris = finetune_val_dataset_uris self.model_name = model_name @@ -199,7 +197,6 @@ def execute(self) -> StepResult: finetune_val_dataset_uris=self.finetune_val_dataset_uris, model_name=self.model_name, ), - handler_args=self.handler_args, step=self.step, workdir=self.workdir, ) diff --git a/client/starwhale/base/scheduler/task.py b/client/starwhale/base/scheduler/task.py index a5b6b7e440..30f1d2c534 100644 --- a/client/starwhale/base/scheduler/task.py +++ b/client/starwhale/base/scheduler/task.py @@ -32,7 +32,6 @@ def __init__( context: Context, workdir: Path, step: Step, - handler_args: t.List[str] | None = None, ): self.index = index self.context = context @@ -40,7 +39,6 @@ def __init__( self.exception: t.Optional[Exception] = None self.step = step self.__status = RunStatus.INIT - self.handler_args = handler_args or [] self._validate() def _validate(self) -> None: @@ -128,10 +126,7 @@ def _do_execute(self) -> None: elif getattr(func, DecoratorInjectAttr.Predict, False): self._run_in_pipeline_handler_cls(func, "predict") elif getattr(func, DecoratorInjectAttr.Step, False): - if self.handler_args: - func(**{"handler_args": self.handler_args}) - else: - func() + func() else: raise NoSupportError( f"func({self.step.module_name}.{self.step.func_name}) should use @handler, @predict, @evaluate or @finetune decorator" @@ -151,10 +146,7 @@ def _do_execute(self) -> None: elif getattr(func, DecoratorInjectAttr.Predict, False): self._run_in_pipeline_handler_cls(func, "predict") else: - if self.handler_args: - func(**{"handler_args": self.handler_args}) - else: - func() + func() else: func = getattr(cls_(), func_name) if getattr(func, DecoratorInjectAttr.Evaluate, False): @@ -162,10 +154,7 @@ def _do_execute(self) -> None: elif getattr(func, DecoratorInjectAttr.Predict, False): self._run_in_pipeline_handler_cls(func, "predict") else: - if self.handler_args: - func(**{"handler_args": self.handler_args}) - else: - func() + func() def execute(self) -> TaskResult: console.info( diff --git a/client/starwhale/core/model/cli.py b/client/starwhale/core/model/cli.py index 2d70d6834c..72ba0ba156 100644 --- a/client/starwhale/core/model/cli.py +++ b/client/starwhale/core/model/cli.py @@ -729,7 +729,6 @@ def _run( "task_num": override_task_num, }, force_generate_jobs_yaml=uri is None, - handler_args=ctx.args, ) diff --git a/client/starwhale/core/model/model.py b/client/starwhale/core/model/model.py index 5c523a5df8..8b36da44fa 100644 --- a/client/starwhale/core/model/model.py +++ b/client/starwhale/core/model/model.py @@ -362,7 +362,6 @@ def run( forbid_snapshot: bool = False, cleanup_snapshot: bool = True, force_generate_jobs_yaml: bool = False, - handler_args: t.List[str] | None = None, ) -> Resource: dataset_uris = dataset_uris or [] finetune_val_dataset_uris = finetune_val_dataset_uris or [] @@ -401,7 +400,6 @@ def run( workdir=snapshot_dir, dataset_uris=dataset_uris, steps=steps, - handler_args=handler_args or [], dataset_head=dataset_head, finetune_val_dataset_uris=finetune_val_dataset_uris, model_name=model_config.name, diff --git a/client/starwhale/core/model/view.py b/client/starwhale/core/model/view.py index 3424921dbb..5283354db8 100644 --- a/client/starwhale/core/model/view.py +++ b/client/starwhale/core/model/view.py @@ -263,7 +263,6 @@ def run_in_host( forbid_snapshot: bool = False, cleanup_snapshot: bool = True, force_generate_jobs_yaml: bool = False, - handler_args: t.List[str] | None = None, ) -> None: if runtime_uri: RuntimeProcess(uri=runtime_uri).run() @@ -282,7 +281,6 @@ def run_in_host( forbid_snapshot=forbid_snapshot, cleanup_snapshot=cleanup_snapshot, force_generate_jobs_yaml=force_generate_jobs_yaml, - handler_args=handler_args or [], ) @classmethod diff --git a/client/tests/sdk/test_job_handler.py b/client/tests/sdk/test_job_handler.py index 5e26f4439b..5dd51d1c06 100644 --- a/client/tests/sdk/test_job_handler.py +++ b/client/tests/sdk/test_job_handler.py @@ -13,11 +13,7 @@ from starwhale.base.models.base import obj_to_model from starwhale.base.uri.project import Project from starwhale.base.models.model import JobHandlers, StepSpecClient -from starwhale.base.client.models.models import ( - FineTune, - RuntimeResource, - ParameterSignature, -) +from starwhale.base.client.models.models import FineTune, RuntimeResource class JobTestCase(unittest.TestCase): @@ -248,8 +244,6 @@ def mock_predict_handler2(data, argument=None): ... show_name="predict", expose=0, require_dataset=True, - parameters_sig=[], - ext_cmd_args="", ), StepSpecClient( cls_name="", @@ -263,8 +257,6 @@ def mock_predict_handler2(data, argument=None): ... show_name="evaluate", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ), ] @@ -288,8 +280,6 @@ def mock_predict_handler2(data, argument=None): ... show_name="predict", expose=0, require_dataset=True, - parameters_sig=[], - ext_cmd_args="", ), StepSpecClient( cls_name="", @@ -303,8 +293,6 @@ def mock_predict_handler2(data, argument=None): ... show_name="evaluate", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ), ] @@ -352,8 +340,6 @@ def evaluate_handler(*args, **kwargs): ... show_name="predict", expose=0, require_dataset=True, - parameters_sig=[], - ext_cmd_args="", ), StepSpecClient( cls_name="", @@ -367,8 +353,6 @@ def evaluate_handler(*args, **kwargs): ... show_name="evaluate", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ), ], "mock_user_module:predict_handler": [ @@ -391,8 +375,6 @@ def evaluate_handler(*args, **kwargs): ... show_name="predict", expose=0, require_dataset=True, - parameters_sig=[], - ext_cmd_args="", ) ], } @@ -517,8 +499,6 @@ def evaluate(self, *args, **kwargs): ... show_name="predict", expose=0, require_dataset=True, - parameters_sig=[], - ext_cmd_args="", ), StepSpecClient( cls_name="MockHandler", @@ -534,8 +514,6 @@ def evaluate(self, *args, **kwargs): ... show_name="evaluate", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ), ] assert jobs_info["mock_user_module:MockHandler.predict"] == [ @@ -552,8 +530,6 @@ def evaluate(self, *args, **kwargs): ... show_name="predict", expose=0, require_dataset=True, - parameters_sig=[], - ext_cmd_args="", ) ] _, steps = Step.get_steps_from_yaml( @@ -620,8 +596,6 @@ def evaluate_handler(self, *args, **kwargs): ... show_name="predict", expose=0, require_dataset=True, - parameters_sig=[], - ext_cmd_args="", ) ] @@ -645,8 +619,6 @@ def evaluate_handler(self, *args, **kwargs): ... show_name="predict", expose=0, require_dataset=True, - parameters_sig=[], - ext_cmd_args="", ), StepSpecClient( cls_name="MockHandler", @@ -660,8 +632,6 @@ def evaluate_handler(self, *args, **kwargs): ... show_name="evaluate", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ), ] @@ -712,18 +682,22 @@ def run(): ... show_name="run", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ) ] } def test_handler_with_other_decorator(self) -> None: content = """ -from starwhale import handler +import dataclasses +from starwhale import handler, argument + +@dataclasses.dataclass +class TestArgument: + epoch: int = dataclasses.field(default=1) +@argument(TestArgument) @handler(replicas=2) -def handle(context): ... +def handle(argument): ... """ self._ensure_py_script(content) @@ -744,14 +718,6 @@ def handle(context): ... show_name="handle", expose=0, require_dataset=False, - parameters_sig=[ - ParameterSignature( - name="context", - required=True, - multiple=False, - ) - ], - ext_cmd_args="--context", ) ] } @@ -791,9 +757,7 @@ def ft2(): ... expose=0, replicas=1, resources=[], - parameters_sig=[], cls_name="", - ext_cmd_args="", extra_kwargs={ "auto_build_model": True, }, @@ -814,9 +778,7 @@ def ft2(): ... expose=0, replicas=1, resources=[], - parameters_sig=[], cls_name="", - ext_cmd_args="", extra_kwargs={ "auto_build_model": True, }, @@ -841,9 +803,7 @@ def ft2(): ... replicas=1, needs=["mock_user_module:ft1"], expose=0, - parameters_sig=[], cls_name="", - ext_cmd_args="", extra_kwargs={ "auto_build_model": False, }, @@ -914,9 +874,7 @@ def report_handler(self): ... needs=[], expose=0, resources=[], - parameters_sig=[], cls_name="", - ext_cmd_args="", ) in report_handler ) @@ -932,9 +890,7 @@ def report_handler(self): ... needs=["mock_user_module:prepare_handler"], expose=0, resources=[], - parameters_sig=[], cls_name="", - ext_cmd_args="", ) in report_handler ) @@ -950,9 +906,7 @@ def report_handler(self): ... show_name="evaluate", expose=0, replicas=1, - parameters_sig=[], cls_name="", - ext_cmd_args="", ) in report_handler ) @@ -972,8 +926,6 @@ def report_handler(self): ... show_name="report", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ) in report_handler ) @@ -989,8 +941,6 @@ def report_handler(self): ... show_name="predict", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ) assert jobs_info["mock_user_module:evaluate_handler"] == [ @@ -1005,8 +955,6 @@ def report_handler(self): ... show_name="prepare", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ), StepSpecClient( cls_name="", @@ -1019,8 +967,6 @@ def report_handler(self): ... show_name="evaluate", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ), ] assert jobs_info["mock_user_module:predict_handler"] == [ @@ -1035,8 +981,6 @@ def report_handler(self): ... show_name="prepare", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ), StepSpecClient( cls_name="", @@ -1049,8 +993,6 @@ def report_handler(self): ... show_name="predict", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ), ] assert jobs_info["mock_user_module:prepare_handler"] == [ @@ -1065,8 +1007,6 @@ def report_handler(self): ... show_name="prepare", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ) ] @@ -1265,203 +1205,3 @@ def predict_handler(): ... yaml_path = self.workdir / "job.yaml" with self.assertRaisesRegex(RuntimeError, "dependency not found"): generate_jobs_yaml([self.module_name], self.workdir, yaml_path) - - def test_handler_args(self) -> None: - content = """ -from starwhale import ( - Context, - Dataset, - handler, - IntInput, - ListInput, - HandlerInput, - ContextInput, - DatasetInput, -) - -class MyInput(HandlerInput): - def parse(self, user_input): - - return f"MyInput {user_input}" - -class X: - def __init__(self) -> None: - self.a = 1 - - @handler() - def f( - self, x=ListInput(IntInput), y=2, mi=MyInput(), ds=DatasetInput(required=True), ctx=ContextInput() - ): - assert self.a + x[0] == 3 - assert self.a + x[1] == 2 - assert y == 2 - assert mi == "MyInput blab-la" - assert isinstance(ds, Dataset) - assert isinstance(ctx, Context) - - -@handler() -def f(x=ListInput(IntInput()), y=2, mi=MyInput(), ds=DatasetInput(required=True), ctx=ContextInput()): - assert x[0] == 2 - assert x[1] == 1 - assert y == 2 - assert mi == "MyInput blab-la" - - assert isinstance(ds, Dataset) - assert isinstance(ctx, Context) - -@handler() -def f_no_args(): - print("code here is executed") - - -""" - self._ensure_py_script(content) - yaml_path = self.workdir / "job.yaml" - generate_jobs_yaml( - [f"{self.module_name}:X", self.module_name], self.workdir, yaml_path - ) - jobs_info = obj_to_model(load_yaml(yaml_path), JobHandlers).data - assert jobs_info["mock_user_module:X.f"] == [ - StepSpecClient( - name="mock_user_module:X.f", - cls_name="X", - func_name="f", - module_name="mock_user_module", - show_name="f", - parameters_sig=[ - ParameterSignature( - name="x", - required=False, - multiple=True, - ), - ParameterSignature( - name="y", - required=False, - multiple=False, - ), - ParameterSignature( - name="mi", - required=False, - multiple=False, - ), - ParameterSignature( - name="ds", - required=True, - multiple=False, - ), - ParameterSignature( - name="ctx", - required=False, - multiple=False, - ), - ], - ext_cmd_args="--ds", - needs=[], - resources=[], - expose=0, - require_dataset=False, - ), - ] - assert jobs_info["mock_user_module:f"] == [ - StepSpecClient( - name="mock_user_module:f", - cls_name="", - func_name="f", - module_name="mock_user_module", - show_name="f", - parameters_sig=[ - ParameterSignature( - name="x", - required=False, - multiple=True, - ), - ParameterSignature( - name="y", - required=False, - multiple=False, - ), - ParameterSignature( - name="mi", - required=False, - multiple=False, - ), - ParameterSignature( - name="ds", - required=True, - multiple=False, - ), - ParameterSignature( - name="ctx", - required=False, - multiple=False, - ), - ], - ext_cmd_args="--ds", - needs=[], - resources=[], - expose=0, - require_dataset=False, - ), - ] - assert jobs_info["mock_user_module:f_no_args"] == [ - StepSpecClient( - name="mock_user_module:f_no_args", - cls_name="", - func_name="f_no_args", - module_name="mock_user_module", - show_name="f_no_args", - parameters_sig=[], - ext_cmd_args="", - needs=[], - resources=[], - expose=0, - require_dataset=False, - ), - ] - _, steps = Step.get_steps_from_yaml("mock_user_module:X.f", yaml_path) - context = Context( - workdir=self.workdir, - run_project=Project("test"), - version="123", - ) - task = TaskExecutor( - index=1, - context=context, - workdir=self.workdir, - step=steps[0], - handler_args=["--x", "2", "-x", "1", "--ds", "mnist", "-mi=blab-la"], - ) - result = task.execute() - assert result.status == "success" - _, steps = Step.get_steps_from_yaml("mock_user_module:f", yaml_path) - context = Context( - workdir=self.workdir, - run_project=Project("test"), - version="123", - ) - task = TaskExecutor( - index=1, - context=context, - workdir=self.workdir, - step=steps[0], - handler_args=["--x", "2", "-x", "1", "--ds", "mnist", "-mi=blab-la"], - ) - result = task.execute() - assert result.status == "success" - - _, steps = Step.get_steps_from_yaml("mock_user_module:f_no_args", yaml_path) - context = Context( - workdir=self.workdir, - run_project=Project("test"), - version="123", - ) - task = TaskExecutor( - index=1, - context=context, - workdir=self.workdir, - step=steps[0], - handler_args=[], - ) - result = task.execute() - assert result.status == "success" diff --git a/client/tests/sdk/test_model.py b/client/tests/sdk/test_model.py index ac367ba78e..76440ee72b 100644 --- a/client/tests/sdk/test_model.py +++ b/client/tests/sdk/test_model.py @@ -192,8 +192,6 @@ def _get_jobs_yaml(model_name: str) -> t.Dict[str, StepSpecClient]: show_name="handle", expose=0, require_dataset=False, - parameters_sig=[], - ext_cmd_args="", ) ] } diff --git a/scripts/client_test/cli_test.py b/scripts/client_test/cli_test.py index a3f499eea9..dc7544810d 100644 --- a/scripts/client_test/cli_test.py +++ b/scripts/client_test/cli_test.py @@ -358,7 +358,23 @@ def test_simple(self) -> None: dataset_uri = self.build_dataset("simple", workdir, DatasetExpl("", "")) remote_job_ids = [] + handler_args = [ + "--learning_rate", + "0.1", + "--epoch", + "100", + "--labels", + "l1", + "--labels", + "l2", + "--labels", + "l3", + "--debug", + "--evaluation_strategy", + "steps", + ] if self.server_url: + # TODO: support to specify model run arguments to server instance remote_job_ids = self.run_model_in_server( dataset_uris=[dataset_uri], model_uri=model_uri, @@ -380,26 +396,14 @@ def test_simple(self) -> None: ) ], run_handler=run_handler, + handler_args=handler_args, ) self.run_model_in_standalone( dataset_uris=[dataset_uri], model_uri=model_uri, run_handler=run_handler, - ) - - self.run_model_in_standalone( - dataset_uris=[], - model_uri=model_uri, - run_handler="src.evaluator:f", - handler_args=["--x", "2", "-x", "1", "--ds", "simple", "-mi=blab-la"], - ) - - self.run_model_in_standalone( - dataset_uris=[], - model_uri=model_uri, - run_handler="src.evaluator:X.f", - handler_args=["--x", "2", "-x", "1", "--ds", "simple", "-mi=blab-la"], + handler_args=handler_args, ) futures = [ @@ -597,18 +601,14 @@ def test_sdk(self) -> None: assert set(ctx_handle_info["handlers"]) == { "src.evaluator:evaluate", "src.evaluator:predict", - "src.evaluator:f", - "src.evaluator:X.f", "src.sdk_model_build:context_handle", "src.sdk_model_build:ft", - }, ctx_handle_info["basic"]["handlers"] + }, ctx_handle_info["handlers"] ctx_handle_no_modules_info = self.model_api.info("ctx_handle_no_modules") assert set(ctx_handle_no_modules_info["handlers"]) == { "src.evaluator:evaluate", "src.evaluator:predict", - "src.evaluator:f", - "src.evaluator:X.f", "src.sdk_model_build:context_handle", "src.sdk_model_build:ft", }, ctx_handle_no_modules_info["handlers"] diff --git a/scripts/example/src/evaluator.py b/scripts/example/src/evaluator.py index fe1aa5e95a..2e5de94a6c 100644 --- a/scripts/example/src/evaluator.py +++ b/scripts/example/src/evaluator.py @@ -1,27 +1,15 @@ +import os import time import random import typing as t import os.path as osp import dataclasses +from enum import Enum from functools import wraps import numpy -from starwhale import ( - Text, - Image, - Context, - Dataset, - handler, - argument, - IntInput, - ListInput, - evaluation, - ContextInput, - DatasetInput, - HandlerInput, - multi_classification, -) +from starwhale import Text, Image, Context, argument, evaluation, multi_classification from starwhale.utils import in_container try: @@ -30,9 +18,24 @@ from util import random_image +class IntervalStrategy(Enum): + NO = "no" + STEPS = "steps" + EPOCH = "epoch" + + @dataclasses.dataclass class TestArguments: + learning_rate: float = dataclasses.field(default=0.01, metadata={"help": "lr"}) epoch: int = dataclasses.field(default=10, metadata={"help": "epoch"}) + labels: t.Optional[t.List[str]] = dataclasses.field( + default=None, metadata={"help": "labels"} + ) + debug: bool = dataclasses.field(default=False, metadata={"help": "debug"}) + evaluation_strategy: t.Union[IntervalStrategy, str] = dataclasses.field( + default="no", metadata={"help": "evaluation strategy"} + ) + default_value: str = dataclasses.field(default="default value") def timing(func: t.Callable) -> t.Any: @@ -60,8 +63,8 @@ def predict(data: t.Dict, external: t.Dict, argument) -> t.Any: assert isinstance(external["context"], Context) assert external["dataset_uri"].name assert external["dataset_uri"].version - assert isinstance(argument, TestArguments) - assert argument.epoch == 10 + + _check_argument_values(argument) if in_container(): assert osp.exists("/tmp/runtime-command-run.flag") @@ -86,8 +89,7 @@ def predict(data: t.Dict, external: t.Dict, argument) -> t.Any: ) @argument(TestArguments) def evaluate(ppl_result: t.Iterator, argument: TestArguments) -> t.Any: - assert isinstance(argument, TestArguments) - assert argument.epoch == 10 + _check_argument_values(argument) result, label, pr = [], [], [] for _data in ppl_result: @@ -105,44 +107,19 @@ def evaluate(ppl_result: t.Iterator, argument: TestArguments) -> t.Any: return label, result, pr -class MyInput(HandlerInput): - def parse(self, user_input): - return f"MyInput {user_input}" - - -class X: - def __init__(self) -> None: - self.a = 1 - - @handler() - def f( - self, - x=ListInput(IntInput), - y=2, - mi=MyInput(), - ds=DatasetInput(required=True), - ctx=ContextInput(), - ): - assert self.a + x[0] == 3 - assert self.a + x[1] == 2 - assert y == 2 - assert mi == "MyInput blab-la" - assert isinstance(ds, Dataset) - assert isinstance(ctx, Context) - - -@handler() -def f( - x=ListInput(IntInput()), - y=2, - mi=MyInput(), - ds=DatasetInput(required=True), - ctx=ContextInput(), -): - assert x[0] == 2 - assert x[1] == 1 - assert y == 2 - assert mi == "MyInput blab-la" - - assert isinstance(ds, Dataset) - assert isinstance(ctx, Context) +def _check_argument_values(argument: TestArguments) -> None: + is_production = ( + os.environ.get("SW_PRODUCTION", "") == "1" + and os.environ.get("SW_CONTAINER", "") == "1" + ) + # TODO: support to specify model run arguments to server instance + if is_production: + return + assert isinstance(argument, TestArguments) + # the values are configured in `scripts/client_test/cli_test.py` + assert argument.default_value == "default value" + assert argument.learning_rate == 0.1 + assert argument.epoch == 100 + assert argument.labels == ["l1", "l2", "l3"] + assert argument.debug is True + assert argument.evaluation_strategy == "steps"