Skip to content

Commit

Permalink
(torchx/config) support builtin argument defaults from .torchxconfig
Browse files Browse the repository at this point in the history
Summary:
**Summary:** Makes it possible to specify component parameter defaults in `.torchxconfig`.  See changes to `.torchxconfig` files included in this diff and the *Test Plan* section for example usage and config specification.

**Motivation:** Useful UX for those using builtin components that have required params (b/c no "global" defaults exist universally and hence cannot be specified as defaults in the component function declaration) that are always static for a particular user/team's use case of the builtin

**Example:**  `image` in `dist.ddp` will in most cases be some constant for the team but no universal default exists (and hence cannot be specified in the function declaration of `dist.ddp` itself) and is cumbersome to specify it all the time in the commandline.

**Alternative:** is to copy the builtin as a separate component and hardcode (or default in the function declaration) the desired fields, but this requires the user to fork the builtin, which is sub-optimal for those in the "exploration/dev" phase and currently uninterested in productionalizing the component.

**Other Notes:** While working on this feature, I've noticed a few improvements/cleanups that we need to work on which I'm tracking as [issue-368](#368). We need to push this code in the interest of time, and I've done as much as I could to NOT change any major APIs until we address the issues properly through issue-368.

Reviewed By: aivanou

Differential Revision: D33576756

fbshipit-source-id: b65af48a570cc83c366df4eb71a8583a0be6018f
  • Loading branch information
Kiuk Chung authored and facebook-github-bot committed Jan 15, 2022
1 parent 782f14d commit c37cfd7
Show file tree
Hide file tree
Showing 11 changed files with 495 additions and 97 deletions.
92 changes: 71 additions & 21 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import threading
from dataclasses import asdict
from pprint import pformat
from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Tuple, Type

import torchx.specs as specs
from pyre_extensions import none_throws
from torchx.cli.cmd_base import SubCommand
from torchx.cli.cmd_log import get_logs
from torchx.runner import Runner, config
from torchx.runner.workspaces import get_workspace_runner, WorkspaceRunner
from torchx.runner.config import load_sections
from torchx.runner.workspaces import WorkspaceRunner, get_workspace_runner
from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories
from torchx.specs import CfgVal
from torchx.specs.finder import (
Expand All @@ -31,6 +32,10 @@
from torchx.util.types import to_dict


MISSING_COMPONENT_ERROR_MSG = (
"missing component name, either provide it from the CLI or in .torchxconfig"
)

logger: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -61,6 +66,54 @@ def _parse_run_config(arg: str, scheduler_opts: specs.runopts) -> Dict[str, CfgV
return conf


def _parse_component_name_and_args(
component_name_and_args: List[str],
subparser: argparse.ArgumentParser,
dirs: Optional[List[str]] = None, # for testing only
) -> Tuple[str, List[str]]:
"""
Given a list of nargs parsed from commandline, parses out the component name
and component args. If component name is not found in the list, then
the default component is loaded from the [cli:run] component section in
.torchxconfig. If no default config is specified in .torchxconfig, then
this method errors out to the specified subparser.
This method deals with the following input list:
1. [$component_name, *$component_args]
- Example: ["utils.echo", "--msg", "hello"] or ["utils.echo"]
- Note: component name and args both in list
2. [*$component_args]
- Example: ["--msg", "hello"] or []
- Note: component name loaded from .torchxconfig, args in list
- Note: assumes list is only args if the first element
looks like an option (e.g. starts with "-")
"""
component = config.get_config(prefix="cli", name="run", key="component", dirs=dirs)
component_args = []

# make a copy of the input list to guard against side-effects
args = list(component_name_and_args)

if len(args) > 0:
# `--` is used to delimit between run's options and nargs which includes component args
# argparse returns the delimiter as part of the nargs so just ignore it if present
if args[0] == "--":
args = args[1:]

if args[0].startswith("-"):
component_args = args
else: # first element is NOT an option; then it must be a component name
component = args[0]
component_args = args[1:]

if not component:
subparser.error(MISSING_COMPONENT_ERROR_MSG)

return component, component_args


class CmdBuiltins(SubCommand):
def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
subparser.add_argument(
Expand Down Expand Up @@ -126,7 +179,7 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
help="Stream logs while waiting for app to finish.",
)
subparser.add_argument(
"conf_args",
"component_name_and_args",
nargs=argparse.REMAINDER,
)

Expand All @@ -143,33 +196,29 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
scheduler_opts = run_opts[args.scheduler]
cfg = _parse_run_config(args.scheduler_args, scheduler_opts)
config.apply(scheduler=args.scheduler, cfg=cfg)

config_files = config.find_configs()
workspace = (
"file://" + os.path.dirname(config_files[0]) if config_files else None
)
component, component_args = _parse_component_name_and_args(
args.component_name_and_args,
none_throws(self._subparser),
)

if len(args.conf_args) < 1:
none_throws(self._subparser).error(
"the following arguments are required: conf_file, conf_args"
)

# Python argparse would remove `--` if it was the first argument. This
# does not work well for torchx, since torchx.specs.api uses another argparser to
# parse component arguments.
conf_file, conf_args = args.conf_args[0], args.conf_args[1:]
try:
if args.dryrun:
if isinstance(runner, WorkspaceRunner):
dryrun_info = runner.dryrun_component(
conf_file,
conf_args,
component,
component_args,
args.scheduler,
workspace=workspace,
cfg=cfg,
)
else:
dryrun_info = runner.dryrun_component(
conf_file, conf_args, args.scheduler, cfg=cfg
component, component_args, args.scheduler, cfg=cfg
)
logger.info(
"\n=== APPLICATION ===\n"
Expand All @@ -180,16 +229,16 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
else:
if isinstance(runner, WorkspaceRunner):
app_handle = runner.run_component(
conf_file,
conf_args,
component,
component_args,
args.scheduler,
workspace=workspace,
cfg=cfg,
)
else:
app_handle = runner.run_component(
conf_file,
conf_args,
component,
component_args,
args.scheduler,
cfg=cfg,
)
Expand All @@ -208,7 +257,7 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
self._wait_and_exit(runner, app_handle, log=args.log)

except (ComponentValidationException, ComponentNotFoundException) as e:
error_msg = f"\nFailed to run component `{conf_file}` got errors: \n {e}"
error_msg = f"\nFailed to run component `{component}` got errors: \n {e}"
logger.error(error_msg)
sys.exit(1)
except specs.InvalidRunConfigException as e:
Expand All @@ -223,7 +272,8 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:

def run(self, args: argparse.Namespace) -> None:
os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "cli_run")
with get_workspace_runner() as runner:
component_defaults = load_sections(prefix="component")
with get_workspace_runner(component_defaults=component_defaults) as runner:
self._run(runner, args)

def _wait_and_exit(self, runner: Runner, app_handle: str, log: bool) -> None:
Expand Down
65 changes: 64 additions & 1 deletion torchx/cli/test/cmd_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
from typing import Generator, List
from unittest.mock import MagicMock, patch

from torchx.cli.cmd_run import CmdBuiltins, CmdRun, _parse_run_config, logger
from torchx.cli.cmd_run import (
CmdBuiltins,
CmdRun,
_parse_component_name_and_args,
_parse_run_config,
logger,
)
from torchx.schedulers.local_scheduler import SignalException
from torchx.specs import runopts

Expand Down Expand Up @@ -198,6 +204,63 @@ def test_parse_runopts(self) -> None:
for k, v in expected_args.items():
self.assertEqual(v, runconfig.get(k))

def test_parse_component_name_and_args_no_default(self) -> None:
sp = argparse.ArgumentParser(prog="test")
self.assertEqual(
("utils.echo", []),
_parse_component_name_and_args(["utils.echo"], sp),
)
self.assertEqual(
("utils.echo", []),
_parse_component_name_and_args(["--", "utils.echo"], sp),
)
self.assertEqual(
("utils.echo", ["--msg", "hello"]),
_parse_component_name_and_args(["utils.echo", "--msg", "hello"], sp),
)

with self.assertRaises(SystemExit):
_parse_component_name_and_args(["--msg", "hello"], sp)

with self.assertRaises(SystemExit):
_parse_component_name_and_args(["-m", "hello"], sp)

def test_parse_component_name_and_args_with_default(self) -> None:
sp = argparse.ArgumentParser(prog="test")
dirs = [str(self.tmpdir)]

with open(Path(self.tmpdir) / ".torchxconfig", "w") as f:
f.write(
"""#
[cli:run]
component = custom.echo
"""
)

self.assertEqual(
("utils.echo", []), _parse_component_name_and_args(["utils.echo"], sp, dirs)
)
self.assertEqual(
("utils.echo", ["--msg", "hello"]),
_parse_component_name_and_args(["utils.echo", "--msg", "hello"], sp, dirs),
)
self.assertEqual(
("custom.echo", []),
_parse_component_name_and_args([], sp, dirs),
)
self.assertEqual(
("custom.echo", ["--msg", "hello"]),
_parse_component_name_and_args(["--", "--msg", "hello"], sp, dirs),
)
self.assertEqual(
("custom.echo", ["--msg", "hello"]),
_parse_component_name_and_args(["--msg", "hello"], sp, dirs),
)
self.assertEqual(
("custom.echo", ["-m", "hello"]),
_parse_component_name_and_args(["-m", "hello"], sp, dirs),
)


class CmdBuiltinTest(unittest.TestCase):
def test_run(self) -> None:
Expand Down
17 changes: 17 additions & 0 deletions torchx/examples/apps/.torchxconfig
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,20 @@ enableGracefulPreemption = False
secure_group = pytorch_r2p
entitlement = default
proxy_workflow_image = None

[cli:run]
component = fb.dist.hpc

# TODO need to add hydra to bento_kernel_torchx and make that the default img
[component:fb.dist.ddp]
img = bento_kernel_pytorch_lightning
m = fb/compute_world_size/main.py

[component:fb.dist.ddp2]
img = bento_kernel_pytorch_lightning
m = fb/compute_world_size/main.py

[component:fb.dist.hpc]
img = bento_kernel_pytorch_lightning
m = fb/compute_world_size/main.py

18 changes: 15 additions & 3 deletions torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
self,
name: str,
schedulers: Dict[SchedulerBackend, Scheduler],
component_defaults: Optional[Dict[str, Dict[str, str]]] = None,
) -> None:
"""
Creates a new runner instance.
Expand All @@ -63,6 +64,9 @@ def __init__(
self._schedulers = schedulers
self._apps: Dict[AppHandle, AppDef] = {}

# component_name -> map of component_fn_param_name -> user-specified default val encoded as str
self._component_defaults: Dict[str, Dict[str, str]] = component_defaults or {}

def __enter__(self) -> "Runner":
return self

Expand Down Expand Up @@ -147,7 +151,11 @@ def dryrun_component(
component, but just returns what "would" have run.
"""
component_def = get_component(component)
app = from_function(component_def.fn, component_args)
app = from_function(
component_def.fn,
component_args,
self._component_defaults.get(component, None),
)
return self.dryrun(app, scheduler, cfg)

def run(
Expand Down Expand Up @@ -521,7 +529,11 @@ def __repr__(self) -> str:
return f"Runner(name={self._name}, schedulers={self._schedulers}, apps={self._apps})"


def get_runner(name: Optional[str] = None, **scheduler_params: Any) -> Runner:
def get_runner(
name: Optional[str] = None,
component_defaults: Optional[Dict[str, Dict[str, str]]] = None,
**scheduler_params: Any,
) -> Runner:
"""
Convenience method to construct and get a Runner object. Usage:
Expand Down Expand Up @@ -554,4 +566,4 @@ def get_runner(name: Optional[str] = None, **scheduler_params: Any) -> Runner:
name = "torchx"

schedulers = get_schedulers(session_name=name, **scheduler_params)
return Runner(name, schedulers)
return Runner(name, schedulers, component_defaults)
Loading

0 comments on commit c37cfd7

Please sign in to comment.