@@ -119,39 +119,40 @@ def main(launcher_agent_group: LauncherAgentGroup):
119
119
worker_log_files = launcher_payload .worker_log_files [agent_rank ]
120
120
num_workers = len (worker_global_ranks )
121
121
122
- if torch .__version__ >= ' 2.3' :
122
+ if torch .__version__ >= " 2.3" :
123
123
# DefaultLogsSpecs only exists in torch >= 2.3
124
124
from torch .distributed .elastic .multiprocessing import DefaultLogsSpecs
125
+
125
126
log_arg = DefaultLogsSpecs (log_dir = tempfile .mkdtemp ())
126
127
else :
127
128
log_arg = tempfile .mkdtemp ()
128
129
129
130
# spawn workers
130
-
131
+
131
132
ctx = start_processes (
132
- f"{ hostname } _" ,
133
- entrypoint ,
134
- {
135
- i : (
136
- WorkerArgs (
137
- function = launcher_payload .fn ,
138
- master_hostname = main_agent_payload .hostname ,
139
- master_port = main_agent_payload .port ,
140
- backend = launcher_payload .backend ,
141
- rank = worker_global_ranks [i ],
142
- local_rank = i ,
143
- local_world_size = num_workers ,
144
- world_size = worker_world_size ,
145
- log_file = worker_log_files [i ],
146
- timeout = launcher_payload .timeout ,
147
- ).to_bytes (),
148
- )
149
- for i in range (num_workers )
150
- },
151
- {i : {} for i in range (num_workers )},
152
- log_arg # type: ignore
133
+ f"{ hostname } _" ,
134
+ entrypoint ,
135
+ {
136
+ i : (
137
+ WorkerArgs (
138
+ function = launcher_payload .fn ,
139
+ master_hostname = main_agent_payload .hostname ,
140
+ master_port = main_agent_payload .port ,
141
+ backend = launcher_payload .backend ,
142
+ rank = worker_global_ranks [i ],
143
+ local_rank = i ,
144
+ local_world_size = num_workers ,
145
+ world_size = worker_world_size ,
146
+ log_file = worker_log_files [i ],
147
+ timeout = launcher_payload .timeout ,
148
+ ).to_bytes (),
153
149
)
154
-
150
+ for i in range (num_workers )
151
+ },
152
+ {i : {} for i in range (num_workers )},
153
+ log_arg , # type: ignore
154
+ )
155
+
155
156
try :
156
157
status = AgentStatus ()
157
158
while True :
0 commit comments