diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a168db7b486d..f49707b296f1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -225,6 +225,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815)) +- Fixed importing `torch_xla.debug` for `torch-xla<1.8` ([#10836](https://github.com/PyTorchLightning/pytorch-lightning/pull/10836)) + + - Fixed an issue to return the results for each dataloader separately instead of duplicating them for each ([#10810](https://github.com/PyTorchLightning/pytorch-lightning/pull/10810)) diff --git a/pytorch_lightning/profiler/xla.py b/pytorch_lightning/profiler/xla.py index e30f06f84e952..c89685bcad0be 100644 --- a/pytorch_lightning/profiler/xla.py +++ b/pytorch_lightning/profiler/xla.py @@ -42,9 +42,10 @@ from typing import Dict from pytorch_lightning.profiler.base import BaseProfiler -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _TPU_AVAILABLE: +if _TPU_AVAILABLE and _TORCH_GREATER_EQUAL_1_8: import torch_xla.debug.profiler as xp log = logging.getLogger(__name__) @@ -65,6 +66,10 @@ class XLAProfiler(BaseProfiler): def __init__(self, port: int = 9012) -> None: """This Profiler will help you debug and optimize training workload performance for your models using Cloud TPU performance tools.""" + if not _TPU_AVAILABLE: + raise MisconfigurationException("`XLAProfiler` is only supported on TPUs") + if not _TORCH_GREATER_EQUAL_1_8: + raise MisconfigurationException("`XLAProfiler` is only supported with `torch-xla >= 1.8`") super().__init__(dirpath=None, filename=None) self.port = port self._recording_map: Dict = {} diff --git a/tests/profiler/test_xla_profiler.py b/tests/profiler/test_xla_profiler.py index 2afbf69a6d0b0..7f460ea11d322 100644 --- a/tests/profiler/test_xla_profiler.py +++ b/tests/profiler/test_xla_profiler.py @@ -18,14 +18,16 @@ from pytorch_lightning import Trainer from pytorch_lightning.profiler import XLAProfiler -from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE from tests.helpers import BoringModel from tests.helpers.runif import RunIf if _TPU_AVAILABLE: - import torch_xla.debug.profiler as xp import torch_xla.utils.utils as xu + if _TORCH_GREATER_EQUAL_1_8: + import torch_xla.debug.profiler as xp + @RunIf(tpu=True) def test_xla_profiler_instance(tmpdir):