Skip to content

Commit 4a04388

Browse files
committed
Renamed timeout argument to worker_timeout
1 parent 6de9197 commit 4a04388

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

src/torchrunx/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def main(
105105
local_world_size=num_workers,
106106
world_size=worker_world_size,
107107
hostname=launcher_payload.hostnames[agent_rank],
108-
timeout=launcher_payload.timeout,
108+
timeout=launcher_payload.worker_timeout,
109109
).serialize(),
110110
)
111111
for i in range(num_workers)

src/torchrunx/integrations/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def add_torchrunx_argument_group(parser: ArgumentParser) -> None:
4646
)
4747

4848
group.add_argument(
49-
"--timeout",
49+
"--worker-timeout",
5050
type=int,
5151
default=600,
5252
help="Worker process group timeout in seconds. Default: 600.",
@@ -112,7 +112,7 @@ def launcher_from_args(args: Namespace) -> Launcher:
112112
else:
113113
backend = _backend # pyright: ignore [reportAssignmentType]
114114

115-
timeout: int = args.timeout
115+
worker_timeout: int = args.worker_timeout
116116
agent_timeout: int = args.agent_timeout
117117

118118
copy_env_vars: tuple[str, ...] = tuple(args.copy_env_vars)
@@ -131,7 +131,7 @@ def launcher_from_args(args: Namespace) -> Launcher:
131131
workers_per_host=workers_per_host,
132132
ssh_config_file=ssh_config_file,
133133
backend=backend,
134-
timeout=timeout,
134+
worker_timeout=worker_timeout,
135135
agent_timeout=agent_timeout,
136136
copy_env_vars=copy_env_vars,
137137
extra_env_vars=extra_env_vars,

src/torchrunx/launcher.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class Launcher:
6161
"""`Backend <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Backend>`_
6262
for worker process group. By default, NCCL (GPU backend).
6363
Use GLOO for CPU backend. ``None`` for no process group."""
64-
timeout: int = 600
64+
worker_timeout: int = 600
6565
"""Worker process group timeout (seconds)."""
6666
agent_timeout: int = 30
6767
"""Agent communication timeout (seconds)."""
@@ -119,7 +119,8 @@ def run( # noqa: C901, PLR0912, PLR0915
119119
)
120120
ssh_config_file = self.ssh_config_file
121121
backend = self.backend
122-
timeout = self.timeout
122+
worker_timeout = self.worker_timeout
123+
agent_timeout = self.agent_timeout
123124

124125
env_vars = {
125126
k: v
@@ -161,7 +162,7 @@ def handler_factory() -> list[logging.Handler]:
161162
worker_global_ranks=worker_global_ranks,
162163
worker_world_size=sum(workers_per_host),
163164
backend=backend,
164-
timeout=timeout,
165+
worker_timeout=worker_timeout,
165166
)
166167
agent_payloads = None
167168

@@ -201,7 +202,7 @@ def handler_factory() -> list[logging.Handler]:
201202
env_vars=env_vars,
202203
env_file=env_file,
203204
hostname=hostname,
204-
agent_timeout=self.agent_timeout,
205+
agent_timeout=agent_timeout,
205206
),
206207
hostname=hostname,
207208
ssh_config_file=ssh_config_file,
@@ -217,7 +218,7 @@ def handler_factory() -> list[logging.Handler]:
217218
launcher_port=launcher_port,
218219
world_size=world_size,
219220
rank=0,
220-
agent_timeout=self.agent_timeout,
221+
agent_timeout=agent_timeout,
221222
)
222223

223224
# Sync initial payloads between launcher and agents

src/torchrunx/utils/comm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class LauncherPayload:
121121
worker_global_ranks: list[list[int]]
122122
worker_world_size: int
123123
backend: Literal["nccl", "gloo", "mpi", "ucc"] | None
124-
timeout: int
124+
worker_timeout: int
125125

126126

127127
@dataclass

0 commit comments

Comments
 (0)