Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Rename classes to use XLA instead of TPU #17383

Merged
merged 23 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source-fabric/api/accelerators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ Accelerators
CPUAccelerator
CUDAAccelerator
MPSAccelerator
TPUAccelerator
XLAAccelerator
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions docs/source-fabric/api/precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ Precision
Precision
DoublePrecision
MixedPrecision
TPUPrecision
TPUBf16Precision
XLAPrecision
XLABf16Precision
FSDPPrecision
2 changes: 1 addition & 1 deletion docs/source-fabric/api/strategies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ Strategies
FSDPStrategy
ParallelStrategy
SingleDeviceStrategy
SingleTPUStrategy
SingleDeviceXLAStrategy
8 changes: 4 additions & 4 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ accelerators
CPUAccelerator
CUDAAccelerator
IPUAccelerator
TPUAccelerator
XLAAccelerator

callbacks
---------
Expand Down Expand Up @@ -119,8 +119,8 @@ precision
IPUPrecisionPlugin
MixedPrecisionPlugin
PrecisionPlugin
TPUBf16PrecisionPlugin
TPUPrecisionPlugin
XLABf16PrecisionPlugin
XLAPrecisionPlugin

environments
""""""""""""
Expand Down Expand Up @@ -215,7 +215,7 @@ strategies
IPUStrategy
ParallelStrategy
SingleDeviceStrategy
SingleTPUStrategy
SingleDeviceXLAStrategy
Strategy
XLAStrategy

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/extensions/accelerator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@ Accelerator API
CUDAAccelerator
IPUAccelerator
MPSAccelerator
TPUAccelerator
XLAAccelerator
4 changes: 2 additions & 2 deletions docs/source-pytorch/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ The full list of built-in precision plugins is listed below.
IPUPrecisionPlugin
MixedPrecisionPlugin
PrecisionPlugin
TPUBf16PrecisionPlugin
TPUPrecisionPlugin
XLABf16PrecisionPlugin
XLAPrecisionPlugin

More information regarding precision with Lightning can be found :ref:`here <precision>`

Expand Down
6 changes: 3 additions & 3 deletions docs/source-pytorch/extensions/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ The below table lists all relevant strategies available in Lightning with their
* - xla
- :class:`~lightning.pytorch.strategies.XLAStrategy`
- Strategy for training on multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn` method. :doc:`Learn more. <../accelerators/tpu>`
* - single_tpu
- :class:`~lightning.pytorch.strategies.SingleTPUStrategy`
- Strategy for training on a single TPU device. :doc:`Learn more. <../accelerators/tpu>`
* - single_xla
- :class:`~lightning.pytorch.strategies.SingleXLAStrategy`
- Strategy for training on a single XLA device, like TPUs. :doc:`Learn more. <../accelerators/tpu>`

----

Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/upgrade/sections/1_9_devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
- `PR16703`_

* - used any function from ``pl.utilities.xla_device``
- switch to ``pl.accelerators.TPUAccelerator.is_available()``
- switch to ``pl.accelerators.XLAAccelerator.is_available()``
- `PR14514`_ `PR14550`_

* - imported functions from ``pl.utilities.device_parser.*``
Expand Down
15 changes: 14 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

-
- Deprecated the `DDPStrategy.is_distributed` property. This strategy is distributed by definition ([#17381](https://github.com/Lightning-AI/lightning/pull/17381))


- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))


- Deprecated the `TPUAccelerator` in favor of `XLAAccelerator` ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))


- Deprecated the `TPUPrecision` in favor of `XLAPrecision` ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))


- Deprecated the `TPUBf16Precision` in favor of `XLABf16Precision` ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))



### Removed
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/fabric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
from lightning.fabric.wrappers import is_wrapped # noqa: E402

# this import needs to go last as it will patch other modules
import lightning.fabric._graveyard # noqa: E402, F401 # isort: skip

__all__ = ["Fabric", "seed_everything", "is_wrapped"]

# for compatibility with namespace packages
Expand Down
14 changes: 14 additions & 0 deletions src/lightning/fabric/_graveyard/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import lightning.fabric._graveyard.tpu # noqa: F401
101 changes: 101 additions & 0 deletions src/lightning/fabric/_graveyard/tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Any

import lightning.fabric as fabric
from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.plugins.precision import XLABf16Precision, XLAPrecision
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.single_xla import SingleDeviceXLAStrategy
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation


def _patch_sys_modules() -> None:
self = sys.modules[__name__]
sys.modules["lightning.fabric.strategies.single_tpu"] = self
sys.modules["lightning.fabric.accelerators.tpu"] = self
sys.modules["lightning.fabric.plugins.precision.tpu"] = self
sys.modules["lightning.fabric.plugins.precision.tpu_bf16"] = self


class SingleTPUStrategy(SingleDeviceXLAStrategy):
"""Legacy class.

Use :class:`~lightning.fabric.strategies.single_xla.SingleDeviceXLAStrategy` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation("The 'single_tpu' strategy is deprecated. Use 'single_xla' instead.")
super().__init__(*args, **kwargs)

@classmethod
def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
if "single_tpu" not in strategy_registry:
strategy_registry.register("single_tpu", cls, description="Legacy class. Use `single_xla` instead.")


class TPUAccelerator(XLAAccelerator):
"""Legacy class.

Use :class:`~lightning.fabric.accelerators.xla.XLAAccelerator` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `TPUAccelerator` class is deprecated. Use `lightning.fabric.accelerators.XLAAccelerator` instead."
)
super().__init__(*args, **kwargs)


class TPUPrecision(XLAPrecision):
"""Legacy class.

Use :class:`~lightning.fabric.plugins.precision.xla.XLAPrecision` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision`" " instead."
)
super().__init__(*args, **kwargs)


class TPUBf16Precision(XLABf16Precision):
"""Legacy class.

Use :class:`~lightning.fabric.plugins.precision.xlabf16.XLABf16Precision` instead.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
rank_zero_deprecation(
"The `TPUBf16Precision` class is deprecated. Use"
" `lightning.fabric.plugins.precision.XLABf16Precision` instead."
)
super().__init__(*args, **kwargs)


def _patch_classes() -> None:
setattr(fabric.strategies, "SingleTPUStrategy", SingleTPUStrategy)
setattr(fabric.accelerators, "TPUAccelerator", TPUAccelerator)
setattr(fabric.plugins, "TPUPrecision", TPUPrecision)
setattr(fabric.plugins.precision, "TPUPrecision", TPUPrecision)
setattr(fabric.plugins, "TPUBf16Precision", TPUBf16Precision)
setattr(fabric.plugins.precision, "TPUBf16Precision", TPUBf16Precision)


_patch_sys_modules()
_patch_classes()

SingleTPUStrategy.register_strategies(fabric.strategies.STRATEGY_REGISTRY)
2 changes: 1 addition & 1 deletion src/lightning/fabric/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from lightning.fabric.accelerators.cuda import CUDAAccelerator, find_usable_cuda_devices # noqa: F401
from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401
from lightning.fabric.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
from lightning.fabric.accelerators.tpu import TPUAccelerator # noqa: F401
from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401

_ACCELERATORS_BASE_MODULE = "lightning.fabric.accelerators"
ACCELERATOR_REGISTRY = _AcceleratorRegistry()
Expand Down
6 changes: 4 additions & 2 deletions src/lightning/fabric/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict
from typing import Any

import torch

from lightning.fabric.accelerators.registry import _AcceleratorRegistry


class Accelerator(ABC):
"""The Accelerator base class.
Expand Down Expand Up @@ -54,5 +56,5 @@ def is_available() -> bool:
"""Detect if the hardware is available."""

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
pass
5 changes: 3 additions & 2 deletions src/lightning/fabric/accelerators/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Union
from typing import List, Union

import torch

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry


class CPUAccelerator(Accelerator):
Expand Down Expand Up @@ -56,7 +57,7 @@ def is_available() -> bool:
return True

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cpu",
cls,
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/fabric/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import warnings
from contextlib import contextmanager
from functools import lru_cache
from typing import cast, Dict, Generator, List, Optional, Union
from typing import cast, Generator, List, Optional, Union

import torch
from lightning_utilities.core.rank_zero import rank_zero_info

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0


Expand Down Expand Up @@ -63,7 +64,7 @@ def is_available() -> bool:
return num_cuda_devices() > 0

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"cuda",
cls,
Expand Down
5 changes: 3 additions & 2 deletions src/lightning/fabric/accelerators/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.
import platform
from functools import lru_cache
from typing import Dict, List, Optional, Union
from typing import List, Optional, Union

import torch

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12


Expand Down Expand Up @@ -69,7 +70,7 @@ def is_available() -> bool:
)

@classmethod
def register_accelerators(cls, accelerator_registry: Dict) -> None:
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
accelerator_registry.register(
"mps",
cls,
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/fabric/accelerators/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from inspect import getmembers, isclass
from typing import Any, Callable, Dict, List, Optional

from lightning.fabric.accelerators.accelerator import Accelerator
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.fabric.utilities.registry import _is_register_method_overridden

Expand Down Expand Up @@ -114,6 +113,8 @@ def __str__(self) -> str:

def call_register_accelerators(registry: _AcceleratorRegistry, base_module: str) -> None:
module = importlib.import_module(base_module)
from lightning.fabric.accelerators.accelerator import Accelerator

for _, mod in getmembers(module, isclass):
if issubclass(mod, Accelerator) and _is_register_method_overridden(mod, Accelerator, "register_accelerators"):
mod.register_accelerators(registry)
Loading