From 327cba5202d77c14fc8faf73e3dff3dabd36883d Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Fri, 17 Dec 2021 12:36:53 +0000 Subject: [PATCH] Remove partitioning of model in ZeRO 3 (#10655) (cherry picked from commit c66cd12445481357cb4e29d69a85d021d5b876ea) --- .azure-pipelines/gpu-tests.yml | 2 +- CHANGELOG.md | 3 ++ dockers/base-cuda/Dockerfile | 2 +- .../plugins/training_type/deepspeed.py | 19 +------ tests/plugins/test_deepspeed_plugin.py | 52 +++++-------------- 5 files changed, 20 insertions(+), 58 deletions(-) diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 8752e8584439a..ca8c54a61479e 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -51,7 +51,7 @@ jobs: - bash: | python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" pip install fairscale==0.4.0 - pip install deepspeed==0.5.4 + pip install deepspeed==0.5.7 pip install . --requirement requirements/devel.txt pip list displayName: 'Install dependencies' diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cfd858a126cc..8bf2ee6664270 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue when torch-scripting a `LightningModule` after training with `Trainer(sync_batchnorm=True)` ([#11078](https://github.com/PyTorchLightning/pytorch-lightning/pull/11078)) - Fixed an `AttributeError` occuring when using a `CombinedLoader` (multiple dataloaders) for prediction ([#11111](https://github.com/PyTorchLightning/pytorch-lightning/pull/11111)) +### Changed + +- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655)) ## [1.5.6] - 2021-12-15 diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 99e8d018f2884..d70761cbdd37a 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -112,7 +112,7 @@ RUN \ RUN \ # install DeepSpeed - pip install deepspeed==0.5.4 + pip install deepspeed==0.5.7 RUN \ # Show what we have diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 4b08c8dc8b039..3359e7776d6e5 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -129,7 +129,6 @@ def __init__( contiguous_memory_optimization: bool = False, synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, - partition_module: bool = True, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -259,12 +258,6 @@ def __init__( load_full_weights: True when loading a single checkpoint file containing the model state dict when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards per worker. - - partition_module: When True, partitions the ``LightningModule`` across devices when using ZeRO Stage 3. - This is the default behaviour to ensure that the entire module is appropriately initialized - for DeepSpeed. When False we do not explicitly convert the model, which is fine if NO layers - or ALL layers are defined in ``configure_sharded_model``. This is useful for layers such as - ``torch.nn.RNN`` which do internal logic when moving to device. """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -317,7 +310,6 @@ def __init__( self.remote_device = remote_device self.load_full_weights = load_full_weights - self.partition_module = partition_module # default FP16 parameters. self.loss_scale = loss_scale @@ -463,13 +455,6 @@ def init_deepspeed(self): precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) - if self.zero_stage_3 and self.partition_module: - # Ensure the entire model has been moved to the appropriate device - dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 - deepspeed.zero.Init( - module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype - ) - if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: @@ -524,7 +509,7 @@ def model_sharded_context(self) -> Generator[None, None, None]: assert self._config_initialized dtype = torch.float16 if self.precision in (16, "mixed") else torch.float32 model_parallel_context = deepspeed.zero.Init( - remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype + remote_device=self.remote_device, pin_memory=True, config_dict_or_path=self.config, dtype=dtype ) else: model_parallel_context = super().model_sharded_context() @@ -554,7 +539,7 @@ def _initialize_deepspeed_inference(self, model): optimizer, lr_scheduler, _ = self._init_optimizers() scheduler = lr_scheduler["scheduler"] inference_config = { - # todo: this is required for DeepSpeed throughput timers, or throughput timers will be incorrect + # todo: this is required for DeepSpeed throughput timers "train_micro_batch_size_per_gpu": 1 } if "fp16" in self.config: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index bb65c61e057fd..d2205e59773d4 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -595,7 +595,9 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config _assert_save_model_is_equal(model, tmpdir, trainer) -def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumulate_grad_batches: int = 2): +@pytest.mark.parametrize(("accumulate_grad_batches", "automatic_optimization"), [(1, False), (2, True)]) +@RunIf(min_gpus=2, deepspeed=True, standalone=True) +def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir, automatic_optimization, accumulate_grad_batches): seed_everything(1) if automatic_optimization: model = ModelParallelClassificationModel() @@ -630,13 +632,6 @@ def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumu assert results[0]["test_acc"] > 0.7 -@RunIf(min_gpus=2, deepspeed=True, standalone=True) -def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): - """Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, and - see convergence.""" - run_checkpoint_test(tmpdir) - - @RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): """Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the @@ -718,24 +713,9 @@ def on_train_batch_start( trainer.fit(model, datamodule=dm, ckpt_path=ck.best_model_path) +@pytest.mark.parametrize("offload_optimizer", [False, True]) @RunIf(min_gpus=2, deepspeed=True, standalone=True) -def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir): - """Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, - where we save the full weights to one file.""" - run_checkpoint_test(tmpdir, automatic_optimization=False, accumulate_grad_batches=1) - - -@RunIf(min_gpus=2, deepspeed=True, standalone=True) -def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir): - _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer=False) - - -@RunIf(min_gpus=2, deepspeed=True, standalone=True) -def test_deepspeed_multigpu_stage_2_accumulated_grad_batches_offload_optimizer(tmpdir): - _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer=True) - - -def _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer): +def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer): """Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works.""" seed_everything(42) @@ -781,6 +761,8 @@ def test_deepspeed_multigpu_test(tmpdir): trainer.test(model) +# TODO(Sean): Once partial parameter partitioning is supported this test should be re-enabled +@pytest.mark.skip("Partial parameter partitioning for DeepSpeed is currently broken.") @RunIf(min_gpus=1, deepspeed=True, standalone=True) def test_deepspeed_multigpu_partial_partition_parameters(tmpdir): """Test to ensure that a module that defines a layer inside the ``__init__`` and ``configure_sharded_model`` @@ -824,7 +806,7 @@ def on_train_epoch_start(self) -> None: model = TestModel() trainer = Trainer( default_root_dir=tmpdir, - strategy=DeepSpeedPlugin(stage=3, partition_module=False), + strategy=DeepSpeedPlugin(stage=3), gpus=1, fast_dev_run=True, precision=16, @@ -941,22 +923,14 @@ def test_dataloader(self): @mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True) +@pytest.mark.parametrize("interval", ["step", "epoch"]) +@pytest.mark.parametrize("max_epoch", [2]) +@pytest.mark.parametrize("limit_train_batches", [2]) @RunIf(min_gpus=1, deepspeed=True, standalone=True) -def test_deepspeed_scheduler_step_count(mock_step): +def test_scheduler_step_count(mock_step, max_epoch, limit_train_batches, interval): """Test to ensure that the scheduler is called the correct amount of times during training when scheduler is - set to step.""" - _run_scheduler_test(mock_step, max_epoch=2, limit_train_batches=2, interval="step") - - -@mock.patch("torch.optim.lr_scheduler.StepLR.step", autospec=True) -@RunIf(min_gpus=1, deepspeed=True, standalone=True) -def test_deepspeed_scheduler_step_count_epoch(mock_step): - """Test to ensure that the scheduler is called the correct amount of times during training when scheduler is - set to epoch.""" - _run_scheduler_test(mock_step, max_epoch=2, limit_train_batches=2, interval="epoch") - + set to step or epoch.""" -def _run_scheduler_test(mock_step, max_epoch, limit_train_batches, interval): class TestModel(BoringModel): def configure_optimizers(self): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)