Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
ba66f96
Basic skeletal code for XLAFSDP support for PyTorch Trainer
gkroiz Oct 8, 2023
26238d8
fix import issue
gkroiz Oct 8, 2023
53703e2
Fixed leaking env vars in tests + minor cleanup
gkroiz Oct 8, 2023
bbed985
Merge branch 'master' into pytorch_xla_fsdp
gkroiz Oct 8, 2023
0132427
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2023
d4ae668
accidentally deleted line in merge
gkroiz Oct 8, 2023
b538f43
fix mypy error
gkroiz Oct 8, 2023
9778505
Minor fixes
carmocca Oct 9, 2023
4098016
Reuse code
carmocca Oct 9, 2023
5ca74df
More minor fixes
carmocca Oct 9, 2023
4bcb2e9
Reorder methods to reduce diff with Fabric
carmocca Oct 9, 2023
957866b
Go through precision plugin
carmocca Oct 9, 2023
222896c
Consistent with regular FSDP
carmocca Oct 9, 2023
cd70e0d
minor fixes, limited checkpointing support, testing for trainer.test …
gkroiz Oct 9, 2023
d7d78f1
mypy fix
gkroiz Oct 9, 2023
2f51e4e
Merge branch 'master' into pytorch_xla_fsdp
gkroiz Oct 10, 2023
7a680e3
consistent setup optimizer
carmocca Oct 10, 2023
46927f4
Merge branch 'master' into pytorch_xla_fsdp
carmocca Oct 10, 2023
c4e295a
Import pbar force fn
carmocca Oct 10, 2023
0cd0cb6
Launched and overrides
carmocca Oct 10, 2023
e0b353f
Multiple optimizers should be fine?
carmocca Oct 10, 2023
60eecd1
Test fixes
carmocca Oct 10, 2023
1150e90
Minor FSDP fixes
carmocca Oct 10, 2023
1ae4957
Fix
carmocca Oct 10, 2023
ef05276
Other xla strategies too
carmocca Oct 10, 2023
1cc6d84
More fixes
carmocca Oct 10, 2023
64ed7fb
Fixes
carmocca Oct 10, 2023
42ea9ac
Restore minor snippet with full checkpointing
gkroiz Oct 10, 2023
be1b509
ignores
carmocca Oct 10, 2023
f9ba3f0
Fixes
carmocca Oct 10, 2023
196960c
mypy
carmocca Oct 10, 2023
fdf21d3
_is_sharded
carmocca Oct 10, 2023
483b93d
Bring over changes from #18774
carmocca Oct 10, 2023
f4c1ba7
Optimizer fix
carmocca Oct 10, 2023
22629f4
Restore checkpoint after setup
carmocca Oct 10, 2023
10253f7
load_checkpoint
carmocca Oct 10, 2023
e391d38
Warning
carmocca Oct 11, 2023
f92de54
CHANGELOG
carmocca Oct 11, 2023
e6fd2ce
Additional changes for Trainer XLAFSDP strategy ckpting
gkroiz Oct 11, 2023
1305faa
add assertions for mypy
gkroiz Oct 11, 2023
bd8fb12
test manual wrap separately from ckpting test
gkroiz Oct 11, 2023
58d43e2
syntax changes in strings from FSDP to XLAFSDP
gkroiz Oct 11, 2023
c264801
Minor fix to tests
gkroiz Oct 11, 2023
3263d98
Merge branch 'master' into pytorch_xla_fsdp
gkroiz Oct 11, 2023
82f9f2f
Merge branch 'master' into pytorch_xla_fsdp
carmocca Oct 11, 2023
80c3faf
Add manual wrapping guard for fabric xlafsdp
carmocca Oct 11, 2023
8f775f4
mypy
carmocca Oct 11, 2023
e5b695a
Merge branch 'master' into pytorch_xla_fsdp
carmocca Oct 11, 2023
7bf9904
Apply suggestions from code review
gkroiz Oct 11, 2023
a00f8fe
Merge branch 'master' into pytorch_xla_fsdp
gkroiz Oct 11, 2023
4280cf4
Apply formatting suggestions
gkroiz Oct 11, 2023
ddda55f
[XLAFSDP] add test for automatic strategy selection
gkroiz Oct 11, 2023
904b50d
fix `test_tpu_invalid_raises` test
gkroiz Oct 11, 2023
0c79282
only run test_xla_fsdp_automatic_strategy_selection when on TPU
gkroiz Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,8 @@ def _save_checkpoint_shard(
# convert the state
if isinstance(obj, Module) and isinstance(obj, XLAFSDP):
converted = obj.state_dict()
# add shard_metadata to state
# add shard_metadata to state. this format is defined by
# https://github.com/pytorch/xla/blob/v2.1.0/torch_xla/distributed/fsdp/state_dict_utils.py#L122-L125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what happens if you save a state where the key for the model is not "model"

fabric.save(path, {"banana": model})

This would be totally valid in any other setting, but I think here it would fail since the XLA format expects these keys.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc this won't work when trying to consolidate the checkpoint using consolidate_sharded_model_checkpoints https://github.com/Lightning-AI/lightning/pull/18746/files/e5b695aff64bc37ae1a67fba4aac4981200eecfd#diff-3908a573abf00ae5f37061f214f2a3c2616b6591e0c96206b9f48b4c7ab49ea4R457. I'm not sure how this works for individual shards.

converted_state["shard_metadata"] = obj.get_shard_metadata()
elif isinstance(obj, Optimizer):
converted = obj.state_dict()
Expand Down Expand Up @@ -566,11 +567,7 @@ def load_checkpoint(
if len(loaded_metadata_keys):
for key in loaded_metadata_keys:
metadata[key] = sharded_ckpt[key]

# remove "shard_metadata" that is loaded in
if "shard_metadata" in metadata:
metadata.pop("shard_metadata")

metadata.pop("shard_metadata", None)
return metadata

if self._state_dict_type == "full":
Expand All @@ -591,6 +588,11 @@ def load_checkpoint(
)
if "model" not in state or not isinstance(model := state["model"], torch.nn.Module):
raise NotImplementedError("XLAFSDP only supports a single model instance with 'model' as the key.")
if any(isinstance(mod, XLAFSDP) for mod in model.modules()):
raise ValueError(
"`XLAFSDPStrategy` does not support loading full model checkpoint"
" if the model or any submodules are manually wrapped."
)
full_ckpt = torch.load(path)
model.load_state_dict(full_ckpt.pop("model"), strict=strict)
return full_ckpt
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for the `max_size_cycle|max_size|min_size` iteration modes during evaluation ([#17163](https://github.com/Lightning-AI/lightning/pull/17163))
- Added support for the TPU-v4 architecture ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))
- Added support for XLA's new PJRT runtime ([#17352](https://github.com/Lightning-AI/lightning/pull/17352))
- Added support for Fully Sharded Data Parallel (FSDP) training with XLA ([#18746](https://github.com/Lightning-AI/lightning/pull/18746))
- Check for invalid TPU device inputs ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))
- Added `XLAStrategy(sync_module_states=bool)` to control whether to broadcast the parameters to all devices ([#17522](https://github.com/Lightning-AI/lightning/pull/17522))
- Added support for multiple optimizer parameter groups when using the FSDP strategy ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
Expand Down
24 changes: 21 additions & 3 deletions src/lightning/pytorch/plugins/precision/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we extend the tests in plugins/precision/test_xla.py in a meaningful way?

import os
from functools import partial
from typing import Any, Callable
from typing import Any, Callable, Union

import torch
from torch.optim import Optimizer
from typing_extensions import get_args

import lightning.pytorch as pl
Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(self, precision: _PRECISION_INPUT = "32-true") -> None:
else:
self._desired_dtype = torch.float32

# boolean flag for simplicity over an entirely new class
self._using_fsdp = False

def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
Expand All @@ -68,7 +72,8 @@ def optimizer_step( # type: ignore[override]
) -> Any:
import torch_xla.core.xla_model as xm

closure = partial(self._xla_wrap_closure, optimizer, closure)
if not self._using_fsdp:
closure = partial(self._reduce_gradients, optimizer, closure)
closure = partial(self._wrap_closure, model, optimizer, closure)
closure_result = optimizer.step(closure=closure, **kwargs)
xm.mark_step()
Expand All @@ -87,9 +92,22 @@ def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)
os.environ.pop("XLA_USE_F16", None)

def _xla_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
def _reduce_gradients(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
import torch_xla.core.xla_model as xm

closure_result = closure()
xm.reduce_gradients(optimizer)
return closure_result

def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
if self._using_fsdp:
# Not supported by us because we need a module reference, this would need to go through the Strategy
# as in Fabric
raise NotImplementedError("XLA's FSDP strategy does not support to clip gradients by norm.")
return super().clip_grad_by_value(optimizer, clip_val)

def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
if self._using_fsdp:
# Not supported by XLA
raise NotImplementedError("XLA's FSDP strategy does not support to clip gradients by value.")
return super().clip_grad_by_value(optimizer, clip_val)
1 change: 1 addition & 0 deletions src/lightning/pytorch/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy # noqa: F401
from lightning.pytorch.strategies.strategy import Strategy
from lightning.pytorch.strategies.xla import XLAStrategy # noqa: F401
from lightning.pytorch.strategies.xla_fsdp import XLAFSDPStrategy # noqa: F401

StrategyRegistry = _StrategyRegistry()
_register_classes(StrategyRegistry, "register_strategies", sys.modules[__name__], Strategy)
Expand Down
8 changes: 5 additions & 3 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH, ReduceOp
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.core.optimizer import LightningOptimizer
from lightning.pytorch.plugins.precision import PrecisionPlugin
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecisionPlugin
Expand All @@ -66,7 +67,7 @@
from lightning.pytorch.strategies.strategy import TBroadcast
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from lightning.pytorch.utilities.rank_zero import rank_zero_only, rank_zero_warn

if TYPE_CHECKING:
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
Expand Down Expand Up @@ -311,9 +312,10 @@ def setup(self, trainer: "pl.Trainer") -> None:

if is_overridden("configure_sharded_model", self.lightning_module):
# legacy: we don't skip setup with the `configure_model` alternative
rank_zero_info(
rank_zero_warn(
"You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers"
" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`."
" are already wrapped for sharding and won't wrap the entire model using `FullyShardedDataParallel`.",
category=PossibleUserWarning,
)
else:
self.model = self._setup_model(self.model)
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class _XLALauncher(_MultiProcessingLauncher):

"""

def __init__(self, strategy: "pl.strategies.XLAStrategy") -> None:
def __init__(self, strategy: Union["pl.strategies.XLAStrategy", "pl.strategies.XLAFSDPStrategy"]) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(strategy=strategy, start_method="fork")
Expand Down
Loading