We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4444634 commit 0fd04edCopy full SHA for 0fd04ed
README.md
@@ -42,10 +42,11 @@ import os
42
import torch
43
import torch.nn as nn
44
45
-def distributed_training(model: nn.Module, num_steps: int = 10) -> nn.Module | None:
+def distributed_training(num_steps: int = 10) -> nn.Module | None:
46
rank = int(os.environ['RANK'])
47
local_rank = int(os.environ['LOCAL_RANK'])
48
49
+ model = nn.Linear(10, 10)
50
model.to(local_rank)
51
ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
52
optimizer = torch.optim.AdamW(ddp_model.parameters())
@@ -81,7 +82,6 @@ launcher = torchrunx.Launcher(
81
82
83
results = launcher.run(
84
distributed_training,
- model = nn.Linear(10, 10),
85
num_steps = 10
86
)
87
```
0 commit comments