Skip to content

Commit

Permalink
dump arguments to json file for model building
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Jan 3, 2024
1 parent 2ce80aa commit 1c92d53
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 20 deletions.
77 changes: 62 additions & 15 deletions client/starwhale/api/_impl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from collections import defaultdict

import click
from pydantic import BaseModel, validator
from typing_extensions import Literal

from starwhale.utils import console

Expand All @@ -29,13 +31,44 @@ def get(cls) -> t.List[str]:
return cls._args or []


# Current supported types(ref: (click types)[https://github.com/pallets/click/blob/main/src/click/types.py]):
# 1. primitive types: INT,FLOAT,BOOL,STRING
# 2. Func: FuncParamType, such as: debug: t.Union[str, t.List[DebugOption]] = dataclasses.field(default="", metadata={"help": "debug mode"})
# we will convert FuncParamType to STRING type to simplify the input implementation. We ignore `func` field.
# 3. Choice: click.Choice type, add choices and case_sensitive options.
class OptionType(BaseModel):
name: str
param_type: str
# for Choice type
choices: t.Optional[t.List[str]] = None
case_sensitive: bool = False

@validator("param_type", pre=True)
def parse_param_type(cls, value: str) -> str:
value = value.upper()
return "STRING" if value == "FUNC" else value


class OptionField(BaseModel):
name: str
opts: t.List[str]
type: OptionType
required: bool = False
multiple: bool = False
default: t.Any = None
help: t.Optional[str] = None
is_flag: bool = False
hidden: bool = False


class ArgumentContext:
_instance = None
_lock = threading.Lock()

def __init__(self) -> None:
self._click_ctx = click.Context(click.Command("Starwhale Argument Decorator"))
self._options: t.Dict[str, list] = defaultdict(list)
self._func_related_dataclasses: t.Dict[str, list] = defaultdict(list)

@classmethod
def get_current_context(cls) -> ArgumentContext:
Expand All @@ -44,9 +77,25 @@ def get_current_context(cls) -> ArgumentContext:
cls._instance = ArgumentContext()
return cls._instance

def add_option(self, option: click.Option, group: str) -> None:
def _key(self, o: t.Any) -> str:
return f"{o.__module__}:{o.__qualname__}"

def add_dataclass_type(self, func: t.Callable, dtype: t.Any) -> None:
with self._lock:
self._func_related_dataclasses[self._key(func)].append(self._key(dtype))

def add_option(self, option: click.Option, dtype: t.Any) -> None:
with self._lock:
self._options[group].append(option)
self._options[self._key(dtype)].append(option)

def asdict(self) -> t.Dict[str, t.Any]:
r: t.Dict = defaultdict(lambda: defaultdict(dict))
for func, dtypes in self._func_related_dataclasses.items():
for dtype in dtypes:
for option in self._options[dtype]:
info = OptionField(**option.to_info_dict())
r[func][dtype][option.name] = info.model_dump(mode="json")
return r

def echo_help(self) -> None:
if not self._options:
Expand Down Expand Up @@ -110,7 +159,7 @@ def evaluate_summary(predict_result_iter, starwhale_arguments: EvaluationArgumen

def _register_wrapper(func: t.Callable) -> t.Any:
# TODO: dump parser to json file when model building
parser = get_parser_from_dataclasses(dataclass_types)
parser = get_parser_from_dataclasses(dataclass_types, func)
lock = threading.Lock()
parsed_cache: t.Any = None

Expand Down Expand Up @@ -166,23 +215,27 @@ def init_dataclasses_values(
return ret


def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser:
def get_parser_from_dataclasses(
dataclass_types: t.List, deco_func: t.Callable | None = None
) -> click.OptionParser:
argument_ctx = ArgumentContext.get_current_context()

parser = click.OptionParser()
for dtype in dataclass_types:
if not dataclasses.is_dataclass(dtype):
raise ValueError(f"{dtype} is not a dataclass type")

if deco_func:
argument_ctx.add_dataclass_type(func=deco_func, dtype=dtype)

type_hints: t.Dict[str, type] = t.get_type_hints(dtype)
for field in dataclasses.fields(dtype):
if not field.init:
continue
field.type = type_hints[field.name]
option = convert_field_to_option(field)
option.add_to_parser(parser=parser, ctx=parser.ctx) # type: ignore
argument_ctx.add_option(
option=option, group=f"{dtype.__module__}.{dtype.__qualname__}"
)
argument_ctx.add_option(option=option, dtype=dtype)

parser.ignore_unknown_options = True
return parser
Expand Down Expand Up @@ -229,16 +282,10 @@ def convert_field_to_option(field: dataclasses.Field) -> click.Option:
)
origin_type = getattr(field.type, "__origin__", field.type)

try:
# typing.Literal is only supported in python3.8+
literal_type = t.Literal # type: ignore[attr-defined]
except AttributeError:
literal_type = None

if (literal_type and origin_type is literal_type) or (
if (origin_type is Literal) or (
isinstance(field.type, type) and issubclass(field.type, Enum)
):
if literal_type and origin_type is literal_type:
if origin_type is Literal:
kw["type"] = click.Choice(field.type.__args__)
else:
kw["type"] = click.Choice([e.value for e in field.type])
Expand Down
2 changes: 2 additions & 0 deletions client/starwhale/consts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

# evaluation related constants
DEFAULT_JOBS_FILE_NAME = "jobs.yaml"
# dump @starwhale.argument received dataclasses to json file
ARGUMENTS_DUMPED_JSON_FILE_NAME = "arguments.json"
# auto generated evaluation panel layout file name from yaml or local console
EVALUATION_PANEL_LAYOUT_JSON_FILE_NAME = "eval_panel_layout.json"
# user defined evaluation panel layout file name
Expand Down
16 changes: 16 additions & 0 deletions client/starwhale/core/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
DEFAULT_RESOURCE_POOL,
DEFAULT_JOBS_FILE_NAME,
DEFAULT_STARWHALE_API_VERSION,
ARGUMENTS_DUMPED_JSON_FILE_NAME,
EVALUATION_PANEL_LAYOUT_JSON_FILE_NAME,
EVALUATION_PANEL_LAYOUT_YAML_FILE_NAME,
DEFAULT_FILE_SIZE_THRESHOLD_TO_TAR_IN_MODEL,
Expand Down Expand Up @@ -287,6 +288,16 @@ def _gen_model_serving(self, search_modules: t.List[str], workdir: Path) -> None
)
Handler._register(h, func)

def _gen_arguments_json(self) -> None:
from starwhale.api._impl.argument import ArgumentContext

ctx = ArgumentContext.get_current_context()
ensure_file(
self.store.src_dir / SW_AUTO_DIRNAME / ARGUMENTS_DUMPED_JSON_FILE_NAME,
json.dumps(ctx.asdict(), indent=4),
parents=True,
)

def _render_eval_layout(self, workdir: Path) -> None:
# render eval layout
eval_layout = workdir / SW_AUTO_DIRNAME / EVALUATION_PANEL_LAYOUT_YAML_FILE_NAME
Expand Down Expand Up @@ -664,6 +675,11 @@ def buildImpl(self, workdir: Path, **kw: t.Any) -> None: # type: ignore[overrid
/ DEFAULT_JOBS_FILE_NAME,
),
),
(
self._gen_arguments_json,
5,
"generate arguments json",
),
(
self._render_eval_layout,
1,
Expand Down
7 changes: 7 additions & 0 deletions client/tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
RESOURCE_FILES_NAME,
DEFAULT_MANIFEST_NAME,
DEFAULT_JOBS_FILE_NAME,
ARGUMENTS_DUMPED_JSON_FILE_NAME,
EVALUATION_PANEL_LAYOUT_JSON_FILE_NAME,
EVALUATION_PANEL_LAYOUT_YAML_FILE_NAME,
)
Expand Down Expand Up @@ -235,6 +236,12 @@ def test_build_workflow(
],
}

argument_json_path = (
bundle_path / "src" / SW_AUTO_DIRNAME / ARGUMENTS_DUMPED_JSON_FILE_NAME
)
assert argument_json_path.exists()
assert json.loads(argument_json_path.read_text()) == {}

_manifest = load_yaml(bundle_path / DEFAULT_MANIFEST_NAME)
assert "name" not in _manifest
assert _manifest["version"] == build_version
Expand Down
98 changes: 93 additions & 5 deletions client/tests/sdk/test_argument.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import typing as t
import dataclasses
from enum import Enum
Expand Down Expand Up @@ -162,7 +163,7 @@ def test_scalar_parser(self) -> None:

argument_ctx = ArgumentContext.get_current_context()
assert len(argument_ctx._options) == 1
options = argument_ctx._options["tests.sdk.test_argument.ScalarArguments"]
options = argument_ctx._options["tests.sdk.test_argument:ScalarArguments"]
assert len(options) == 5
assert options[0].name == "batch"
assert options[-1].name == "epoch"
Expand Down Expand Up @@ -213,27 +214,27 @@ def test_compose_parser(self) -> None:

argument_ctx = ArgumentContext.get_current_context()
assert len(argument_ctx._options) == 1
options = argument_ctx._options["tests.sdk.test_argument.ComposeArguments"]
options = argument_ctx._options["tests.sdk.test_argument:ComposeArguments"]
assert len(options) == 6
assert options[0].name == "debug"
argument_ctx.echo_help()

@patch("click.echo")
def test_argument_help_output(self, mock_echo: MagicMock):
def test_argument_help_output(self, mock_echo: MagicMock) -> None:
@argument_decorator((ScalarArguments, ComposeArguments))
def mock_func(starwhale_argument: t.Tuple) -> None:
...

ArgumentContext.get_current_context().echo_help()
help_output = mock_echo.call_args[0][0]
cases = [
"tests.sdk.test_argument.ScalarArguments:",
"tests.sdk.test_argument:ScalarArguments:",
"--batch INTEGER",
"--overwrite",
"--learning_rate, --learning-rate FLOAT",
"--half_precision_backend, --half-precision-backend TEXT",
"--epoch INTEGER",
"tests.sdk.test_argument.ComposeArguments:",
"tests.sdk.test_argument:ComposeArguments:",
"--debug DEBUGOPTION",
"--lr_scheduler_kwargs, --lr-scheduler-kwargs DICT",
"--evaluation_strategy, --evaluation-strategy [no|steps|epoch]",
Expand All @@ -243,3 +244,90 @@ def mock_func(starwhale_argument: t.Tuple) -> None:
]
for case in cases:
assert case in help_output

def test_argument_dict(self) -> None:
@argument_decorator((ScalarArguments, ComposeArguments))
def mock_f1(starwhale_argument: t.Tuple) -> None:
...

@argument_decorator(ScalarArguments)
def mock_f2(starwhale_argument: t.Tuple) -> None:
...

@argument_decorator(ComposeArguments)
def mock_f3(starwhale_argument: t.Tuple) -> None:
...

info = ArgumentContext.get_current_context().asdict()
assert len(info) == 3
assert list(
info[
"tests.sdk.test_argument:ArgumentTestCase.test_argument_dict.<locals>.mock_f1"
].keys()
) == [
"tests.sdk.test_argument:ScalarArguments",
"tests.sdk.test_argument:ComposeArguments",
]
batch = info[
"tests.sdk.test_argument:ArgumentTestCase.test_argument_dict.<locals>.mock_f1"
]["tests.sdk.test_argument:ScalarArguments"]["batch"]
assert batch == {
"name": "batch",
"opts": ["--batch"],
"type": {
"name": "integer",
"param_type": "INT",
"case_sensitive": False,
"choices": None,
},
"required": False,
"multiple": False,
"default": 64,
"help": "batch size",
"is_flag": False,
"hidden": False,
}

evaluation_strategy = info[
"tests.sdk.test_argument:ArgumentTestCase.test_argument_dict.<locals>.mock_f1"
]["tests.sdk.test_argument:ComposeArguments"]["evaluation_strategy"]

assert evaluation_strategy == {
"default": "no",
"help": "evaluation strategy",
"hidden": False,
"is_flag": False,
"multiple": False,
"name": "evaluation_strategy",
"opts": ["--evaluation_strategy", "--evaluation-strategy"],
"required": False,
"type": {
"case_sensitive": True,
"choices": ["no", "steps", "epoch"],
"name": "choice",
"param_type": "CHOICE",
},
}

debug = info[
"tests.sdk.test_argument:ArgumentTestCase.test_argument_dict.<locals>.mock_f1"
]["tests.sdk.test_argument:ComposeArguments"]["debug"]

assert debug == {
"default": None,
"help": "debug mode",
"hidden": False,
"is_flag": False,
"multiple": True,
"name": "debug",
"opts": ["--debug"],
"required": False,
"type": {
"name": "DebugOption",
"param_type": "STRING",
"case_sensitive": False,
"choices": None,
},
}

assert json.loads(json.dumps(info)) == info

0 comments on commit 1c92d53

Please sign in to comment.