Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More XLA fixes for nightly support #18085

Merged
merged 2 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/lightning/fabric/accelerators/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,21 @@ def auto_device_count() -> int:
from torch_xla._internal import tpu

return tpu.num_available_devices()
from torch_xla.experimental import tpu

device_count_on_version = {2: 8, 3: 8, 4: 4}
return device_count_on_version.get(tpu.version(), 8)
return getenv_as(xenv.TPU_NUM_DEVICES, int, 8)

@staticmethod
@functools.lru_cache(maxsize=1)
def is_available() -> bool:
return XLAAccelerator.auto_device_count() > 0
try:
return XLAAccelerator.auto_device_count() > 0
except (ValueError, AssertionError, OSError):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
# when `torch_xla` is imported but not used
return False

@classmethod
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
Expand Down
31 changes: 19 additions & 12 deletions tests/tests_fabric/plugins/environments/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

import lightning.fabric
from lightning.fabric.accelerators.xla import _using_pjrt, _XLA_GREATER_EQUAL_2_1
from lightning.fabric.plugins.environments import XLAEnvironment
from tests_fabric.helpers.runif import RunIf

Expand All @@ -27,13 +28,15 @@
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_default_attributes(monkeypatch):
"""Test the default attributes when no environment variables are set."""
from torch_xla.experimental import pjrt

if pjrt.using_pjrt():
if _using_pjrt():
if _XLA_GREATER_EQUAL_2_1:
from torch_xla import runtime as module
else:
from torch_xla.experimental import pjrt as module
# calling these creates side effects in other tests
monkeypatch.setattr(pjrt, "world_size", lambda: 1)
monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0)
monkeypatch.setattr(pjrt, "local_ordinal", lambda: 0)
monkeypatch.setattr(module, "world_size", lambda: 1)
monkeypatch.setattr(module, "global_ordinal", lambda: 0)
monkeypatch.setattr(module, "local_ordinal", lambda: 0)
else:
from torch_xla import _XLAC

Expand All @@ -57,10 +60,9 @@ def test_default_attributes(monkeypatch):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_attributes_from_environment_variables(monkeypatch):
"""Test that the default cluster environment takes the attributes from the environment variables."""
from torch_xla.experimental import pjrt

os.environ["XRT_HOST_ORDINAL"] = "3"
if not pjrt.using_pjrt():

if not _using_pjrt():
os.environ.update(
{
"XRT_SHARD_WORLD_SIZE": "1",
Expand All @@ -69,10 +71,15 @@ def test_attributes_from_environment_variables(monkeypatch):
}
)
else:
if _XLA_GREATER_EQUAL_2_1:
from torch_xla import runtime as module
else:
from torch_xla.experimental import pjrt as module

# PJRT doesn't pull these from envvars
monkeypatch.setattr(pjrt, "world_size", lambda: 1)
monkeypatch.setattr(pjrt, "global_ordinal", lambda: 0)
monkeypatch.setattr(pjrt, "local_ordinal", lambda: 2)
monkeypatch.setattr(module, "world_size", lambda: 1)
monkeypatch.setattr(module, "global_ordinal", lambda: 0)
monkeypatch.setattr(module, "local_ordinal", lambda: 2)

env = XLAEnvironment()
with pytest.raises(NotImplementedError):
Expand Down
8 changes: 3 additions & 5 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset

from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.fabric.fabric import Fabric
from lightning.fabric.plugins import Precision
from lightning.fabric.strategies import (
Expand Down Expand Up @@ -533,11 +534,8 @@ def test_setup_dataloaders_replace_standard_sampler(shuffle, strategy):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_to_device(accelerator, expected):
"""Test that the to_device method can move various objects to the device determined by the accelerator."""
if accelerator == "tpu":
from torch_xla.experimental import pjrt

if not pjrt.using_pjrt():
expected = "xla:1"
if accelerator == "tpu" and not _using_pjrt():
expected = "xla:1"

fabric = Fabric(accelerator=accelerator, devices=1)
fabric.launch()
Expand Down
5 changes: 2 additions & 3 deletions tests/tests_pytorch/accelerators/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch import nn
from torch.utils.data import DataLoader

import lightning.fabric
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator
Expand Down Expand Up @@ -314,9 +315,7 @@ def test_warning_if_tpus_not_used(tpu_available):
@pytest.mark.parametrize("runtime", ["xrt", "pjrt"])
@RunIf(min_python="3.9") # mocking issue
def test_trainer_config_device_ids(devices, expected_device_ids, runtime, tpu_available, monkeypatch):
from torch_xla.experimental import pjrt

monkeypatch.setattr(pjrt, "using_pjrt", lambda: runtime == "pjrt")
monkeypatch.setattr(lightning.fabric.accelerators.xla, "_using_pjrt", lambda: runtime == "pjrt")

mock = DeviceMock()
monkeypatch.setattr(torch, "device", mock)
Expand Down
5 changes: 2 additions & 3 deletions tests/tests_pytorch/callbacks/test_device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest
import torch

from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators.cpu import _CPU_PERCENT, _CPU_SWAP_PERCENT, _CPU_VM_PERCENT, get_cpu_stats
from lightning.pytorch.callbacks import DeviceStatsMonitor
Expand Down Expand Up @@ -129,9 +130,7 @@ def test_device_stats_monitor_tpu(tmpdir):
try:
trainer.fit(model)
except RuntimeError as e:
from torch_xla.experimental import pjrt

if pjrt.using_pjrt() and "GetMemoryInfo not implemented" in str(e):
if _using_pjrt() and "GetMemoryInfo not implemented" in str(e):
pytest.xfail("`xm.get_memory_info` is not implemented with PJRT")
raise e

Expand Down
11 changes: 4 additions & 7 deletions tests/tests_pytorch/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.utils.data import DataLoader

import tests_pytorch.helpers.pipelines as tpipes
from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import XLAAccelerator
from lightning.pytorch.callbacks import EarlyStopping
Expand Down Expand Up @@ -75,9 +76,8 @@ def test_model_tpu_index(tmpdir, tpu_core):
model = BoringModel()
tpipes.run_model_test(trainer_options, model, with_hpc=False)
import torch_xla
from torch_xla.experimental import pjrt

expected = tpu_core if pjrt.using_pjrt() else tpu_core + 1
expected = tpu_core if _using_pjrt() else tpu_core + 1
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{expected}"


Expand Down Expand Up @@ -138,9 +138,8 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
model = BoringModel()
tpipes.run_model_test(trainer_options, model)
import torch_xla
from torch_xla.experimental import pjrt

expected = tpu_core if pjrt.using_pjrt() else tpu_core + 1
expected = tpu_core if _using_pjrt() else tpu_core + 1
assert torch_xla._XLAC._xla_get_default_device() == f"xla:{expected}"


Expand Down Expand Up @@ -348,9 +347,7 @@ def on_train_start(self):
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_tpu_host_world_size(tmpdir):
"""Test Host World size env setup on TPU."""
from torch_xla.experimental import pjrt

if pjrt.using_pjrt():
if _using_pjrt():
pytest.skip("PJRT doesn't set 'XRT_HOST_WORLD_SIZE'")

trainer_options = {
Expand Down
9 changes: 3 additions & 6 deletions tests/tests_pytorch/strategies/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from lightning.fabric.accelerators.xla import _using_pjrt
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import XLAAccelerator
from lightning.pytorch.demos.boring_classes import BoringModel
Expand All @@ -25,9 +26,7 @@

class BoringModelTPU(BoringModel):
def on_train_start(self) -> None:
from torch_xla.experimental import pjrt

index = 0 if pjrt.using_pjrt() else 1
index = 0 if _using_pjrt() else 1
# assert strategy attributes for device setting
assert self.device == torch.device("xla", index=index)
assert os.environ.get("PT_XLA_DEBUG") == "1"
Expand All @@ -38,10 +37,8 @@ def on_train_start(self) -> None:
def test_xla_strategy_debug_state():
"""Tests if device/debug flag is set correctly when training and after teardown for XLAStrategy."""
model = BoringModelTPU()
from torch_xla.experimental import pjrt

trainer_kwargs = {}
if not pjrt.using_pjrt():
if not _using_pjrt():
# only XRT supports XLA with a single process
trainer_kwargs["devices"] = 1
trainer = Trainer(fast_dev_run=True, strategy=XLAStrategy(debug=True), **trainer_kwargs)
Expand Down
Loading