Skip to content

Commit 22b7db6

Browse files
committed
updated how it works
1 parent c295819 commit 22b7db6

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,9 @@ torch.save(trained_model.state_dict(), "output/model.pth")
112112
113113
5. **Bonus features** 🎁
114114

115-
> - Fine-grained, custom handling of logging, environment variables, and exception propagation. We have nice defaults too: no more interleaved logs and irrelevant exceptions!
116-
> - No need to manually set up a [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group)
115+
> - Typing for function arguments and return values.
116+
> - Custom, fine-grained handling of logging, environment variables, and exception propagation. We have nice defaults too: no more interleaved logs and irrelevant exceptions!
117+
> - No need to manually set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group)
117118
> - Automatic detection of SLURM environments.
118119
> - Start multi-node training from Python notebooks!
119120

docs/source/how_it_works.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
# How It Works
22

3-
If you want to (e.g.) train your model on several machines with **N** GPUs each, you should run your training function in **N** parallel processes on each machine. During training, each of these processes runs the same training code (i.e. your function) and communicate with each other (e.g. to synchronize gradients) using a [distributed process group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group).
3+
Suppose you want to run a script (`train.py`) on `N` machines (or "nodes") with `M` GPUs each.
44

5-
Your script can call our library (via `mod:torchrunx.launch`) and specify a function to distribute. The main process running your script is henceforth known as the **launcher** process.
5+
You'll need to start a new process for each GPU. Each process will execute your script in parallel and select its GPU based on the process rank. Your script will also form a [distributed group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) so the processes may communicate with each other (e.g. passing tensors).
66

7-
Our launcher process spawns an **agent** process (via SSH) on each machine. Each agent then spawns **N** processes (known as **workers**) on its machine. All workers form a process group (with the specified `mod:torchrunx.launch` `backend`) and run your function in parallel.
7+
Normally, you'd do this by running the `torchrun --node-rank {i} ... train.py ...` command on every machine. In short, you'll end up with a topology like:
88

9-
**Agent–Worker Communication.** Our agents poll their workers every second and time-out if unresponsive for 5 seconds. Upon polling, our agents receive `None` (if the worker is still running) or a [RunProcsResult](https://pytorch.org/docs/stable/elastic/multiprocessing.html#torch.distributed.elastic.multiprocessing.api.RunProcsResult), indicating that the workers have either completed (providing an object returned from or the exception raised by our function) or failed (e.g. due to segmentation fault or OS signal).
9+
>
1010
11-
**Launcher–Agent Communication.** The launcher and agents form a distributed group (with the CPU-based [GLOO backend](https://pytorch.org/docs/stable/distributed.html#backends)) for the communication purposes of our library. Our agents synchronize their own "statuses" with each other and the launcher. An agent's status can include whether it is running/failed/completed and the result of the function. If the launcher or any agent fails to synchronize, all raise a `mod:torchrunx.AgentFailedError` and terminate. If any worker fails or raises an exception, the launcher raises a `mod:torchrunx.WorkerFailedError` or that exception and terminates along with all the agents. If all agents succeed, the launcher returns the objects returned by each worker.
11+
As a side effect of this structure, every process will run until (1) script completion or (2) another process stops communicating (e.g. if killed by the system for abnormal reasons). The status of other processes is not actively communicated: so if some process is indeed killed, it would take 10 minutes (by default) for the remaining processes to time-out. Also, since this approach parallelizes the entire script, we can't catch and handle these system-level issues as exceptions.
12+
13+
`torchrunx` offers a functional interface, with a launcher–worker topology, instead.
14+
15+
>
16+
17+
{func}`torchrunx.Launcher.run` runs in the current, *launcher* process. It uses SSH to start an *agent* process on every node (specified in `hostnames`), which in turn spawn `M` *worker* processes. The workers form a distributed process group and each executes `func(*args, **kwargs)` in parallel. Once all workers are finished, all of their returned values are propagated to the initial launcher process. Our agents constantly communicate (over their own GLOO-backend distributed group), so any agent or worker failures are immediately propagated, and all launched processes are terminated. Worker exceptions and system failures are propagated to and raised by {func}`torchrunx.Launcher.run`.

0 commit comments

Comments
 (0)