Skip to content

Commit cfd967a

Browse files
authored
Merge pull request #93 from apoorvkh/update-docs
Updated example scripts and README
2 parents 6736b19 + e32695e commit cfd967a

14 files changed

+149
-128
lines changed

README.md

Lines changed: 15 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ By [Apoorv Khandelwal](https://apoorvkh.com) and [Peter Curtin](https://github.c
1313

1414
---
1515

16-
**`torchrunx`** is a *functional* utility for distributing PyTorch code across devices. This is a [more convenient, robust, and featureful](#torchrunx-uniquely-offers) alternative to CLI-based launchers, like `torchrun`, `accelerate launch`, and `deepspeed`.
16+
**`torchrunx`** is a *functional* utility for distributing PyTorch code across devices. This is a [more convenient, robust, and featureful](https://torchrun.xyz/features.html) alternative to CLI-based launchers, like `torchrun`, `accelerate launch`, and `deepspeed`.
1717

1818
It enables complex workflows within a single script and has useful features even if only using 1 GPU.
1919

@@ -29,20 +29,13 @@ Requires: Linux. If using multiple machines: SSH & shared filesystem.
2929

3030
Suppose we have some distributed training function (which needs to run on every GPU):
3131

32-
```python
33-
def distributed_training(model: nn.Module, num_steps: int) -> nn.Module: ...
34-
```
35-
36-
<details>
37-
<summary><b>Implementation of <code>distributed_training</code> (click to expand)</b></summary>
38-
3932
```python
4033
from __future__ import annotations
4134
import os
4235
import torch
4336
import torch.nn as nn
4437

45-
def distributed_training(num_steps: int = 10) -> nn.Module | None:
38+
def distributed_training(output_dir: str, num_steps: int = 10) -> str | None:
4639
rank = int(os.environ['RANK'])
4740
local_rank = int(os.environ['LOCAL_RANK'])
4841

@@ -62,10 +55,13 @@ def distributed_training(num_steps: int = 10) -> nn.Module | None:
6255
optimizer.step()
6356

6457
if rank == 0:
65-
return model.cpu()
66-
```
58+
os.makedirs(output_dir, exist_ok=True)
59+
checkpoint_path = os.path.join(output_dir, "model.pt")
60+
torch.save(model, checkpoint_path)
61+
return checkpoint_path
6762

68-
</details>
63+
return None
64+
```
6965

7066
We can distribute and run this function (e.g. on 2 machines x 2 GPUs) using **`torchrunx`**!
7167

@@ -82,18 +78,20 @@ launcher = torchrunx.Launcher(
8278

8379
results = launcher.run(
8480
distributed_training,
85-
num_steps = 10
81+
output_dir = "outputs",
82+
num_steps = 10,
8683
)
8784
```
8885

8986
Once completed, you can retrieve the results and process them as you wish.
9087

9188
```python
92-
trained_model: nn.Module = results.rank(0)
93-
# or: results.index(hostname="localhost", local_rank=0)
89+
checkpoint_path: str = results.rank(0)
90+
# or: results.index(hostname="localhost", local_rank=0)
9491

9592
# and continue your script
96-
torch.save(trained_model.state_dict(), "outputs/model.pth")
93+
model = torch.load(checkpoint_path, weights_only=False)
94+
model.eval()
9795
```
9896

9997
**See more examples where we fine-tune LLMs using:**
@@ -102,43 +100,4 @@ torch.save(trained_model.state_dict(), "outputs/model.pth")
102100
- [PyTorch Lightning](https://torchrun.xyz/examples/lightning.html)
103101
- [Accelerate](https://torchrun.xyz/examples/accelerate.html)
104102

105-
**Refer to our [API](https://torchrun.xyz/api.html) and [Usage](https://torchrun.xyz/usage/general.html) for many more capabilities!**
106-
107-
---
108-
109-
## `torchrunx` uniquely offers
110-
111-
1. **An automatic launcher that "just works" for everyone** 🚀
112-
113-
> `torchrunx` is an SSH-based, pure-Python library that is universally easy to install.<br>
114-
> No system-specific dependencies and orchestration for *automatic* multi-node distribution.
115-
116-
2. **Conventional CLI commands** 🖥️
117-
118-
> Run familiar commands, like `python my_script.py ...`, and customize arguments as you wish.
119-
>
120-
> Other launchers override `python` in a cumbersome way: e.g. `torchrun --nproc_per_node=2 --nnodes=2 --node_rank=0 --master_addr=100.43.331.111 --master_port=1234 my_script.py ...`.
121-
122-
3. **Support for more complex workflows in a single script** 🎛️
123-
124-
> Your workflow may have steps that are complex (e.g. pre-train, fine-tune, test) or may different parallelizations (e.g. training on 8 GPUs, testing on 1 GPU). In these cases, CLI-based launchers require each step to live in its own script. Our library treats these steps in a modular way, so they can cleanly fit together in a single script!
125-
>
126-
>
127-
> We clean memory leaks as we go, so previous steps won't crash or adversely affect future steps.
128-
129-
4. **Better handling of system failures. No more zombies!** 🧟
130-
131-
> With `torchrun`, your "work" is inherently coupled to your main Python process. If the system kills one of your workers (e.g. due to RAM OOM or segmentation faults), there is no way to fail gracefully in Python. Your processes might hang for 10 minutes (the NCCL timeout) or become perpetual zombies.
132-
>
133-
>
134-
> `torchrunx` decouples "launcher" and "worker" processes. If the system kills a worker, our launcher immediately raises a `WorkerFailure` exception, which users can handle as they wish. We always clean up all nodes, so no more zombies!
135-
136-
5. **Bonus features** 🎁
137-
138-
> - Return objects from distributed functions.
139-
> - [Automatic detection of SLURM environments.](https://torchrun.xyz/usage/slurm.html)
140-
> - Start multi-node training from Python notebooks!
141-
> - Our library is fully typed!
142-
> - Custom, fine-grained handling of [logging](https://torchrun.xyz/usage/logging.html), [environment variables](https://torchrun.xyz/usage/general.html#environment-variables), and [exception propagation](https://torchrun.xyz/usage/general.html#exceptions). We have nice defaults too: no more interleaved logs and irrelevant exceptions!
143-
144-
**On our [roadmap](https://github.com/apoorvkh/torchrunx/issues?q=is%3Aopen+is%3Aissue+label%3Aenhancement): higher-order parallelism, support for debuggers, and more!**
103+
**Refer to our [API](https://torchrun.xyz/api.html), [Features](https://torchrun.xyz/features.html), and [Usage](https://torchrun.xyz/usage/general.html) for many more capabilities!**

docs/source/artifacts/accelerate_help.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@ usage: accelerate_train.py [-h] [OPTIONS]
2222
│ ``"/etc/ssh/ssh_config"``. (default: None) │
2323
│ --launcher.backend {None,nccl,gloo,mpi,ucc} │
2424
│ `Backend │
25-
│ <https://pytorch.org/docs/stable/distributed.html#torch.distributed.B… │
25+
│ <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Ba │
26+
│ ckend>`_ │
2627
│ for worker process group. By default, NCCL (GPU backend). │
2728
│ Use GLOO for CPU backend. ``None`` for no process group. │
2829
│ (default: nccl) │
29-
│ --launcher.timeout INT
30+
│ --launcher.worker-timeout INT │
3031
│ Worker process group timeout (seconds). (default: 600) │
32+
│ --launcher.agent-timeout INT │
33+
│ Agent communication timeout (seconds). (default: 180) │
3134
│ --launcher.copy-env-vars [STR [STR ...]] │
3235
│ Environment variables to copy from the launcher process to workers. │
3336
│ Supports Unix pattern matching syntax. (default: PATH LD_LIBRARY │

docs/source/artifacts/argparse_cli_help.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
usage: -c [-h] [--hostnames HOSTNAMES [HOSTNAMES ...]]
22
[--workers-per-host WORKERS_PER_HOST [WORKERS_PER_HOST ...]]
33
[--ssh-config-file SSH_CONFIG_FILE]
4-
[--backend {nccl,gloo,mpi,ucc,None}] [--timeout TIMEOUT]
4+
[--backend {nccl,gloo,mpi,ucc,None}]
5+
[--worker-timeout WORKER_TIMEOUT] [--agent-timeout AGENT_TIMEOUT]
56
[--copy-env-vars COPY_ENV_VARS [COPY_ENV_VARS ...]]
67
[--extra-env-vars [EXTRA_ENV_VARS ...]] [--env-file ENV_FILE]
78

@@ -21,7 +22,10 @@ torchrunx:
2122
--backend {nccl,gloo,mpi,ucc,None}
2223
For worker process group. Default: 'nccl'. Use 'gloo'
2324
for CPU. 'None' to disable.
24-
--timeout TIMEOUT Worker process group timeout in seconds. Default: 600.
25+
--worker-timeout WORKER_TIMEOUT
26+
Worker process group timeout in seconds. Default: 600.
27+
--agent-timeout AGENT_TIMEOUT
28+
Agent communication timeout in seconds. Default: 180.
2529
--copy-env-vars COPY_ENV_VARS [COPY_ENV_VARS ...]
2630
Environment variables to copy to workers. Supports
2731
Unix pattern matching.

docs/source/artifacts/deepspeed_help.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
[2025-02-23 16:02:38,914] [WARNING] [real_accelerator.py:181:get_accelerator] Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.
2-
[2025-02-23 16:02:38,930] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cpu (auto detect)
1+
[2025-06-25 15:33:02,489] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cuda (auto detect)
2+
Warning: The cache directory for DeepSpeed Triton autotune, /users/akhand10/.triton/autotune, appears to be on an NFS system. While this is generally acceptable, if you experience slowdowns or hanging when DeepSpeed exits, it is recommended to set the TRITON_CACHE_DIR environment variable to a non-NFS path.
33
usage: deepspeed_train.py [-h] [OPTIONS]
44

55
╭─ options ──────────────────────────────────────────────────────────────────╮
@@ -42,8 +42,10 @@ usage: deepspeed_train.py [-h] [OPTIONS]
4242
│ for worker process group. By default, NCCL (GPU backend). │
4343
│ Use GLOO for CPU backend. ``None`` for no process group. │
4444
│ (default: nccl) │
45-
│ --launcher.timeout INT
45+
│ --launcher.worker-timeout INT │
4646
│ Worker process group timeout (seconds). (default: 600) │
47+
│ --launcher.agent-timeout INT │
48+
│ Agent communication timeout (seconds). (default: 180) │
4749
│ --launcher.copy-env-vars [STR [STR ...]] │
4850
│ Environment variables to copy from the launcher process to workers. │
4951
│ Supports Unix pattern matching syntax. (default: PATH LD_LIBRARY │

docs/source/artifacts/lightning_help.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@ usage: lightning_train.py [-h] [OPTIONS]
1818
│ ``"/etc/ssh/ssh_config"``. (default: None) │
1919
│ --launcher.backend {None,nccl,gloo,mpi,ucc} │
2020
│ `Backend │
21-
│ <https://pytorch.org/docs/stable/distributed.html#torch.distributed.B… │
21+
│ <https://pytorch.org/docs/stable/distributed.html#torch.distributed.Ba │
22+
│ ckend>`_ │
2223
│ for worker process group. By default, NCCL (GPU backend). │
2324
│ Use GLOO for CPU backend. ``None`` for no process group. │
2425
│ (default: nccl) │
25-
│ --launcher.timeout INT
26+
│ --launcher.worker-timeout INT │
2627
│ Worker process group timeout (seconds). (default: 600) │
28+
│ --launcher.agent-timeout INT │
29+
│ Agent communication timeout (seconds). (default: 180) │
2730
│ --launcher.copy-env-vars [STR [STR ...]] │
2831
│ Environment variables to copy from the launcher process to workers. │
2932
│ Supports Unix pattern matching syntax. (default: PATH LD_LIBRARY │

0 commit comments

Comments
 (0)