Skip to content

Commit

Permalink
Merge branch 'master' into dependabot-pip-requirements-pytest-8.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Mar 5, 2024
2 parents 813a456 + 9d76f3f commit 1cbcfde
Show file tree
Hide file tree
Showing 130 changed files with 1,012 additions and 942 deletions.
2 changes: 1 addition & 1 deletion .azure/gpu-integrations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
echo "##vso[task.setvariable variable=CUDA_VERSION_MM]$CUDA_version_mm"
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${CUDA_version_mm}/torch_stable.html"
# packages for running assistant
pip install -q packaging fire requests wget
pip install -q fire wget packaging
displayName: "set Env. vars"
- bash: |
Expand Down
1 change: 1 addition & 0 deletions .azure/gpu-unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ jobs:
displayName: "set Env. vars for PRs"
- bash: |
pip install -q fire pyGithub
printf "PR: $PR_NUMBER \n"
focus=$(python .github/assistant.py changed-domains $PR_NUMBER)
printf "focus: $focus \n"
Expand Down
4 changes: 2 additions & 2 deletions .github/actions/pull-caches/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ runs:
using: "composite"
steps:
- name: install assistant's deps
run: pip install -q fire requests packaging wget
run: pip install -q packaging fire wget
shell: bash

- name: Set PyTorch version
Expand Down Expand Up @@ -90,5 +90,5 @@ runs:

- name: Restored References
continue-on-error: true
run: ls -lh tests/_cache-references/
run: py-tree tests/_cache-references/ --show_hidden
shell: bash
2 changes: 1 addition & 1 deletion .github/actions/push-caches/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,5 @@ runs:
key: cache-references

- name: Post References
run: ls -lh tests/_cache-references/
run: py-tree tests/_cache-references/ --show_hidden
shell: bash
28 changes: 5 additions & 23 deletions .github/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import logging
import os
import re
import sys
import traceback
from typing import List, Optional, Tuple, Union

import fire
import requests
from packaging.version import parse
from pkg_resources import parse_requirements

Expand All @@ -38,19 +35,6 @@
REQUIREMENTS_FILES = (*glob.glob(_path("requirements", "*.txt")), _path("requirements.txt"))


def request_url(url: str, auth_token: Optional[str] = None) -> Optional[dict]:
"""General request with checking if request limit was reached."""
auth_header = {"Authorization": f"token {auth_token}"} if auth_token else {}
try:
req = requests.get(url, headers=auth_header, timeout=_REQUEST_TIMEOUT)
except requests.exceptions.Timeout:
traceback.print_exc()
return None
if req.status_code == 403:
return None
return json.loads(req.content.decode(req.encoding))


class AssistantCLI:
"""CLI assistant for local CI."""

Expand Down Expand Up @@ -114,15 +98,13 @@ def changed_domains(
general_sub_pkgs: Tuple[str] = _PKG_WIDE_SUBPACKAGES,
) -> Union[str, List[str]]:
"""Determine what domains were changed in particular PR."""
import github

if not pr:
return "unittests"
url = f"https://api.github.com/repos/Lightning-AI/torchmetrics/pulls/{pr}/files"
logging.debug(url)
data = request_url(url, auth_token)
if not data:
logging.debug("WARNING: No data was received -> test everything.")
return "unittests"
files = [d["filename"] for d in data]
gh = github.Github()
pr = gh.get_repo("Lightning-AI/torchmetrics").get_pull(pr)
files = [f.filename for f in pr.get_files()]

# filter out all integrations as they run in separate suit
files = [fn for fn in files if not fn.startswith("tests/integrations")]
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/_focus-diff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:

jobs:
eval-diff:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
timeout-minutes: 5
# Map the job outputs to step outputs
outputs:
Expand All @@ -26,8 +26,9 @@ jobs:
env:
PR_NUMBER: "${{ github.event.pull_request.number }}"
run: |
set -e
echo $PR_NUMBER
pip install fire requests
pip install -q -U packaging fire pyGithub pyopenssl
# python .github/assistant.py changed-domains $PR_NUMBER
echo "focus=$(python .github/assistant.py changed-domains $PR_NUMBER)" >> $GITHUB_OUTPUT
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci-integrate.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
- { python-version: "3.10", os: "windows" } # todo: https://discuss.pytorch.org/t/numpy-is-not-available-error/146192
include:
- { python-version: "3.10", requires: "latest", os: "ubuntu-22.04" }
- { python-version: "3.10", requires: "latest", os: "macOS-14" } # M1 machine
env:
PYTORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,19 @@ jobs:
- "2.1.2"
- "2.2.1"
include:
# cover additional python nad PR combinations
- { os: "ubuntu-22.04", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.1" }
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.2.1" }
# standard mac machine, not the M1
- { os: "macOS-12", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "macOS-12", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "macOS-12", python-version: "3.11", pytorch-version: "2.2.1" }
# using the ARM based M1 machine
- { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "macOS-14", python-version: "3.11", pytorch-version: "2.2.1" }
# some windows
- { os: "windows-2022", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "windows-2022", python-version: "3.11", pytorch-version: "2.2.1" }
Expand All @@ -75,6 +81,7 @@ jobs:
if: ${{ runner.os == 'macOS' }}
run: |
echo 'UNITTEST_TIMEOUT=--timeout=75' >> $GITHUB_ENV
brew install mecab # https://github.com/coqui-ai/TTS/issues/1533#issuecomment-1338662303
brew install gcc libomp ffmpeg # https://github.com/pytorch/pytorch/issues/20030
- name: Setup Linux
if: ${{ runner.os == 'Linux' }}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/publish-pkg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
- run: ls -lh dist/
# We do this, since failures on test.pypi aren't that bad
- name: Publish to Test PyPI
uses: pypa/gh-action-pypi-publish@v1.8.11
uses: pypa/gh-action-pypi-publish@v1.8.12
with:
user: __token__
password: ${{ secrets.test_pypi_password }}
Expand All @@ -94,7 +94,7 @@ jobs:
path: dist
- run: ls -lh dist/
- name: Publish distribution 📦 to PyPI
uses: pypa/gh-action-pypi-publish@v1.8.11
uses: pypa/gh-action-pypi-publish@v1.8.12
with:
user: __token__
password: ${{ secrets.pypi_password }}
Expand Down
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ pip-delete-this-directory.txt
# Unit test / coverage reports
tests/_data/
data.zip
tests/_cache-references/
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))


### Deprecated
Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ env:
pip install -e . -U -r requirements/_devel.txt

data:
python -c "from urllib.request import urlretrieve ; urlretrieve('https://pl-public-data.s3.amazonaws.com/metrics/data.zip', 'data.zip')"
pip install -q wget
python -m wget https://pl-public-data.s3.amazonaws.com/metrics/data.zip
unzip -o data.zip -d ./tests
4 changes: 2 additions & 2 deletions requirements/_doctest.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

pytest >=8.0.0, <8.1.0
pytest-doctestplus >=0.9.0, <=1.1.0
pytest-rerunfailures >=10.0, <14.0
pytest-doctestplus >1.0.0, <=1.2.0
pytest-rerunfailures >10.0, <14.0
4 changes: 2 additions & 2 deletions requirements/_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
coverage ==7.4.3
pytest ==8.0.0
pytest-cov ==4.1.0
pytest-doctestplus ==1.1.0
pytest-doctestplus ==1.2.0
pytest-rerunfailures ==13.0
pytest-timeout ==2.2.0
pytest-xdist ==3.5.0
phmdoctest ==1.4.0

psutil <5.10.0
requests <=2.31.0
pyGithub ==2.2.0
fire <=0.5.0

cloudpickle >1.3, <=3.0.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/text_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
jiwer >=2.3.0, <3.1.0
rouge-score >0.1.0, <=0.1.2
bert_score ==0.3.13
huggingface-hub <0.21 # hotfix, failing SDR for latest PT 1.11
huggingface-hub <0.22
sacrebleu >=2.3.0, <2.5.0
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _relative_average_spectral_error(preds: Tensor, target: Tensor, window_size:
>>> preds = torch.rand(4, 3, 16, 16, generator=gen)
>>> target = torch.rand(4, 3, 16, 16, generator=gen)
>>> _relative_average_spectral_error(preds, target)
tensor(5114.6641)
tensor(5114.66...)
"""
_deprecated_root_import_func("relative_average_spectral_error", "image")
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/rase.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def relative_average_spectral_error(preds: Tensor, target: Tensor, window_size:
>>> preds = torch.rand(4, 3, 16, 16)
>>> target = torch.rand(4, 3, 16, 16)
>>> relative_average_spectral_error(preds, target)
tensor(5114.6641)
tensor(5114.66...)
Raises:
ValueError: If ``window_size`` is not a positive integer.
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class _RelativeAverageSpectralError(RelativeAverageSpectralError):
>>> target = torch.rand(4, 3, 16, 16)
>>> rase = _RelativeAverageSpectralError()
>>> rase(preds, target)
tensor(5114.6641)
tensor(5114.66...)
"""

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/rase.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class RelativeAverageSpectralError(Metric):
>>> target = torch.rand(4, 3, 16, 16)
>>> rase = RelativeAverageSpectralError()
>>> rase(preds, target)
tensor(5114.6641)
tensor(5114.66...)
Raises:
ValueError: If ``window_size`` is not a positive integer.
Expand Down
4 changes: 0 additions & 4 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
import sys

from lightning_utilities.core.imports import RequirementCache
from packaging.version import Version, parse

_PYTHON_VERSION = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
_PYTHON_LOWER_3_8 = parse(_PYTHON_VERSION) < Version("3.8")
_TORCH_LOWER_2_0 = RequirementCache("torch<2.0.0")
_TORCH_GREATER_EQUAL_1_11 = RequirementCache("torch>=1.11.0")
_TORCH_GREATER_EQUAL_1_12 = RequirementCache("torch>=1.12.0")
Expand All @@ -29,7 +27,6 @@
_TORCH_GREATER_EQUAL_2_1 = RequirementCache("torch>=2.1.0")
_TORCH_GREATER_EQUAL_2_2 = RequirementCache("torch>=2.2.0")

_JIWER_AVAILABLE = RequirementCache("jiwer")
_NLTK_AVAILABLE = RequirementCache("nltk")
_ROUGE_SCORE_AVAILABLE = RequirementCache("rouge_score")
_BERTSCORE_AVAILABLE = RequirementCache("bert_score")
Expand All @@ -49,7 +46,6 @@
_GAMMATONE_AVAILABLE = RequirementCache("gammatone")
_TORCHAUDIO_AVAILABLE = RequirementCache("torchaudio")
_TORCHAUDIO_GREATER_EQUAL_0_10 = RequirementCache("torchaudio>=0.10.0")
_SACREBLEU_AVAILABLE = RequirementCache("sacrebleu")
_REGEX_AVAILABLE = RequirementCache("regex")
_PYSTOI_AVAILABLE = RequirementCache("pystoi")
_FAST_BSS_EVAL_AVAILABLE = RequirementCache("fast_bss_eval")
Expand Down
23 changes: 23 additions & 0 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from typing import Any, Dict, List, Optional, Sequence, Union

from torch import Tensor
Expand All @@ -20,6 +21,9 @@
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.wrappers.abstract import WrapperMetric

if typing.TYPE_CHECKING:
from torch.nn import Module

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["ClasswiseWrapper.plot"]

Expand Down Expand Up @@ -209,3 +213,22 @@ def plot(
"""
return self._plot(val, ax)

def __getattr__(self, name: str) -> Union[Tensor, "Module"]:
"""Get attribute from classwise wrapper."""
if name == "metric" or (name in self.__dict__ and name not in self.metric.__dict__):
# we need this to prevent from infinite getattribute loop.
return super().__getattr__(name)

return getattr(self.metric, name)

def __setattr__(self, name: str, value: Any) -> None:
"""Set attribute to classwise wrapper."""
if hasattr(self, "metric") and name in self.metric._defaults:
setattr(self.metric, name, value)
else:
super().__setattr__(name, value)
if name == "metric":
self._defaults = self.metric._defaults
self._persistent = self.metric._persistent
self._reductions = self.metric._reductions
23 changes: 23 additions & 0 deletions tests/unittests/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
import os
from typing import Callable, Optional

from torch import Tensor

from unittests import _PATH_ALL_TESTS

_SAMPLE_AUDIO_SPEECH = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "audio_speech.wav")
_SAMPLE_AUDIO_SPEECH_BAB_DB = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "audio_speech_bab_0dB.wav")
_SAMPLE_NUMPY_ISSUE_895 = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "issue_895.npz")


def _average_metric_wrapper(
preds: Tensor, target: Tensor, metric_func: Callable, res_index: Optional[int] = None
) -> Tensor:
"""Average the metric values.
Args:
preds: predictions, shape[batch, spk, time]
target: targets, shape[batch, spk, time]
metric_func: a function which return best_metric and best_perm
res_index: if not None, return best_metric[res_index]
Returns:
the average of best_metric
"""
if res_index is not None:
return metric_func(preds, target)[res_index].mean()
return metric_func(preds, target).mean()
Loading

0 comments on commit 1cbcfde

Please sign in to comment.