diff --git a/src/lightning/fabric/accelerators/xla.py b/src/lightning/fabric/accelerators/xla.py index bd7014d009325..fbe29eee641ec 100644 --- a/src/lightning/fabric/accelerators/xla.py +++ b/src/lightning/fabric/accelerators/xla.py @@ -76,6 +76,8 @@ def auto_device_count() -> int: from torch_xla._internal import tpu return tpu.num_available_devices() + from torch_xla.experimental import tpu + device_count_on_version = {2: 8, 3: 8, 4: 4} return device_count_on_version.get(tpu.version(), 8) return getenv_as(xenv.TPU_NUM_DEVICES, int, 8) @@ -83,7 +85,12 @@ def auto_device_count() -> int: @staticmethod @functools.lru_cache(maxsize=1) def is_available() -> bool: - return XLAAccelerator.auto_device_count() > 0 + try: + return XLAAccelerator.auto_device_count() > 0 + except (ValueError, AssertionError, OSError): + # XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases + # when `torch_xla` is imported but not used + return False @classmethod def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index d65cd018f7ed7..a16f4521c8129 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -18,6 +18,7 @@ import torch import lightning.fabric +from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_GREATER_EQUAL_2_1 from lightning.fabric.plugins.environments import XLAEnvironment from tests_fabric.helpers.runif import RunIf @@ -27,13 +28,15 @@ @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_default_attributes(monkeypatch): """Test the default attributes when no environment variables are set.""" - from torch_xla.experimental import pjrt - - if pjrt.using_pjrt(): + if _using_pjrt(): + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as module + else: + from torch_xla.experimental import pjrt as module # calling these creates side effects in other tests - monkeypatch.setattr(pjrt, "world_size", lambda: 1) - monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0) - monkeypatch.setattr(pjrt, "local_ordinal", lambda: 0) + monkeypatch.setattr(module, "world_size", lambda: 1) + monkeypatch.setattr(module, "global_ordinal", lambda: 0) + monkeypatch.setattr(module, "local_ordinal", lambda: 0) else: from torch_xla import _XLAC @@ -57,10 +60,9 @@ def test_default_attributes(monkeypatch): @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_attributes_from_environment_variables(monkeypatch): """Test that the default cluster environment takes the attributes from the environment variables.""" - from torch_xla.experimental import pjrt - os.environ["XRT_HOST_ORDINAL"] = "3" - if not pjrt.using_pjrt(): + + if not _using_pjrt(): os.environ.update( { "XRT_SHARD_WORLD_SIZE": "1", @@ -69,10 +71,15 @@ def test_attributes_from_environment_variables(monkeypatch): } ) else: + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as module + else: + from torch_xla.experimental import pjrt as module + # PJRT doesn't pull these from envvars - monkeypatch.setattr(pjrt, "world_size", lambda: 1) - monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0) - monkeypatch.setattr(pjrt, "local_ordinal", lambda: 2) + monkeypatch.setattr(module, "world_size", lambda: 1) + monkeypatch.setattr(module, "global_ordinal", lambda: 0) + monkeypatch.setattr(module, "local_ordinal", lambda: 2) env = XLAEnvironment() with pytest.raises(NotImplementedError): diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index befe22620a775..06069e5db4624 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -24,6 +24,7 @@ from torch import nn from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset +from lightning.fabric.accelerators.xla import _using_pjrt from lightning.fabric.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.strategies import ( @@ -533,11 +534,8 @@ def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy): @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_to_device(accelerator, expected): """Test that the to_device method can move various objects to the device determined by the accelerator.""" - if accelerator == "tpu": - from torch_xla.experimental import pjrt - - if not pjrt.using_pjrt(): - expected = "xla:1" + if accelerator == "tpu" and not _using_pjrt(): + expected = "xla:1" fabric = Fabric(accelerator=accelerator, devices=1) fabric.launch() diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py index 659854db81644..d8961abec4537 100644 --- a/tests/tests_pytorch/accelerators/test_xla.py +++ b/tests/tests_pytorch/accelerators/test_xla.py @@ -22,6 +22,7 @@ from torch import nn from torch.utils.data import DataLoader +import lightning.fabric from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator @@ -314,9 +315,7 @@ def test_warning_if_tpus_not_used(tpu_available): @pytest.mark.parametrize("runtime", ["xrt", "pjrt"]) @RunIf(min_python="3.9") # mocking issue def test_trainer_config_device_ids(devices, expected_device_ids, runtime, tpu_available, monkeypatch): - from torch_xla.experimental import pjrt - - monkeypatch.setattr(pjrt, "using_pjrt", lambda: runtime == "pjrt") + monkeypatch.setattr(lightning.fabric.accelerators.xla, "_using_pjrt", lambda: runtime == "pjrt") mock = DeviceMock() monkeypatch.setattr(torch, "device", mock) diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index 9f6c6633d1707..d0b8e99e89525 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -21,6 +21,7 @@ import pytest import torch +from lightning.fabric.accelerators.xla import _using_pjrt from lightning.pytorch import Trainer from lightning.pytorch.accelerators.cpu import _CPU_PERCENT, _CPU_SWAP_PERCENT, _CPU_VM_PERCENT, get_cpu_stats from lightning.pytorch.callbacks import DeviceStatsMonitor @@ -129,9 +130,7 @@ def test_device_stats_monitor_tpu(tmpdir): try: trainer.fit(model) except RuntimeError as e: - from torch_xla.experimental import pjrt - - if pjrt.using_pjrt() and "GetMemoryInfo not implemented" in str(e): + if _using_pjrt() and "GetMemoryInfo not implemented" in str(e): pytest.xfail("`xm.get_memory_info` is not implemented with PJRT") raise e diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index 4842f853e8eee..107c0947f1c13 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -20,6 +20,7 @@ from torch.utils.data import DataLoader import tests_pytorch.helpers.pipelines as tpipes +from lightning.fabric.accelerators.xla import _using_pjrt from lightning.pytorch import Trainer from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.callbacks import EarlyStopping @@ -75,9 +76,8 @@ def test_model_tpu_index(tmpdir, tpu_core): model = BoringModel() tpipes.run_model_test(trainer_options, model, with_hpc=False) import torch_xla - from torch_xla.experimental import pjrt - expected = tpu_core if pjrt.using_pjrt() else tpu_core + 1 + expected = tpu_core if _using_pjrt() else tpu_core + 1 assert torch_xla._XLAC._xla_get_default_device() == f"xla:{expected}" @@ -138,9 +138,8 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core): model = BoringModel() tpipes.run_model_test(trainer_options, model) import torch_xla - from torch_xla.experimental import pjrt - expected = tpu_core if pjrt.using_pjrt() else tpu_core + 1 + expected = tpu_core if _using_pjrt() else tpu_core + 1 assert torch_xla._XLAC._xla_get_default_device() == f"xla:{expected}" @@ -348,9 +347,7 @@ def on_train_start(self): @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_tpu_host_world_size(tmpdir): """Test Host World size env setup on TPU.""" - from torch_xla.experimental import pjrt - - if pjrt.using_pjrt(): + if _using_pjrt(): pytest.skip("PJRT doesn't set 'XRT_HOST_WORLD_SIZE'") trainer_options = { diff --git a/tests/tests_pytorch/strategies/test_xla.py b/tests/tests_pytorch/strategies/test_xla.py index 7857a393b25f4..87d617ccb68c6 100644 --- a/tests/tests_pytorch/strategies/test_xla.py +++ b/tests/tests_pytorch/strategies/test_xla.py @@ -16,6 +16,7 @@ import torch +from lightning.fabric.accelerators.xla import _using_pjrt from lightning.pytorch import Trainer from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.demos.boring_classes import BoringModel @@ -25,9 +26,7 @@ class BoringModelTPU(BoringModel): def on_train_start(self) -> None: - from torch_xla.experimental import pjrt - - index = 0 if pjrt.using_pjrt() else 1 + index = 0 if _using_pjrt() else 1 # assert strategy attributes for device setting assert self.device == torch.device("xla", index=index) assert os.environ.get("PT_XLA_DEBUG") == "1" @@ -38,10 +37,8 @@ def on_train_start(self) -> None: def test_xla_strategy_debug_state(): """Tests if device/debug flag is set correctly when training and after teardown for XLAStrategy.""" model = BoringModelTPU() - from torch_xla.experimental import pjrt - trainer_kwargs = {} - if not pjrt.using_pjrt(): + if not _using_pjrt(): # only XRT supports XLA with a single process trainer_kwargs["devices"] = 1 trainer = Trainer(fast_dev_run=True, strategy=XLAStrategy(debug=True), **trainer_kwargs)