From db0a0ecb601510a0f6edf661a3a0e859a521bda2 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 28 May 2024 10:47:00 -0700 Subject: [PATCH] [FSDP2] Added test for N-way TP and 1-way FSDP with CPU offloading (#127024) This PR shows that we can use FSDP solely for CPU offloading when composing with N-way TP. Each FSDP mesh is just 1 rank. This was motivated from an ask on Slack :) Pull Request resolved: https://github.com/pytorch/pytorch/pull/127024 Approved by: https://github.com/weifengpy, https://github.com/wanchaol --- .../fsdp/test_fully_shard_training.py | 67 ++++++++++++++++++- torch/testing/_internal/common_fsdp.py | 8 +-- 2 files changed, 68 insertions(+), 7 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 392596549d7716..a7b97f8f7dd3fd 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -56,6 +56,7 @@ from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir c10d_ops = torch.ops.c10d +funcol = torch.ops.c10d_functional class TestFullyShardForwardInputs(FSDPTestMultiThread): @@ -927,7 +928,10 @@ def _test_train_parity_2d_mlp( replicate(ref_model, device_ids=[self.rank], process_group=dp_pg) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) model.parallelize( - tp_mesh, dp_mesh, use_activation_checkpointing, reshard_after_forward + tp_mesh, + dp_mesh, + use_activation_checkpointing, + reshard_after_forward=reshard_after_forward, ) optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False) @@ -943,6 +947,62 @@ def _test_train_parity_2d_mlp( _optim.step() self.assertEqual(losses[0], losses[1]) + @skip_if_lt_x_gpu(2) + @skipIfRocm + def test_tp_with_fsdp_offloading(self): + global_mesh = init_device_mesh( + "cuda", (1, self.world_size), mesh_dim_names=("dp", "tp") + ) + dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"] + torch.manual_seed(42) + mlp_dim = 16 + model = MLPStack(mlp_dim) + ref_model = copy.deepcopy(model).cuda() + ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) + # Parallelize with N-way TP and 1-way FSDP + model.parallelize( + tp_mesh, + dp_mesh, + use_activation_checkpointing=False, + reshard_after_forward=True, + offload_policy=CPUOffloadPolicy(), + ) + for param in model.parameters(): + self.assertEqual(param.device.type, "cpu") + num_mlps = sum(isinstance(module, MLP) for module in model.modules()) + optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False) + + # NOTE: We still see the FSDP all-gather/reduce-scatter c10d ops + # called, but they will just be no-ops without issuing any kernels. + # We prefer to keep the no-op check at the c10d level, not in FSDP. + inp = torch.randn((4, mlp_dim), device="cuda") # same on all ranks + for iter_idx in range(10): + ref_optim.zero_grad() + optim.zero_grad() + + with CommDebugMode() as fwd_comm_mode: + loss = model(inp).sum() + + fwd_comm_counts = fwd_comm_mode.get_comm_counts() + self.assertEqual(len(fwd_comm_counts), 2) + self.assertEqual(fwd_comm_counts[funcol.all_reduce], num_mlps) + self.assertEqual(fwd_comm_counts[c10d_ops._allgather_base_], num_mlps) + ref_loss = ref_model(inp).sum() + self.assertEqual(loss, ref_loss) + + with CommDebugMode() as bwd_comm_mode: + loss.backward() + bwd_comm_counts = bwd_comm_mode.get_comm_counts() + self.assertEqual(len(bwd_comm_counts), 3) + # First MLP's input gradient does not need to be all-reduced + self.assertEqual(bwd_comm_counts[funcol.all_reduce], num_mlps - 1) + self.assertEqual(bwd_comm_counts[c10d_ops._allgather_base_], num_mlps) + self.assertEqual(bwd_comm_counts[c10d_ops._reduce_scatter_base_], num_mlps) + ref_loss.backward() + + optim.step() + ref_optim.step() + @skip_if_lt_x_gpu(2) @with_temp_dir def test_train_parity_2d_transformer_checkpoint_resume(self): @@ -1103,7 +1163,10 @@ def _test_2d_mlp_with_nd_mesh( replicate(ref_model, device_ids=[self.rank], process_group=dp_pg) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach) model.parallelize( - tp_mesh, dp_mesh, use_activation_checkpointing, reshard_after_forward + tp_mesh, + dp_mesh, + use_activation_checkpointing, + reshard_after_forward=reshard_after_forward, ) optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 94b6a68f931c68..4e266117c13b72 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -893,7 +893,7 @@ def parallelize( tp_mesh: DeviceMesh, dp_mesh: DeviceMesh, use_activation_checkpointing: bool, - reshard_after_forward: bool, + **fsdp_kwargs, ) -> "MLPStack": parallelize_plan = { # Pass `use_local_output=False` to keep as DTensor to preserve @@ -915,10 +915,8 @@ def parallelize( continue if use_activation_checkpointing: checkpoint(module) - fully_shard( - module, mesh=dp_mesh, reshard_after_forward=reshard_after_forward - ) - fully_shard(self, mesh=dp_mesh, reshard_after_forward=reshard_after_forward) + fully_shard(module, mesh=dp_mesh, **fsdp_kwargs) + fully_shard(self, mesh=dp_mesh, **fsdp_kwargs) return self