Skip to content

Commit

Permalink
add argument decorator test
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Dec 27, 2023
1 parent b62096e commit b9500a0
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 12 deletions.
42 changes: 30 additions & 12 deletions client/starwhale/api/_impl/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,19 @@ def init_dataclasses_values(
ret = []
for dtype in dataclass_types:
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
inputs = {k: param_map[k].type(v) for k, v in args_map.items() if k in keys}
inputs = {}
for k, v in args_map.items():
if k not in keys:
continue

# TODO: support dict type convert
# handle multiple args for list type
if isinstance(v, list):
v = [param_map[k].type(i) for i in v]
else:
v = param_map[k].type(v)
inputs[k] = v

for k in inputs:
del args_map[k]
ret.append(dtype(**inputs))
Expand Down Expand Up @@ -118,8 +130,11 @@ def get_parser_from_dataclasses(dataclass_types: t.Any) -> click.OptionParser:

def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field) -> None:
# TODO: field.name need format for click option?
decls = [f"--{field.name}"]
if "_" in field.name:
decls.append(f"--{field.name.replace('_', '-')}")
kw: t.Dict[str, t.Any] = {
"param_decls": [f"--{field.name}"],
"param_decls": decls,
"help": field.metadata.get("help"),
"show_default": True,
"hidden": field.metadata.get("hidden", False),
Expand All @@ -129,9 +144,9 @@ def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field)
# only support: Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type
origin_type = getattr(field.type, "__origin__", field.type)
if origin_type is t.Union:
if (str not in field.type.__args and type(None) not in field.type.__args__) or (
len(field.type.__args__) != 2
):
if (
str not in field.type.__args__ and type(None) not in field.type.__args__
) or (len(field.type.__args__) != 2):
raise ValueError(
f"{field.type} is not supported."
"Only support Union[xxx, None] or Union[EnumType, str] or [List[EnumType], str] type"
Expand Down Expand Up @@ -171,14 +186,17 @@ def add_field_into_parser(parser: click.OptionParser, field: dataclasses.Field)
kw["is_flag"] = True
kw["type"] = bool
kw["default"] = False if field.default is dataclasses.MISSING else field.default
elif inspect.isclass(origin_type) and issubclass(origin_type, list):
kw["type"] = field.type.__args__[0]
kw["multiple"] = True
if field.default is not dataclasses.MISSING:
kw["default"] = field.default
elif field.default_factory is not dataclasses.MISSING:
elif inspect.isclass(origin_type) and issubclass(origin_type, (list, dict)):
if issubclass(origin_type, list):
kw["type"] = field.type.__args__[0]
kw["multiple"] = True
elif issubclass(origin_type, dict):
kw["type"] = dict

# list and dict types both need default_factory
if field.default_factory is not dataclasses.MISSING:
kw["default"] = field.default_factory()
else:
elif field.default is dataclasses.MISSING:
kw["required"] = True
else:
kw["type"] = field.type
Expand Down
196 changes: 196 additions & 0 deletions client/tests/sdk/test_argument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
from __future__ import annotations

import typing as t
import dataclasses
from enum import Enum

import click
from pyfakefs.fake_filesystem_unittest import TestCase

from starwhale.api._impl.argument import argument as argument_decorator
from starwhale.api._impl.argument import (
ExtraCliArgsRegistry,
get_parser_from_dataclasses,
)


class IntervalStrategy(Enum):
NO = "no"
STEPS = "steps"
EPOCH = "epoch"


class DebugOption(Enum):
UNDERFLOW_OVERFLOW = "underflow_overflow"
TPU_METRICS_DEBUG = "tpu_metrics_debug"


@dataclasses.dataclass
class ScalarArguments:
no_field = 1
batch: int = dataclasses.field(default=64, metadata={"help": "batch size"})
overwrite: bool = dataclasses.field(default=False, metadata={"help": "overwrite"})
learning_rate: float = dataclasses.field(
default=0.01, metadata={"help": "learning rate"}
)
half_precision_backend: str = dataclasses.field(
default="auto", metadata={"help": "half precision backend"}
)
epoch: int = dataclasses.field(default_factory=lambda: 1)


@dataclasses.dataclass
class ComposeArguments:
# simply huggingface transformers TrainingArguments for test
debug: t.Union[str, t.List[DebugOption]] = dataclasses.field(
default="", metadata={"help": "debug mode"}
)

lr_scheduler_kwargs: t.Optional[t.Dict] = dataclasses.field(
default_factory=dict, metadata={"help": "lr scheduler kwargs"}
)
evaluation_strategy: t.Union[IntervalStrategy, str] = dataclasses.field(
default="no", metadata={"help": "evaluation strategy"}
)
per_gpu_train_batch_size: t.Optional[int] = dataclasses.field(default=None)
eval_delay: t.Optional[float] = dataclasses.field(
default=0, metadata={"help": "evaluation delay"}
)
label_names: t.Optional[t.List[str]] = dataclasses.field(
default=None, metadata={"help": "label names"}
)


class ArgumentTestCase(TestCase):
def setUp(self) -> None:
self.setUpPyfakefs()

def tearDown(self) -> None:
ExtraCliArgsRegistry._args = None

def test_argument_exceptions(self) -> None:
@argument_decorator(ScalarArguments)
def no_argument_func():
...

@argument_decorator(ScalarArguments)
def argument_keyword_func(argument):
...

with self.assertRaisesRegex(TypeError, "got an unexpected keyword argument"):
no_argument_func()

with self.assertRaisesRegex(RuntimeError, "argument is a reserved keyword"):
argument_keyword_func(argument=1)

def test_argument_decorator(self) -> None:
@argument_decorator((ScalarArguments, ComposeArguments))
def assert_func(argument: t.Tuple) -> None:
scalar_argument, compose_argument = argument
assert isinstance(scalar_argument, ScalarArguments)
assert isinstance(compose_argument, ComposeArguments)

assert scalar_argument.batch == 128
assert scalar_argument.overwrite is True
assert scalar_argument.learning_rate == 0.02
assert scalar_argument.half_precision_backend == "auto"
assert scalar_argument.epoch == 1

assert compose_argument.label_names == ["a", "b", "c"]
assert compose_argument.eval_delay == 0
assert compose_argument.per_gpu_train_batch_size == 8
assert compose_argument.evaluation_strategy == "steps"
assert compose_argument.debug == [DebugOption.UNDERFLOW_OVERFLOW]

ExtraCliArgsRegistry.set(
[
"--batch",
"128",
"--overwrite",
"--learning-rate=0.02",
"--debug",
"underflow_overflow",
"--evaluation_strategy",
"steps",
"--per_gpu_train_batch_size",
"8",
"--label_names",
"a",
"--label_names",
"b",
"--label_names",
"c",
"--no-defined-arg=1",
]
)
assert_func()

def test_parser_exceptions(self) -> None:
with self.assertRaisesRegex(ValueError, "is not a dataclass type"):
get_parser_from_dataclasses([None])

def test_scalar_parser(self) -> None:
scalar_parser = get_parser_from_dataclasses([ScalarArguments])
assert scalar_parser.ignore_unknown_options

assert "--no_field" not in scalar_parser._long_opt

batch = scalar_parser._long_opt["--batch"].obj
assert batch.type == click.INT
assert not batch.required
assert batch.help == "batch size"
assert not batch.is_flag
assert batch.default == 64
overwrite = scalar_parser._long_opt["--overwrite"].obj
assert overwrite.type == click.BOOL
assert overwrite.is_flag
assert overwrite.default is False
assert scalar_parser._long_opt["--learning-rate"].obj.type == click.FLOAT
assert (
scalar_parser._long_opt["--half_precision_backend"].obj.type == click.STRING
)
assert scalar_parser._long_opt["--epoch"].obj.type == click.INT
assert scalar_parser._long_opt["--epoch"].obj.default == 1

def test_compose_parser(self) -> None:
compose_parser = get_parser_from_dataclasses([ComposeArguments])

dict_obj = compose_parser._long_opt["--lr-scheduler-kwargs"].obj
assert not dict_obj.required
assert dict_obj.default == {}
assert not dict_obj.multiple
assert isinstance(dict_obj.type, click.types.FuncParamType)
assert dict_obj.type.func == dict

union_enum_obj = compose_parser._long_opt["--evaluation_strategy"].obj
assert not union_enum_obj.required
assert union_enum_obj.default == "no"
assert isinstance(union_enum_obj.type, click.Choice)
assert union_enum_obj.type.choices == ["no", "steps", "epoch"]
assert union_enum_obj.show_choices
assert not union_enum_obj.multiple

union_list_obj = compose_parser._long_opt["--debug"].obj
assert isinstance(union_list_obj.type, click.types.FuncParamType)
assert union_list_obj.type.func == DebugOption
assert not union_list_obj.required
assert union_list_obj.default is None
assert union_list_obj.multiple

optional_int_obj = compose_parser._long_opt["--per_gpu_train_batch_size"].obj
assert optional_int_obj.type == click.INT
assert not optional_int_obj.required
assert optional_int_obj.default is None
assert not optional_int_obj.multiple

optional_float_obj = compose_parser._long_opt["--eval_delay"].obj
assert optional_float_obj.type == click.FLOAT
assert not optional_float_obj.required
assert optional_float_obj.default == 0
assert not optional_float_obj.multiple

optional_list_obj = compose_parser._long_opt["--label_names"].obj
assert optional_list_obj.type == click.STRING
assert not optional_list_obj.required
assert optional_list_obj.multiple
assert optional_list_obj.default is None

0 comments on commit b9500a0

Please sign in to comment.