Skip to content

Commit

Permalink
Fix 2D example to pass in data parallel pg (pytorch#1160)
Browse files Browse the repository at this point in the history
  • Loading branch information
fduwjj authored Jun 5, 2023
1 parent 55c663f commit 8c16e96
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions distributed/tensor_parallelism/two_d_parallel_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def demo_2d(rank, args):
assert (
enable_2d_with_fsdp()
), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0"
model = FSDP(model)
dp_pg = device_mesh.get_dim_groups()[0]
model = FSDP(model, process_group=dp_pg)

# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
Expand All @@ -94,7 +95,7 @@ def demo_2d(rank, args):
dp_rank = (
rank
if args.run_seq_parallel
else dist.get_rank(device_mesh.get_dim_groups()[0])
else dist.get_rank(dp_pg)
)
torch.manual_seed(i + dp_rank)
inp = torch.rand(20, 10).cuda(rank)
Expand Down

0 comments on commit 8c16e96

Please sign in to comment.