Skip to content

Commit f913221

Browse files
authored
Update README.md
1 parent c17be1f commit f913221

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

README.md

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,32 @@ Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0
2020

2121
Shared filesystem & SSH access if using multiple machines
2222

23+
## Minimal example
24+
25+
Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each):
26+
27+
```python
28+
def train_model(model, dataset):
29+
trained_model = train(model, dataset)
30+
31+
if int(os.environ["RANK"]) == 0:
32+
torch.save(learned_model, 'model.pt')
33+
return 'model.pt'
34+
35+
return None
36+
```
37+
38+
```python
39+
import torchrunx as trx
40+
41+
model_path = trx.launch(
42+
func=train_model,
43+
func_kwargs={'model': my_model, 'training_dataset': mnist_train},
44+
hostnames=["localhost", "other_node"],
45+
workers_per_host=2
46+
)["localhost"][0] # return from rank 0 (first worker on "localhost")
47+
```
48+
2349
## Why should I use this?
2450

2551
[`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) is a hammer. `torchrunx` is a chisel.
@@ -48,31 +74,7 @@ Why not?
4874

4975
- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR.
5076

51-
## Usage
52-
53-
Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each):
54-
55-
```python
56-
def train_model(model, dataset):
57-
trained_model = train(model, dataset)
58-
59-
if int(os.environ["RANK"]) == 0:
60-
torch.save(learned_model, 'model.pt')
61-
return 'model.pt'
62-
63-
return None
64-
```
65-
66-
```python
67-
import torchrunx as trx
68-
69-
model_path = trx.launch(
70-
func=train_model,
71-
func_kwargs={'model': my_model, 'training_dataset': mnist_train},
72-
hostnames=["localhost", "other_node"],
73-
workers_per_host=2
74-
)["localhost"][0] # return from rank 0 (first worker on "localhost")
75-
```
77+
## More complicated example
7678

7779
We could also launch multiple functions, with different GPUs:
7880

0 commit comments

Comments
 (0)