|
19 | 19 | from hypothesis import given, settings, strategies as st, Verbosity |
20 | 20 | from torch import nn, optim |
21 | 21 | from torch._dynamo.testing import reduce_to_scalar_loss |
| 22 | +from torch._dynamo.utils import counters |
22 | 23 | from torchrec.distributed import DistributedModelParallel |
23 | 24 | from torchrec.distributed.embedding_types import EmbeddingComputeKernel |
24 | 25 | from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder |
|
53 | 54 | TrainPipelinePT2, |
54 | 55 | TrainPipelineSemiSync, |
55 | 56 | TrainPipelineSparseDist, |
| 57 | + TrainPipelineSparseDistCompAutograd, |
56 | 58 | ) |
57 | 59 | from torchrec.distributed.train_pipeline.utils import ( |
58 | 60 | DataLoadingThread, |
@@ -393,7 +395,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: |
393 | 395 | sharded_sparse_arch_pipeline.parameters(), lr=0.1 |
394 | 396 | ) |
395 | 397 |
|
396 | | - pipeline = TrainPipelineSparseDist( |
| 398 | + pipeline = self.pipeline_class( |
397 | 399 | sharded_sparse_arch_pipeline, |
398 | 400 | optimizer_pipeline, |
399 | 401 | self.device, |
@@ -441,7 +443,7 @@ def _setup_pipeline( |
441 | 443 | dict(in_backward_optimizer_filter(distributed_model.named_parameters())), |
442 | 444 | lambda params: optim.SGD(params, lr=0.1), |
443 | 445 | ) |
444 | | - return TrainPipelineSparseDist( |
| 446 | + return self.pipeline_class( |
445 | 447 | model=distributed_model, |
446 | 448 | optimizer=optimizer_distributed, |
447 | 449 | device=self.device, |
@@ -508,7 +510,7 @@ def test_equal_to_non_pipelined( |
508 | 510 | sharded_model.state_dict(), sharded_model_pipelined.state_dict() |
509 | 511 | ) |
510 | 512 |
|
511 | | - pipeline = TrainPipelineSparseDist( |
| 513 | + pipeline = self.pipeline_class( |
512 | 514 | model=sharded_model_pipelined, |
513 | 515 | optimizer=optim_pipelined, |
514 | 516 | device=self.device, |
@@ -621,7 +623,7 @@ def test_model_detach_during_train(self) -> None: |
621 | 623 | sharded_model.state_dict(), sharded_model_pipelined.state_dict() |
622 | 624 | ) |
623 | 625 |
|
624 | | - pipeline = TrainPipelineSparseDist( |
| 626 | + pipeline = self.pipeline_class( |
625 | 627 | model=sharded_model_pipelined, |
626 | 628 | optimizer=optim_pipelined, |
627 | 629 | device=self.device, |
@@ -719,7 +721,7 @@ def test_model_detach_after_train(self) -> None: |
719 | 721 | sharded_model.state_dict(), sharded_model_pipelined.state_dict() |
720 | 722 | ) |
721 | 723 |
|
722 | | - pipeline = TrainPipelineSparseDist( |
| 724 | + pipeline = self.pipeline_class( |
723 | 725 | model=sharded_model_pipelined, |
724 | 726 | optimizer=optim_pipelined, |
725 | 727 | device=self.device, |
@@ -862,7 +864,7 @@ def _check_output_equal( |
862 | 864 | sharded_model.state_dict(), sharded_model_pipelined.state_dict() |
863 | 865 | ) |
864 | 866 |
|
865 | | - pipeline = TrainPipelineSparseDist( |
| 867 | + pipeline = self.pipeline_class( |
866 | 868 | model=sharded_model_pipelined, |
867 | 869 | optimizer=optim_pipelined, |
868 | 870 | device=self.device, |
@@ -1116,7 +1118,7 @@ def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None: |
1116 | 1118 | model, self.sharding_type, self.kernel_type, self.fused_params |
1117 | 1119 | ) |
1118 | 1120 |
|
1119 | | - pipeline = TrainPipelineSparseDist( |
| 1121 | + pipeline = self.pipeline_class( |
1120 | 1122 | model=sharded_model_pipelined, |
1121 | 1123 | optimizer=optim_pipelined, |
1122 | 1124 | device=self.device, |
@@ -1171,7 +1173,7 @@ def test_pipeline_invalid_preproc_trainable_params_recursive( |
1171 | 1173 | model, self.sharding_type, self.kernel_type, self.fused_params |
1172 | 1174 | ) |
1173 | 1175 |
|
1174 | | - pipeline = TrainPipelineSparseDist( |
| 1176 | + pipeline = self.pipeline_class( |
1175 | 1177 | model=sharded_model_pipelined, |
1176 | 1178 | optimizer=optim_pipelined, |
1177 | 1179 | device=self.device, |
@@ -1217,7 +1219,7 @@ def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None: |
1217 | 1219 | model, self.sharding_type, self.kernel_type, self.fused_params |
1218 | 1220 | ) |
1219 | 1221 |
|
1220 | | - pipeline = TrainPipelineSparseDist( |
| 1222 | + pipeline = self.pipeline_class( |
1221 | 1223 | model=sharded_model_pipelined, |
1222 | 1224 | optimizer=optim_pipelined, |
1223 | 1225 | device=self.device, |
@@ -1280,7 +1282,7 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None: |
1280 | 1282 | model, self.sharding_type, self.kernel_type, self.fused_params |
1281 | 1283 | ) |
1282 | 1284 |
|
1283 | | - pipeline = TrainPipelineSparseDist( |
| 1285 | + pipeline = self.pipeline_class( |
1284 | 1286 | model=sharded_model_pipelined, |
1285 | 1287 | optimizer=optim_pipelined, |
1286 | 1288 | device=self.device, |
@@ -2100,3 +2102,24 @@ def gpu_preproc(x: StageOut) -> StageOut: |
2100 | 2102 | self.assertEqual(len(pipelined_out), len(non_pipelined_outputs)) |
2101 | 2103 | for out, ref_out in zip(pipelined_out, non_pipelined_outputs): |
2102 | 2104 | torch.testing.assert_close(out, ref_out) |
| 2105 | + |
| 2106 | + |
| 2107 | +class TrainPipelineSparseDistCompAutogradTest(TrainPipelineSparseDistTest): |
| 2108 | + def setUp(self) -> None: |
| 2109 | + super().setUp() |
| 2110 | + self.pipeline_class = TrainPipelineSparseDistCompAutograd |
| 2111 | + torch._dynamo.reset() |
| 2112 | + counters["compiled_autograd"].clear() |
| 2113 | + # Compiled Autograd don't work with Anomaly Mode |
| 2114 | + torch.autograd.set_detect_anomaly(False) |
| 2115 | + |
| 2116 | + def tearDown(self) -> None: |
| 2117 | + # Every single test has two captures, one for forward and one for backward |
| 2118 | + self.assertEqual(counters["compiled_autograd"]["captures"], 2) |
| 2119 | + return super().tearDown() |
| 2120 | + |
| 2121 | + @unittest.skip("Dynamo only supports FSDP with use_orig_params=True") |
| 2122 | + # pyre-ignore[56] |
| 2123 | + @given(execute_all_batches=st.booleans()) |
| 2124 | + def test_pipelining_fsdp_pre_trace(self, execute_all_batches: bool) -> None: |
| 2125 | + super().test_pipelining_fsdp_pre_trace() |
0 commit comments