@@ -20,6 +20,32 @@ Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0
20
20
21
21
Shared filesystem & SSH access if using multiple machines
22
22
23
+ ## Minimal example
24
+
25
+ Here's a simple example where we distribute ` distributed_function ` to two hosts (with 2 GPUs each):
26
+
27
+ ``` python
28
+ def train_model (model , dataset ):
29
+ trained_model = train(model, dataset)
30
+
31
+ if int (os.environ[" RANK" ]) == 0 :
32
+ torch.save(learned_model, ' model.pt' )
33
+ return ' model.pt'
34
+
35
+ return None
36
+ ```
37
+
38
+ ``` python
39
+ import torchrunx as trx
40
+
41
+ model_path = trx.launch(
42
+ func = train_model,
43
+ func_kwargs = {' model' : my_model, ' training_dataset' : mnist_train},
44
+ hostnames = [" localhost" , " other_node" ],
45
+ workers_per_host = 2
46
+ )[" localhost" ][0 ] # return from rank 0 (first worker on "localhost")
47
+ ```
48
+
23
49
## Why should I use this?
24
50
25
51
[ ` torchrun ` ] ( https://pytorch.org/docs/stable/elastic/run.html ) is a hammer. ` torchrunx ` is a chisel.
@@ -48,31 +74,7 @@ Why not?
48
74
49
75
- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR.
50
76
51
- ## Usage
52
-
53
- Here's a simple example where we distribute ` distributed_function ` to two hosts (with 2 GPUs each):
54
-
55
- ``` python
56
- def train_model (model , dataset ):
57
- trained_model = train(model, dataset)
58
-
59
- if int (os.environ[" RANK" ]) == 0 :
60
- torch.save(learned_model, ' model.pt' )
61
- return ' model.pt'
62
-
63
- return None
64
- ```
65
-
66
- ``` python
67
- import torchrunx as trx
68
-
69
- model_path = trx.launch(
70
- func = train_model,
71
- func_kwargs = {' model' : my_model, ' training_dataset' : mnist_train},
72
- hostnames = [" localhost" , " other_node" ],
73
- workers_per_host = 2
74
- )[" localhost" ][0 ] # return from rank 0 (first worker on "localhost")
75
- ```
77
+ ## More complicated example
76
78
77
79
We could also launch multiple functions, with different GPUs:
78
80
0 commit comments