Skip to content

Commit

Permalink
Only import PostLocalSGD related modules when it's needed (#10359)
Browse files Browse the repository at this point in the history
* Only import PostLocalSGD related modules when it's needed

* Only import PostLocalSGD related modules when it's needed

* Only import PostLocalSGD related modules when it's needed
  • Loading branch information
four4fish authored and lexierule committed Nov 9, 2021
1 parent c7be274 commit 297d549
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,14 @@
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_10:
if not _IS_WINDOWS:
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
if _TORCH_GREATER_EQUAL_1_10:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
import torch.distributed.algorithms.model_averaging.averagers as averagers


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -324,19 +317,24 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

if (
_TORCH_GREATER_EQUAL_1_10
and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState)
and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
):
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)

def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
optimizers = self.lightning_module.trainer.optimizers
if self._model_averaging_period is None:
raise ValueError(
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin."
)
if _TORCH_GREATER_EQUAL_1_10:
if not _IS_WINDOWS:
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.algorithms.model_averaging.averagers as averagers
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer

averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):
Expand Down

0 comments on commit 297d549

Please sign in to comment.