diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 54a38205b80a4..750001f527954 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a callback for spike-detection ([#18014](https://github.com/Lightning-AI/lightning/pull/18014)) +- Added the ability to set the `torch.distributed.fsdp.ShardingStrategy` via string in `FSDPStrategy` ([#18087](https://github.com/Lightning-AI/lightning/pull/18087)) + + ### Changed - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 7f54eaa078c14..85291b95d0b9b 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -71,7 +71,12 @@ from lightning.fabric.utilities.types import _PATH if TYPE_CHECKING: - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + CPUOffload, + FullyShardedDataParallel, + MixedPrecision, + ShardingStrategy, + ) from lightning.fabric.wrappers import _FabricModule @@ -82,6 +87,8 @@ else: _POLICY = Union[Set, Callable[[Module, bool, int], bool]] # type: ignore[misc] + _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] + _FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload") _METADATA_FILENAME = "meta.pt" @@ -115,6 +122,17 @@ class FSDPStrategy(ParallelStrategy, _Sharded): want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. For convenience, this also accepts a set of the layer classes to wrap. + sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination of + them. Available values are: + + - ``"FULL_SHARD"``: Shards model parameters, gradients, and optimizer states (default). + - ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated. + - ``"NO_SHARD"``: No sharding (identical to regular DDP). + - ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but + replicates across machines. + + Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value. + state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint. - ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file. @@ -138,6 +156,7 @@ def __init__( auto_wrap_policy: Optional["_POLICY"] = None, activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, activation_checkpointing_policy: Optional["_POLICY"] = None, + sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", state_dict_type: Literal["full", "sharded"] = "sharded", **kwargs: Any, ) -> None: @@ -165,6 +184,7 @@ def __init__( activation_checkpointing, activation_checkpointing_policy ) self._state_dict_type = state_dict_type + self.sharding_strategy = _init_sharding_strategy(sharding_strategy) self.cpu_offload = _init_cpu_offload(cpu_offload) self.mixed_precision = mixed_precision @@ -249,6 +269,7 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel": module=module, cpu_offload=self.cpu_offload, mixed_precision=self.mixed_precision_config, + sharding_strategy=self.sharding_strategy, device_id=self.root_device.index, **self._fsdp_kwargs, ) @@ -306,6 +327,7 @@ def module_sharded_context(self) -> Generator: wrapper_cls=FullyShardedDataParallel, cpu_offload=self.cpu_offload, mixed_precision=self.mixed_precision_config, + sharding_strategy=self.sharding_strategy, device_id=self.root_device.index, **self._fsdp_kwargs, ): @@ -707,6 +729,12 @@ def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUO return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload)) +def _init_sharding_strategy(sharding_strategy: "_SHARDING_STRATEGY") -> "ShardingStrategy": + from torch.distributed.fsdp import ShardingStrategy + + return ShardingStrategy[sharding_strategy.upper()] if isinstance(sharding_strategy, str) else sharding_strategy + + def _optimizer_has_flat_params(optimizer: Optimizer) -> bool: _FSDP_FLATTENED = "_fsdp_flattened" if _TORCH_GREATER_EQUAL_1_13: diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 0f0b220a981c7..86a5dc71989de 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a callback for spike-detection ([#18014](https://github.com/Lightning-AI/lightning/pull/18014)) +- Added the ability to set the `torch.distributed.fsdp.ShardingStrategy` via string in `FSDPStrategy` ([#18087](https://github.com/Lightning-AI/lightning/pull/18087)) + + ### Changed - Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index fd89907928cd2..cb2df2908fad8 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -14,7 +14,7 @@ import logging from contextlib import contextmanager, nullcontext from datetime import timedelta -from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Set, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, TYPE_CHECKING, Union import torch from torch import Tensor @@ -30,6 +30,7 @@ _auto_wrap_policy_kwargs, _get_full_state_dict_context, _init_cpu_offload, + _init_sharding_strategy, _optimizer_has_flat_params, _setup_activation_checkpointing, ) @@ -60,7 +61,12 @@ from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only if TYPE_CHECKING: - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + CPUOffload, + FullyShardedDataParallel, + MixedPrecision, + ShardingStrategy, + ) if _TORCH_GREATER_EQUAL_2_0: from torch.distributed.fsdp.wrap import _FSDPPolicy @@ -69,6 +75,9 @@ else: _POLICY = Union[Set, Callable[[Module, bool, int], bool]] # type: ignore[misc] + _SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]] + + log = logging.getLogger(__name__) @@ -101,6 +110,17 @@ class FSDPStrategy(ParallelStrategy): want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. For convenience, this also accepts a set of the layer classes to wrap. + sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination of + them. Available values are: + + - ``"FULL_SHARD"``: Shards model parameters, gradients, and optimizer states (default). + - ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated. + - ``"NO_SHARD"``: No sharding (identical to regular DDP). + - ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but + replicates across machines. + + Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value. + \**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. """ @@ -121,6 +141,7 @@ def __init__( auto_wrap_policy: Optional["_POLICY"] = None, activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None, activation_checkpointing_policy: Optional["_POLICY"] = None, + sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD", **kwargs: Any, ) -> None: if not _TORCH_GREATER_EQUAL_1_12: @@ -138,6 +159,7 @@ def __init__( self._process_group_backend = process_group_backend self._timeout: Optional[timedelta] = timeout self.cpu_offload = _init_cpu_offload(cpu_offload) + self.sharding_strategy = _init_sharding_strategy(sharding_strategy) self.mixed_precision = mixed_precision self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs) @@ -250,6 +272,7 @@ def _setup_model(self, model: Module) -> "FullyShardedDataParallel": process_group=self.process_group, cpu_offload=self.cpu_offload, mixed_precision=self.mixed_precision_config, + sharding_strategy=self.sharding_strategy, device_id=self.root_device.index, **self.kwargs, ) @@ -330,6 +353,7 @@ def model_sharded_context(self) -> Generator[None, None, None]: process_group=self.process_group, cpu_offload=self.cpu_offload, mixed_precision=self.mixed_precision_config, + sharding_strategy=self.sharding_strategy, device_id=self.root_device.index, **self.kwargs, ): diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index 5724a183e773e..32765531535bc 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -65,6 +65,26 @@ def test_fsdp_cpu_offload(): assert strategy.cpu_offload == config +@RunIf(min_torch="1.12") +def test_fsdp_sharding_strategy(): + """Test the different ways the sharding strategy can be set.""" + from torch.distributed.fsdp import ShardingStrategy + + # default + strategy = FSDPStrategy() + assert strategy.sharding_strategy == ShardingStrategy.FULL_SHARD + + # enum + strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP) + assert strategy.sharding_strategy == ShardingStrategy.SHARD_GRAD_OP + + # string + strategy = FSDPStrategy(sharding_strategy="NO_SHARD") + assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD + strategy = FSDPStrategy(sharding_strategy="no_shard") + assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD + + @RunIf(min_torch="1.12") @pytest.mark.parametrize("torch_ge_2_0", [False, True]) def test_fsdp_setup_optimizer_validation(torch_ge_2_0): diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 2f82c8fc96aa5..67ff595026b69 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -486,6 +486,26 @@ def test_fsdp_strategy_cpu_offload(): assert strategy.cpu_offload == config +@RunIf(min_torch="1.12") +def test_fsdp_sharding_strategy(): + """Test the different ways the sharding strategy can be set.""" + from torch.distributed.fsdp import ShardingStrategy + + # default + strategy = FSDPStrategy() + assert strategy.sharding_strategy == ShardingStrategy.FULL_SHARD + + # enum + strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP) + assert strategy.sharding_strategy == ShardingStrategy.SHARD_GRAD_OP + + # string + strategy = FSDPStrategy(sharding_strategy="NO_SHARD") + assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD + strategy = FSDPStrategy(sharding_strategy="no_shard") + assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD + + @RunIf(min_torch="1.12") def test_fsdp_use_orig_params(): """Test that Lightning enables `use_orig_params` in PyTorch >= 2.0."""