4
4
import os
5
5
import socket
6
6
import sys
7
+ import tempfile
7
8
from dataclasses import dataclass
8
9
from typing import Callable , Literal
9
10
10
11
import cloudpickle
11
12
import torch
12
13
import torch .distributed as dist
13
- from torch .distributed .elastic .multiprocessing import DefaultLogsSpecs
14
- from torch .distributed .elastic .multiprocessing .api import MultiprocessContext , Std
14
+ from torch .distributed .elastic .multiprocessing import start_processes
15
15
from typing_extensions import Self
16
16
17
17
from .utils import (
@@ -108,7 +108,7 @@ def main(launcher_agent_group: LauncherAgentGroup):
108
108
port = get_open_port (),
109
109
process_id = os .getpid (),
110
110
)
111
-
111
+ # DefaultLogsSpecs(log_dir=None, tee=Std.ALL, local_ranks_filter={0}),
112
112
all_payloads = launcher_agent_group .sync_payloads (payload = payload )
113
113
launcher_payload : LauncherPayload = all_payloads [0 ] # pyright: ignore[reportAssignmentType]
114
114
main_agent_payload : AgentPayload = all_payloads [1 ] # pyright: ignore[reportAssignmentType]
@@ -119,36 +119,40 @@ def main(launcher_agent_group: LauncherAgentGroup):
119
119
worker_log_files = launcher_payload .worker_log_files [agent_rank ]
120
120
num_workers = len (worker_global_ranks )
121
121
122
- # spawn workers
122
+ if torch .__version__ > '2.2' :
123
+ # DefaultLogsSpecs only exists in torch >= 2.3
124
+ from torch .distributed .elastic .multiprocessing import DefaultLogsSpecs
125
+ log_arg = DefaultLogsSpecs (log_dir = tempfile .mkdtemp ())
126
+ else :
127
+ log_arg = tempfile .mkdtemp ()
123
128
124
- ctx = MultiprocessContext (
125
- name = f"{ hostname } _" ,
126
- entrypoint = entrypoint ,
127
- args = {
128
- i : (
129
- WorkerArgs (
130
- function = launcher_payload .fn ,
131
- master_hostname = main_agent_payload .hostname ,
132
- master_port = main_agent_payload .port ,
133
- backend = launcher_payload .backend ,
134
- rank = worker_global_ranks [i ],
135
- local_rank = i ,
136
- local_world_size = num_workers ,
137
- world_size = worker_world_size ,
138
- log_file = worker_log_files [i ],
139
- timeout = launcher_payload .timeout ,
140
- ).to_bytes (),
129
+ # spawn workers
130
+
131
+ ctx = start_processes (
132
+ f"{ hostname } _" ,
133
+ entrypoint ,
134
+ {
135
+ i : (
136
+ WorkerArgs (
137
+ function = launcher_payload .fn ,
138
+ master_hostname = main_agent_payload .hostname ,
139
+ master_port = main_agent_payload .port ,
140
+ backend = launcher_payload .backend ,
141
+ rank = worker_global_ranks [i ],
142
+ local_rank = i ,
143
+ local_world_size = num_workers ,
144
+ world_size = worker_world_size ,
145
+ log_file = worker_log_files [i ],
146
+ timeout = launcher_payload .timeout ,
147
+ ).to_bytes (),
148
+ )
149
+ for i in range (num_workers )
150
+ },
151
+ {i : {} for i in range (num_workers )},
152
+ log_arg # type: ignore
141
153
)
142
- for i in range (num_workers )
143
- },
144
- envs = {i : {} for i in range (num_workers )},
145
- logs_specs = DefaultLogsSpecs (log_dir = None , tee = Std .ALL , local_ranks_filter = {0 }),
146
- start_method = "spawn" ,
147
- )
148
-
154
+
149
155
try :
150
- ctx .start ()
151
-
152
156
status = AgentStatus ()
153
157
while True :
154
158
if status .is_running ():
@@ -163,7 +167,6 @@ def main(launcher_agent_group: LauncherAgentGroup):
163
167
164
168
if any (s .is_failed () for s in agent_statuses ):
165
169
raise RuntimeError ()
166
-
167
170
except :
168
171
raise
169
172
finally :
0 commit comments