Skip to content

Commit

Permalink
Don't import torch_xla.debug for torch-xla<1.8 (#10836)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Dec 6, 2021
1 parent 3d59a2f commit 6599ced
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815))


- Fixed importing `torch_xla.debug` for `torch-xla<1.8` ([#10836](https://github.com/PyTorchLightning/pytorch-lightning/pull/10836))


- Fixed an issue to return the results for each dataloader separately instead of duplicating them for each ([#10810](https://github.com/PyTorchLightning/pytorch-lightning/pull/10810))


Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/profiler/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@
from typing import Dict

from pytorch_lightning.profiler.base import BaseProfiler
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TPU_AVAILABLE:
if _TPU_AVAILABLE and _TORCH_GREATER_EQUAL_1_8:
import torch_xla.debug.profiler as xp

log = logging.getLogger(__name__)
Expand All @@ -65,6 +66,10 @@ class XLAProfiler(BaseProfiler):
def __init__(self, port: int = 9012) -> None:
"""This Profiler will help you debug and optimize training workload performance for your models using Cloud
TPU performance tools."""
if not _TPU_AVAILABLE:
raise MisconfigurationException("`XLAProfiler` is only supported on TPUs")
if not _TORCH_GREATER_EQUAL_1_8:
raise MisconfigurationException("`XLAProfiler` is only supported with `torch-xla >= 1.8`")
super().__init__(dirpath=None, filename=None)
self.port = port
self._recording_map: Dict = {}
Expand Down
6 changes: 4 additions & 2 deletions tests/profiler/test_xla_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@

from pytorch_lightning import Trainer
from pytorch_lightning.profiler import XLAProfiler
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, _TPU_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

if _TPU_AVAILABLE:
import torch_xla.debug.profiler as xp
import torch_xla.utils.utils as xu

if _TORCH_GREATER_EQUAL_1_8:
import torch_xla.debug.profiler as xp


@RunIf(tpu=True)
def test_xla_profiler_instance(tmpdir):
Expand Down

0 comments on commit 6599ced

Please sign in to comment.