Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fabric: drop FairScale's sharded implementation #16329

Merged
merged 15 commits into from
Jan 11, 2023
1 change: 0 additions & 1 deletion docs/source-pytorch/fabric/api/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ Strategies
Strategy
DDPStrategy
DataParallelStrategy
DDPShardedStrategy
FSDPStrategy
ParallelStrategy
SingleDeviceStrategy
Expand Down
7 changes: 2 additions & 5 deletions docs/source-pytorch/fabric/api/fabric_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Fabric Arguments
accelerator
===========

Choose one of ``"cpu"``, ``"gpu"``, ``"tpu"``, ``"auto"`` (IPU support is coming soon).
Choose one of ``"cpu"``, ``"gpu"``, ``"tpu"``, ``"auto"``.

.. code-block:: python

Expand All @@ -35,7 +35,7 @@ The ``"auto"`` option recognizes the machine you are on and selects the availabl
strategy
========

Choose a training strategy: ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"tpu_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``, or ``"ddp_sharded_spawn"``.
Choose a training strategy: ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"xla"``, ``"deepspeed"``, ``"fsdp"````.
carmocca marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

Expand All @@ -55,9 +55,6 @@ Additionally, you can pass in your custom strategy by configuring additional par
fabric = Fabric(strategy=DeepSpeedStrategy(stage=2), accelerator="gpu", devices=2)


Support for Fully Sharded training strategies are coming soon.


devices
=======

Expand Down
2 changes: 0 additions & 2 deletions requirements/fabric/strategies.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

fairscale>=0.4.5, <0.4.13
deepspeed>=0.6.0, <=0.7.0
2 changes: 1 addition & 1 deletion src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
carmocca marked this conversation as resolved.
Show resolved Hide resolved
- Removed support for FairScale's sharded training (`strategy='ddp_sharded'|'ddp_sharded_spawn'`). Use Fully-Sharded Data Parallel instead (`strategy='fsdp'`) ([#16329](https://github.com/Lightning-AI/lightning/pull/16329))

### Fixed

Expand Down
2 changes: 0 additions & 2 deletions src/lightning_fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from lightning_fabric.plugins.precision.fsdp import FSDPPrecision
from lightning_fabric.plugins.precision.precision import _PRECISION_INPUT, _PRECISION_INPUT_INT, _PRECISION_INPUT_STR
from lightning_fabric.strategies import (
DDPShardedStrategy,
DDPStrategy,
DeepSpeedStrategy,
SingleDeviceStrategy,
Expand Down Expand Up @@ -557,7 +556,6 @@ def is_distributed(self) -> bool:
return self.strategy.is_distributed
distributed_strategy = (
DDPStrategy,
DDPShardedStrategy,
DeepSpeedStrategy,
XLAStrategy,
)
Expand Down
20 changes: 3 additions & 17 deletions src/lightning_fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,7 @@
from lightning_fabric.plugins import Precision # avoid circular imports: # isort: split
from lightning_fabric.accelerators.accelerator import Accelerator
from lightning_fabric.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
from lightning_fabric.strategies import (
DDPShardedStrategy,
DeepSpeedStrategy,
FSDPStrategy,
SingleDeviceStrategy,
Strategy,
XLAStrategy,
)
from lightning_fabric.strategies import DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy, XLAStrategy
from lightning_fabric.strategies.strategy import _Sharded, TBroadcast
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars, convert_to_tensors
Expand Down Expand Up @@ -69,7 +62,7 @@ class Fabric:
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
strategy: Strategy for how to run across multiple devices. Possible choices are:
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``.
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
Expand Down Expand Up @@ -713,15 +706,8 @@ def _validate_setup_module(self, module: nn.Module) -> None:
if isinstance(module, _FabricModule):
raise ValueError("A model should be passed only once to the `setup_module` method.")

if isinstance(self._strategy, DDPShardedStrategy):
raise RuntimeError(
f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly"
" through `.setup(model, optimizer, ...)`. For inference, choose a different strategy, for example"
" `ddp`."
)

def _validate_setup_optimizers(self, optimizers: Sequence[Optimizer]) -> None:
if isinstance(self._strategy, (DeepSpeedStrategy, DDPShardedStrategy, XLAStrategy)):
if isinstance(self._strategy, (DeepSpeedStrategy, XLAStrategy)):
raise RuntimeError(
f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly"
" through `.setup(model, optimizer, ...)`."
Expand Down
1 change: 0 additions & 1 deletion src/lightning_fabric/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from lightning_fabric.strategies.ddp import DDPStrategy # noqa: F401
from lightning_fabric.strategies.deepspeed import DeepSpeedStrategy # noqa: F401
from lightning_fabric.strategies.dp import DataParallelStrategy # noqa: F401
from lightning_fabric.strategies.fairscale import DDPShardedStrategy # noqa: F401
from lightning_fabric.strategies.fsdp import FSDPStrategy # noqa: F401
from lightning_fabric.strategies.parallel import ParallelStrategy # noqa: F401
from lightning_fabric.strategies.registry import _call_register_strategies, _StrategyRegistry
Expand Down
146 changes: 0 additions & 146 deletions src/lightning_fabric/strategies/fairscale.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/lightning_fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,12 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"fsdp",
cls,
description="Fully Sharded Data Parallel training from torch.distributed.",
description="Fully Sharded Data Parallel",
)
strategy_registry.register(
"fsdp_full_shard_offload",
cls,
description="Native FSDP with Full Sharding and CPU Offloading",
description="Fully Sharded Data Parallel and CPU Offloading",
cpu_offload=True,
)

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
Allows for syncing/collating optimizer state from processes in custom plugins.
"""
if hasattr(optimizer, "consolidate_state_dict"):
# there are optimizers like Fairscale's OSS or PyTorch's ZeroRedundancyOptimizer that shard their
# there are optimizers like PyTorch's ZeroRedundancyOptimizer that shard their
# states, and to avoid OOM we consolidate the full state on rank 0 only
optimizer.consolidate_state_dict()
return optimizer.state_dict() if self.is_global_zero else {}
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed support for `LightningCLI(seed_everything_default=None)` ([#16131](https://github.com/Lightning-AI/lightning/pull/16131))


- Removed support in LightningLite for FairScale's sharded training (`strategy='ddp_sharded'|'ddp_sharded_spawn'`). Use Fully-Sharded Data Parallel instead (`strategy='fsdp'`) ([#16329](https://github.com/Lightning-AI/lightning/pull/16329))


### Fixed

- Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253))
Expand Down
45 changes: 16 additions & 29 deletions src/pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from lightning_fabric.plugins import TPUBf16Precision as LiteTPUBf16Precision
from lightning_fabric.plugins import TPUPrecision as LiteTPUPrecision
from lightning_fabric.strategies import DataParallelStrategy as LiteDataParallelStrategy
from lightning_fabric.strategies import DDPShardedStrategy as LiteDDPShardedStrategy
from lightning_fabric.strategies import DDPStrategy as LiteDDPStrategy
from lightning_fabric.strategies import DeepSpeedStrategy as LiteDeepSpeedStrategy
from lightning_fabric.strategies import SingleDeviceStrategy as LiteSingleDeviceStrategy
Expand Down Expand Up @@ -75,7 +74,7 @@ class LightningLite(Fabric, ABC):
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
strategy: Strategy for how to run across multiple devices. Possible choices are:
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``.
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
Expand Down Expand Up @@ -109,8 +108,7 @@ def __init__(

rank_zero_deprecation(
"The `pytorch_lightning.lite.LightningLite` class was deprecated in v1.9.0 and will be renamed to"
" `lightning.fabric.Fabric` in v2.0.0. It is no longer part of the pure `pytorch_lightning` package, and"
" now lives in the main `lightning` package."
" `lightning_fabric.Fabric` in v2.0.0."
)

if gpus is not None or tpu_cores is not None:
Expand All @@ -131,6 +129,20 @@ def __init__(
else:
lite_plugins = plugins

if type(strategy) in (PLDDPShardedStrategy, PLDDPSpawnShardedStrategy) or strategy in (
"ddp_sharded",
"ddp_sharded_spawn",
):
spawn_message = ""
if type(strategy) is PLDDPSpawnShardedStrategy or strategy == "ddp_sharded_spawn":
spawn_message = ", start_method='spawn'"
raise RuntimeError(
"LightningLite's sharded implementation using FairScale has been removed in favor of PyTorch's FSDP."
" You can try"
f" `Fabric(strategy=FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP{spawn_message}))`"
" which implements optimizer-only sharding a-la ZeRO-2. Or full sharding with `Fabric(strategy='fsdp')`"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
)

super().__init__(
accelerator=accelerator,
strategy=(_to_lite_strategy(strategy) if isinstance(strategy, PLStrategy) else strategy),
Expand Down Expand Up @@ -245,31 +257,6 @@ def _to_lite_strategy(strategy: PLStrategy) -> LiteStrategy:
precision=_to_lite_precision(strategy.precision_plugin),
)

if strategy_cls is PLDDPShardedStrategy:
return LiteDDPShardedStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
**strategy._ddp_kwargs,
)

if strategy_cls is PLDDPSpawnShardedStrategy:
return LiteDDPShardedStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
start_method=strategy._start_method,
**strategy._ddp_kwargs,
)

if strategy_cls is PLSingleDeviceStrategy:
return LiteSingleDeviceStrategy(
device=strategy.root_device,
Expand Down
Loading