Skip to content

Commit 5e585f9

Browse files
committed
format
1 parent 6507dcf commit 5e585f9

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

src/torchrunx/agent.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -119,39 +119,40 @@ def main(launcher_agent_group: LauncherAgentGroup):
119119
worker_log_files = launcher_payload.worker_log_files[agent_rank]
120120
num_workers = len(worker_global_ranks)
121121

122-
if torch.__version__ >= '2.3':
122+
if torch.__version__ >= "2.3":
123123
# DefaultLogsSpecs only exists in torch >= 2.3
124124
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
125+
125126
log_arg = DefaultLogsSpecs(log_dir=tempfile.mkdtemp())
126127
else:
127128
log_arg = tempfile.mkdtemp()
128129

129130
# spawn workers
130-
131+
131132
ctx = start_processes(
132-
f"{hostname}_",
133-
entrypoint,
134-
{
135-
i: (
136-
WorkerArgs(
137-
function=launcher_payload.fn,
138-
master_hostname=main_agent_payload.hostname,
139-
master_port=main_agent_payload.port,
140-
backend=launcher_payload.backend,
141-
rank=worker_global_ranks[i],
142-
local_rank=i,
143-
local_world_size=num_workers,
144-
world_size=worker_world_size,
145-
log_file=worker_log_files[i],
146-
timeout=launcher_payload.timeout,
147-
).to_bytes(),
148-
)
149-
for i in range(num_workers)
150-
},
151-
{i: {} for i in range(num_workers)},
152-
log_arg # type: ignore
133+
f"{hostname}_",
134+
entrypoint,
135+
{
136+
i: (
137+
WorkerArgs(
138+
function=launcher_payload.fn,
139+
master_hostname=main_agent_payload.hostname,
140+
master_port=main_agent_payload.port,
141+
backend=launcher_payload.backend,
142+
rank=worker_global_ranks[i],
143+
local_rank=i,
144+
local_world_size=num_workers,
145+
world_size=worker_world_size,
146+
log_file=worker_log_files[i],
147+
timeout=launcher_payload.timeout,
148+
).to_bytes(),
153149
)
154-
150+
for i in range(num_workers)
151+
},
152+
{i: {} for i in range(num_workers)},
153+
log_arg, # type: ignore
154+
)
155+
155156
try:
156157
status = AgentStatus()
157158
while True:

0 commit comments

Comments
 (0)