Skip to content

Commit

Permalink
support inject_name option for argument decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Dec 28, 2023
1 parent a8b3f3e commit 29175f3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
15 changes: 11 additions & 4 deletions client/starwhale/api/_impl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get(cls) -> t.List[str]:
return cls._args or []


def argument(dataclass_types: t.Any) -> t.Any:
def argument(dataclass_types: t.Any, inject_name: str = "argument") -> t.Any:
"""argument is a decorator function to define arguments for model running(predict, evaluate, serve and finetune).
The decorated function will receive the instances of the dataclass types as the arguments.
Expand All @@ -40,6 +40,8 @@ def argument(dataclass_types: t.Any) -> t.Any:
Argument:
dataclass_types: [required] The dataclass type of the arguments.
A list of dataclass types or a single dataclass type is supported.
inject_name: [optional] The name of the keyword argument that will be passed to the decorated function.
Default is "argument".
Examples:
```python
Expand All @@ -53,6 +55,11 @@ class EvaluationArguments:
@evaluation.predict
def predict_image(data, argument: EvaluationArguments):
...
@argument(EvaluationArguments, inject_name="starwhale_arguments")
@evaluation.evaluate(needs=[])
def evaluate_summary(predict_result_iter, starwhale_arguments: EvaluationArguments):
...
```
"""
is_sequence = True
Expand All @@ -69,11 +76,11 @@ def _register_wrapper(func: t.Callable) -> t.Any:
@wraps(func)
def _run_wrapper(*args: t.Any, **kw: t.Any) -> t.Any:
dataclass_values = init_dataclasses_values(parser, dataclass_types)
if "argument" in kw:
if inject_name in kw:
raise RuntimeError(
"argument is a reserved keyword for @starwhale.argument decorator in the "
f"{inject_name} has been used as a keyword argument in the decorated function, please use another name by the `inject_name` option."
)
kw["argument"] = dataclass_values if is_sequence else dataclass_values[0]
kw[inject_name] = dataclass_values if is_sequence else dataclass_values[0]
return func(*args, **kw)

return _run_wrapper
Expand Down
13 changes: 9 additions & 4 deletions client/tests/sdk/test_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,18 @@ def argument_keyword_func(argument):
with self.assertRaisesRegex(TypeError, "got an unexpected keyword argument"):
no_argument_func()

with self.assertRaisesRegex(RuntimeError, "argument is a reserved keyword"):
with self.assertRaisesRegex(
RuntimeError,
"has been used as a keyword argument in the decorated function",
):
argument_keyword_func(argument=1)

def test_argument_decorator(self) -> None:
@argument_decorator((ScalarArguments, ComposeArguments))
def assert_func(argument: t.Tuple) -> None:
scalar_argument, compose_argument = argument
@argument_decorator(
(ScalarArguments, ComposeArguments), inject_name="starwhale_argument"
)
def assert_func(starwhale_argument: t.Tuple) -> None:
scalar_argument, compose_argument = starwhale_argument
assert isinstance(scalar_argument, ScalarArguments)
assert isinstance(compose_argument, ComposeArguments)

Expand Down

0 comments on commit 29175f3

Please sign in to comment.