Skip to content

Commit 79fb0a8

Browse files
committed
LaunchResult, first draft
1 parent f913221 commit 79fb0a8

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

src/torchrunx/launcher.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,14 @@
1414
from logging import Handler
1515
from multiprocessing import Process
1616
from pathlib import Path
17-
from typing import Any, Callable, Literal
17+
from typing import Any, Callable, Literal, overload
1818

1919
import fabric
2020
import torch.distributed as dist
2121

2222
from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
2323
from .logging_utils import LogRecordSocketReceiver, default_handlers
24-
from .utils import (
25-
LauncherAgentGroup,
26-
LauncherPayload,
27-
WorkerException,
28-
get_open_port,
29-
)
24+
from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port
3025

3126

3227
def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
@@ -180,7 +175,7 @@ def run( # noqa: C901, PLR0912
180175
func: Callable,
181176
func_args: tuple[Any] | None = None,
182177
func_kwargs: dict[str, Any] | None = None,
183-
) -> dict[str, dict[int, Any]]:
178+
) -> LaunchResult:
184179
"""
185180
Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch`
186181
@@ -309,10 +304,7 @@ def run( # noqa: C901, PLR0912
309304
ssh_config_file=self.ssh_config_file,
310305
)
311306

312-
return {
313-
hostname: agent_status.return_values
314-
for hostname, agent_status in zip(hostnames, agent_statuses)
315-
}
307+
return LaunchResult(hostnames=hostnames, agent_statuses=agent_statuses)
316308

317309

318310
def launch(
@@ -336,7 +328,7 @@ def launch(
336328
),
337329
env_file: str | os.PathLike | None = None,
338330
timeout: int = 600,
339-
) -> dict[str, dict[int, Any]]:
331+
) -> LaunchResult:
340332
"""
341333
Launch a distributed PyTorch function on the specified nodes.
342334
@@ -378,3 +370,26 @@ def launch(
378370
env_file=env_file,
379371
timeout=timeout,
380372
).run(func=func, func_args=func_args, func_kwargs=func_kwargs)
373+
374+
375+
class LaunchResult:
376+
def __init__(self, hostnames: list[str], agent_statuses: list[AgentStatus]) -> None:
377+
self.results = {
378+
hostname: agent_status.return_values
379+
for hostname, agent_status in zip(hostnames, agent_statuses)
380+
}
381+
382+
def all(self) -> dict[str, list[Any]]:
383+
return self.results
384+
385+
# all(by='rank')
386+
387+
# value(rank: int)
388+
389+
@overload
390+
def value(self, hostname: str) -> list[Any]:
391+
return list(self.results[hostname].values())
392+
393+
@overload
394+
def value(self, hostname: str, rank: int) -> Any:
395+
return self.results[hostname][rank]

0 commit comments

Comments
 (0)