diff --git a/client/starwhale/api/_impl/argument.py b/client/starwhale/api/_impl/argument.py index 447c750001..77ee36b625 100644 --- a/client/starwhale/api/_impl/argument.py +++ b/client/starwhale/api/_impl/argument.py @@ -28,7 +28,7 @@ def get(cls) -> t.List[str]: return cls._args or [] -def argument(dataclass_types: t.Any) -> t.Any: +def argument(dataclass_types: t.Any, inject_name: str = "argument") -> t.Any: """argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune). The decorated function will receive the instances of the dataclass types as the arguments. @@ -40,6 +40,8 @@ def argument(dataclass_types: t.Any) -> t.Any: Argument: dataclass_types: [required] The dataclass type of the arguments. A list of dataclass types or a single dataclass type is supported. + inject_name: [optional] The name of the keyword argument that will be passed to the decorated function. + Default is "argument". Examples: ```python @@ -53,6 +55,11 @@ class EvaluationArguments: @evaluation.predict def predict_image(data, argument: EvaluationArguments): ... + + @argument(EvaluationArguments, inject_name="starwhale_arguments") + @evaluation.evaluate(needs=[]) + def evaluate_summary(predict_result_iter, starwhale_arguments: EvaluationArguments): + ... ``` """ is_sequence = True @@ -69,11 +76,11 @@ def _register_wrapper(func: t.Callable) -> t.Any: @wraps(func) def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any: dataclass_values = init_dataclasses_values(parser, dataclass_types) - if "argument" in kw: + if inject_name in kw: raise RuntimeError( - "argument is a reserved keyword for @starwhale.argument decorator in the " + f"{inject_name} has been used as a keyword argument in the decorated function, please use another name by the `inject_name` option." ) - kw["argument"] = dataclass_values if is_sequence else dataclass_values[0] + kw[inject_name] = dataclass_values if is_sequence else dataclass_values[0] return func(*args, **kw) return _run_wrapper diff --git a/client/tests/sdk/test_argument.py b/client/tests/sdk/test_argument.py index 7bf5590647..9dfe80fa32 100644 --- a/client/tests/sdk/test_argument.py +++ b/client/tests/sdk/test_argument.py @@ -80,13 +80,18 @@ def argument_keyword_func(argument): with self.assertRaisesRegex(TypeError, "got an unexpected keyword argument"): no_argument_func() - with self.assertRaisesRegex(RuntimeError, "argument is a reserved keyword"): + with self.assertRaisesRegex( + RuntimeError, + "has been used as a keyword argument in the decorated function", + ): argument_keyword_func(argument=1) def test_argument_decorator(self) -> None: - @argument_decorator((ScalarArguments, ComposeArguments)) - def assert_func(argument: t.Tuple) -> None: - scalar_argument, compose_argument = argument + @argument_decorator( + (ScalarArguments, ComposeArguments), inject_name="starwhale_argument" + ) + def assert_func(starwhale_argument: t.Tuple) -> None: + scalar_argument, compose_argument = starwhale_argument assert isinstance(scalar_argument, ScalarArguments) assert isinstance(compose_argument, ComposeArguments)