Skip to content

Commit

Permalink
Fabric: drop FairScale's sharded implementation (#16329)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
carmocca and awaelchli authored Jan 11, 2023
1 parent 3c3bff5 commit 428844d
Show file tree
Hide file tree
Showing 32 changed files with 115 additions and 484 deletions.
1 change: 0 additions & 1 deletion docs/source-pytorch/fabric/api/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ Strategies
Strategy
DDPStrategy
DataParallelStrategy
DDPShardedStrategy
FSDPStrategy
ParallelStrategy
SingleDeviceStrategy
Expand Down
7 changes: 2 additions & 5 deletions docs/source-pytorch/fabric/api/fabric_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Fabric Arguments
accelerator
===========

Choose one of ``"cpu"``, ``"gpu"``, ``"tpu"``, ``"auto"`` (IPU support is coming soon).
Choose one of ``"cpu"``, ``"gpu"``, ``"tpu"``, ``"auto"``.

.. code-block:: python
Expand All @@ -35,7 +35,7 @@ The ``"auto"`` option recognizes the machine you are on and selects the availabl
strategy
========

Choose a training strategy: ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"tpu_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``, or ``"ddp_sharded_spawn"``.
Choose a training strategy: ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"xla"``, ``"deepspeed"``, ``"fsdp"````.

.. code-block:: python
Expand All @@ -55,9 +55,6 @@ Additionally, you can pass in your custom strategy by configuring additional par
fabric = Fabric(strategy=DeepSpeedStrategy(stage=2), accelerator="gpu", devices=2)
Support for Fully Sharded training strategies are coming soon.


devices
=======

Expand Down
2 changes: 0 additions & 2 deletions requirements/fabric/strategies.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

fairscale>=0.4.5, <0.4.13
deepspeed>=0.6.0, <=0.7.0
2 changes: 1 addition & 1 deletion src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed

-
- Removed support for FairScale's sharded training (`strategy='ddp_sharded'|'ddp_sharded_spawn'`). Use Fully-Sharded Data Parallel instead (`strategy='fsdp'`) ([#16329](https://github.com/Lightning-AI/lightning/pull/16329))

### Fixed

Expand Down
20 changes: 3 additions & 17 deletions src/lightning_fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,7 @@
from lightning_fabric.plugins import Precision # avoid circular imports: # isort: split
from lightning_fabric.accelerators.accelerator import Accelerator
from lightning_fabric.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
from lightning_fabric.strategies import (
DDPShardedStrategy,
DeepSpeedStrategy,
FSDPStrategy,
SingleDeviceStrategy,
Strategy,
XLAStrategy,
)
from lightning_fabric.strategies import DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy, XLAStrategy
from lightning_fabric.strategies.strategy import _Sharded, TBroadcast
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.apply_func import convert_tensors_to_scalars, convert_to_tensors
Expand Down Expand Up @@ -69,7 +62,7 @@ class Fabric:
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
strategy: Strategy for how to run across multiple devices. Possible choices are:
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``, ``"ddp_sharded"``.
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
Expand Down Expand Up @@ -713,15 +706,8 @@ def _validate_setup_module(self, module: nn.Module) -> None:
if isinstance(module, _FabricModule):
raise ValueError("A model should be passed only once to the `setup_module` method.")

if isinstance(self._strategy, DDPShardedStrategy):
raise RuntimeError(
f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly"
" through `.setup(model, optimizer, ...)`. For inference, choose a different strategy, for example"
" `ddp`."
)

def _validate_setup_optimizers(self, optimizers: Sequence[Optimizer]) -> None:
if isinstance(self._strategy, (DeepSpeedStrategy, DDPShardedStrategy, XLAStrategy)):
if isinstance(self._strategy, (DeepSpeedStrategy, XLAStrategy)):
raise RuntimeError(
f"The `{type(self._strategy).__name__}` requires the model and optimizer(s) to be set up jointly"
" through `.setup(model, optimizer, ...)`."
Expand Down
1 change: 0 additions & 1 deletion src/lightning_fabric/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from lightning_fabric.strategies.ddp import DDPStrategy # noqa: F401
from lightning_fabric.strategies.deepspeed import DeepSpeedStrategy # noqa: F401
from lightning_fabric.strategies.dp import DataParallelStrategy # noqa: F401
from lightning_fabric.strategies.fairscale import DDPShardedStrategy # noqa: F401
from lightning_fabric.strategies.fsdp import FSDPStrategy # noqa: F401
from lightning_fabric.strategies.parallel import ParallelStrategy # noqa: F401
from lightning_fabric.strategies.registry import _call_register_strategies, _StrategyRegistry
Expand Down
146 changes: 0 additions & 146 deletions src/lightning_fabric/strategies/fairscale.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/lightning_fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"fsdp",
cls,
description="Fully Sharded Data Parallel training from torch.distributed.",
description="Fully Sharded Data Parallel",
)
strategy_registry.register(
"fsdp_full_shard_offload",
cls,
description="Native FSDP with Full Sharding and CPU Offloading",
description="Fully Sharded Data Parallel and CPU Offloading",
cpu_offload=True,
)

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_fabric/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def get_optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
Allows for syncing/collating optimizer state from processes in custom plugins.
"""
if hasattr(optimizer, "consolidate_state_dict"):
# there are optimizers like Fairscale's OSS or PyTorch's ZeroRedundancyOptimizer that shard their
# there are optimizers like PyTorch's ZeroRedundancyOptimizer that shard their
# states, and to avoid OOM we consolidate the full state on rank 0 only
optimizer.consolidate_state_dict()
return optimizer.state_dict() if self.is_global_zero else {}
Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed support for `LightningCLI(seed_everything_default=None)` ([#16131](https://github.com/Lightning-AI/lightning/pull/16131))


- Removed support in LightningLite for FairScale's sharded training (`strategy='ddp_sharded'|'ddp_sharded_spawn'`). Use Fully-Sharded Data Parallel instead (`strategy='fsdp'`) ([#16329](https://github.com/Lightning-AI/lightning/pull/16329))


### Fixed

- Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253))
Expand Down
45 changes: 16 additions & 29 deletions src/pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from lightning_fabric.plugins import TPUBf16Precision as LiteTPUBf16Precision
from lightning_fabric.plugins import TPUPrecision as LiteTPUPrecision
from lightning_fabric.strategies import DataParallelStrategy as LiteDataParallelStrategy
from lightning_fabric.strategies import DDPShardedStrategy as LiteDDPShardedStrategy
from lightning_fabric.strategies import DDPStrategy as LiteDDPStrategy
from lightning_fabric.strategies import DeepSpeedStrategy as LiteDeepSpeedStrategy
from lightning_fabric.strategies import SingleDeviceStrategy as LiteSingleDeviceStrategy
Expand Down Expand Up @@ -75,7 +74,7 @@ class LightningLite(Fabric, ABC):
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
strategy: Strategy for how to run across multiple devices. Possible choices are:
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``.
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
Expand Down Expand Up @@ -109,8 +108,7 @@ def __init__(

rank_zero_deprecation(
"The `pytorch_lightning.lite.LightningLite` class was deprecated in v1.9.0 and will be renamed to"
" `lightning.fabric.Fabric` in v2.0.0. It is no longer part of the pure `pytorch_lightning` package, and"
" now lives in the main `lightning` package."
" `lightning_fabric.Fabric` in v2.0.0."
)

if gpus is not None or tpu_cores is not None:
Expand All @@ -131,6 +129,20 @@ def __init__(
else:
lite_plugins = plugins

if type(strategy) in (PLDDPShardedStrategy, PLDDPSpawnShardedStrategy) or strategy in (
"ddp_sharded",
"ddp_sharded_spawn",
):
spawn_message = ""
if type(strategy) is PLDDPSpawnShardedStrategy or strategy == "ddp_sharded_spawn":
spawn_message = ", start_method='spawn'"
raise RuntimeError(
"LightningLite's sharded implementation using FairScale has been removed in favor of PyTorch's FSDP."
" You can try"
f" `Fabric(strategy=FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP{spawn_message}))`"
" which implements optimizer-only sharding à la ZeRO-2. Or full sharding with `Fabric(strategy='fsdp')`"
)

super().__init__(
accelerator=accelerator,
strategy=(_to_lite_strategy(strategy) if isinstance(strategy, PLStrategy) else strategy),
Expand Down Expand Up @@ -245,31 +257,6 @@ def _to_lite_strategy(strategy: PLStrategy) -> LiteStrategy:
precision=_to_lite_precision(strategy.precision_plugin),
)

if strategy_cls is PLDDPShardedStrategy:
return LiteDDPShardedStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
**strategy._ddp_kwargs,
)

if strategy_cls is PLDDPSpawnShardedStrategy:
return LiteDDPShardedStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
start_method=strategy._start_method,
**strategy._ddp_kwargs,
)

if strategy_cls is PLSingleDeviceStrategy:
return LiteSingleDeviceStrategy(
device=strategy.root_device,
Expand Down
Loading

0 comments on commit 428844d

Please sign in to comment.