1414from typing import Dict
1515
1616import torch
17- from monarch ._rust_bindings .monarch_hyperactor .alloc import AllocConstraints , AllocSpec
18- from monarch ._src .actor .allocator import RemoteAllocator , TorchXRemoteAllocInitializer
19- from monarch .actor import Actor , current_rank , endpoint , ProcMesh , this_host
20- from monarch .tools import commands
21- from monarch .tools .components import hyperactor
22- from monarch .tools .config import Config
17+ from monarch .actor import Actor , current_rank , endpoint , HostMesh , ProcMesh , this_host
18+ from monarch .job import SlurmJob
2319from monarch .utils import setup_env_for_distributed
2420from torchtitan .config import ConfigManager , JobConfig
2521from torchtitan .tools .logging import init_logger , logger
2622from torchtitan .train import Trainer
2723from utils .failure import Failure , FailureActor , FailureController
2824
2925
30- # ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
26+ # ==== Allocation boilerplate ====
3127class MonarchSlurm :
32- # Cluster Configuration - update these values for your specific cluster
33- machine : str = "gpu.xlarge"
34- machine_memory : int = 2062607
3528 job_name_prefix : str = "monarch-torchft"
3629
3730 def __init__ (self ):
38- self .job_handles : Dict [str , str ] = {}
31+ self .job_handles : Dict [str , SlurmJob ] = {}
3932 atexit .register (self .kill_jobs )
4033
41- def get_config (self , mesh_name : str , nodes_per_mesh : int ) -> Config :
42- mesh = [f"{ mesh_name } :{ nodes_per_mesh } :{ MonarchSlurm .machine } " ]
43- # to enable relative import of utils on actors
44- current_dir = os .path .dirname (os .path .abspath (__file__ ))
45- env = {"PYTHONPATH" : current_dir }
46-
47- appdef = hyperactor .host_mesh (meshes = mesh , env = env )
48-
49- for role in appdef .roles :
50- role .resource .memMB = MonarchSlurm .machine_memory
51-
52- return Config (scheduler = "slurm" , appdef = appdef )
53-
54- async def get_or_create_job (self , mesh_name : str , nodes_per_mesh : int = 1 ) -> None :
55- config = self .get_config (mesh_name , nodes_per_mesh )
56- job_name = f"{ MonarchSlurm .job_name_prefix } -{ mesh_name } "
57- server_spec = await commands .get_or_create (job_name , config , force_restart = True )
58- self .job_handles [mesh_name ] = server_spec .name
34+ async def get_or_create_job (
35+ self , mesh_name : str , nodes_per_mesh : int = 1 , gpus_per_node : int = 8
36+ ) -> None :
37+ job = SlurmJob (
38+ meshes = {mesh_name : nodes_per_mesh },
39+ gpus_per_node = gpus_per_node ,
40+ job_name = f"{ self .job_name_prefix } -{ mesh_name } " ,
41+ )
42+ job .apply ()
43+ self .job_handles [mesh_name ] = job
5944
6045 def kill_jobs (self ):
6146 for mesh_name in self .job_handles .keys ():
6247 self .kill_job (mesh_name )
6348
6449 def kill_job (self , mesh_name : str ):
6550 try :
66- job_handle = self .job_handles [mesh_name ]
51+ job = self .job_handles [mesh_name ]
6752 logger .info (f"Destroying job for mesh { mesh_name } " )
68- commands .kill (f"slurm:/// { job_handle } " )
53+ job .kill ()
6954 except Exception as e :
70- logger .warning (f"Failed to destroy job for { mesh_name } : { e } " )
71-
72- def proc_mesh (
73- self ,
74- mesh_name : str ,
75- num_hosts : int = 1 ,
76- num_gpus : int = 8 ,
77- ) -> ProcMesh :
78- allocator = RemoteAllocator (
79- world_id = MonarchSlurm .job_name_prefix ,
80- initializer = TorchXRemoteAllocInitializer (
81- f"slurm:///{ self .job_handles [mesh_name ]} "
82- ),
83- )
84- alloc = allocator .allocate (
85- AllocSpec (AllocConstraints (), hosts = num_hosts , gpus = num_gpus )
86- )
55+ logger .exception (f"Failed to destroy job for { mesh_name } : { e } " )
8756
88- return ProcMesh .from_alloc (alloc )
57+ def proc_mesh (self , mesh_name : str , num_procs : int ) -> ProcMesh :
58+ job = self .job_handles [mesh_name ]
59+ mesh : HostMesh = getattr (job .state (cached_path = None ), mesh_name )
60+ proc_mesh = mesh .spawn_procs ({"gpus" : num_procs })
61+ return proc_mesh
8962
9063
9164# ==== allocation boilerplate ====
@@ -177,13 +150,12 @@ async def start_replica(self) -> None:
177150 init_logger ()
178151 logger .info (f"{ self .uid } Spawning trainers" )
179152
180- trainers_proc_mesh : ProcMesh | None = None
181- try :
182- trainers_proc_mesh = self .scheduler .proc_mesh (
183- f"replica_{ self .replica_id } " ,
184- self .spec .hosts_per_replica ,
185- self .spec .gpus_per_node ,
186- )
153+ trainers_proc_mesh = self .scheduler .proc_mesh (
154+ f"replica_{ self .replica_id } " ,
155+ num_procs = self .spec .gpus_per_node ,
156+ )
157+
158+ async with trainers_proc_mesh :
187159 await trainers_proc_mesh .logging_option (stream_to_client = True )
188160 await setup_env_for_distributed (trainers_proc_mesh )
189161
@@ -200,11 +172,6 @@ async def start_replica(self) -> None:
200172
201173 logger .info (f"{ self .uid } Starting trainers" )
202174 await training_actors .start_training .call (self .spec .lighthouse_address )
203- await trainers_proc_mesh .stop ()
204- except Exception as e :
205- if trainers_proc_mesh :
206- await trainers_proc_mesh .stop ()
207- raise e
208175
209176 @endpoint
210177 async def inject_failure (self , failure_type : Failure ):
@@ -216,8 +183,7 @@ async def inject_failure(self, failure_type: Failure):
216183
217184 await self .failure_actors .fail .choose (failure_type )
218185 except Exception as e :
219- error_msg = f"{ self .uid } Injected failure: { e } "
220- logger .error (error_msg )
186+ logger .exception (f"{ self .uid } Injected failure: { e } " )
221187 else :
222188 error_msg = f"{ self .uid } No failure actors available"
223189 logger .error (error_msg )
@@ -268,7 +234,7 @@ async def start_training(self) -> None:
268234 async def start_lighthouse (self ) -> None :
269235 if self .spec .remote_lighthouse :
270236 await self .scheduler .get_or_create_job ("lighthouse" )
271- self .lighthouse_mesh = self .scheduler .proc_mesh ("lighthouse" , num_gpus = 1 )
237+ self .lighthouse_mesh = self .scheduler .proc_mesh ("lighthouse" , num_procs = 1 )
272238 else :
273239 self .lighthouse_mesh = this_host ().spawn_procs ({"gpus" : 1 })
274240
@@ -287,7 +253,7 @@ async def stop_lighthouse(self) -> None:
287253 await self .lighthouse_mesh .stop ()
288254 logger .info ("[Controller] Lighthouse stopped" )
289255 except Exception as e :
290- logger .warning (f"[Controller] Failed to stop lighthouse: { e } " )
256+ logger .exception (f"[Controller] Failed to stop lighthouse: { e } " )
291257
292258 async def _run_replica (self , replica_id : int , attempt_number : int ) -> None :
293259 if attempt_number >= MAX_ATTEMPT :
@@ -300,7 +266,7 @@ async def _run_replica(self, replica_id: int, attempt_number: int) -> None:
300266 await self ._teardown (replica_id )
301267 except Exception as e :
302268 await self ._teardown (replica_id )
303- logger .info (f"[Controller] replica { replica_id } failed: { e } " )
269+ logger .exception (f"[Controller] replica { replica_id } failed: { e } " )
304270 await self ._run_replica (replica_id , attempt_number + 1 )
305271
306272 async def _spin_up_replica (self , replica_id : int , attempt_number : int = 0 ) -> None :
@@ -332,11 +298,18 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
332298 async def _teardown (self , replica_id : int ) -> None :
333299 try :
334300 replica = self .replicas [replica_id ]
335- await replica .proc_mesh .stop ()
301+ try :
302+ await replica .proc_mesh .stop ()
303+ except Exception as e :
304+ logger .exception (
305+ f"[Controller] Failed to stop replica { replica_id } , it may already be stopped. { e } "
306+ )
336307 del self .replicas [replica_id ]
337308 del replica .proc_mesh
338309 except Exception as e :
339- logger .error (f"[Controller] Failed to _teardown replica { replica_id } : { e } " )
310+ logger .exception (
311+ f"[Controller] Failed to teardown replica { replica_id } : { e } "
312+ )
340313
341314
342315# === CLI / CONFIG === #
0 commit comments