Skip to content

Commit 002d09a

Browse files
amirafzalimeta-codesync[bot]
authored andcommitted
move monarch to HostMeshV1 (#285)
Summary: Pull Request resolved: #285 1. We can ride off SlurmJob to simplify the allocation logic 2. Some small modifications to ensure HostMeshV1 support Reviewed By: d4l3k, colin2328 Differential Revision: D84853200 fbshipit-source-id: 35ae0f547fa2f7c9ac871dff406101007d58cae6
1 parent e4d99b5 commit 002d09a

File tree

2 files changed

+43
-70
lines changed

2 files changed

+43
-70
lines changed

examples/monarch/train_distributed.py

Lines changed: 41 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,78 +14,51 @@
1414
from typing import Dict
1515

1616
import torch
17-
from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints, AllocSpec
18-
from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer
19-
from monarch.actor import Actor, current_rank, endpoint, ProcMesh, this_host
20-
from monarch.tools import commands
21-
from monarch.tools.components import hyperactor
22-
from monarch.tools.config import Config
17+
from monarch.actor import Actor, current_rank, endpoint, HostMesh, ProcMesh, this_host
18+
from monarch.job import SlurmJob
2319
from monarch.utils import setup_env_for_distributed
2420
from torchtitan.config import ConfigManager, JobConfig
2521
from torchtitan.tools.logging import init_logger, logger
2622
from torchtitan.train import Trainer
2723
from utils.failure import Failure, FailureActor, FailureController
2824

2925

30-
# ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
26+
# ==== Allocation boilerplate ====
3127
class MonarchSlurm:
32-
# Cluster Configuration - update these values for your specific cluster
33-
machine: str = "gpu.xlarge"
34-
machine_memory: int = 2062607
3528
job_name_prefix: str = "monarch-torchft"
3629

3730
def __init__(self):
38-
self.job_handles: Dict[str, str] = {}
31+
self.job_handles: Dict[str, SlurmJob] = {}
3932
atexit.register(self.kill_jobs)
4033

41-
def get_config(self, mesh_name: str, nodes_per_mesh: int) -> Config:
42-
mesh = [f"{mesh_name}:{nodes_per_mesh}:{MonarchSlurm.machine}"]
43-
# to enable relative import of utils on actors
44-
current_dir = os.path.dirname(os.path.abspath(__file__))
45-
env = {"PYTHONPATH": current_dir}
46-
47-
appdef = hyperactor.host_mesh(meshes=mesh, env=env)
48-
49-
for role in appdef.roles:
50-
role.resource.memMB = MonarchSlurm.machine_memory
51-
52-
return Config(scheduler="slurm", appdef=appdef)
53-
54-
async def get_or_create_job(self, mesh_name: str, nodes_per_mesh: int = 1) -> None:
55-
config = self.get_config(mesh_name, nodes_per_mesh)
56-
job_name = f"{MonarchSlurm.job_name_prefix}-{mesh_name}"
57-
server_spec = await commands.get_or_create(job_name, config, force_restart=True)
58-
self.job_handles[mesh_name] = server_spec.name
34+
async def get_or_create_job(
35+
self, mesh_name: str, nodes_per_mesh: int = 1, gpus_per_node: int = 8
36+
) -> None:
37+
job = SlurmJob(
38+
meshes={mesh_name: nodes_per_mesh},
39+
gpus_per_node=gpus_per_node,
40+
job_name=f"{self.job_name_prefix}-{mesh_name}",
41+
)
42+
job.apply()
43+
self.job_handles[mesh_name] = job
5944

6045
def kill_jobs(self):
6146
for mesh_name in self.job_handles.keys():
6247
self.kill_job(mesh_name)
6348

6449
def kill_job(self, mesh_name: str):
6550
try:
66-
job_handle = self.job_handles[mesh_name]
51+
job = self.job_handles[mesh_name]
6752
logger.info(f"Destroying job for mesh {mesh_name}")
68-
commands.kill(f"slurm:///{job_handle}")
53+
job.kill()
6954
except Exception as e:
70-
logger.warning(f"Failed to destroy job for {mesh_name}: {e}")
71-
72-
def proc_mesh(
73-
self,
74-
mesh_name: str,
75-
num_hosts: int = 1,
76-
num_gpus: int = 8,
77-
) -> ProcMesh:
78-
allocator = RemoteAllocator(
79-
world_id=MonarchSlurm.job_name_prefix,
80-
initializer=TorchXRemoteAllocInitializer(
81-
f"slurm:///{self.job_handles[mesh_name]}"
82-
),
83-
)
84-
alloc = allocator.allocate(
85-
AllocSpec(AllocConstraints(), hosts=num_hosts, gpus=num_gpus)
86-
)
55+
logger.exception(f"Failed to destroy job for {mesh_name}: {e}")
8756

88-
return ProcMesh.from_alloc(alloc)
57+
def proc_mesh(self, mesh_name: str, num_procs: int) -> ProcMesh:
58+
job = self.job_handles[mesh_name]
59+
mesh: HostMesh = getattr(job.state(cached_path=None), mesh_name)
60+
proc_mesh = mesh.spawn_procs({"gpus": num_procs})
61+
return proc_mesh
8962

9063

9164
# ==== allocation boilerplate ====
@@ -177,13 +150,12 @@ async def start_replica(self) -> None:
177150
init_logger()
178151
logger.info(f"{self.uid} Spawning trainers")
179152

180-
trainers_proc_mesh: ProcMesh | None = None
181-
try:
182-
trainers_proc_mesh = self.scheduler.proc_mesh(
183-
f"replica_{self.replica_id}",
184-
self.spec.hosts_per_replica,
185-
self.spec.gpus_per_node,
186-
)
153+
trainers_proc_mesh = self.scheduler.proc_mesh(
154+
f"replica_{self.replica_id}",
155+
num_procs=self.spec.gpus_per_node,
156+
)
157+
158+
async with trainers_proc_mesh:
187159
await trainers_proc_mesh.logging_option(stream_to_client=True)
188160
await setup_env_for_distributed(trainers_proc_mesh)
189161

@@ -200,11 +172,6 @@ async def start_replica(self) -> None:
200172

201173
logger.info(f"{self.uid} Starting trainers")
202174
await training_actors.start_training.call(self.spec.lighthouse_address)
203-
await trainers_proc_mesh.stop()
204-
except Exception as e:
205-
if trainers_proc_mesh:
206-
await trainers_proc_mesh.stop()
207-
raise e
208175

209176
@endpoint
210177
async def inject_failure(self, failure_type: Failure):
@@ -216,8 +183,7 @@ async def inject_failure(self, failure_type: Failure):
216183

217184
await self.failure_actors.fail.choose(failure_type)
218185
except Exception as e:
219-
error_msg = f"{self.uid} Injected failure: {e}"
220-
logger.error(error_msg)
186+
logger.exception(f"{self.uid} Injected failure: {e}")
221187
else:
222188
error_msg = f"{self.uid} No failure actors available"
223189
logger.error(error_msg)
@@ -268,7 +234,7 @@ async def start_training(self) -> None:
268234
async def start_lighthouse(self) -> None:
269235
if self.spec.remote_lighthouse:
270236
await self.scheduler.get_or_create_job("lighthouse")
271-
self.lighthouse_mesh = self.scheduler.proc_mesh("lighthouse", num_gpus=1)
237+
self.lighthouse_mesh = self.scheduler.proc_mesh("lighthouse", num_procs=1)
272238
else:
273239
self.lighthouse_mesh = this_host().spawn_procs({"gpus": 1})
274240

@@ -287,7 +253,7 @@ async def stop_lighthouse(self) -> None:
287253
await self.lighthouse_mesh.stop()
288254
logger.info("[Controller] Lighthouse stopped")
289255
except Exception as e:
290-
logger.warning(f"[Controller] Failed to stop lighthouse: {e}")
256+
logger.exception(f"[Controller] Failed to stop lighthouse: {e}")
291257

292258
async def _run_replica(self, replica_id: int, attempt_number: int) -> None:
293259
if attempt_number >= MAX_ATTEMPT:
@@ -300,7 +266,7 @@ async def _run_replica(self, replica_id: int, attempt_number: int) -> None:
300266
await self._teardown(replica_id)
301267
except Exception as e:
302268
await self._teardown(replica_id)
303-
logger.info(f"[Controller] replica {replica_id} failed: {e}")
269+
logger.exception(f"[Controller] replica {replica_id} failed: {e}")
304270
await self._run_replica(replica_id, attempt_number + 1)
305271

306272
async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> None:
@@ -332,11 +298,18 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
332298
async def _teardown(self, replica_id: int) -> None:
333299
try:
334300
replica = self.replicas[replica_id]
335-
await replica.proc_mesh.stop()
301+
try:
302+
await replica.proc_mesh.stop()
303+
except Exception as e:
304+
logger.exception(
305+
f"[Controller] Failed to stop replica {replica_id}, it may already be stopped. {e}"
306+
)
336307
del self.replicas[replica_id]
337308
del replica.proc_mesh
338309
except Exception as e:
339-
logger.error(f"[Controller] Failed to _teardown replica {replica_id}: {e}")
310+
logger.exception(
311+
f"[Controller] Failed to teardown replica {replica_id}: {e}"
312+
)
340313

341314

342315
# === CLI / CONFIG === #

examples/monarch/utils/failure.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ async def execute_failures(
126126
f"[FailureController] Failure injection ({last_failure}) sent to replica {last_replica.rid}"
127127
)
128128
except Exception as e:
129-
logger.info(
129+
logger.exception(
130130
f"[FailureController] Failed to inject failure in replica {last_replica.rid}: {e}"
131131
)
132132
await asyncio.sleep(rest_time)
133133
except Exception as e:
134-
logger.info(
134+
logger.exception(
135135
f"[FailureController] Something went wrong while injecting failure: {e}"
136136
)

0 commit comments

Comments
 (0)