Skip to content

Commit

Permalink
Support sets for policies in FSDP (#18084)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jul 15, 2023
1 parent e9c42ed commit c60f67e
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 64 deletions.
20 changes: 2 additions & 18 deletions docs/source-pytorch/advanced/model_parallel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,7 @@ You can customize the strategy configuration by adjusting the arguments of :clas
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
# configure the wrapping condition
if torch.__version__ >= "2.1":
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
my_policy = ModuleWrapPolicy({MyTransformerBlock})
else:
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
import functools
my_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda module: isinstance(module, torch.nn.Linear))
fsdp = FSDPStrategy(auto_wrap_policy=my_policy)
fsdp = FSDPStrategy(auto_wrap_policy={MyTransformerBlock})
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
Expand Down Expand Up @@ -241,14 +232,7 @@ Enable checkpointing on large layers (like Transformers) by providing a policy:
from lightning.pytorch.strategies import FSDPStrategy
if torch.__version__ >= "2.1":
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
my_policy = ModuleWrapPolicy({MyTransformerBlock})
fsdp = FSDPStrategy(activation_checkpointing_policy=my_policy)
else:
fsdp = FSDPStrategy(activation_checkpointing=MyTransformerBlock) # or pass a list with multiple types
fsdp = FSDPStrategy(activation_checkpointing_policy={MyTransformerBlock})
trainer = pl.Trainer(strategy=fsdp, accelerator="gpu", devices=4)
Expand Down
65 changes: 50 additions & 15 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import threading
from contextlib import _GeneratorContextManager, contextmanager, nullcontext
from datetime import timedelta
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Tuple, Type, TYPE_CHECKING, Union
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Literal,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

import torch
from torch import Tensor
Expand Down Expand Up @@ -64,9 +78,9 @@
if _TORCH_GREATER_EQUAL_2_0:
from torch.distributed.fsdp.wrap import _FSDPPolicy

_POLICY = Union[Callable[[Module, bool, int], bool], _FSDPPolicy]
_POLICY = Union[Set, Callable[[Module, bool, int], bool], _FSDPPolicy]
else:
_POLICY = Callable[[Module, bool, int], bool] # type: ignore[misc]
_POLICY = Union[Set, Callable[[Module, bool, int], bool]] # type: ignore[misc]

_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
_METADATA_FILENAME = "meta.pt"
Expand All @@ -92,13 +106,15 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
Arguments:
cpu_offload: See ``cpu_offload`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
mixed_precision: See ``mixed_precision`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
activation_checkpointing: Deprecated. Use ``activation_checkpointing_policy``. A single layer or a list of
layer classes for which you want to enable activation checkpointing. This is typically your transformer
block (including attention + feed-forward).
auto_wrap_policy: Same as ``auto_wrap_policy`` parameter in
:class:`torch.distributed.fsdp.FullyShardedDataParallel`. For convenience, this also accepts a set of the
layer classes to wrap.
activation_checkpointing: Deprecated. Use ``activation_checkpointing_policy``.
activation_checkpointing_policy: Same as ``auto_wrap_policy`` parameter in
:class:`torch.distributed.fsdp.FullyShardedDataParallel` but used when selecting the modules for which you
want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the
cost of speed since activations in these layers need to be recomputed during backpropagation.
cost of speed since activations in these layers need to be recomputed during backpropagation. For
convenience, this also accepts a set of the layer classes to wrap.
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand All @@ -119,6 +135,7 @@ def __init__(
timeout: Optional[timedelta] = default_pg_timeout,
cpu_offload: Union[bool, "CPUOffload", None] = None,
mixed_precision: Optional["MixedPrecision"] = None,
auto_wrap_policy: Optional["_POLICY"] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
activation_checkpointing_policy: Optional["_POLICY"] = None,
state_dict_type: Literal["full", "sharded"] = "sharded",
Expand All @@ -138,7 +155,7 @@ def __init__(
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
self._backward_sync_control = _FSDPBackwardSyncControl()
self._fsdp_kwargs = kwargs
self._fsdp_kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)

if _TORCH_GREATER_EQUAL_2_0:
# Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()`
Expand Down Expand Up @@ -615,16 +632,35 @@ def _activation_checkpointing_kwargs(
if _TORCH_GREATER_EQUAL_2_1:
rank_zero_deprecation(
f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use "
"`FSDPStrategy(activation_checkpointing_policy=torch.distributed.fsdp.wrap.ModuleWrapPolicy"
f"({set(classes)}))` instead."
f"`FSDPStrategy(activation_checkpointing_policy={set(classes)})` instead."
)
return {"check_fn": lambda submodule: isinstance(submodule, classes)}
assert activation_checkpointing_policy is not None
if isinstance(activation_checkpointing_policy, set):
if _TORCH_GREATER_EQUAL_2_1:
return _auto_wrap_policy_kwargs(activation_checkpointing_policy, {})
return {"check_fn": lambda submodule: isinstance(submodule, tuple(activation_checkpointing_policy))}
if not _TORCH_GREATER_EQUAL_2_1:
raise ValueError("`activation_checkpointing_policy` requires torch >= 2.1.0. HINT: `pip install -U torch`")
return {"auto_wrap_policy": activation_checkpointing_policy}


def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict:
if policy is None:
return kwargs
if isinstance(policy, set):
if _TORCH_GREATER_EQUAL_2_1:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

policy = ModuleWrapPolicy(policy)
else:
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# this is not transformer specific despite the name
policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=policy)
kwargs["auto_wrap_policy"] = policy
return kwargs


def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwargs: Dict) -> None:
if not activation_checkpointing_kwargs:
return
Expand All @@ -642,10 +678,9 @@ def _setup_activation_checkpointing(module: Module, activation_checkpointing_kwa
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
CheckpointWrapper,
)

wrapper = functools.partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT)
wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT)
apply_activation_checkpointing(module, checkpoint_wrapper_fn=wrapper, **activation_checkpointing_kwargs)


Expand Down Expand Up @@ -800,7 +835,7 @@ def _opt_hook(h: FlatParamHandle, flat_param: FlatParameter, *_unused: Any) -> N
h.prepare_gradient_for_optim = _no_op # type: ignore[method-assign]
maybe_step(flat_param._params or (), h._clear_grads_if_needed)

hook = functools.partial(_opt_hook, h, flat_param)
hook = partial(_opt_hook, h, flat_param)
hook_handles.append(fsdp_acc_grad.register_hook(hook))

yield
Expand Down
22 changes: 14 additions & 8 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Type, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Set, Type, TYPE_CHECKING, Union

import torch
from torch import Tensor
Expand All @@ -27,6 +27,7 @@
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.fsdp import (
_activation_checkpointing_kwargs,
_auto_wrap_policy_kwargs,
_get_full_state_dict_context,
_init_cpu_offload,
_optimizer_has_flat_params,
Expand Down Expand Up @@ -64,9 +65,9 @@
if _TORCH_GREATER_EQUAL_2_0:
from torch.distributed.fsdp.wrap import _FSDPPolicy

_POLICY = Union[Callable[[Module, bool, int], bool], _FSDPPolicy]
_POLICY = Union[Set, Callable[[Module, bool, int], bool], _FSDPPolicy]
else:
_POLICY = Callable[[Module, bool, int], bool] # type: ignore[misc]
_POLICY = Union[Set, Callable[[Module, bool, int], bool]] # type: ignore[misc]

log = logging.getLogger(__name__)

Expand All @@ -91,13 +92,15 @@ class FSDPStrategy(ParallelStrategy):
Arguments:
cpu_offload: See ``cpu_offload`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
mixed_precision: See ``mixed_precision`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
activation_checkpointing: Deprecated. Use ``activation_checkpointing_policy``. A single layer or a list of
layer classes for which you want to enable activation checkpointing. This is typically your transformer
block (including attention + feed-forward).
auto_wrap_policy: Same as ``auto_wrap_policy`` parameter in
:class:`torch.distributed.fsdp.FullyShardedDataParallel`. For convenience, this also accepts a set of the
layer classes to wrap.
activation_checkpointing: Deprecated. Use ``activation_checkpointing_policy``.
activation_checkpointing_policy: Same as ``auto_wrap_policy`` parameter in
:class:`torch.distributed.fsdp.FullyShardedDataParallel` but used when selecting the modules for which you
want to enable activation checkpointing. Enabling this can free up a significant amount of memory at the
cost of speed since activations in these layers need to be recomputed during backpropagation.
cost of speed since activations in these layers need to be recomputed during backpropagation. For
convenience, this also accepts a set of the layer classes to wrap.
\**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
"""

Expand All @@ -115,6 +118,7 @@ def __init__(
timeout: Optional[timedelta] = default_pg_timeout,
cpu_offload: Union[bool, "CPUOffload", None] = None,
mixed_precision: Optional["MixedPrecision"] = None,
auto_wrap_policy: Optional["_POLICY"] = None,
activation_checkpointing: Optional[Union[Type[Module], List[Type[Module]]]] = None,
activation_checkpointing_policy: Optional["_POLICY"] = None,
**kwargs: Any,
Expand All @@ -135,11 +139,13 @@ def __init__(
self._timeout: Optional[timedelta] = timeout
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.mixed_precision = mixed_precision
self.kwargs = kwargs
self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)

if _TORCH_GREATER_EQUAL_2_0:
# Avoids the need for user to reference params in `configure_optimizers` via
# `self.trainer.model.parameters()` and enables support for multiple parameter groups.
self.kwargs.setdefault("use_orig_params", True)

self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
activation_checkpointing, activation_checkpointing_policy
)
Expand Down
19 changes: 10 additions & 9 deletions tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import contextlib
import datetime
import functools
import os
from datetime import timedelta
from re import escape
Expand Down Expand Up @@ -157,18 +156,26 @@ def __init__(self):
if _TORCH_GREATER_EQUAL_2_1:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1}))
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)

strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
else:
strategy = FSDPStrategy(activation_checkpointing=Block1)
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}

strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}

strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}

strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}

strategy._parallel_devices = [torch.device("cuda", 0)]
with mock.patch(
"torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel"
Expand Down Expand Up @@ -402,11 +409,6 @@ def __del__(self) -> None:
[(Block,), (SubBlock,), (Block, SubBlock, nn.Linear), None],
)
def test_apply_optimizer_in_backward(checkpoint):
try:
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
except ImportError:
pytest.skip("Failed to import `lambda_auto_wrap_policy`")

from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles

num_gpus = 2
Expand All @@ -431,9 +433,8 @@ def test_apply_optimizer_in_backward(checkpoint):
upper_savings_bound = 4 * feature_dim**2 * 2 * (num_blocks - 1)
lower_savings_bound = upper_savings_bound / 3

auto_wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda module: isinstance(module, Block))
strategy = FSDPStrategy(
auto_wrap_policy=auto_wrap_policy,
auto_wrap_policy={Block},
activation_checkpointing=checkpoint,
timeout=datetime.timedelta(seconds=10),
)
Expand Down
15 changes: 2 additions & 13 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@
from lightning.fabric import Fabric
from lightning.fabric.plugins import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_1_12,
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.wrappers import _FabricOptimizer
from tests_fabric.helpers.models import BoringFabric
from tests_fabric.helpers.runif import RunIf
Expand Down Expand Up @@ -409,14 +405,7 @@ def test_fsdp_save_filter(tmp_path):
@RunIf(min_torch="1.13", min_cuda_gpus=1)
def test_fsdp_manual_activation_checkpointing():
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Linear(1, 1))

if _TORCH_GREATER_EQUAL_2_1:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({torch.nn.Linear}))
else:
strategy = FSDPStrategy(activation_checkpointing=torch.nn.Linear)

strategy = FSDPStrategy(activation_checkpointing_policy={torch.nn.Linear})
fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy)
fabric.launch()

Expand Down
10 changes: 9 additions & 1 deletion tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,18 +438,26 @@ def __init__(self):
if _TORCH_GREATER_EQUAL_2_1:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1}))
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)

strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
else:
strategy = FSDPStrategy(activation_checkpointing=Block1)
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}

strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}

strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}

strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}

model = Model()
strategy._parallel_devices = [torch.device("cuda", 0)]
strategy._lightning_module = model
Expand Down

0 comments on commit c60f67e

Please sign in to comment.