Skip to content

Commit 23aaae6

Browse files
committed
larger try-catch wrapper in Launcher
1 parent b8c68d6 commit 23aaae6

File tree

2 files changed

+74
-62
lines changed

2 files changed

+74
-62
lines changed

src/torchrunx/launcher.py

Lines changed: 69 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -203,79 +203,83 @@ def run(
203203
launcher_port = get_open_port()
204204
world_size = len(hostnames) + 1
205205

206-
# start logging server
206+
log_receiver = None
207+
log_process = None
208+
launcher_agent_group = None
207209

208-
log_receiver = build_logging_server(
209-
log_handlers=self.log_handlers,
210-
launcher_hostname=launcher_hostname,
211-
hostnames=hostnames,
212-
workers_per_host=workers_per_host,
213-
log_dir=Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")),
214-
log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], # noqa: SLF001
215-
)
216-
217-
log_process = Process(
218-
target=log_receiver.serve_forever,
219-
daemon=True,
220-
)
210+
try:
211+
# start logging server
212+
213+
log_receiver = build_logging_server(
214+
log_handlers=self.log_handlers,
215+
launcher_hostname=launcher_hostname,
216+
hostnames=hostnames,
217+
workers_per_host=workers_per_host,
218+
log_dir=Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")),
219+
log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], # noqa: SLF001
220+
)
221221

222-
log_process.start()
223-
224-
# start agents on each node
225-
226-
for i, hostname in enumerate(hostnames):
227-
execute_command(
228-
command=build_command(
229-
launcher_hostname=launcher_hostname,
230-
launcher_port=launcher_port,
231-
logger_port=log_receiver.port,
232-
world_size=world_size,
233-
rank=i + 1,
234-
env_vars=self.env_vars,
235-
env_file=self.env_file,
236-
),
237-
hostname=hostname,
238-
ssh_config_file=self.ssh_config_file,
222+
log_process = Process(
223+
target=log_receiver.serve_forever,
224+
daemon=True,
239225
)
240226

241-
# initialize launcher-agent process group
242-
# ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1])
227+
log_process.start()
243228

244-
launcher_agent_group = LauncherAgentGroup(
245-
launcher_hostname=launcher_hostname,
246-
launcher_port=launcher_port,
247-
world_size=world_size,
248-
rank=0,
249-
)
229+
# start agents on each node
250230

251-
# build and sync payloads between launcher and agents
231+
for i, hostname in enumerate(hostnames):
232+
execute_command(
233+
command=build_command(
234+
launcher_hostname=launcher_hostname,
235+
launcher_port=launcher_port,
236+
logger_port=log_receiver.port,
237+
world_size=world_size,
238+
rank=i + 1,
239+
env_vars=self.env_vars,
240+
env_file=self.env_file,
241+
),
242+
hostname=hostname,
243+
ssh_config_file=self.ssh_config_file,
244+
)
252245

253-
_cumulative_workers = [0, *itertools.accumulate(workers_per_host)]
246+
# initialize launcher-agent process group
247+
# ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1])
254248

255-
worker_global_ranks = [
256-
list(range(_cumulative_workers[n], _cumulative_workers[n + 1]))
257-
for n in range(len(hostnames))
258-
]
249+
launcher_agent_group = LauncherAgentGroup(
250+
launcher_hostname=launcher_hostname,
251+
launcher_port=launcher_port,
252+
world_size=world_size,
253+
rank=0,
254+
)
259255

260-
payload = LauncherPayload(
261-
fn=partial(func, *(func_args or ()), **(func_kwargs or {})),
262-
hostnames=hostnames,
263-
worker_global_ranks=worker_global_ranks,
264-
worker_world_size=sum(workers_per_host),
265-
backend=self.backend,
266-
timeout=self.timeout,
267-
)
256+
# build and sync payloads between launcher and agents
268257

269-
launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload)
258+
_cumulative_workers = [0, *itertools.accumulate(workers_per_host)]
270259

271-
# loop to monitor agent statuses (until failed or done)
260+
worker_global_ranks = [
261+
list(range(_cumulative_workers[n], _cumulative_workers[n + 1]))
262+
for n in range(len(hostnames))
263+
]
264+
265+
payload = LauncherPayload(
266+
fn=partial(func, *(func_args or ()), **(func_kwargs or {})),
267+
hostnames=hostnames,
268+
worker_global_ranks=worker_global_ranks,
269+
worker_world_size=sum(workers_per_host),
270+
backend=self.backend,
271+
timeout=self.timeout,
272+
)
273+
274+
launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload)
275+
276+
# loop to monitor agent statuses (until failed or done)
272277

273-
try:
274278
while True:
275-
# raises exception if communication timeout due to death of any agent
279+
# raises RuntimeError if communication timeout due to death of any agent
276280
agent_statuses = launcher_agent_group.sync_agent_statuses(status=None)
277281

278-
# raises exception if any agent failed
282+
# raises specific exception if any agent fails
279283
for s in agent_statuses:
280284
for value in s.return_values.values():
281285
if isinstance(value, WorkerException):
@@ -294,10 +298,13 @@ def run(
294298
)
295299
raise
296300
finally:
297-
log_receiver.shutdown()
298-
log_receiver.server_close()
299-
log_process.kill()
300-
dist.destroy_process_group()
301+
if log_receiver is not None:
302+
log_receiver.shutdown()
303+
log_receiver.server_close()
304+
if log_process is not None:
305+
log_process.kill()
306+
if launcher_agent_group is not None:
307+
launcher_agent_group.shutdown()
301308

302309
return {
303310
hostname: agent_status.return_values

src/torchrunx/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class LauncherAgentGroup:
7373
rank: int
7474

7575
def __post_init__(self) -> None:
76+
# timeout will raise torch.distributed.DistStoreError
7677
self.group = dist.init_process_group(
7778
backend="gloo",
7879
world_size=self.world_size,
@@ -96,6 +97,7 @@ def _all_gather(self, obj: Any) -> list:
9697
"""gather object from every rank to list on every rank"""
9798
object_bytes = self._serialize(obj)
9899
object_list = [b""] * self.world_size
100+
# raises RuntimeError if timeout
99101
dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group)
100102
return [self._deserialize(o) for o in object_list]
101103

@@ -110,3 +112,6 @@ def sync_payloads(
110112

111113
def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]:
112114
return self._all_gather(status)[1:] # [0] is launcher (status=None)
115+
116+
def shutdown(self) -> None:
117+
dist.destroy_process_group(group=self.group)

0 commit comments

Comments
 (0)