Skip to content

Commit

Permalink
bump: support torch>=2.0 (#2671)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 7f579eb commit 346bcdc
Show file tree
Hide file tree
Showing 47 changed files with 125 additions and 232 deletions.
6 changes: 3 additions & 3 deletions .azure/gpu-integrations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ jobs:
- job: integrate_GPU
strategy:
matrix:
"torch | 1.x":
docker-image: "pytorchlightning/torchmetrics:ubuntu22.04-cuda11.8.0-py3.9-torch1.13"
torch-ver: "1.13"
"torch | 2.0":
docker-image: "pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime"
torch-ver: "2.0"
requires: "oldest"
"torch | 2.x":
docker-image: "pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime"
Expand Down
15 changes: 6 additions & 9 deletions .azure/gpu-unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,10 @@ jobs:
- job: unitest_GPU
strategy:
matrix:
"PyTorch | 1.10 oldest":
"PyTorch | 2.0 oldest":
# Torch does not have build wheels with old Torch versions for newer CUDA
docker-image: "ubuntu20.04-cuda11.3.1-py3.9-torch1.10"
torch-ver: "1.10"
"PyTorch | 1.X LTS":
docker-image: "ubuntu22.04-cuda11.8.0-py3.9-torch1.13"
torch-ver: "1.13"
docker-image: "ubuntu22.04-cuda11.8.0-py3.10-torch2.0"
torch-ver: "2.0"
"PyTorch | 2.X stable":
docker-image: "ubuntu22.04-cuda12.1.1-py3.11-torch2.4"
torch-ver: "2.4"
Expand Down Expand Up @@ -123,7 +120,7 @@ jobs:
- bash: |
python .github/assistant.py set-oldest-versions
condition: eq(variables['torch-ver'], '1.10')
condition: eq(variables['torch-ver'], '2.0')
displayName: "Setting oldest versions"
- bash: |
Expand Down Expand Up @@ -191,7 +188,7 @@ jobs:
workingDirectory: "tests/"
# skip for PR if there is nothing to test, note that outside PR there is default 'unittests'
condition: and(succeeded(), ne(variables['TEST_DIRS'], ''))
timeoutInMinutes: "90"
timeoutInMinutes: "95"
displayName: "UnitTesting common"
- bash: |
Expand All @@ -203,7 +200,7 @@ jobs:
workingDirectory: "tests/"
# skip for PR if there is nothing to test, note that outside PR there is default 'unittests'
condition: and(succeeded(), ne(variables['TEST_DIRS'], ''))
timeoutInMinutes: "90"
timeoutInMinutes: "95"
displayName: "UnitTesting DDP"
- bash: |
Expand Down
7 changes: 0 additions & 7 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,22 @@ jobs:
os: ["ubuntu-20.04"]
python-version: ["3.9"]
pytorch-version:
- "1.10.2"
- "1.11.0"
- "1.12.1"
- "1.13.1"
- "2.0.1"
- "2.1.2"
- "2.2.2"
- "2.3.1"
- "2.4.0"
include:
# cover additional python and PT combinations
- { os: "ubuntu-22.04", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.2" }
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.3.1" }
# standard mac machine, not the M1
- { os: "macOS-13", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "macOS-13", python-version: "3.10", pytorch-version: "2.0.1" }
# using the ARM based M1 machine
- { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "macOS-14", python-version: "3.11", pytorch-version: "2.4.0" }
# some windows
- { os: "windows-2022", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "windows-2022", python-version: "3.11", pytorch-version: "2.4.0" }
# Future released version
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/docker-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ jobs:
include:
# These are the base images for PL release docker images,
# so include at least all the combinations in release-dockers.yml.
- { python: "3.9", pytorch: "1.10", cuda: "11.3.1", ubuntu: "20.04" }
#- { python: "3.9", pytorch: "1.11", cuda: "11.8.0", ubuntu: "22.04" }
- { python: "3.9", pytorch: "1.13", cuda: "11.8.0", ubuntu: "22.04" }
- { python: "3.10", pytorch: "2.2", cuda: "12.1.1", ubuntu: "22.04" }
- { python: "3.11", pytorch: "2.2", cuda: "12.1.1", ubuntu: "22.04" }
- { python: "3.11", pytorch: "2.3", cuda: "12.1.1", ubuntu: "22.04" }
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

-
- Changed minimum supported Pytorch version to 2.0 ([#2671](https://github.com/Lightning-AI/torchmetrics/pull/2671))


### Fixed
Expand Down
1 change: 1 addition & 0 deletions requirements/_tests.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

codecov ==2.1.13
coverage ==7.6.*
codecov ==2.1.13
pytest ==8.3.*
Expand Down
4 changes: 2 additions & 2 deletions requirements/audio.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# this need to be the same as used inside speechmetrics
pesq >=0.0.4, <0.0.5
pystoi >=0.4.0, <0.5.0
torchaudio >=0.10.0, <2.5.0
torchaudio >=2.0.1, <2.5.0
gammatone >=1.0.0, <1.1.0
librosa >=0.9.0, <0.11.0
librosa >=0.10.0, <0.11.0
onnxruntime >=1.12.0, <1.20 # installing onnxruntime_gpu-gpu failed on macos
requests >=2.19.0, <2.33.0
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

numpy >1.20.0, <2.0 # strict, for compatibility reasons
packaging >17.1
torch >=1.10.0, <2.5.0
torch >=2.0.0, <2.5.0
typing-extensions; python_version < '3.9'
lightning-utilities >=0.8.0, <0.12.0
2 changes: 1 addition & 1 deletion requirements/detection.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torchvision >=0.8, <0.20.0
torchvision >=0.15.1, <0.20.0
pycocotools >2.0.0, <2.1.0
2 changes: 1 addition & 1 deletion requirements/image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

scipy >1.0.0, <1.15.0
torchvision >=0.8, <0.20.0
torchvision >=0.15.1, <0.20.0
torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing
11 changes: 9 additions & 2 deletions src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@
_ONNXRUNTIME_AVAILABLE,
_PESQ_AVAILABLE,
_PYSTOI_AVAILABLE,
_SCIPI_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)

if _SCIPI_AVAILABLE:
import scipy.signal

# back compatibility patch due to SMRMpy using scipy.signal.hamming
if not hasattr(scipy.signal, "hamming"):
scipy.signal.hamming = scipy.signal.windows.hamming

__all__ = [
"PermutationInvariantTraining",
"ScaleInvariantSignalDistortionRatio",
Expand All @@ -52,7 +59,7 @@

__all__ += ["ShortTimeObjectiveIntelligibility"]

if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE and _TORCHAUDIO_GREATER_EQUAL_0_10:
if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE:
from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio

__all__ += ["SpeechReverberationModulationEnergyRatio"]
Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
_GAMMATONE_AVAILABLE,
_MATPLOTLIB_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHAUDIO_GREATER_EQUAL_0_10]):
if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE]):
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio", "SpeechReverberationModulationEnergyRatio.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio.plot"]
Expand Down Expand Up @@ -105,7 +104,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not _TORCHAUDIO_AVAILABLE or not _TORCHAUDIO_GREATER_EQUAL_0_10 or not _GAMMATONE_AVAILABLE:
if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
raise ModuleNotFoundError(
"speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
" `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
Expand Down
23 changes: 11 additions & 12 deletions src/torchmetrics/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.detection.panoptic_qualities import ModifiedPanopticQuality, PanopticQuality
from torchmetrics.utilities.imports import (
_TORCHVISION_GREATER_EQUAL_0_8,
_TORCHVISION_GREATER_EQUAL_0_13,
)
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE

__all__ = ["ModifiedPanopticQuality", "PanopticQuality"]

if _TORCHVISION_GREATER_EQUAL_0_8:
if _TORCHVISION_AVAILABLE:
from torchmetrics.detection.ciou import CompleteIntersectionOverUnion
from torchmetrics.detection.diou import DistanceIntersectionOverUnion
from torchmetrics.detection.giou import GeneralizedIntersectionOverUnion
from torchmetrics.detection.iou import IntersectionOverUnion
from torchmetrics.detection.mean_ap import MeanAveragePrecision

__all__ += ["MeanAveragePrecision", "GeneralizedIntersectionOverUnion", "IntersectionOverUnion"]

if _TORCHVISION_GREATER_EQUAL_0_13:
from torchmetrics.detection.ciou import CompleteIntersectionOverUnion
from torchmetrics.detection.diou import DistanceIntersectionOverUnion

__all__ += ["CompleteIntersectionOverUnion", "DistanceIntersectionOverUnion"]
__all__ += [
"MeanAveragePrecision",
"GeneralizedIntersectionOverUnion",
"IntersectionOverUnion",
"CompleteIntersectionOverUnion",
"DistanceIntersectionOverUnion",
]
9 changes: 0 additions & 9 deletions src/torchmetrics/detection/_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
from typing import Any, Collection

from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from torchmetrics.utilities.prints import _deprecated_root_import_class

if not _TORCH_GREATER_EQUAL_1_12:
__doctest_skip__ = [
"_PanopticQuality",
"_PanopticQuality.*",
"_ModifiedPanopticQuality",
"_ModifiedPanopticQuality.*",
]


class _ModifiedPanopticQuality(ModifiedPanopticQuality):
"""Wrapper for deprecated import.
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/detection/_mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import _cumsum
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MeanAveragePrecision.plot"]

if not _TORCHVISION_GREATER_EQUAL_0_8 or not _PYCOCOTOOLS_AVAILABLE:
if not _TORCHVISION_AVAILABLE or not _PYCOCOTOOLS_AVAILABLE:
__doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"]

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -327,10 +327,10 @@ def __init__(
"`MAP` metric requires that `pycocotools` installed."
" Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`"
)
if not _TORCHVISION_GREATER_EQUAL_0_8:
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(
"`MeanAveragePrecision` metric requires that `torchvision` version 0.8.0 or newer is installed."
" Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`."
"`MeanAveragePrecision` metric requires that `torchvision` is installed."
" Please install with `pip install torchmetrics[detection]`."
)

allowed_box_formats = ("xyxy", "xywh", "cxcywh")
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/detection/ciou.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

from torchmetrics.detection.iou import IntersectionOverUnion
from torchmetrics.functional.detection.ciou import _ciou_compute, _ciou_update
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _TORCHVISION_GREATER_EQUAL_0_13:
if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["CompleteIntersectionOverUnion", "CompleteIntersectionOverUnion.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["CompleteIntersectionOverUnion.plot"]
Expand Down Expand Up @@ -110,10 +110,10 @@ def __init__(
respect_labels: bool = True,
**kwargs: Any,
) -> None:
if not _TORCHVISION_GREATER_EQUAL_0_13:
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed."
" Please install with `pip install torchmetrics[detection]`."
)
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)

Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/detection/diou.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

from torchmetrics.detection.iou import IntersectionOverUnion
from torchmetrics.functional.detection.diou import _diou_compute, _diou_update
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _TORCHVISION_GREATER_EQUAL_0_13:
if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["DistanceIntersectionOverUnion", "DistanceIntersectionOverUnion.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["DistanceIntersectionOverUnion.plot"]
Expand Down Expand Up @@ -110,10 +110,10 @@ def __init__(
respect_labels: bool = True,
**kwargs: Any,
) -> None:
if not _TORCHVISION_GREATER_EQUAL_0_13:
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed."
" Please install with `pip install torchmetrics[detection]`."
)
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/detection/giou.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

from torchmetrics.detection.iou import IntersectionOverUnion
from torchmetrics.functional.detection.giou import _giou_compute, _giou_update
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _TORCHVISION_GREATER_EQUAL_0_8:
if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["GeneralizedIntersectionOverUnion", "GeneralizedIntersectionOverUnion.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["GeneralizedIntersectionOverUnion.plot"]
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/detection/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from torchmetrics.functional.detection.iou import _iou_compute, _iou_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _TORCHVISION_GREATER_EQUAL_0_8:
if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["IntersectionOverUnion", "IntersectionOverUnion.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["IntersectionOverUnion.plot"]
Expand Down Expand Up @@ -146,10 +146,10 @@ def __init__(
) -> None:
super().__init__(**kwargs)

if not _TORCHVISION_GREATER_EQUAL_0_8:
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.8.0 or newer is installed."
" Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`."
f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed."
" Please install with `pip install torchmetrics[detection]`."
)

allowed_box_formats = ("xyxy", "xywh", "cxcywh")
Expand Down
Loading

0 comments on commit 346bcdc

Please sign in to comment.