Skip to content

Commit

Permalink
[FSDP2] Added test for N-way TP and 1-way FSDP with CPU offloading (p…
Browse files Browse the repository at this point in the history
…ytorch#127024)

This PR shows that we can use FSDP solely for CPU offloading when composing with N-way TP. Each FSDP mesh is just 1 rank.

This was motivated from an ask on Slack :)

Pull Request resolved: pytorch#127024
Approved by: https://github.com/weifengpy, https://github.com/wanchaol
  • Loading branch information
awgu authored and pytorchmergebot committed May 28, 2024
1 parent 6b24155 commit db0a0ec
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 7 deletions.
67 changes: 65 additions & 2 deletions test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir

c10d_ops = torch.ops.c10d
funcol = torch.ops.c10d_functional


class TestFullyShardForwardInputs(FSDPTestMultiThread):
Expand Down Expand Up @@ -927,7 +928,10 @@ def _test_train_parity_2d_mlp(
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
model.parallelize(
tp_mesh, dp_mesh, use_activation_checkpointing, reshard_after_forward
tp_mesh,
dp_mesh,
use_activation_checkpointing,
reshard_after_forward=reshard_after_forward,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)

Expand All @@ -943,6 +947,62 @@ def _test_train_parity_2d_mlp(
_optim.step()
self.assertEqual(losses[0], losses[1])

@skip_if_lt_x_gpu(2)
@skipIfRocm
def test_tp_with_fsdp_offloading(self):
global_mesh = init_device_mesh(
"cuda", (1, self.world_size), mesh_dim_names=("dp", "tp")
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
torch.manual_seed(42)
mlp_dim = 16
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
# Parallelize with N-way TP and 1-way FSDP
model.parallelize(
tp_mesh,
dp_mesh,
use_activation_checkpointing=False,
reshard_after_forward=True,
offload_policy=CPUOffloadPolicy(),
)
for param in model.parameters():
self.assertEqual(param.device.type, "cpu")
num_mlps = sum(isinstance(module, MLP) for module in model.modules())
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)

# NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops
# called, but they will just be no-ops without issuing any kernels.
# We prefer to keep the no-op check at the c10d level, not in FSDP.
inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks
for iter_idx in range(10):
ref_optim.zero_grad()
optim.zero_grad()

with CommDebugMode() as fwd_comm_mode:
loss = model(inp).sum()

fwd_comm_counts = fwd_comm_mode.get_comm_counts()
self.assertEqual(len(fwd_comm_counts), 2)
self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps)
self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
ref_loss = ref_model(inp).sum()
self.assertEqual(loss, ref_loss)

with CommDebugMode() as bwd_comm_mode:
loss.backward()
bwd_comm_counts = bwd_comm_mode.get_comm_counts()
self.assertEqual(len(bwd_comm_counts), 3)
# First MLP's input gradient does not need to be all-reduced
self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1)
self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps)
self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps)
ref_loss.backward()

optim.step()
ref_optim.step()

@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_train_parity_2d_transformer_checkpoint_resume(self):
Expand Down Expand Up @@ -1103,7 +1163,10 @@ def _test_2d_mlp_with_nd_mesh(
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
model.parallelize(
tp_mesh, dp_mesh, use_activation_checkpointing, reshard_after_forward
tp_mesh,
dp_mesh,
use_activation_checkpointing,
reshard_after_forward=reshard_after_forward,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)

Expand Down
8 changes: 3 additions & 5 deletions torch/testing/_internal/common_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def parallelize(
tp_mesh: DeviceMesh,
dp_mesh: DeviceMesh,
use_activation_checkpointing: bool,
reshard_after_forward: bool,
**fsdp_kwargs,
) -> "MLPStack":
parallelize_plan = {
# Pass `use_local_output=False` to keep as DTensor to preserve
Expand All @@ -915,10 +915,8 @@ def parallelize(
continue
if use_activation_checkpointing:
checkpoint(module)
fully_shard(
module, mesh=dp_mesh, reshard_after_forward=reshard_after_forward
)
fully_shard(self, mesh=dp_mesh, reshard_after_forward=reshard_after_forward)
fully_shard(module, mesh=dp_mesh, **fsdp_kwargs)
fully_shard(self, mesh=dp_mesh, **fsdp_kwargs)
return self


Expand Down

0 comments on commit db0a0ec

Please sign in to comment.