diff --git a/client/starwhale/api/_impl/argument.py b/client/starwhale/api/_impl/argument.py index 96c1d309ae..4f5de98a02 100644 --- a/client/starwhale/api/_impl/argument.py +++ b/client/starwhale/api/_impl/argument.py @@ -9,6 +9,8 @@ from collections import defaultdict import click +from pydantic import BaseModel, validator +from typing_extensions import Literal from starwhale.utils import console @@ -29,6 +31,36 @@ def get(cls) -> t.List[str]: return cls._args or [] +# Current supported types(ref: (click types)[https://github.com/pallets/click/blob/main/src/click/types.py]): +# 1. primitive types: INT,FLOAT,BOOL,STRING +# 2. Func: FuncParamType, such as: debug: t.Union[str, t.List[DebugOption]] = dataclasses.field(default="", metadata={"help": "debug mode"}) +# we will convert FuncParamType to STRING type to simplify the input implementation. We ignore `func` field. +# 3. Choice: click.Choice type, add choices and case_sensitive options. +class OptionType(BaseModel): + name: str + param_type: str + # for Choice type + choices: t.Optional[t.List[str]] = None + case_sensitive: bool = False + + @validator("param_type", pre=True) + def parse_param_type(cls, value: str) -> str: + value = value.upper() + return "STRING" if value == "FUNC" else value + + +class OptionField(BaseModel): + name: str + opts: t.List[str] + type: OptionType + required: bool = False + multiple: bool = False + default: t.Any = None + help: t.Optional[str] = None + is_flag: bool = False + hidden: bool = False + + class ArgumentContext: _instance = None _lock = threading.Lock() @@ -36,6 +68,7 @@ class ArgumentContext: def __init__(self) -> None: self._click_ctx = click.Context(click.Command("Starwhale Argument Decorator")) self._options: t.Dict[str, list] = defaultdict(list) + self._func_related_dataclasses: t.Dict[str, list] = defaultdict(list) @classmethod def get_current_context(cls) -> ArgumentContext: @@ -44,9 +77,25 @@ def get_current_context(cls) -> ArgumentContext: cls._instance = ArgumentContext() return cls._instance - def add_option(self, option: click.Option, group: str) -> None: + def _key(self, o: t.Any) -> str: + return f"{o.__module__}:{o.__qualname__}" + + def add_dataclass_type(self, func: t.Callable, dtype: t.Any) -> None: + with self._lock: + self._func_related_dataclasses[self._key(func)].append(self._key(dtype)) + + def add_option(self, option: click.Option, dtype: t.Any) -> None: with self._lock: - self._options[group].append(option) + self._options[self._key(dtype)].append(option) + + def asdict(self) -> t.Dict[str, t.Any]: + r: t.Dict = defaultdict(lambda: defaultdict(dict)) + for func, dtypes in self._func_related_dataclasses.items(): + for dtype in dtypes: + for option in self._options[dtype]: + info = OptionField(**option.to_info_dict()) + r[func][dtype][option.name] = info.model_dump(mode="json") + return r def echo_help(self) -> None: if not self._options: @@ -110,7 +159,7 @@ def evaluate_summary(predict_result_iter, starwhale_arguments: EvaluationArgumen def _register_wrapper(func: t.Callable) -> t.Any: # TODO: dump parser to json file when model building - parser = get_parser_from_dataclasses(dataclass_types) + parser = get_parser_from_dataclasses(dataclass_types, func) lock = threading.Lock() parsed_cache: t.Any = None @@ -166,13 +215,19 @@ def init_dataclasses_values( return ret -def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser: +def get_parser_from_dataclasses( + dataclass_types: t.List, deco_func: t.Callable | None = None +) -> click.OptionParser: argument_ctx = ArgumentContext.get_current_context() + parser = click.OptionParser() for dtype in dataclass_types: if not dataclasses.is_dataclass(dtype): raise ValueError(f"{dtype} is not a dataclass type") + if deco_func: + argument_ctx.add_dataclass_type(func=deco_func, dtype=dtype) + type_hints: t.Dict[str, type] = t.get_type_hints(dtype) for field in dataclasses.fields(dtype): if not field.init: @@ -180,9 +235,7 @@ def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser: field.type = type_hints[field.name] option = convert_field_to_option(field) option.add_to_parser(parser=parser, ctx=parser.ctx) # type: ignore - argument_ctx.add_option( - option=option, group=f"{dtype.__module__}.{dtype.__qualname__}" - ) + argument_ctx.add_option(option=option, dtype=dtype) parser.ignore_unknown_options = True return parser @@ -229,16 +282,10 @@ def convert_field_to_option(field: dataclasses.Field) -> click.Option: ) origin_type = getattr(field.type, "__origin__", field.type) - try: - # typing.Literal is only supported in python3.8+ - literal_type = t.Literal # type: ignore[attr-defined] - except AttributeError: - literal_type = None - - if (literal_type and origin_type is literal_type) or ( + if (origin_type is Literal) or ( isinstance(field.type, type) and issubclass(field.type, Enum) ): - if literal_type and origin_type is literal_type: + if origin_type is Literal: kw["type"] = click.Choice(field.type.__args__) else: kw["type"] = click.Choice([e.value for e in field.type]) diff --git a/client/starwhale/consts/__init__.py b/client/starwhale/consts/__init__.py index 35f7232441..16eafe5a59 100644 --- a/client/starwhale/consts/__init__.py +++ b/client/starwhale/consts/__init__.py @@ -21,6 +21,8 @@ # evaluation related constants DEFAULT_JOBS_FILE_NAME = "jobs.yaml" +# dump @starwhale.argument received dataclasses to json file +ARGUMENTS_DUMPED_JSON_FILE_NAME = "arguments.json" # auto generated evaluation panel layout file name from yaml or local console EVALUATION_PANEL_LAYOUT_JSON_FILE_NAME = "eval_panel_layout.json" # user defined evaluation panel layout file name diff --git a/client/starwhale/core/model/model.py b/client/starwhale/core/model/model.py index 8b36da44fa..4c5011db06 100644 --- a/client/starwhale/core/model/model.py +++ b/client/starwhale/core/model/model.py @@ -40,6 +40,7 @@ DEFAULT_RESOURCE_POOL, DEFAULT_JOBS_FILE_NAME, DEFAULT_STARWHALE_API_VERSION, + ARGUMENTS_DUMPED_JSON_FILE_NAME, EVALUATION_PANEL_LAYOUT_JSON_FILE_NAME, EVALUATION_PANEL_LAYOUT_YAML_FILE_NAME, DEFAULT_FILE_SIZE_THRESHOLD_TO_TAR_IN_MODEL, @@ -287,6 +288,16 @@ def _gen_model_serving(self, search_modules: t.List[str], workdir: Path) -> None ) Handler._register(h, func) + def _gen_arguments_json(self) -> None: + from starwhale.api._impl.argument import ArgumentContext + + ctx = ArgumentContext.get_current_context() + ensure_file( + self.store.src_dir / SW_AUTO_DIRNAME / ARGUMENTS_DUMPED_JSON_FILE_NAME, + json.dumps(ctx.asdict(), indent=4), + parents=True, + ) + def _render_eval_layout(self, workdir: Path) -> None: # render eval layout eval_layout = workdir / SW_AUTO_DIRNAME / EVALUATION_PANEL_LAYOUT_YAML_FILE_NAME @@ -664,6 +675,11 @@ def buildImpl(self, workdir: Path, **kw: t.Any) -> None: # type: ignore[overrid / DEFAULT_JOBS_FILE_NAME, ), ), + ( + self._gen_arguments_json, + 5, + "generate arguments json", + ), ( self._render_eval_layout, 1, diff --git a/client/tests/core/test_model.py b/client/tests/core/test_model.py index 6c6e5fe782..0df96c6090 100644 --- a/client/tests/core/test_model.py +++ b/client/tests/core/test_model.py @@ -25,6 +25,7 @@ RESOURCE_FILES_NAME, DEFAULT_MANIFEST_NAME, DEFAULT_JOBS_FILE_NAME, + ARGUMENTS_DUMPED_JSON_FILE_NAME, EVALUATION_PANEL_LAYOUT_JSON_FILE_NAME, EVALUATION_PANEL_LAYOUT_YAML_FILE_NAME, ) @@ -235,6 +236,12 @@ def test_build_workflow( ], } + argument_json_path = ( + bundle_path / "src" / SW_AUTO_DIRNAME / ARGUMENTS_DUMPED_JSON_FILE_NAME + ) + assert argument_json_path.exists() + assert json.loads(argument_json_path.read_text()) == {} + _manifest = load_yaml(bundle_path / DEFAULT_MANIFEST_NAME) assert "name" not in _manifest assert _manifest["version"] == build_version diff --git a/client/tests/sdk/test_argument.py b/client/tests/sdk/test_argument.py index aae43faa51..91799d30cd 100644 --- a/client/tests/sdk/test_argument.py +++ b/client/tests/sdk/test_argument.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import typing as t import dataclasses from enum import Enum @@ -162,7 +163,7 @@ def test_scalar_parser(self) -> None: argument_ctx = ArgumentContext.get_current_context() assert len(argument_ctx._options) == 1 - options = argument_ctx._options["tests.sdk.test_argument.ScalarArguments"] + options = argument_ctx._options["tests.sdk.test_argument:ScalarArguments"] assert len(options) == 5 assert options[0].name == "batch" assert options[-1].name == "epoch" @@ -213,13 +214,13 @@ def test_compose_parser(self) -> None: argument_ctx = ArgumentContext.get_current_context() assert len(argument_ctx._options) == 1 - options = argument_ctx._options["tests.sdk.test_argument.ComposeArguments"] + options = argument_ctx._options["tests.sdk.test_argument:ComposeArguments"] assert len(options) == 6 assert options[0].name == "debug" argument_ctx.echo_help() @patch("click.echo") - def test_argument_help_output(self, mock_echo: MagicMock): + def test_argument_help_output(self, mock_echo: MagicMock) -> None: @argument_decorator((ScalarArguments, ComposeArguments)) def mock_func(starwhale_argument: t.Tuple) -> None: ... @@ -227,13 +228,13 @@ def mock_func(starwhale_argument: t.Tuple) -> None: ArgumentContext.get_current_context().echo_help() help_output = mock_echo.call_args[0][0] cases = [ - "tests.sdk.test_argument.ScalarArguments:", + "tests.sdk.test_argument:ScalarArguments:", "--batch INTEGER", "--overwrite", "--learning_rate, --learning-rate FLOAT", "--half_precision_backend, --half-precision-backend TEXT", "--epoch INTEGER", - "tests.sdk.test_argument.ComposeArguments:", + "tests.sdk.test_argument:ComposeArguments:", "--debug DEBUGOPTION", "--lr_scheduler_kwargs, --lr-scheduler-kwargs DICT", "--evaluation_strategy, --evaluation-strategy [no|steps|epoch]", @@ -243,3 +244,90 @@ def mock_func(starwhale_argument: t.Tuple) -> None: ] for case in cases: assert case in help_output + + def test_argument_dict(self) -> None: + @argument_decorator((ScalarArguments, ComposeArguments)) + def mock_f1(starwhale_argument: t.Tuple) -> None: + ... + + @argument_decorator(ScalarArguments) + def mock_f2(starwhale_argument: t.Tuple) -> None: + ... + + @argument_decorator(ComposeArguments) + def mock_f3(starwhale_argument: t.Tuple) -> None: + ... + + info = ArgumentContext.get_current_context().asdict() + assert len(info) == 3 + assert list( + info[ + "tests.sdk.test_argument:ArgumentTestCase.test_argument_dict..mock_f1" + ].keys() + ) == [ + "tests.sdk.test_argument:ScalarArguments", + "tests.sdk.test_argument:ComposeArguments", + ] + batch = info[ + "tests.sdk.test_argument:ArgumentTestCase.test_argument_dict..mock_f1" + ]["tests.sdk.test_argument:ScalarArguments"]["batch"] + assert batch == { + "name": "batch", + "opts": ["--batch"], + "type": { + "name": "integer", + "param_type": "INT", + "case_sensitive": False, + "choices": None, + }, + "required": False, + "multiple": False, + "default": 64, + "help": "batch size", + "is_flag": False, + "hidden": False, + } + + evaluation_strategy = info[ + "tests.sdk.test_argument:ArgumentTestCase.test_argument_dict..mock_f1" + ]["tests.sdk.test_argument:ComposeArguments"]["evaluation_strategy"] + + assert evaluation_strategy == { + "default": "no", + "help": "evaluation strategy", + "hidden": False, + "is_flag": False, + "multiple": False, + "name": "evaluation_strategy", + "opts": ["--evaluation_strategy", "--evaluation-strategy"], + "required": False, + "type": { + "case_sensitive": True, + "choices": ["no", "steps", "epoch"], + "name": "choice", + "param_type": "CHOICE", + }, + } + + debug = info[ + "tests.sdk.test_argument:ArgumentTestCase.test_argument_dict..mock_f1" + ]["tests.sdk.test_argument:ComposeArguments"]["debug"] + + assert debug == { + "default": None, + "help": "debug mode", + "hidden": False, + "is_flag": False, + "multiple": True, + "name": "debug", + "opts": ["--debug"], + "required": False, + "type": { + "name": "DebugOption", + "param_type": "STRING", + "case_sensitive": False, + "choices": None, + }, + } + + assert json.loads(json.dumps(info)) == info