14
14
from logging import Handler
15
15
from multiprocessing import Process
16
16
from pathlib import Path
17
- from typing import Any , Callable , Literal
17
+ from typing import Any , Callable , Literal , overload
18
18
19
19
import fabric
20
20
import torch .distributed as dist
21
21
22
22
from .environment import auto_hosts , auto_workers , slurm_hosts , slurm_workers
23
23
from .logging_utils import LogRecordSocketReceiver , default_handlers
24
- from .utils import (
25
- LauncherAgentGroup ,
26
- LauncherPayload ,
27
- WorkerException ,
28
- get_open_port ,
29
- )
24
+ from .utils import AgentStatus , LauncherAgentGroup , LauncherPayload , WorkerException , get_open_port
30
25
31
26
32
27
def resolve_hostnames (hostnames : list [str ] | Literal ["auto" , "slurm" ]) -> list [str ]:
@@ -180,7 +175,7 @@ def run( # noqa: C901, PLR0912
180
175
func : Callable ,
181
176
func_args : tuple [Any ] | None = None ,
182
177
func_kwargs : dict [str , Any ] | None = None ,
183
- ) -> dict [ str , dict [ int , Any ]] :
178
+ ) -> LaunchResult :
184
179
"""
185
180
Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch`
186
181
@@ -309,10 +304,7 @@ def run( # noqa: C901, PLR0912
309
304
ssh_config_file = self .ssh_config_file ,
310
305
)
311
306
312
- return {
313
- hostname : agent_status .return_values
314
- for hostname , agent_status in zip (hostnames , agent_statuses )
315
- }
307
+ return LaunchResult (hostnames = hostnames , agent_statuses = agent_statuses )
316
308
317
309
318
310
def launch (
@@ -336,7 +328,7 @@ def launch(
336
328
),
337
329
env_file : str | os .PathLike | None = None ,
338
330
timeout : int = 600 ,
339
- ) -> dict [ str , dict [ int , Any ]] :
331
+ ) -> LaunchResult :
340
332
"""
341
333
Launch a distributed PyTorch function on the specified nodes.
342
334
@@ -378,3 +370,26 @@ def launch(
378
370
env_file = env_file ,
379
371
timeout = timeout ,
380
372
).run (func = func , func_args = func_args , func_kwargs = func_kwargs )
373
+
374
+
375
+ class LaunchResult :
376
+ def __init__ (self , hostnames : list [str ], agent_statuses : list [AgentStatus ]) -> None :
377
+ self .results = {
378
+ hostname : agent_status .return_values
379
+ for hostname , agent_status in zip (hostnames , agent_statuses )
380
+ }
381
+
382
+ def all (self ) -> dict [str , list [Any ]]:
383
+ return self .results
384
+
385
+ # all(by='rank')
386
+
387
+ # value(rank: int)
388
+
389
+ @overload
390
+ def value (self , hostname : str ) -> list [Any ]:
391
+ return list (self .results [hostname ].values ())
392
+
393
+ @overload
394
+ def value (self , hostname : str , rank : int ) -> Any :
395
+ return self .results [hostname ][rank ]
0 commit comments