Skip to content

Commit 3b386f4

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Refactored cmd_conf in the benchmark_utils (#3079)
Summary: Pull Request resolved: #3079 * Refactored the cmd_conf function for easier use. You can now use the m3rlin45_conf decorator without needing to duplicate the parameters again. * Added the support of `Optional[List[Any]]` types, which will be used later for pooling factors configurations (`Optional[List[float]]`) Reviewed By: iamzainhuda Differential Revision: D76402139 fbshipit-source-id: 2d8d88427a0b5b9b3cedb39e7920cd617bf2c2df
1 parent 4e43395 commit 3b386f4

File tree

2 files changed

+57
-42
lines changed

2 files changed

+57
-42
lines changed

torchrec/distributed/benchmark/benchmark_train_sparsenn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,7 @@ def generate_pipeline(
229229
raise RuntimeError(f"unknown pipeline option {self.pipeline}")
230230

231231

232-
@click.command()
233-
@cmd_conf(RunOptions, EmbeddingTablesConfig, TestSparseNNInputConfig, PipelineConfig)
232+
@cmd_conf
234233
def main(
235234
run_option: RunOptions,
236235
table_config: EmbeddingTablesConfig,

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import contextlib
1515
import copy
1616
import gc
17+
import inspect
1718
import json
1819
import logging
1920
import os
@@ -26,6 +27,8 @@
2627
Callable,
2728
ContextManager,
2829
Dict,
30+
get_args,
31+
get_origin,
2932
List,
3033
Optional,
3134
Set,
@@ -467,47 +470,60 @@ def set_embedding_config(
467470

468471

469472
# pyre-ignore [24]
470-
def cmd_conf(*configs: Any) -> Callable:
471-
support_classes: List[Any] = [int, str, bool, float, Enum] # pyre-ignore[33]
472-
473-
# pyre-ignore [24]
474-
def wrapper(func: Callable) -> Callable:
475-
for config in configs:
476-
assert is_dataclass(config), f"{config} should be a dataclass"
477-
478-
# pyre-ignore
479-
def rtf(**kwargs):
480-
loglevel = logging._nameToLevel[kwargs["loglevel"].upper()]
481-
logger.setLevel(logging.INFO)
482-
input_configs = []
483-
for config in configs:
484-
params = {}
485-
for field in fields(config):
486-
params[field.name] = kwargs.get(field.name, field.default)
487-
conf = config(**params)
488-
logger.info(conf)
489-
input_configs.append(conf)
490-
logger.setLevel(loglevel)
491-
return func(*input_configs)
492-
493-
names: Set[str] = set()
494-
for config in configs:
495-
for field in fields(config):
496-
if not isinstance(field.default, tuple(support_classes)):
497-
continue
498-
if field.name not in names:
499-
names.add(field.name)
473+
def cmd_conf(func: Callable) -> Callable:
474+
475+
# pyre-ignore [3]
476+
def wrapper() -> Any:
477+
sig = inspect.signature(func)
478+
parser = argparse.ArgumentParser(func.__doc__)
479+
480+
seen_args = set() # track all --<name> we've added
481+
482+
for _name, param in sig.parameters.items():
483+
cls = param.annotation
484+
if not is_dataclass(cls):
485+
continue
486+
487+
for f in fields(cls):
488+
arg_name = f.name
489+
if arg_name in seen_args:
490+
parser.error(f"Duplicate argument {arg_name}")
491+
seen_args.add(arg_name)
492+
493+
ftype = f.type
494+
origin = get_origin(ftype)
495+
496+
# Unwrapping Optional[X] to X
497+
if origin is Union and type(None) in get_args(ftype):
498+
non_none = [t for t in get_args(ftype) if t is not type(None)]
499+
if len(non_none) == 1:
500+
ftype = non_none[0]
501+
origin = get_origin(ftype)
502+
503+
arg_kwargs = {
504+
"default": f.default,
505+
"help": f"({cls.__name__}) {arg_name}",
506+
}
507+
508+
if origin in (list, List):
509+
elem_type = get_args(ftype)[0]
510+
arg_kwargs.update(nargs="*", type=elem_type)
500511
else:
501-
logger.warn(f"WARNING: duplicate argument {field.name}")
502-
continue
503-
rtf = click.option(
504-
f"--{field.name}", type=field.type, default=field.default
505-
)(rtf)
506-
return click.option(
507-
"--loglevel",
508-
type=click.Choice(list(logging._nameToLevel.keys()), case_sensitive=False),
509-
default=logging._levelToName[logger.level],
510-
)(rtf)
512+
arg_kwargs.update(type=ftype)
513+
514+
parser.add_argument(f"--{arg_name}", **arg_kwargs)
515+
516+
args = parser.parse_args()
517+
518+
# Build the dataclasses
519+
kwargs = {}
520+
for name, param in sig.parameters.items():
521+
cls = param.annotation
522+
if is_dataclass(cls):
523+
data = {f.name: getattr(args, f.name) for f in fields(cls)}
524+
kwargs[name] = cls(**data) # pyre-ignore [29]
525+
526+
return func(**kwargs)
511527

512528
return wrapper
513529

0 commit comments

Comments
 (0)