Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Rename classes to use XLA instead of TPU #17383

Merged
merged 23 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source-fabric/api/accelerators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ Accelerators
CPUAccelerator
CUDAAccelerator
MPSAccelerator
TPUAccelerator
XLAAccelerator
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions docs/source-fabric/api/precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ Precision
Precision
DoublePrecision
MixedPrecision
TPUPrecision
TPUBf16Precision
XLAPrecision
XLABf16Precision
FSDPPrecision
2 changes: 1 addition & 1 deletion docs/source-fabric/api/strategies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ Strategies
FSDPStrategy
ParallelStrategy
SingleDeviceStrategy
SingleTPUStrategy
SingleDeviceXLAStrategy
8 changes: 4 additions & 4 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ accelerators
Accelerator
CPUAccelerator
CUDAAccelerator
TPUAccelerator
XLAAccelerator

callbacks
---------
Expand Down Expand Up @@ -117,8 +117,8 @@ precision
FSDPMixedPrecisionPlugin
MixedPrecisionPlugin
PrecisionPlugin
TPUBf16PrecisionPlugin
TPUPrecisionPlugin
XLABf16PrecisionPlugin
XLAPrecisionPlugin

environments
""""""""""""
Expand Down Expand Up @@ -212,7 +212,7 @@ strategies
FSDPStrategy
ParallelStrategy
SingleDeviceStrategy
SingleTPUStrategy
SingleDeviceXLAStrategy
Strategy
XLAStrategy

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/extensions/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,4 @@ Accelerator API
CPUAccelerator
CUDAAccelerator
MPSAccelerator
TPUAccelerator
XLAAccelerator
4 changes: 2 additions & 2 deletions docs/source-pytorch/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ The full list of built-in precision plugins is listed below.
FSDPMixedPrecisionPlugin
MixedPrecisionPlugin
PrecisionPlugin
TPUBf16PrecisionPlugin
TPUPrecisionPlugin
XLABf16PrecisionPlugin
XLAPrecisionPlugin

More information regarding precision with Lightning can be found :ref:`here <precision>`

Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/extensions/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ The below table lists all relevant strategies available in Lightning with their
* - xla
- :class:`~lightning.pytorch.strategies.XLAStrategy`
- Strategy for training on multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method. :doc:`Learn more. <../accelerators/tpu>`
* - single_tpu
- :class:`~lightning.pytorch.strategies.SingleTPUStrategy`
- Strategy for training on a single TPU device. :doc:`Learn more. <../accelerators/tpu>`
* - single_xla
- :class:`~lightning.pytorch.strategies.SingleXLAStrategy`
- Strategy for training on a single XLA device, like TPUs. :doc:`Learn more. <../accelerators/tpu>`

----

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/upgrade/sections/1_9_devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
- `PR16703`_

* - used any function from ``pl.utilities.xla_device``
- switch to ``pl.accelerators.TPUAccelerator.is_available()``
- switch to ``pl.accelerators.XLAAccelerator.is_available()``
- `PR14514`_ `PR14550`_

* - imported functions from ``pl.utilities.device_parser.*``
Expand Down
14 changes: 14 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated the `DDPStrategy.is_distributed` property. This strategy is distributed by definition ([#17381](https://github.com/Lightning-AI/lightning/pull/17381))


- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))


- Deprecated the `TPUAccelerator` in favor of `XLAAccelerator` ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))


- Deprecated the `TPUPrecision` in favor of `XLAPrecision` ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))


- Deprecated the `TPUBf16Precision` in favor of `XLABf16Precision` ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))


- Deprecated the `Fabric.sharded_model()` context manager in favor of `Fabric.init_module()` ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))

Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
from lightning.fabric.wrappers import is_wrapped # noqa: E402

# this import needs to go last as it will patch other modules
import lightning.fabric._graveyard # noqa: E402, F401 # isort: skip

__all__ = ["Fabric", "seed_everything", "is_wrapped"]

# for compatibility with namespace packages
Expand Down
14 changes: 14 additions & 0 deletions src/lightning/fabric/_graveyard/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import lightning.fabric._graveyard.tpu # noqa: F401
101 changes: 101 additions & 0 deletions src/lightning/fabric/_graveyard/tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Any

import lightning.fabric as fabric
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.plugins.precision import XLABf16Precision, XLAPrecision
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.single_xla import SingleDeviceXLAStrategy
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation


def _patch_sys_modules() -> None:
self = sys.modules[__name__]
sys.modules["lightning.fabric.strategies.single_tpu"] = self
sys.modules["lightning.fabric.accelerators.tpu"] = self
sys.modules["lightning.fabric.plugins.precision.tpu"] = self
sys.modules["lightning.fabric.plugins.precision.tpu_bf16"] = self


class SingleTPUStrategy(SingleDeviceXLAStrategy):
"""Legacy class.

Use :class:`~lightning.fabric.strategies.single_xla.SingleDeviceXLAStrategy` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation("The 'single_tpu' strategy is deprecated. Use 'single_xla' instead.")
super().__init__(*args, **kwargs)

@classmethod
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
if "single_tpu" not in strategy_registry:
strategy_registry.register("single_tpu", cls, description="Legacy class. Use `single_xla` instead.")


class TPUAccelerator(XLAAccelerator):
"""Legacy class.

Use :class:`~lightning.fabric.accelerators.xla.XLAAccelerator` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `TPUAccelerator` class is deprecated. Use `lightning.fabric.accelerators.XLAAccelerator` instead."
)
super().__init__(*args, **kwargs)


class TPUPrecision(XLAPrecision):
"""Legacy class.

Use :class:`~lightning.fabric.plugins.precision.xla.XLAPrecision` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision`" " instead."
)
super().__init__(*args, **kwargs)


class TPUBf16Precision(XLABf16Precision):
"""Legacy class.

Use :class:`~lightning.fabric.plugins.precision.xlabf16.XLABf16Precision` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `TPUBf16Precision` class is deprecated. Use"
" `lightning.fabric.plugins.precision.XLABf16Precision` instead."
)
super().__init__(*args, **kwargs)


def _patch_classes() -> None:
setattr(fabric.strategies, "SingleTPUStrategy", SingleTPUStrategy)
setattr(fabric.accelerators, "TPUAccelerator", TPUAccelerator)
setattr(fabric.plugins, "TPUPrecision", TPUPrecision)
setattr(fabric.plugins.precision, "TPUPrecision", TPUPrecision)
setattr(fabric.plugins, "TPUBf16Precision", TPUBf16Precision)
setattr(fabric.plugins.precision, "TPUBf16Precision", TPUBf16Precision)


_patch_sys_modules()
_patch_classes()

SingleTPUStrategy.register_strategies(fabric.strategies.STRATEGY_REGISTRY)
2 changes: 1 addition & 1 deletion src/lightning/fabric/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lightning.fabric.accelerators.cuda import CUDAAccelerator, find_usable_cuda_devices # noqa: F401
from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401
from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
from lightning.fabric.accelerators.tpu import TPUAccelerator # noqa: F401
from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401

_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators"
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
# limitations under the License.
import functools
from multiprocessing import Process, Queue
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, List, Union

import torch
from lightning_utilities.core.imports import RequirementCache

from lightning.fabric.accelerators import _AcceleratorRegistry
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.utilities.device_parser import _check_data_type


class TPUAccelerator(Accelerator):
"""Accelerator for TPU devices.
class XLAAccelerator(Accelerator):
"""Accelerator for XLA devices, normally TPUs.

.. warning:: Use of this accelerator beyond import and instantiation is experimental.
"""
Expand Down Expand Up @@ -99,12 +100,8 @@ def is_available() -> bool:
return queue.get_nowait()

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
accelerator_registry.register(
"tpu",
cls,
description=cls.__class__.__name__,
)
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register("tpu", cls, description=cls.__class__.__name__)


# define TPU availability timeout in seconds
Expand Down Expand Up @@ -160,7 +157,7 @@ def _parse_tpu_devices(devices: Union[int, str, List[int]]) -> Union[int, List[i


def _check_tpu_devices_valid(devices: object) -> None:
device_count = TPUAccelerator.auto_device_count()
device_count = XLAAccelerator.auto_device_count()
if (
# support number of devices
isinstance(devices, int)
Expand Down
34 changes: 17 additions & 17 deletions src/lightning/fabric/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.cuda import CUDAAccelerator
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.accelerators.tpu import TPUAccelerator
from lightning.fabric.accelerators.xla import XLAAccelerator
from lightning.fabric.plugins import (
CheckpointIO,
DeepSpeedPrecision,
HalfPrecision,
MixedPrecision,
Precision,
TPUBf16Precision,
TPUPrecision,
XLABf16Precision,
XLAPrecision,
)
from lightning.fabric.plugins.environments import (
ClusterEnvironment,
Expand All @@ -54,7 +54,7 @@
DeepSpeedStrategy,
ParallelStrategy,
SingleDeviceStrategy,
SingleTPUStrategy,
SingleDeviceXLAStrategy,
Strategy,
STRATEGY_REGISTRY,
XLAStrategy,
Expand Down Expand Up @@ -311,7 +311,7 @@ def _check_device_config_and_set_final_flags(self, devices: Union[List[int], str

def _choose_auto_accelerator(self) -> str:
"""Choose the accelerator type (str) based on availability when ``accelerator='auto'``."""
if TPUAccelerator.is_available():
if XLAAccelerator.is_available():
return "tpu"
if MPSAccelerator.is_available():
return "mps"
Expand Down Expand Up @@ -373,12 +373,12 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
return LightningEnvironment()

def _choose_strategy(self) -> Union[Strategy, str]:
if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, TPUAccelerator):
if self._accelerator_flag == "tpu" or isinstance(self._accelerator_flag, XLAAccelerator):
if self._parallel_devices and len(self._parallel_devices) > 1:
return "xla"
else:
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_tpu_device"
return SingleTPUStrategy(device=self._parallel_devices[0])
# TODO: lazy initialized device, then here could be self._strategy_flag = "single_xla"
return SingleDeviceXLAStrategy(device=self._parallel_devices[0])
if self._num_nodes_flag > 1:
return "ddp"
if len(self._parallel_devices) <= 1:
Expand Down Expand Up @@ -434,16 +434,16 @@ def _check_and_init_precision(self) -> Precision:
if isinstance(self._precision_instance, Precision):
return self._precision_instance

if isinstance(self.accelerator, TPUAccelerator):
if isinstance(self.accelerator, XLAAccelerator):
if self._precision_input == "32-true":
return TPUPrecision()
return XLAPrecision()
elif self._precision_input in ("16-mixed", "bf16-mixed"):
if self._precision_input == "16-mixed":
rank_zero_warn(
"You passed `Fabric(accelerator='tpu', precision='16-mixed')` but AMP with fp16"
" is not supported with TPUs. Using `precision='bf16-mixed'` instead."
)
return TPUBf16Precision()
return XLABf16Precision()
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecision(self._precision_input) # type: ignore

Expand Down Expand Up @@ -477,16 +477,16 @@ def _check_and_init_precision(self) -> Precision:

def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, and accelerator."""
if isinstance(self.accelerator, TPUAccelerator):
if isinstance(self.accelerator, XLAAccelerator):
if self._precision_input == "64-true":
raise NotImplementedError(
"`Fabric(accelerator='tpu', precision='64-true')` is not implemented."
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
" requesting this feature."
)
if self._precision_instance and not isinstance(self._precision_instance, (TPUPrecision, TPUBf16Precision)):
if self._precision_instance and not isinstance(self._precision_instance, (XLAPrecision, XLABf16Precision)):
raise ValueError(
f"The `TPUAccelerator` can only be used with a `TPUPrecision` plugin,"
f"The `XLAAccelerator` can only be used with a `XLAPrecision` plugin,"
f" found: {self._precision_instance}."
)

Expand Down Expand Up @@ -523,11 +523,11 @@ def _lazy_init_strategy(self) -> None:

# TODO: should be moved to _check_strategy_and_fallback().
# Current test check precision first, so keep this check here to meet error order
if isinstance(self.accelerator, TPUAccelerator) and not isinstance(
self.strategy, (SingleTPUStrategy, XLAStrategy)
if isinstance(self.accelerator, XLAAccelerator) and not isinstance(
self.strategy, (SingleDeviceXLAStrategy, XLAStrategy)
):
raise ValueError(
"The `TPUAccelerator` can only be used with a `SingleTPUStrategy` or `XLAStrategy`,"
"The `XLAAccelerator` can only be used with a `SingleDeviceXLAStrategy` or `XLAStrategy`,"
f" found {self.strategy.__class__.__name__}."
)

Expand Down
Loading