Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flexible and easy to use HSDP setting #19504

Merged
merged 29 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
312caee
Add fsdp_size for FSDPStrategy
Liyang90 Jan 17, 2024
45c1123
fix import
Liyang90 Jan 17, 2024
0ddc51d
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Feb 20, 2024
c952536
Add flexible HSDP in fabric
Liyang90 Feb 20, 2024
8fc2404
minor update
Liyang90 Feb 20, 2024
da3900f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
8311be1
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Mar 1, 2024
d1d719a
Use device_mesh arg to set flexible HSDP with a Tuple
Liyang90 Mar 4, 2024
3315893
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
4652b74
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Mar 5, 2024
4049f60
minor fix
Liyang90 Mar 5, 2024
9c14afe
add simple docs
awaelchli Mar 8, 2024
1f2c3ff
correct doc string
Liyang90 Apr 1, 2024
07f7c1b
set as explicit args in FSDPStrategy
Liyang90 Apr 4, 2024
2ab0423
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Apr 4, 2024
899e032
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
4259df2
update fsdp tests
Liyang90 Apr 18, 2024
dbe22f3
Type check error
Liyang90 Apr 18, 2024
2320a4e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2024
b0d4783
merge
Liyang90 Apr 18, 2024
9d7dfbe
type check
Liyang90 Apr 18, 2024
483f745
Merge branch 'master' into hybrid_fsdp_stage
Liyang90 May 16, 2024
ba0b10b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2024
d2d9fe8
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Liyang90 Jun 5, 2024
bd03b05
simplify imports
awaelchli Jun 5, 2024
11bc4ee
extend test
awaelchli Jun 5, 2024
c6a052c
add changelog
awaelchli Jun 5, 2024
00efbcf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2024
949d36f
Merge branch 'master' into hybrid_fsdp_stage
awaelchli Jun 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@
_POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed._tensor import DeviceMesh
else:
DeviceMesh = None # type: ignore

_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")


Expand Down Expand Up @@ -117,10 +122,14 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines.
replicates across machines. See also the `device_mesh` parameter below.

Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.

device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
with the `HYBRID_SHARD` sharding strategy.

state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.

- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand All @@ -146,6 +155,7 @@ def __init__(
activation_checkpointing_policy: Optional["_POLICY"] = None,
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
state_dict_type: Literal["full", "sharded"] = "sharded",
device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -163,6 +173,11 @@ def __init__(
# Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()`
self._fsdp_kwargs.setdefault("use_orig_params", True)

if device_mesh is not None:
if not _TORCH_GREATER_EQUAL_2_2:
raise ValueError("The device_mesh argument is only supported in torch >= 2.2.")
self._fsdp_kwargs["device_mesh"] = device_mesh

self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
activation_checkpointing, activation_checkpointing_policy
)
Expand Down Expand Up @@ -244,6 +259,12 @@ def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()

# if 'device_mesh' in the `_fsdp_kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self._fsdp_kwargs.get("device_mesh"), tuple):
from torch.distributed.device_mesh import init_device_mesh

self._fsdp_kwargs["device_mesh"] = init_device_mesh("cuda", self._fsdp_kwargs["device_mesh"])
Liyang90 marked this conversation as resolved.
Show resolved Hide resolved

@override
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
Expand Down
45 changes: 42 additions & 3 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,21 @@
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Literal,
Mapping,
Optional,
Set,
Tuple,
Type,
Union,
)

import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
Expand Down Expand Up @@ -53,7 +67,10 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2,
)
from lightning.fabric.utilities.init import _EmptyInit, _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
from lightning.fabric.utilities.optimizer import _optimizers_to_device
Expand All @@ -76,6 +93,11 @@
_POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]

if _TORCH_GREATER_EQUAL_2_2:
from torch.distributed._tensor import DeviceMesh
else:
DeviceMesh = None # type: ignore


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -114,10 +136,14 @@ class FSDPStrategy(ParallelStrategy):
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines.
replicates across machines. See also the `device_mesh` parameter below.

Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.

device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
with the `HYBRID_SHARD` sharding strategy.

state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.

- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand Down Expand Up @@ -147,6 +173,7 @@ def __init__(
activation_checkpointing_policy: Optional["_POLICY"] = None,
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
state_dict_type: Literal["full", "sharded"] = "full",
device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -162,6 +189,12 @@ def __init__(
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.mixed_precision = mixed_precision
self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)

if device_mesh is not None:
if not _TORCH_GREATER_EQUAL_2_2:
raise ValueError("The device_mesh argument is only supported in torch >= 2.2.")
self.kwargs["device_mesh"] = device_mesh

self.sharding_strategy = _init_sharding_strategy(sharding_strategy, self.kwargs)

# Avoids the need for user to reference params in `configure_optimizers` via
Expand Down Expand Up @@ -242,6 +275,12 @@ def setup_environment(self) -> None:
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self.kwargs.get("device_mesh"), tuple):
from torch.distributed.device_mesh import init_device_mesh

self.kwargs["device_mesh"] = init_device_mesh("cuda", self.kwargs["device_mesh"])

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

Expand Down
15 changes: 8 additions & 7 deletions tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,14 @@ def test_hybrid_shard_configuration(sharding_strategy):
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy._fsdp_kwargs["process_group"] is process_group

device_mesh = Mock()
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy._fsdp_kwargs["device_mesh"] is device_mesh

with pytest.raises(ValueError, match="process_group.* device_mesh=.* are mutually exclusive"):
FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh)
with mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True):
device_mesh = Mock()
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy._fsdp_kwargs["device_mesh"] is device_mesh

with pytest.raises(ValueError, match="process_group.* device_mesh=.* are mutually exclusive"):
FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh)


def test_checkpoint_io_unsupported():
Expand Down
15 changes: 8 additions & 7 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,13 +514,14 @@ def test_hybrid_sharding_strategy(sharding_strategy):
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy.kwargs["process_group"] is process_group

device_mesh = Mock()
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy.kwargs["device_mesh"] is device_mesh

with pytest.raises(ValueError, match="process_group.* device_mesh=.* are mutually exclusive"):
FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh)
with mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True):
device_mesh = Mock()
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy.kwargs["device_mesh"] is device_mesh

with pytest.raises(ValueError, match="process_group.* device_mesh=.* are mutually exclusive"):
FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh)


def test_use_orig_params():
Expand Down
Loading