diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index ea4820f61ec7c..fe10088b19944 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -63,11 +63,6 @@ from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import STEP_OUTPUT -if _TORCH_GREATER_EQUAL_1_10: - if not _IS_WINDOWS: - from torch.distributed.optim import DistributedOptimizer - from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer - if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS if _HYDRA_AVAILABLE: @@ -75,9 +70,7 @@ from hydra.utils import get_original_cwd, to_absolute_path if _TORCH_GREATER_EQUAL_1_8: from pytorch_lightning.utilities.distributed import register_ddp_comm_hook -if _TORCH_GREATER_EQUAL_1_10: - import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD - import torch.distributed.algorithms.model_averaging.averagers as averagers + log = logging.getLogger(__name__) @@ -324,12 +317,11 @@ def _register_ddp_hooks(self) -> None: ddp_comm_wrapper=self._ddp_comm_wrapper, ) - if ( - _TORCH_GREATER_EQUAL_1_10 - and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState) - and self.lightning_module.trainer.state.fn == TrainerFn.FITTING - ): - self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) + if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD + + if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): + self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): optimizers = self.lightning_module.trainer.optimizers @@ -337,6 +329,12 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): raise ValueError( "Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin." ) + if _TORCH_GREATER_EQUAL_1_10: + if not _IS_WINDOWS: + from torch.distributed.optim import DistributedOptimizer + import torch.distributed.algorithms.model_averaging.averagers as averagers + from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer + averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps) for x, optimizer in enumerate(optimizers): if isinstance(optimizer, LightningOptimizer):