Skip to content

Commit 1514090

Browse files
committed
handle older PyTorch versions
1 parent 0bd7964 commit 1514090

File tree

1 file changed

+34
-31
lines changed

1 file changed

+34
-31
lines changed

src/torchrunx/agent.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import os
55
import socket
66
import sys
7+
import tempfile
78
from dataclasses import dataclass
89
from typing import Callable, Literal
910

1011
import cloudpickle
1112
import torch
1213
import torch.distributed as dist
13-
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
14-
from torch.distributed.elastic.multiprocessing.api import MultiprocessContext, Std
14+
from torch.distributed.elastic.multiprocessing import start_processes
1515
from typing_extensions import Self
1616

1717
from .utils import (
@@ -108,7 +108,7 @@ def main(launcher_agent_group: LauncherAgentGroup):
108108
port=get_open_port(),
109109
process_id=os.getpid(),
110110
)
111-
111+
# DefaultLogsSpecs(log_dir=None, tee=Std.ALL, local_ranks_filter={0}),
112112
all_payloads = launcher_agent_group.sync_payloads(payload=payload)
113113
launcher_payload: LauncherPayload = all_payloads[0] # pyright: ignore[reportAssignmentType]
114114
main_agent_payload: AgentPayload = all_payloads[1] # pyright: ignore[reportAssignmentType]
@@ -119,36 +119,40 @@ def main(launcher_agent_group: LauncherAgentGroup):
119119
worker_log_files = launcher_payload.worker_log_files[agent_rank]
120120
num_workers = len(worker_global_ranks)
121121

122-
# spawn workers
122+
if torch.__version__ > '2.2':
123+
# DefaultLogsSpecs only exists in torch >= 2.3
124+
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
125+
log_arg = DefaultLogsSpecs(log_dir=tempfile.mkdtemp())
126+
else:
127+
log_arg = tempfile.mkdtemp()
123128

124-
ctx = MultiprocessContext(
125-
name=f"{hostname}_",
126-
entrypoint=entrypoint,
127-
args={
128-
i: (
129-
WorkerArgs(
130-
function=launcher_payload.fn,
131-
master_hostname=main_agent_payload.hostname,
132-
master_port=main_agent_payload.port,
133-
backend=launcher_payload.backend,
134-
rank=worker_global_ranks[i],
135-
local_rank=i,
136-
local_world_size=num_workers,
137-
world_size=worker_world_size,
138-
log_file=worker_log_files[i],
139-
timeout=launcher_payload.timeout,
140-
).to_bytes(),
129+
# spawn workers
130+
131+
ctx = start_processes(
132+
f"{hostname}_",
133+
entrypoint,
134+
{
135+
i: (
136+
WorkerArgs(
137+
function=launcher_payload.fn,
138+
master_hostname=main_agent_payload.hostname,
139+
master_port=main_agent_payload.port,
140+
backend=launcher_payload.backend,
141+
rank=worker_global_ranks[i],
142+
local_rank=i,
143+
local_world_size=num_workers,
144+
world_size=worker_world_size,
145+
log_file=worker_log_files[i],
146+
timeout=launcher_payload.timeout,
147+
).to_bytes(),
148+
)
149+
for i in range(num_workers)
150+
},
151+
{i: {} for i in range(num_workers)},
152+
log_arg # type: ignore
141153
)
142-
for i in range(num_workers)
143-
},
144-
envs={i: {} for i in range(num_workers)},
145-
logs_specs=DefaultLogsSpecs(log_dir=None, tee=Std.ALL, local_ranks_filter={0}),
146-
start_method="spawn",
147-
)
148-
154+
149155
try:
150-
ctx.start()
151-
152156
status = AgentStatus()
153157
while True:
154158
if status.is_running():
@@ -163,7 +167,6 @@ def main(launcher_agent_group: LauncherAgentGroup):
163167

164168
if any(s.is_failed() for s in agent_statuses):
165169
raise RuntimeError()
166-
167170
except:
168171
raise
169172
finally:

0 commit comments

Comments
 (0)