@@ -203,79 +203,83 @@ def run(
203
203
launcher_port = get_open_port ()
204
204
world_size = len (hostnames ) + 1
205
205
206
- # start logging server
206
+ log_receiver = None
207
+ log_process = None
208
+ launcher_agent_group = None
207
209
208
- log_receiver = build_logging_server (
209
- log_handlers = self .log_handlers ,
210
- launcher_hostname = launcher_hostname ,
211
- hostnames = hostnames ,
212
- workers_per_host = workers_per_host ,
213
- log_dir = Path (os .environ .get ("TORCHRUNX_LOG_DIR" , "torchrunx_logs" )),
214
- log_level = logging ._nameToLevel [os .environ .get ("TORCHRUNX_LOG_LEVEL" , "INFO" )], # noqa: SLF001
215
- )
216
-
217
- log_process = Process (
218
- target = log_receiver .serve_forever ,
219
- daemon = True ,
220
- )
210
+ try :
211
+ # start logging server
212
+
213
+ log_receiver = build_logging_server (
214
+ log_handlers = self .log_handlers ,
215
+ launcher_hostname = launcher_hostname ,
216
+ hostnames = hostnames ,
217
+ workers_per_host = workers_per_host ,
218
+ log_dir = Path (os .environ .get ("TORCHRUNX_LOG_DIR" , "torchrunx_logs" )),
219
+ log_level = logging ._nameToLevel [os .environ .get ("TORCHRUNX_LOG_LEVEL" , "INFO" )], # noqa: SLF001
220
+ )
221
221
222
- log_process .start ()
223
-
224
- # start agents on each node
225
-
226
- for i , hostname in enumerate (hostnames ):
227
- execute_command (
228
- command = build_command (
229
- launcher_hostname = launcher_hostname ,
230
- launcher_port = launcher_port ,
231
- logger_port = log_receiver .port ,
232
- world_size = world_size ,
233
- rank = i + 1 ,
234
- env_vars = self .env_vars ,
235
- env_file = self .env_file ,
236
- ),
237
- hostname = hostname ,
238
- ssh_config_file = self .ssh_config_file ,
222
+ log_process = Process (
223
+ target = log_receiver .serve_forever ,
224
+ daemon = True ,
239
225
)
240
226
241
- # initialize launcher-agent process group
242
- # ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1])
227
+ log_process .start ()
243
228
244
- launcher_agent_group = LauncherAgentGroup (
245
- launcher_hostname = launcher_hostname ,
246
- launcher_port = launcher_port ,
247
- world_size = world_size ,
248
- rank = 0 ,
249
- )
229
+ # start agents on each node
250
230
251
- # build and sync payloads between launcher and agents
231
+ for i , hostname in enumerate (hostnames ):
232
+ execute_command (
233
+ command = build_command (
234
+ launcher_hostname = launcher_hostname ,
235
+ launcher_port = launcher_port ,
236
+ logger_port = log_receiver .port ,
237
+ world_size = world_size ,
238
+ rank = i + 1 ,
239
+ env_vars = self .env_vars ,
240
+ env_file = self .env_file ,
241
+ ),
242
+ hostname = hostname ,
243
+ ssh_config_file = self .ssh_config_file ,
244
+ )
252
245
253
- _cumulative_workers = [0 , * itertools .accumulate (workers_per_host )]
246
+ # initialize launcher-agent process group
247
+ # ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1])
254
248
255
- worker_global_ranks = [
256
- list (range (_cumulative_workers [n ], _cumulative_workers [n + 1 ]))
257
- for n in range (len (hostnames ))
258
- ]
249
+ launcher_agent_group = LauncherAgentGroup (
250
+ launcher_hostname = launcher_hostname ,
251
+ launcher_port = launcher_port ,
252
+ world_size = world_size ,
253
+ rank = 0 ,
254
+ )
259
255
260
- payload = LauncherPayload (
261
- fn = partial (func , * (func_args or ()), ** (func_kwargs or {})),
262
- hostnames = hostnames ,
263
- worker_global_ranks = worker_global_ranks ,
264
- worker_world_size = sum (workers_per_host ),
265
- backend = self .backend ,
266
- timeout = self .timeout ,
267
- )
256
+ # build and sync payloads between launcher and agents
268
257
269
- launcher_payload , agent_payloads = launcher_agent_group . sync_payloads ( payload = payload )
258
+ _cumulative_workers = [ 0 , * itertools . accumulate ( workers_per_host )]
270
259
271
- # loop to monitor agent statuses (until failed or done)
260
+ worker_global_ranks = [
261
+ list (range (_cumulative_workers [n ], _cumulative_workers [n + 1 ]))
262
+ for n in range (len (hostnames ))
263
+ ]
264
+
265
+ payload = LauncherPayload (
266
+ fn = partial (func , * (func_args or ()), ** (func_kwargs or {})),
267
+ hostnames = hostnames ,
268
+ worker_global_ranks = worker_global_ranks ,
269
+ worker_world_size = sum (workers_per_host ),
270
+ backend = self .backend ,
271
+ timeout = self .timeout ,
272
+ )
273
+
274
+ launcher_payload , agent_payloads = launcher_agent_group .sync_payloads (payload = payload )
275
+
276
+ # loop to monitor agent statuses (until failed or done)
272
277
273
- try :
274
278
while True :
275
- # raises exception if communication timeout due to death of any agent
279
+ # raises RuntimeError if communication timeout due to death of any agent
276
280
agent_statuses = launcher_agent_group .sync_agent_statuses (status = None )
277
281
278
- # raises exception if any agent failed
282
+ # raises specific exception if any agent fails
279
283
for s in agent_statuses :
280
284
for value in s .return_values .values ():
281
285
if isinstance (value , WorkerException ):
@@ -294,10 +298,13 @@ def run(
294
298
)
295
299
raise
296
300
finally :
297
- log_receiver .shutdown ()
298
- log_receiver .server_close ()
299
- log_process .kill ()
300
- dist .destroy_process_group ()
301
+ if log_receiver is not None :
302
+ log_receiver .shutdown ()
303
+ log_receiver .server_close ()
304
+ if log_process is not None :
305
+ log_process .kill ()
306
+ if launcher_agent_group is not None :
307
+ launcher_agent_group .shutdown ()
301
308
302
309
return {
303
310
hostname : agent_status .return_values
0 commit comments