|
14 | 14 | import contextlib |
15 | 15 | import copy |
16 | 16 | import gc |
| 17 | +import inspect |
17 | 18 | import json |
18 | 19 | import logging |
19 | 20 | import os |
|
26 | 27 | Callable, |
27 | 28 | ContextManager, |
28 | 29 | Dict, |
| 30 | + get_args, |
| 31 | + get_origin, |
29 | 32 | List, |
30 | 33 | Optional, |
31 | 34 | Set, |
@@ -467,47 +470,60 @@ def set_embedding_config( |
467 | 470 |
|
468 | 471 |
|
469 | 472 | # pyre-ignore [24] |
470 | | -def cmd_conf(*configs: Any) -> Callable: |
471 | | - support_classes: List[Any] = [int, str, bool, float, Enum] # pyre-ignore[33] |
472 | | - |
473 | | - # pyre-ignore [24] |
474 | | - def wrapper(func: Callable) -> Callable: |
475 | | - for config in configs: |
476 | | - assert is_dataclass(config), f"{config} should be a dataclass" |
477 | | - |
478 | | - # pyre-ignore |
479 | | - def rtf(**kwargs): |
480 | | - loglevel = logging._nameToLevel[kwargs["loglevel"].upper()] |
481 | | - logger.setLevel(logging.INFO) |
482 | | - input_configs = [] |
483 | | - for config in configs: |
484 | | - params = {} |
485 | | - for field in fields(config): |
486 | | - params[field.name] = kwargs.get(field.name, field.default) |
487 | | - conf = config(**params) |
488 | | - logger.info(conf) |
489 | | - input_configs.append(conf) |
490 | | - logger.setLevel(loglevel) |
491 | | - return func(*input_configs) |
492 | | - |
493 | | - names: Set[str] = set() |
494 | | - for config in configs: |
495 | | - for field in fields(config): |
496 | | - if not isinstance(field.default, tuple(support_classes)): |
497 | | - continue |
498 | | - if field.name not in names: |
499 | | - names.add(field.name) |
| 473 | +def cmd_conf(func: Callable) -> Callable: |
| 474 | + |
| 475 | + # pyre-ignore [3] |
| 476 | + def wrapper() -> Any: |
| 477 | + sig = inspect.signature(func) |
| 478 | + parser = argparse.ArgumentParser(func.__doc__) |
| 479 | + |
| 480 | + seen_args = set() # track all --<name> we've added |
| 481 | + |
| 482 | + for _name, param in sig.parameters.items(): |
| 483 | + cls = param.annotation |
| 484 | + if not is_dataclass(cls): |
| 485 | + continue |
| 486 | + |
| 487 | + for f in fields(cls): |
| 488 | + arg_name = f.name |
| 489 | + if arg_name in seen_args: |
| 490 | + parser.error(f"Duplicate argument {arg_name}") |
| 491 | + seen_args.add(arg_name) |
| 492 | + |
| 493 | + ftype = f.type |
| 494 | + origin = get_origin(ftype) |
| 495 | + |
| 496 | + # Unwrapping Optional[X] to X |
| 497 | + if origin is Union and type(None) in get_args(ftype): |
| 498 | + non_none = [t for t in get_args(ftype) if t is not type(None)] |
| 499 | + if len(non_none) == 1: |
| 500 | + ftype = non_none[0] |
| 501 | + origin = get_origin(ftype) |
| 502 | + |
| 503 | + arg_kwargs = { |
| 504 | + "default": f.default, |
| 505 | + "help": f"({cls.__name__}) {arg_name}", |
| 506 | + } |
| 507 | + |
| 508 | + if origin in (list, List): |
| 509 | + elem_type = get_args(ftype)[0] |
| 510 | + arg_kwargs.update(nargs="*", type=elem_type) |
500 | 511 | else: |
501 | | - logger.warn(f"WARNING: duplicate argument {field.name}") |
502 | | - continue |
503 | | - rtf = click.option( |
504 | | - f"--{field.name}", type=field.type, default=field.default |
505 | | - )(rtf) |
506 | | - return click.option( |
507 | | - "--loglevel", |
508 | | - type=click.Choice(list(logging._nameToLevel.keys()), case_sensitive=False), |
509 | | - default=logging._levelToName[logger.level], |
510 | | - )(rtf) |
| 512 | + arg_kwargs.update(type=ftype) |
| 513 | + |
| 514 | + parser.add_argument(f"--{arg_name}", **arg_kwargs) |
| 515 | + |
| 516 | + args = parser.parse_args() |
| 517 | + |
| 518 | + # Build the dataclasses |
| 519 | + kwargs = {} |
| 520 | + for name, param in sig.parameters.items(): |
| 521 | + cls = param.annotation |
| 522 | + if is_dataclass(cls): |
| 523 | + data = {f.name: getattr(args, f.name) for f in fields(cls)} |
| 524 | + kwargs[name] = cls(**data) # pyre-ignore [29] |
| 525 | + |
| 526 | + return func(**kwargs) |
511 | 527 |
|
512 | 528 | return wrapper |
513 | 529 |
|
|
0 commit comments