Skip to content
Open
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
ba66f96
Basic skeletal code for XLAFSDP support for PyTorch Trainer
gkroiz Oct 8, 2023
26238d8
fix import issue
gkroiz Oct 8, 2023
53703e2
Fixed leaking env vars in tests + minor cleanup
gkroiz Oct 8, 2023
bbed985
Merge branch 'master' into pytorch_xla_fsdp
gkroiz Oct 8, 2023
0132427
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2023
d4ae668
accidentally deleted line in merge
gkroiz Oct 8, 2023
b538f43
fix mypy error
gkroiz Oct 8, 2023
9778505
Minor fixes
carmocca Oct 9, 2023
4098016
Reuse code
carmocca Oct 9, 2023
5ca74df
More minor fixes
carmocca Oct 9, 2023
4bcb2e9
Reorder methods to reduce diff with Fabric
carmocca Oct 9, 2023
957866b
Go through precision plugin
carmocca Oct 9, 2023
222896c
Consistent with regular FSDP
carmocca Oct 9, 2023
cd70e0d
minor fixes, limited checkpointing support, testing for trainer.test …
gkroiz Oct 9, 2023
d7d78f1
mypy fix
gkroiz Oct 9, 2023
2f51e4e
Merge branch 'master' into pytorch_xla_fsdp
gkroiz Oct 10, 2023
7a680e3
consistent setup optimizer
carmocca Oct 10, 2023
46927f4
Merge branch 'master' into pytorch_xla_fsdp
carmocca Oct 10, 2023
c4e295a
Import pbar force fn
carmocca Oct 10, 2023
0cd0cb6
Launched and overrides
carmocca Oct 10, 2023
e0b353f
Multiple optimizers should be fine?
carmocca Oct 10, 2023
60eecd1
Test fixes
carmocca Oct 10, 2023
1150e90
Minor FSDP fixes
carmocca Oct 10, 2023
1ae4957
Fix
carmocca Oct 10, 2023
ef05276
Other xla strategies too
carmocca Oct 10, 2023
1cc6d84
More fixes
carmocca Oct 10, 2023
64ed7fb
Fixes
carmocca Oct 10, 2023
42ea9ac
Restore minor snippet with full checkpointing
gkroiz Oct 10, 2023
be1b509
ignores
carmocca Oct 10, 2023
f9ba3f0
Fixes
carmocca Oct 10, 2023
196960c
mypy
carmocca Oct 10, 2023
fdf21d3
_is_sharded
carmocca Oct 10, 2023
483b93d
Bring over changes from #18774
carmocca Oct 10, 2023
f4c1ba7
Optimizer fix
carmocca Oct 10, 2023
22629f4
Restore checkpoint after setup
carmocca Oct 10, 2023
10253f7
load_checkpoint
carmocca Oct 10, 2023
e391d38
Warning
carmocca Oct 11, 2023
f92de54
CHANGELOG
carmocca Oct 11, 2023
e6fd2ce
Additional changes for Trainer XLAFSDP strategy ckpting
gkroiz Oct 11, 2023
1305faa
add assertions for mypy
gkroiz Oct 11, 2023
bd8fb12
test manual wrap separately from ckpting test
gkroiz Oct 11, 2023
58d43e2
syntax changes in strings from FSDP to XLAFSDP
gkroiz Oct 11, 2023
c264801
Minor fix to tests
gkroiz Oct 11, 2023
3263d98
Merge branch 'master' into pytorch_xla_fsdp
gkroiz Oct 11, 2023
82f9f2f
Merge branch 'master' into pytorch_xla_fsdp
carmocca Oct 11, 2023
80c3faf
Add manual wrapping guard for fabric xlafsdp
carmocca Oct 11, 2023
8f775f4
mypy
carmocca Oct 11, 2023
e5b695a
Merge branch 'master' into pytorch_xla_fsdp
carmocca Oct 11, 2023
7bf9904
Apply suggestions from code review
gkroiz Oct 11, 2023
a00f8fe
Merge branch 'master' into pytorch_xla_fsdp
gkroiz Oct 11, 2023
4280cf4
Apply formatting suggestions
gkroiz Oct 11, 2023
ddda55f
[XLAFSDP] add test for automatic strategy selection
gkroiz Oct 11, 2023
904b50d
fix `test_tpu_invalid_raises` test
gkroiz Oct 11, 2023
0c79282
only run test_xla_fsdp_automatic_strategy_selection when on TPU
gkroiz Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,10 +454,9 @@ def _init_strategy(self) -> None:
self.strategy = self._strategy_flag

def _check_and_init_precision(self) -> Precision:
self._validate_precision_choice()
if isinstance(self._precision_instance, Precision):
return self._precision_instance
if isinstance(self.accelerator, XLAAccelerator):
if isinstance(self.strategy, (SingleDeviceXLAStrategy, XLAStrategy, XLAFSDPStrategy)):
return XLAPrecision(self._precision_input) # type: ignore
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore
Expand Down Expand Up @@ -492,18 +491,6 @@ def _check_and_init_precision(self) -> Precision:

raise RuntimeError("No precision set")

def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, and accelerator."""
if (
isinstance(self.accelerator, XLAAccelerator)
and self._precision_instance
and not isinstance(self._precision_instance, XLAPrecision)
):
raise ValueError(
f"The `XLAAccelerator` can only be used with a `XLAPrecision` plugin,"
f" found: {self._precision_instance}."
)

def _lazy_init_strategy(self) -> None:
"""Lazily set missing attributes on the previously instantiated strategy."""
self.strategy.accelerator = self.accelerator
Expand Down
24 changes: 7 additions & 17 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,25 +305,15 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
flattened parameters.

"""
if _TORCH_GREATER_EQUAL_2_0:
return optimizer

from torch.distributed.fsdp import FlatParameter

num_groups = len(optimizer.param_groups)
if num_groups > 1:
if self._fsdp_kwargs.get("use_orig_params"):
return super().setup_optimizer(optimizer)
if not _optimizer_has_flat_params(optimizer):
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
raise ValueError(
"An optimizer used with an FSDP model does not support multiple param groups."
f" Found {num_groups} parameter groups."
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer"
" after setting up the model."
)

if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]):
return optimizer

raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer"
" after setting up the model."
)
return optimizer

def module_to_device(self, module: Module) -> None:
pass
Expand Down
37 changes: 27 additions & 10 deletions src/lightning/fabric/strategies/single_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.single_device import SingleDeviceStrategy
from lightning.fabric.utilities.types import _DEVICE
Expand All @@ -32,8 +31,8 @@ def __init__(
self,
device: _DEVICE,
accelerator: Optional[Accelerator] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
checkpoint_io: Optional[XLACheckpointIO] = None,
precision: Optional[XLAPrecision] = None,
):
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
Expand All @@ -50,16 +49,34 @@ def __init__(
precision=precision,
)

@property
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = XLACheckpointIO()
return self._checkpoint_io
@property # type: ignore[override]
def checkpoint_io(self) -> XLACheckpointIO:
plugin = self._checkpoint_io
if plugin is not None:
assert isinstance(plugin, XLACheckpointIO)
return plugin
return XLACheckpointIO()

@checkpoint_io.setter
def checkpoint_io(self, io: CheckpointIO) -> None:
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
if io is not None and not isinstance(io, XLACheckpointIO):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io

@property # type: ignore[override]
def precision(self) -> XLAPrecision:
plugin = self._precision
if plugin is not None:
assert isinstance(plugin, XLAPrecision)
return plugin
return XLAPrecision("32-true")

@precision.setter
def precision(self, precision: Optional[XLAPrecision]) -> None:
if precision is not None and not isinstance(precision, XLAPrecision):
raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision}")
self._precision = precision

@classmethod
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
strategy_registry.register("single_xla", cls, description=cls.__name__)
38 changes: 27 additions & 11 deletions src/lightning/fabric/strategies/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1, _using_pjrt
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
from lightning.fabric.strategies.launchers.xla import _XLALauncher
from lightning.fabric.strategies.strategy import TBroadcast
Expand All @@ -44,8 +43,8 @@ def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision: Optional[Precision] = None,
checkpoint_io: Optional[XLACheckpointIO] = None,
precision: Optional[XLAPrecision] = None,
sync_module_states: bool = True,
) -> None:
super().__init__(
Expand All @@ -55,7 +54,6 @@ def __init__(
checkpoint_io=checkpoint_io,
precision=precision,
)
self._checkpoint_io: Optional[CheckpointIO]
self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
self._launched = False
self._sync_module_states = sync_module_states
Expand All @@ -72,16 +70,34 @@ def root_device(self) -> torch.device:
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = XLACheckpointIO()
return self._checkpoint_io
@property # type: ignore[override]
def checkpoint_io(self) -> XLACheckpointIO:
plugin = self._checkpoint_io
if plugin is not None:
assert isinstance(plugin, XLACheckpointIO)
return plugin
return XLACheckpointIO()

@checkpoint_io.setter
def checkpoint_io(self, io: CheckpointIO) -> None:
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
if io is not None and not isinstance(io, XLACheckpointIO):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io

@property # type: ignore[override]
def precision(self) -> XLAPrecision:
plugin = self._precision
if plugin is not None:
assert isinstance(plugin, XLAPrecision)
return plugin
return XLAPrecision("32-true")

@precision.setter
def precision(self, precision: Optional[XLAPrecision]) -> None:
if precision is not None and not isinstance(precision, XLAPrecision):
raise TypeError(f"The XLA strategy can only work with the `XLAPrecision` plugin, found {precision}")
self._precision = precision

@property
def global_rank(self) -> int:
return super().global_rank if self._launched else 0
Expand Down
62 changes: 32 additions & 30 deletions src/lightning/fabric/strategies/xla_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
from torch.utils.data import DataLoader

from lightning.fabric.accelerators import Accelerator
from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, _using_pjrt
from lightning.fabric.plugins import XLAPrecision
from lightning.fabric.plugins.environments import XLAEnvironment
from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO
from lightning.fabric.plugins.io.xla import XLACheckpointIO
from lightning.fabric.strategies import ParallelStrategy, _StrategyRegistry
from lightning.fabric.strategies.fsdp import _apply_filter
Expand Down Expand Up @@ -85,22 +84,23 @@ def __init__(
self,
accelerator: Optional[Accelerator] = None,
parallel_devices: Optional[List[torch.device]] = None,
checkpoint_io: Optional[CheckpointIO] = None,
checkpoint_io: Optional[XLACheckpointIO] = None,
precision: Optional[XLAPrecision] = None,
auto_wrap_policy: Optional[_POLICY] = None,
activation_checkpointing_policy: Optional[_POLICY_SET] = None,
state_dict_type: Literal["full", "sharded"] = "sharded",
sequential_save: bool = False,
**kwargs: Any,
) -> None:
if not _XLA_AVAILABLE:
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=XLAEnvironment(),
checkpoint_io=checkpoint_io,
precision=precision,
)
self._checkpoint_io: Optional[CheckpointIO]
self._backward_sync_control = _XLAFSDPBackwardSyncControl()

self._auto_wrap_policy = auto_wrap_policy
Expand All @@ -122,16 +122,34 @@ def root_device(self) -> torch.device:
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
def checkpoint_io(self) -> CheckpointIO:
if self._checkpoint_io is None:
self._checkpoint_io = XLACheckpointIO()
return self._checkpoint_io
@property # type: ignore[override]
def checkpoint_io(self) -> XLACheckpointIO:
plugin = self._checkpoint_io
if plugin is not None:
assert isinstance(plugin, XLACheckpointIO)
return plugin
return XLACheckpointIO()

@checkpoint_io.setter
def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:
def checkpoint_io(self, io: Optional[XLACheckpointIO]) -> None:
if io is not None and not isinstance(io, XLACheckpointIO):
raise TypeError(f"The XLA strategy can only work with the `XLACheckpointIO` plugin, found {io}")
self._checkpoint_io = io

@property # type: ignore[override]
def precision(self) -> XLAPrecision:
plugin = self._precision
if plugin is not None:
assert isinstance(plugin, XLAPrecision)
return plugin
return XLAPrecision("32-true")

@precision.setter
def precision(self, precision: Optional[XLAPrecision]) -> None:
if precision is not None and not isinstance(precision, XLAPrecision):
raise TypeError(f"The XLA FSDP strategy can only work with the `XLAPrecision` plugin, found {precision}")
self._precision = precision

@property
def global_rank(self) -> int:
return super().global_rank if self._launched else 0
Expand Down Expand Up @@ -227,21 +245,8 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
flattened parameters.

"""
if _TORCH_GREATER_EQUAL_2_0:
return optimizer

from torch_xla.distributed.fsdp.xla_flatten_params_wrapper import FlatParameter

num_groups = len(optimizer.param_groups)
if num_groups > 1:
raise ValueError(
"An optimizer used with an XLAFSDP model does not support multiple param groups."
f" Found {num_groups} parameter groups."
)

if any(isinstance(param, FlatParameter) for param in optimizer.param_groups[0]["params"]):
if any(getattr(p, "_is_sharded", False) for group in optimizer.param_groups for p in group["params"]):
return optimizer

raise ValueError(
"The optimizer does not seem to reference any XLAFSDP parameters. HINT: Make sure to create the optimizer"
" after setting up the model."
Expand Down Expand Up @@ -470,7 +475,8 @@ def _save_checkpoint_shard(
# convert the state
if isinstance(obj, Module) and isinstance(obj, XLAFSDP):
converted = obj.state_dict()
# add shard_metadata to state
# add shard_metadata to state. this format is defined by
# https://github.com/pytorch/xla/blob/v2.1.0/torch_xla/distributed/fsdp/state_dict_utils.py#L122-L125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what happens if you save a state where the key for the model is not "model"

fabric.save(path, {"banana": model})

This would be totally valid in any other setting, but I think here it would fail since the XLA format expects these keys.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc this won't work when trying to consolidate the checkpoint using consolidate_sharded_model_checkpoints https://github.com/Lightning-AI/lightning/pull/18746/files/e5b695aff64bc37ae1a67fba4aac4981200eecfd#diff-3908a573abf00ae5f37061f214f2a3c2616b6591e0c96206b9f48b4c7ab49ea4R457. I'm not sure how this works for individual shards.

converted_state["shard_metadata"] = obj.get_shard_metadata()
elif isinstance(obj, Optimizer):
converted = obj.state_dict()
Expand Down Expand Up @@ -561,11 +567,7 @@ def load_checkpoint(
if len(loaded_metadata_keys):
for key in loaded_metadata_keys:
metadata[key] = sharded_ckpt[key]

# remove "shard_metadata" that is loaded in
if "shard_metadata" in metadata:
metadata.pop("shard_metadata")

metadata.pop("shard_metadata", None)
return metadata

if self._state_dict_type == "full":
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for the `max_size_cycle|max_size|min_size` iteration modes during evaluation ([#17163](https://github.com/Lightning-AI/lightning/pull/17163))
- Added support for the TPU-v4 architecture ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))
- Added support for XLA's new PJRT runtime ([#17352](https://github.com/Lightning-AI/lightning/pull/17352))
- Added support for Fully Sharded Data Parallel (FSDP) training with XLA ([#18746](https://github.com/Lightning-AI/lightning/pull/18746))
- Check for invalid TPU device inputs ([#17227](https://github.com/Lightning-AI/lightning/pull/17227))
- Added `XLAStrategy(sync_module_states=bool)` to control whether to broadcast the parameters to all devices ([#17522](https://github.com/Lightning-AI/lightning/pull/17522))
- Added support for multiple optimizer parameter groups when using the FSDP strategy ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
Expand Down
24 changes: 21 additions & 3 deletions src/lightning/pytorch/plugins/precision/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we extend the tests in plugins/precision/test_xla.py in a meaningful way?

import os
from functools import partial
from typing import Any, Callable
from typing import Any, Callable, Union

import torch
from torch.optim import Optimizer
from typing_extensions import get_args

import lightning.pytorch as pl
Expand Down Expand Up @@ -59,6 +60,9 @@ def __init__(self, precision: _PRECISION_INPUT = "32-true") -> None:
else:
self._desired_dtype = torch.float32

# boolean flag for simplicity over an entirely new class
self._using_fsdp = False

def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
Expand All @@ -68,7 +72,8 @@ def optimizer_step( # type: ignore[override]
) -> Any:
import torch_xla.core.xla_model as xm

closure = partial(self._xla_wrap_closure, optimizer, closure)
if not self._using_fsdp:
closure = partial(self._reduce_gradients, optimizer, closure)
closure = partial(self._wrap_closure, model, optimizer, closure)
closure_result = optimizer.step(closure=closure, **kwargs)
xm.mark_step()
Expand All @@ -87,9 +92,22 @@ def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)
os.environ.pop("XLA_USE_F16", None)

def _xla_wrap_closure(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
def _reduce_gradients(self, optimizer: Optimizable, closure: Callable[[], Any]) -> Any:
import torch_xla.core.xla_model as xm

closure_result = closure()
xm.reduce_gradients(optimizer)
return closure_result

def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
if self._using_fsdp:
# Not supported by us because we need a module reference, this would need to go through the Strategy
# as in Fabric
raise NotImplementedError("XLA's FSDP strategy does not support to clip gradients by norm.")
return super().clip_grad_by_value(optimizer, clip_val)

def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
if self._using_fsdp:
# Not supported by XLA
raise NotImplementedError("XLA's FSDP strategy does not support to clip gradients by value.")
return super().clip_grad_by_value(optimizer, clip_val)
1 change: 1 addition & 0 deletions src/lightning/pytorch/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy # noqa: F401
from lightning.pytorch.strategies.strategy import Strategy
from lightning.pytorch.strategies.xla import XLAStrategy # noqa: F401
from lightning.pytorch.strategies.xla_fsdp import XLAFSDPStrategy # noqa: F401

StrategyRegistry = _StrategyRegistry()
_register_classes(StrategyRegistry, "register_strategies", sys.modules[__name__], Strategy)
Expand Down
Loading