Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Apr 27, 2023
1 parent a78d73d commit b3a43b6
Show file tree
Hide file tree
Showing 24 changed files with 55 additions and 32 deletions.
6 changes: 4 additions & 2 deletions src/lightning/fabric/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any

import torch

from lightning.fabric.accelerators.registry import _AcceleratorRegistry


class Accelerator(ABC):
"""The Accelerator base class.
Expand Down Expand Up @@ -54,5 +56,5 @@ def is_available() -> bool:
"""Detect if the hardware is available."""

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
pass
5 changes: 3 additions & 2 deletions src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Union
from typing import List, Union

import torch

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry


class CPUAccelerator(Accelerator):
Expand Down Expand Up @@ -56,7 +57,7 @@ def is_available() -> bool:
return True

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cpu",
cls,
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import warnings
from contextlib import contextmanager
from functools import lru_cache
from typing import cast, Dict, Generator, List, Optional, Union
from typing import cast, Generator, List, Optional, Union

import torch
from lightning_utilities.core.rank_zero import rank_zero_info

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0


Expand Down Expand Up @@ -63,7 +64,7 @@ def is_available() -> bool:
return num_cuda_devices() > 0

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cuda",
cls,
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/fabric/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
import platform
from functools import lru_cache
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union

import torch

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12


Expand Down Expand Up @@ -69,7 +70,7 @@ def is_available() -> bool:
)

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"mps",
cls,
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/accelerators/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from inspect import getmembers, isclass
from typing import Any, Callable, Dict, List, Optional

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.fabric.utilities.registry import _is_register_method_overridden

Expand Down Expand Up @@ -114,6 +113,8 @@ def __str__(self) -> str:

def call_register_accelerators(registry: _AcceleratorRegistry, base_module: str) -> None:
module = importlib.import_module(base_module)
from lightning.fabric.accelerators.accelerator import Accelerator

for _, mod in getmembers(module, isclass):
if issubclass(mod, Accelerator) and _is_register_method_overridden(mod, Accelerator, "register_accelerators"):
mod.register_accelerators(registry)
3 changes: 2 additions & 1 deletion src/lightning/fabric/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import _BackwardSyncControl, TBroadcast
from lightning.fabric.utilities.distributed import (
_get_default_process_group_backend_for_device,
Expand Down Expand Up @@ -163,7 +164,7 @@ def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]
return super().get_module_state_dict(module)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
entries = (
("ddp", "popen"),
("ddp_spawn", "spawn"),
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.ddp import DDPStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.distributed import log
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
Expand Down Expand Up @@ -512,7 +513,7 @@ def clip_gradients_value(
)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy")
strategy_registry.register("deepspeed_stage_1", cls, description="DeepSpeed with ZeRO Stage 1 enabled", stage=1)
strategy_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2)
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/strategies/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import TBroadcast, TReduce
from lightning.fabric.utilities.apply_func import apply_to_collection
from lightning.fabric.utilities.distributed import ReduceOp
Expand Down Expand Up @@ -89,5 +90,5 @@ def get_module_state_dict(self, module: Module) -> Dict[str, Union[Any, Tensor]]
return super().get_module_state_dict(module)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register("dp", cls, description=cls.__class__.__name__)
3 changes: 2 additions & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lightning.fabric.plugins.precision.fsdp import FSDPPrecision
from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from lightning.fabric.strategies.parallel import ParallelStrategy
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.strategies.strategy import _BackwardSyncControl, _Sharded, TBroadcast
from lightning.fabric.utilities.distributed import (
_get_default_process_group_backend_for_device,
Expand Down Expand Up @@ -455,7 +456,7 @@ def load_checkpoint(
return metadata

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available():
return

Expand Down
5 changes: 3 additions & 2 deletions src/lightning/fabric/strategies/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from inspect import getmembers, isclass
from typing import Any, Callable, Dict, List, Optional

from lightning.fabric.strategies.strategy import Strategy
from lightning.fabric.utilities.registry import _is_register_method_overridden


Expand Down Expand Up @@ -82,7 +81,7 @@ def do_register(strategy: Callable) -> Callable:

return do_register

def get(self, name: str, default: Optional[Strategy] = None) -> Strategy: # type: ignore[override]
def get(self, name: str, default: Optional[Any] = None) -> Any: # type: ignore[override]
"""Calls the registered strategy with the required parameters and returns the strategy object.
Args:
Expand Down Expand Up @@ -113,6 +112,8 @@ def __str__(self) -> str:

def _call_register_strategies(registry: _StrategyRegistry, base_module: str) -> None:
module = importlib.import_module(base_module)
from lightning.fabric.strategies.strategy import Strategy

for _, mod in getmembers(module, isclass):
if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"):
mod.register_strategies(registry)
3 changes: 2 additions & 1 deletion src/lightning/fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from lightning.fabric.plugins.io.torch_io import TorchCheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.launchers.launcher import _Launcher
from lightning.fabric.strategies.registry import _StrategyRegistry
from lightning.fabric.utilities.apply_func import move_data_to_device
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import _PATH, _Stateful, Optimizable, ReduceOp
Expand Down Expand Up @@ -339,7 +340,7 @@ def clip_gradients_value(self, module: torch.nn.Module, optimizer: Optimizer, cl
return torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)

@classmethod
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
pass

def _err_msg_joint_setup_required(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies import ParallelStrategy
from lightning.fabric.strategies import _StrategyRegistry, ParallelStrategy
from lightning.fabric.strategies.launchers.xla import _XLALauncher
from lightning.fabric.strategies.strategy import TBroadcast
from lightning.fabric.utilities.rank_zero import rank_zero_only
Expand Down Expand Up @@ -234,5 +234,5 @@ def remove_checkpoint(self, filepath: _PATH) -> None:
self.checkpoint_io.remove_checkpoint(filepath)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register("xla", cls, description=cls.__class__.__name__)
2 changes: 1 addition & 1 deletion src/lightning/pytorch/_graveyard/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,4 @@ def _patch_classes() -> None:
_patch_sys_modules()
_patch_classes()

SingleTPUStrategy.register_strategies(pl.strategies.StrategyRegistry)
SingleTPUStrategy.register_strategies(pl.strategies.StrategyRegistry) # type: ignore[has-type]
5 changes: 3 additions & 2 deletions src/lightning/pytorch/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from lightning_utilities.core.imports import RequirementCache

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.cpu import _parse_cpu_cores
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -64,11 +65,11 @@ def is_available() -> bool:
return True

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cpu",
cls,
description=f"{cls.__class__.__name__}",
description=cls.__class__.__name__,
)


Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch

import lightning.pytorch as pl
from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.cuda import _check_cuda_matmul_precision, _clear_cuda_memory, num_cuda_devices
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.types import _DEVICE
Expand Down Expand Up @@ -94,7 +95,7 @@ def is_available() -> bool:
return num_cuda_devices() > 0

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cuda",
cls,
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/accelerators/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch
from lightning_utilities.core.imports import package_available

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.accelerators.accelerator import Accelerator

Expand Down Expand Up @@ -68,7 +69,7 @@ def is_available() -> bool:
return _IPU_AVAILABLE

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"ipu",
cls,
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.mps import MPSAccelerator as _MPSAccelerator
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.types import _DEVICE
Expand Down Expand Up @@ -70,7 +71,7 @@ def is_available() -> bool:
return _MPSAccelerator.is_available()

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"mps",
cls,
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import lightning.pytorch as pl
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.utilities.distributed import (
_get_default_process_group_backend_for_device,
_init_dist_connection,
Expand Down Expand Up @@ -362,7 +363,7 @@ def post_training_step(self) -> None:
self.model.require_backward_grad_sync = True # type: ignore[assignment]

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
entries = (
("ddp", "popen"),
("ddp_spawn", "spawn"),
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import lightning.pytorch as pl
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
Expand Down Expand Up @@ -862,7 +863,7 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
pass

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy")
strategy_registry.register("deepspeed_stage_1", cls, description="DeepSpeed with ZeRO Stage 1 enabled", stage=1)
strategy_registry.register("deepspeed_stage_2", cls, description="DeepSpeed with ZeRO Stage 2 enabled", stage=2)
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import lightning.pytorch as pl
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.fsdp import (
_init_cpu_offload,
_optimizer_has_flat_params,
Expand Down Expand Up @@ -395,7 +396,7 @@ def get_registered_strategies(cls) -> List[str]:
return cls._registered_strategies

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
if not _fsdp_available:
return
strategy_registry.register(
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import lightning.pytorch as pl
from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.pytorch.accelerators.ipu import _IPU_AVAILABLE, _POPTORCH_AVAILABLE
from lightning.pytorch.overrides.base import _LightningModuleWrapperBase
Expand Down Expand Up @@ -371,7 +372,7 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return obj

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/pytorch/strategies/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import lightning.pytorch as pl
from lightning.fabric.plugins import CheckpointIO
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.utilities.types import _DEVICE
from lightning.pytorch.plugins.precision import PrecisionPlugin
from lightning.pytorch.strategies.strategy import Strategy, TBroadcast
Expand Down Expand Up @@ -86,9 +87,9 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return obj

@classmethod
def register_strategies(cls, strategy_registry: dict) -> None:
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register(
cls.strategy_name,
cls,
description=f"{cls.__class__.__name__}",
description=cls.__class__.__name__,
)
Loading

0 comments on commit b3a43b6

Please sign in to comment.