Skip to content

Commit 0fd04ed

Browse files
authored
Update README.md
1 parent 4444634 commit 0fd04ed

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,11 @@ import os
4242
import torch
4343
import torch.nn as nn
4444

45-
def distributed_training(model: nn.Module, num_steps: int = 10) -> nn.Module | None:
45+
def distributed_training(num_steps: int = 10) -> nn.Module | None:
4646
rank = int(os.environ['RANK'])
4747
local_rank = int(os.environ['LOCAL_RANK'])
4848

49+
model = nn.Linear(10, 10)
4950
model.to(local_rank)
5051
ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
5152
optimizer = torch.optim.AdamW(ddp_model.parameters())
@@ -81,7 +82,6 @@ launcher = torchrunx.Launcher(
8182

8283
results = launcher.run(
8384
distributed_training,
84-
model = nn.Linear(10, 10),
8585
num_steps = 10
8686
)
8787
```

0 commit comments

Comments
 (0)