Skip to content

Commit

Permalink
Enable setting the sharding strategy as string in FSDP (#18087)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jul 15, 2023
1 parent c60f67e commit 080eaf3
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a callback for spike-detection ([#18014](https://github.com/Lightning-AI/lightning/pull/18014))


- Added the ability to set the `torch.distributed.fsdp.ShardingStrategy` via string in `FSDPStrategy` ([#18087](https://github.com/Lightning-AI/lightning/pull/18087))


### Changed

- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
Expand Down
30 changes: 29 additions & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,12 @@
from lightning.fabric.utilities.types import _PATH

if TYPE_CHECKING:
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
FullyShardedDataParallel,
MixedPrecision,
ShardingStrategy,
)

from lightning.fabric.wrappers import _FabricModule

Expand All @@ -82,6 +87,8 @@
else:
_POLICY = Union[Set, Callable[[Module, bool, int], bool]] # type: ignore[misc]

_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]

_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
_METADATA_FILENAME = "meta.pt"

Expand Down Expand Up @@ -115,6 +122,17 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the
cost of speed since activations in these layers need to be recomputed during backpropagation. For
convenience, this also accepts a set of the layer classes to wrap.
sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination of
them. Available values are:
- ``"FULL_SHARD"``: Shards model parameters, gradients, and optimizer states (default).
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines.
Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand All @@ -138,6 +156,7 @@ def __init__(
auto_wrap_policy: Optional["_POLICY"] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
activation_checkpointing_policy: Optional["_POLICY"] = None,
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
state_dict_type: Literal["full", "sharded"] = "sharded",
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -165,6 +184,7 @@ def __init__(
activation_checkpointing, activation_checkpointing_policy
)
self._state_dict_type = state_dict_type
self.sharding_strategy = _init_sharding_strategy(sharding_strategy)
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.mixed_precision = mixed_precision

Expand Down Expand Up @@ -249,6 +269,7 @@ def setup_module(self, module: Module) -> "FullyShardedDataParallel":
module=module,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
sharding_strategy=self.sharding_strategy,
device_id=self.root_device.index,
**self._fsdp_kwargs,
)
Expand Down Expand Up @@ -306,6 +327,7 @@ def module_sharded_context(self) -> Generator:
wrapper_cls=FullyShardedDataParallel,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
sharding_strategy=self.sharding_strategy,
device_id=self.root_device.index,
**self._fsdp_kwargs,
):
Expand Down Expand Up @@ -707,6 +729,12 @@ def _init_cpu_offload(cpu_offload: Optional[Union[bool, "CPUOffload"]]) -> "CPUO
return cpu_offload if isinstance(cpu_offload, CPUOffload) else CPUOffload(offload_params=bool(cpu_offload))


def _init_sharding_strategy(sharding_strategy: "_SHARDING_STRATEGY") -> "ShardingStrategy":
from torch.distributed.fsdp import ShardingStrategy

return ShardingStrategy[sharding_strategy.upper()] if isinstance(sharding_strategy, str) else sharding_strategy


def _optimizer_has_flat_params(optimizer: Optimizer) -> bool:
_FSDP_FLATTENED = "_fsdp_flattened"
if _TORCH_GREATER_EQUAL_1_13:
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a callback for spike-detection ([#18014](https://github.com/Lightning-AI/lightning/pull/18014))


- Added the ability to set the `torch.distributed.fsdp.ShardingStrategy` via string in `FSDPStrategy` ([#18087](https://github.com/Lightning-AI/lightning/pull/18087))


### Changed

- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
Expand Down
28 changes: 26 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Set, Type, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, TYPE_CHECKING, Union

import torch
from torch import Tensor
Expand All @@ -30,6 +30,7 @@
_auto_wrap_policy_kwargs,
_get_full_state_dict_context,
_init_cpu_offload,
_init_sharding_strategy,
_optimizer_has_flat_params,
_setup_activation_checkpointing,
)
Expand Down Expand Up @@ -60,7 +61,12 @@
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only

if TYPE_CHECKING:
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
FullyShardedDataParallel,
MixedPrecision,
ShardingStrategy,
)

if _TORCH_GREATER_EQUAL_2_0:
from torch.distributed.fsdp.wrap import _FSDPPolicy
Expand All @@ -69,6 +75,9 @@
else:
_POLICY = Union[Set, Callable[[Module, bool, int], bool]] # type: ignore[misc]

_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -101,6 +110,17 @@ class FSDPStrategy(ParallelStrategy):
want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the
cost of speed since activations in these layers need to be recomputed during backpropagation. For
convenience, this also accepts a set of the layer classes to wrap.
sharding_strategy: Select whether to shard model parameters, gradients, optimizer states, or a combination of
them. Available values are:
- ``"FULL_SHARD"``: Shards model parameters, gradients, and optimizer states (default).
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines.
Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.
\**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
"""

Expand All @@ -121,6 +141,7 @@ def __init__(
auto_wrap_policy: Optional["_POLICY"] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
activation_checkpointing_policy: Optional["_POLICY"] = None,
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
**kwargs: Any,
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
Expand All @@ -138,6 +159,7 @@ def __init__(
self._process_group_backend = process_group_backend
self._timeout: Optional[timedelta] = timeout
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.sharding_strategy = _init_sharding_strategy(sharding_strategy)
self.mixed_precision = mixed_precision
self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)

Expand Down Expand Up @@ -250,6 +272,7 @@ def _setup_model(self, model: Module) -> "FullyShardedDataParallel":
process_group=self.process_group,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
sharding_strategy=self.sharding_strategy,
device_id=self.root_device.index,
**self.kwargs,
)
Expand Down Expand Up @@ -330,6 +353,7 @@ def model_sharded_context(self) -> Generator[None, None, None]:
process_group=self.process_group,
cpu_offload=self.cpu_offload,
mixed_precision=self.mixed_precision_config,
sharding_strategy=self.sharding_strategy,
device_id=self.root_device.index,
**self.kwargs,
):
Expand Down
20 changes: 20 additions & 0 deletions tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,26 @@ def test_fsdp_cpu_offload():
assert strategy.cpu_offload == config


@RunIf(min_torch="1.12")
def test_fsdp_sharding_strategy():
"""Test the different ways the sharding strategy can be set."""
from torch.distributed.fsdp import ShardingStrategy

# default
strategy = FSDPStrategy()
assert strategy.sharding_strategy == ShardingStrategy.FULL_SHARD

# enum
strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
assert strategy.sharding_strategy == ShardingStrategy.SHARD_GRAD_OP

# string
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
strategy = FSDPStrategy(sharding_strategy="no_shard")
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD


@RunIf(min_torch="1.12")
@pytest.mark.parametrize("torch_ge_2_0", [False, True])
def test_fsdp_setup_optimizer_validation(torch_ge_2_0):
Expand Down
20 changes: 20 additions & 0 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,26 @@ def test_fsdp_strategy_cpu_offload():
assert strategy.cpu_offload == config


@RunIf(min_torch="1.12")
def test_fsdp_sharding_strategy():
"""Test the different ways the sharding strategy can be set."""
from torch.distributed.fsdp import ShardingStrategy

# default
strategy = FSDPStrategy()
assert strategy.sharding_strategy == ShardingStrategy.FULL_SHARD

# enum
strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
assert strategy.sharding_strategy == ShardingStrategy.SHARD_GRAD_OP

# string
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
strategy = FSDPStrategy(sharding_strategy="no_shard")
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD


@RunIf(min_torch="1.12")
def test_fsdp_use_orig_params():
"""Test that Lightning enables `use_orig_params` in PyTorch >= 2.0."""
Expand Down

0 comments on commit 080eaf3

Please sign in to comment.