Skip to content

Commit

Permalink
More XLA fixes for nightly support (#18085)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Jul 14, 2023
1 parent 356f5d0 commit e9c42ed
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 37 deletions.
9 changes: 8 additions & 1 deletion src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,21 @@ def auto_device_count() -> int:
from torch_xla._internal import tpu

return tpu.num_available_devices()
from torch_xla.experimental import tpu

device_count_on_version = {2: 8, 3: 8, 4: 4}
return device_count_on_version.get(tpu.version(), 8)
return getenv_as(xenv.TPU_NUM_DEVICES, int, 8)

@staticmethod
@functools.lru_cache(maxsize=1)
def is_available() -> bool:
return XLAAccelerator.auto_device_count() > 0
try:
return XLAAccelerator.auto_device_count() > 0
except (ValueError, AssertionError, OSError):
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
# when `torch_xla` is imported but not used
return False

@classmethod
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
Expand Down
31 changes: 19 additions & 12 deletions tests/tests_fabric/plugins/environments/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

import lightning.fabric
from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_GREATER_EQUAL_2_1
from lightning.fabric.plugins.environments import XLAEnvironment
from tests_fabric.helpers.runif import RunIf

Expand All @@ -27,13 +28,15 @@
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_default_attributes(monkeypatch):
"""Test the default attributes when no environment variables are set."""
from torch_xla.experimental import pjrt

if pjrt.using_pjrt():
if _using_pjrt():
if _XLA_GREATER_EQUAL_2_1:
from torch_xla import runtime as module
else:
from torch_xla.experimental import pjrt as module
# calling these creates side effects in other tests
monkeypatch.setattr(pjrt, "world_size", lambda: 1)
monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0)
monkeypatch.setattr(pjrt, "local_ordinal", lambda: 0)
monkeypatch.setattr(module, "world_size", lambda: 1)
monkeypatch.setattr(module, "global_ordinal", lambda: 0)
monkeypatch.setattr(module, "local_ordinal", lambda: 0)
else:
from torch_xla import _XLAC

Expand All @@ -57,10 +60,9 @@ def test_default_attributes(monkeypatch):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_attributes_from_environment_variables(monkeypatch):
"""Test that the default cluster environment takes the attributes from the environment variables."""
from torch_xla.experimental import pjrt

os.environ["XRT_HOST_ORDINAL"] = "3"
if not pjrt.using_pjrt():

if not _using_pjrt():
os.environ.update(
{
"XRT_SHARD_WORLD_SIZE": "1",
Expand All @@ -69,10 +71,15 @@ def test_attributes_from_environment_variables(monkeypatch):
}
)
else:
if _XLA_GREATER_EQUAL_2_1:
from torch_xla import runtime as module
else:
from torch_xla.experimental import pjrt as module

# PJRT doesn't pull these from envvars
monkeypatch.setattr(pjrt, "world_size", lambda: 1)
monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0)
monkeypatch.setattr(pjrt, "local_ordinal", lambda: 2)
monkeypatch.setattr(module, "world_size", lambda: 1)
monkeypatch.setattr(module, "global_ordinal", lambda: 0)
monkeypatch.setattr(module, "local_ordinal", lambda: 2)

env = XLAEnvironment()
with pytest.raises(NotImplementedError):
Expand Down
8 changes: 3 additions & 5 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset

from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.fabric.fabric import Fabric
from lightning.fabric.plugins import Precision
from lightning.fabric.strategies import (
Expand Down Expand Up @@ -533,11 +534,8 @@ def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_to_device(accelerator, expected):
"""Test that the to_device method can move various objects to the device determined by the accelerator."""
if accelerator == "tpu":
from torch_xla.experimental import pjrt

if not pjrt.using_pjrt():
expected = "xla:1"
if accelerator == "tpu" and not _using_pjrt():
expected = "xla:1"

fabric = Fabric(accelerator=accelerator, devices=1)
fabric.launch()
Expand Down
5 changes: 2 additions & 3 deletions tests/tests_pytorch/accelerators/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch import nn
from torch.utils.data import DataLoader

import lightning.fabric
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator
Expand Down Expand Up @@ -314,9 +315,7 @@ def test_warning_if_tpus_not_used(tpu_available):
@pytest.mark.parametrize("runtime", ["xrt", "pjrt"])
@RunIf(min_python="3.9") # mocking issue
def test_trainer_config_device_ids(devices, expected_device_ids, runtime, tpu_available, monkeypatch):
from torch_xla.experimental import pjrt

monkeypatch.setattr(pjrt, "using_pjrt", lambda: runtime == "pjrt")
monkeypatch.setattr(lightning.fabric.accelerators.xla, "_using_pjrt", lambda: runtime == "pjrt")

mock = DeviceMock()
monkeypatch.setattr(torch, "device", mock)
Expand Down
5 changes: 2 additions & 3 deletions tests/tests_pytorch/callbacks/test_device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest
import torch

from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators.cpu import _CPU_PERCENT, _CPU_SWAP_PERCENT, _CPU_VM_PERCENT, get_cpu_stats
from lightning.pytorch.callbacks import DeviceStatsMonitor
Expand Down Expand Up @@ -129,9 +130,7 @@ def test_device_stats_monitor_tpu(tmpdir):
try:
trainer.fit(model)
except RuntimeError as e:
from torch_xla.experimental import pjrt

if pjrt.using_pjrt() and "GetMemoryInfo not implemented" in str(e):
if _using_pjrt() and "GetMemoryInfo not implemented" in str(e):
pytest.xfail("`xm.get_memory_info` is not implemented with PJRT")
raise e

Expand Down
11 changes: 4 additions & 7 deletions tests/tests_pytorch/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.utils.data import DataLoader

import tests_pytorch.helpers.pipelines as tpipes
from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import XLAAccelerator
from lightning.pytorch.callbacks import EarlyStopping
Expand Down Expand Up @@ -75,9 +76,8 @@ def test_model_tpu_index(tmpdir, tpu_core):
model = BoringModel()
tpipes.run_model_test(trainer_options, model, with_hpc=False)
import torch_xla
from torch_xla.experimental import pjrt

expected = tpu_core if pjrt.using_pjrt() else tpu_core + 1
expected = tpu_core if _using_pjrt() else tpu_core + 1
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{expected}"


Expand Down Expand Up @@ -138,9 +138,8 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
model = BoringModel()
tpipes.run_model_test(trainer_options, model)
import torch_xla
from torch_xla.experimental import pjrt

expected = tpu_core if pjrt.using_pjrt() else tpu_core + 1
expected = tpu_core if _using_pjrt() else tpu_core + 1
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{expected}"


Expand Down Expand Up @@ -348,9 +347,7 @@ def on_train_start(self):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_tpu_host_world_size(tmpdir):
"""Test Host World size env setup on TPU."""
from torch_xla.experimental import pjrt

if pjrt.using_pjrt():
if _using_pjrt():
pytest.skip("PJRT doesn't set 'XRT_HOST_WORLD_SIZE'")

trainer_options = {
Expand Down
9 changes: 3 additions & 6 deletions tests/tests_pytorch/strategies/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import XLAAccelerator
from lightning.pytorch.demos.boring_classes import BoringModel
Expand All @@ -25,9 +26,7 @@

class BoringModelTPU(BoringModel):
def on_train_start(self) -> None:
from torch_xla.experimental import pjrt

index = 0 if pjrt.using_pjrt() else 1
index = 0 if _using_pjrt() else 1
# assert strategy attributes for device setting
assert self.device == torch.device("xla", index=index)
assert os.environ.get("PT_XLA_DEBUG") == "1"
Expand All @@ -38,10 +37,8 @@ def on_train_start(self) -> None:
def test_xla_strategy_debug_state():
"""Tests if device/debug flag is set correctly when training and after teardown for XLAStrategy."""
model = BoringModelTPU()
from torch_xla.experimental import pjrt

trainer_kwargs = {}
if not pjrt.using_pjrt():
if not _using_pjrt():
# only XRT supports XLA with a single process
trainer_kwargs["devices"] = 1
trainer = Trainer(fast_dev_run=True, strategy=XLAStrategy(debug=True), **trainer_kwargs)
Expand Down

0 comments on commit e9c42ed

Please sign in to comment.