Skip to content

Commit 6736b19

Browse files
authored
Merge pull request #92 from apoorvkh/agent-timeout-arg
Added argument for agent timeout
2 parents 0fd04ed + 920951f commit 6736b19

File tree

6 files changed

+32
-10
lines changed

6 files changed

+32
-10
lines changed

src/torchrunx/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
parser.add_argument("--world-size", type=int)
1313
parser.add_argument("--rank", type=int)
1414
parser.add_argument("--hostname", type=str)
15+
parser.add_argument("--agent-timeout", type=int, default=30)
1516
args = parser.parse_args()
1617

1718
main(
@@ -22,4 +23,5 @@
2223
logger_hostname=args.launcher_hostname,
2324
logger_port=args.logger_port,
2425
hostname=args.hostname,
26+
agent_timeout=args.agent_timeout,
2527
)

src/torchrunx/agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def main(
3131
logger_hostname: str,
3232
logger_port: int,
3333
hostname: str,
34+
agent_timeout: int = 30,
3435
) -> None:
3536
"""Main function for agent processes (started on each node).
3637
@@ -46,6 +47,7 @@ def main(
4647
logger_hostname: Hostname of the logging server.
4748
logger_port: Port for the logging server.
4849
hostname: Hostname of this agent.
50+
agent_timeout: Agent communication timeout (seconds).
4951
"""
5052
# Setup logging & stream logs to server
5153

@@ -63,6 +65,7 @@ def main(
6365
launcher_port=launcher_port,
6466
world_size=world_size,
6567
rank=rank,
68+
agent_timeout=agent_timeout,
6669
)
6770

6871
agent_rank = launcher_agent_group.rank - 1
@@ -102,7 +105,7 @@ def main(
102105
local_world_size=num_workers,
103106
world_size=worker_world_size,
104107
hostname=launcher_payload.hostnames[agent_rank],
105-
timeout=launcher_payload.timeout,
108+
timeout=launcher_payload.worker_timeout,
106109
).serialize(),
107110
)
108111
for i in range(num_workers)

src/torchrunx/integrations/cli.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,19 @@ def add_torchrunx_argument_group(parser: ArgumentParser) -> None:
4646
)
4747

4848
group.add_argument(
49-
"--timeout",
49+
"--worker-timeout",
5050
type=int,
5151
default=600,
5252
help="Worker process group timeout in seconds. Default: 600.",
5353
)
5454

55+
group.add_argument(
56+
"--agent-timeout",
57+
type=int,
58+
default=180,
59+
help="Agent communication timeout in seconds. Default: 180.",
60+
)
61+
5562
group.add_argument(
5663
"--copy-env-vars",
5764
type=str,
@@ -105,7 +112,8 @@ def launcher_from_args(args: Namespace) -> Launcher:
105112
else:
106113
backend = _backend # pyright: ignore [reportAssignmentType]
107114

108-
timeout: int = args.timeout
115+
worker_timeout: int = args.worker_timeout
116+
agent_timeout: int = args.agent_timeout
109117

110118
copy_env_vars: tuple[str, ...] = tuple(args.copy_env_vars)
111119

@@ -123,7 +131,8 @@ def launcher_from_args(args: Namespace) -> Launcher:
123131
workers_per_host=workers_per_host,
124132
ssh_config_file=ssh_config_file,
125133
backend=backend,
126-
timeout=timeout,
134+
worker_timeout=worker_timeout,
135+
agent_timeout=agent_timeout,
127136
copy_env_vars=copy_env_vars,
128137
extra_env_vars=extra_env_vars,
129138
env_file=env_file,

src/torchrunx/launcher.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ class Launcher:
6161
"""`Backend <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend>`_
6262
for worker process group. By default, NCCL (GPU backend).
6363
Use GLOO for CPU backend. ``None`` for no process group."""
64-
timeout: int = 600
64+
worker_timeout: int = 600
6565
"""Worker process group timeout (seconds)."""
66+
agent_timeout: int = 180
67+
"""Agent communication timeout (seconds)."""
6668
copy_env_vars: tuple[str, ...] = DEFAULT_ENV_VARS_FOR_COPY
6769
"""Environment variables to copy from the launcher process to workers.
6870
Supports Unix pattern matching syntax."""
@@ -117,7 +119,8 @@ def run( # noqa: C901, PLR0912, PLR0915
117119
)
118120
ssh_config_file = self.ssh_config_file
119121
backend = self.backend
120-
timeout = self.timeout
122+
worker_timeout = self.worker_timeout
123+
agent_timeout = self.agent_timeout
121124

122125
env_vars = {
123126
k: v
@@ -159,7 +162,7 @@ def handler_factory() -> list[logging.Handler]:
159162
worker_global_ranks=worker_global_ranks,
160163
worker_world_size=sum(workers_per_host),
161164
backend=backend,
162-
timeout=timeout,
165+
worker_timeout=worker_timeout,
163166
)
164167
agent_payloads = None
165168

@@ -199,6 +202,7 @@ def handler_factory() -> list[logging.Handler]:
199202
env_vars=env_vars,
200203
env_file=env_file,
201204
hostname=hostname,
205+
agent_timeout=agent_timeout,
202206
),
203207
hostname=hostname,
204208
ssh_config_file=ssh_config_file,
@@ -214,6 +218,7 @@ def handler_factory() -> list[logging.Handler]:
214218
launcher_port=launcher_port,
215219
world_size=world_size,
216220
rank=0,
221+
agent_timeout=agent_timeout,
217222
)
218223

219224
# Sync initial payloads between launcher and agents

src/torchrunx/utils/comm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class LauncherAgentGroup(Generic[FunctionR]):
4646
launcher_port: int
4747
world_size: int
4848
rank: int
49+
agent_timeout: int = 30
4950

5051
def __post_init__(self) -> None:
5152
"""Initialize process group.
@@ -63,7 +64,7 @@ def __post_init__(self) -> None:
6364
world_size=self.world_size,
6465
is_master=(self.rank == 0),
6566
),
66-
timeout=datetime.timedelta(seconds=30),
67+
timeout=datetime.timedelta(seconds=self.agent_timeout),
6768
)
6869

6970
def _all_gather(self, obj: ObjectT) -> list[ObjectT]:
@@ -120,7 +121,7 @@ class LauncherPayload:
120121
worker_global_ranks: list[list[int]]
121122
worker_world_size: int
122123
backend: Literal["nccl", "gloo", "mpi", "ucc"] | None
123-
timeout: int
124+
worker_timeout: int
124125

125126

126127
@dataclass

src/torchrunx/utils/environment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def build_launch_command(
121121
env_vars: dict[str, str],
122122
env_file: str | os.PathLike | None,
123123
hostname: str,
124+
agent_timeout: int,
124125
) -> str:
125126
"""Generator for command to launch torchrunx on an agent."""
126127
# shlex.quote prevents shell injection here (resolves S602 in execute_command)
@@ -147,7 +148,8 @@ def build_launch_command(
147148
f"--logger-port {logger_port} "
148149
f"--world-size {world_size} "
149150
f"--rank {rank} "
150-
f"--hostname {hostname}",
151+
f"--hostname {hostname} "
152+
f"--agent-timeout {agent_timeout}",
151153
)
152154

153155
return " && ".join(commands)

0 commit comments

Comments
 (0)