Skip to content

Commit 8ad727d

Browse files
committed
add GROUP_RANK as node rank
1 parent 45fe57c commit 8ad727d

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

src/torchrunx/agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
8383
backend=launcher_payload.backend,
8484
rank=worker_global_ranks[i],
8585
local_rank=i,
86+
node_rank=agent_rank,
8687
local_world_size=num_workers,
8788
world_size=worker_world_size,
8889
hostname=launcher_payload.hostnames[agent_rank],

src/torchrunx/worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class WorkerArgs:
3232
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None
3333
rank: int
3434
local_rank: int
35+
node_rank: int
3536
local_world_size: int
3637
world_size: int
3738
hostname: str
@@ -79,6 +80,7 @@ def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | Exc
7980

8081
os.environ["RANK"] = str(worker_args.rank)
8182
os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
83+
os.environ["GROUP_RANK"] = str(worker_args.node_rank)
8284
os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
8385
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
8486
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname

0 commit comments

Comments
 (0)