Skip to content

add C-SI-SNR #1785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e50db6e
add C-SI-SNR
quancs May 15, 2023
b4d89d6
+ Module
quancs May 15, 2023
1909502
add test
quancs May 15, 2023
68babad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2023
1a58c7c
update
quancs May 15, 2023
2fd234c
Merge branch 'c-si-snr' of https://github.com/quancs/torchmetrics int…
quancs May 15, 2023
569528e
fix
quancs May 15, 2023
a237160
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2023
0aa7172
fix
quancs May 16, 2023
aa520cb
fix
quancs May 17, 2023
aeb396c
fix
quancs May 18, 2023
cf360fc
update
quancs May 18, 2023
3773bcf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 18, 2023
4f1e91d
fix
quancs May 19, 2023
fc0ac7e
add
quancs May 19, 2023
df60e87
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2023
e40be39
fix
quancs May 19, 2023
8e63fd2
fix
quancs May 19, 2023
50b2b89
fix
quancs May 19, 2023
c1f694e
fix
quancs May 19, 2023
d816199
fix
quancs May 21, 2023
00cdc2a
fix
quancs May 21, 2023
06962dc
Merge branch 'master' into c-si-snr
quancs May 21, 2023
31fe63c
fix
quancs May 22, 2023
8a55883
Merge branch 'c-si-snr' of https://github.com/quancs/torchmetrics int…
quancs May 22, 2023
b3440e7
Merge branch 'master' into c-si-snr
Borda May 22, 2023
85bc278
Merge branch 'master' into c-si-snr
SkafteNicki May 22, 2023
63c4ec5
add plot testing
SkafteNicki May 22, 2023
7211024
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2023
c019a5e
Apply suggestions from code review
SkafteNicki May 22, 2023
0694ea1
Merge branch 'master' into c-si-snr
mergify[bot] May 22, 2023
a8c0794
Merge branch 'master' into c-si-snr
SkafteNicki May 22, 2023
e969243
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2023
98b03b9
Merge branch 'master' into c-si-snr
mergify[bot] May 22, 2023
6e9cab8
fix imports
SkafteNicki May 22, 2023
80d0c0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2023
5983c80
Merge branch 'master' into c-si-snr
mergify[bot] May 22, 2023
0fb8c05
Merge branch 'master' into c-si-snr
mergify[bot] May 22, 2023
5eef250
fix tests
SkafteNicki May 22, 2023
1f00b8b
Merge branch 'master' into c-si-snr
mergify[bot] May 22, 2023
790b9a6
Merge branch 'master' into c-si-snr
mergify[bot] May 22, 2023
92af26e
Merge branch 'master' into c-si-snr
mergify[bot] May 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new global arg `compute_with_cache` to control caching behaviour after `compute` method ([#1754](https://github.com/Lightning-AI/torchmetrics/pull/1754))


- Added `ComplexScaleInvariantSignalNoiseRatio` for audio package ([#1785](https://github.com/Lightning-AI/torchmetrics/pull/1785))


- Added `Running` wrapper for calculate running statistics ([#1752](https://github.com/Lightning-AI/torchmetrics/pull/1752))


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Complex Scale-Invariant Signal-to-Noise Ratio (C-SI-SNR)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

########################################################
Complex Scale-Invariant Signal-to-Noise Ratio (C-SI-SNR)
########################################################

Module Interface
________________

.. autoclass:: torchmetrics.audio.ComplexScaleInvariantSignalNoiseRatio
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.audio.complex_scale_invariant_signal_noise_ratio
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
.. _sdr ref2: https://arxiv.org/abs/2110.06440
.. _Scale-invariant signal-to-distortion ratio: https://arxiv.org/abs/1811.02508
.. _Scale-invariant signal-to-noise ratio: https://arxiv.org/abs/1711.00541
.. _Complex scale-invariant signal-to-noise ratio: https://arxiv.org/abs/2011.09162
.. _Signal-to-noise ratio: https://arxiv.org/abs/1811.02508
.. _Permutation invariant training: https://arxiv.org/abs/1607.00325
.. _ranking ref1: https://link.springer.com/chapter/10.1007/978-0-387-09823-4_34
Expand Down
7 changes: 6 additions & 1 deletion src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.
from torchmetrics.audio.pit import PermutationInvariantTraining
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio
from torchmetrics.audio.snr import (
ComplexScaleInvariantSignalNoiseRatio,
ScaleInvariantSignalNoiseRatio,
SignalNoiseRatio,
)
from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE

__all__ = [
Expand All @@ -22,6 +26,7 @@
"SignalDistortionRatio",
"ScaleInvariantSignalNoiseRatio",
"SignalNoiseRatio",
"ComplexScaleInvariantSignalNoiseRatio",
]

if _PESQ_AVAILABLE:
Expand Down
122 changes: 120 additions & 2 deletions src/torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@

from torch import Tensor, tensor

from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio
from torchmetrics.functional.audio.snr import (
complex_scale_invariant_signal_noise_ratio,
scale_invariant_signal_noise_ratio,
signal_noise_ratio,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SignalNoiseRatio.plot", "ScaleInvariantSignalNoiseRatio.plot"]
__doctest_skip__ = [
"SignalNoiseRatio.plot",
"ScaleInvariantSignalNoiseRatio.plot",
"ComplexScaleInvariantSignalNoiseRatio.plot",
]


class SignalNoiseRatio(Metric):
Expand Down Expand Up @@ -151,6 +159,7 @@ class ScaleInvariantSignalNoiseRatio(Metric):
if target and preds have a different shape

Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
Expand Down Expand Up @@ -225,3 +234,112 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class ComplexScaleInvariantSignalNoiseRatio(Metric):
"""Calculate `Complex scale-invariant signal-to-noise ratio`_ (C-SI-SNR) metric for evaluating quality of audio.

As input to `forward` and `update` the metric accepts the following input

- ``preds`` (:class:`~torch.Tensor`): real/complex float tensor with shape ``(..., frequency, time, 2)``\
/ ``(..., frequency, time)``

- ``target`` (: :class:`~torch.Tensor`): real/complex float tensor with shape ``(..., frequency, time, 2)``\
/ ``(..., frequency, time)``

As output of `forward` and `compute` the metric returns the following output

- ``c_si_snr`` (: :class:`~torch.Tensor`): float scalar tensor with average C-SI-SNR value over samples

Args:
zero_mean: if to zero mean target and preds or not
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
If ``zero_mean`` is not an bool
TypeError:
If ``preds`` is not the shape (..., frequency, time, 2) (after being converted to real if it is complex).
If ``preds`` and ``target`` does not have the same shape.

Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn((1,257,100,2))
>>> target = torch.randn((1,257,100,2))
>>> c_si_snr = ComplexScaleInvariantSignalNoiseRatio()
>>> c_si_snr(preds, target)
tensor(-63.4849)
"""

is_differentiable = True
sum: Tensor
num: Tensor
higher_is_better = True
plot_lower_bound: Optional[float] = None
plot_upper_bound: Optional[float] = None

def __init__(
self,
zero_mean: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not isinstance(zero_mean, bool):
raise ValueError(f"Expected argument `zero_mean` to be an bool, but got {zero_mean}")
self.zero_mean = zero_mean

self.add_state("sum", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("num", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
v = complex_scale_invariant_signal_noise_ratio(preds=preds, target=target, zero_mean=self.zero_mean)

self.sum += v.sum()
self.num += v.numel()

def compute(self) -> Tensor:
"""Compute metric."""
return self.sum / self.num

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
>>> metric = ComplexScaleInvariantSignalNoiseRatio()
>>> metric.update(torch.rand(1,257,100,2), torch.rand(1,257,100,2))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
>>> metric = ComplexScaleInvariantSignalNoiseRatio()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(1,257,100,2), torch.rand(1,257,100,2)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
16 changes: 14 additions & 2 deletions src/torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,22 @@
# limitations under the License.
from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate
from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio
from torchmetrics.functional.audio.snr import (
complex_scale_invariant_signal_noise_ratio,
scale_invariant_signal_noise_ratio,
signal_noise_ratio,
)
from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE

__all__ = []
__all__ = [
"permutation_invariant_training",
"pit_permutate",
"scale_invariant_signal_distortion_ratio",
"signal_distortion_ratio",
"scale_invariant_signal_noise_ratio",
"signal_noise_ratio",
"complex_scale_invariant_signal_noise_ratio",
]

if _PESQ_AVAILABLE:
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality
Expand Down
42 changes: 42 additions & 0 deletions src/torchmetrics/functional/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,45 @@ def scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor:
tensor(15.0918)
"""
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=True)


def complex_scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
"""`Complex scale-invariant signal-to-noise ratio`_ (C-SI-SNR).

Args:
preds: real/complex float tensor with shape ``(..., frequency, time, 2)``/``(..., frequency, time)``
target: real/complex float tensor with shape ``(..., frequency, time, 2)``/``(..., frequency, time)``
zero_mean: When set to True, the mean of all signals is subtracted prior to computation of the metrics

Returns:
Float tensor with shape ``(...,)`` of C-SI-SNR values per sample

Raises:
RuntimeError:
If ``preds`` is not the shape (..., frequency, time, 2) (after being converted to real if it is complex).
If ``preds`` and ``target`` does not have the same shape.

Example:
>>> import torch
>>> from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn((1,257,100,2))
>>> target = torch.randn((1,257,100,2))
>>> complex_scale_invariant_signal_noise_ratio(preds, target)
tensor([-63.4849])
"""
if preds.is_complex():
preds = torch.view_as_real(preds)
if target.is_complex():
target = torch.view_as_real(target)

if (preds.ndim < 3 or preds.shape[-1] != 2) or (target.ndim < 3 or target.shape[-1] != 2):
raise RuntimeError(
"Predictions and targets are expected to have the shape (..., frequency, time, 2),"
" but got {preds.shape} and {target.shape}."
)

preds = preds.reshape(*preds.shape[:-3], -1)
target = target.reshape(*target.shape[:-3], -1)

return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=zero_mean)
97 changes: 97 additions & 0 deletions tests/unittests/audio/test_c_si_snr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple

import pytest
import torch
from scipy.io import wavfile
from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio
from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio

from unittests import BATCH_SIZE, NUM_BATCHES
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

Input = namedtuple("Input", ["preds", "target"])

inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 129, 20, 2),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 129, 20, 2),
)


@pytest.mark.parametrize(
"preds, target, ref_metric, zero_mean",
[
(inputs.preds, inputs.target, None, True),
(inputs.preds, inputs.target, None, False),
],
)
class TestComplexSISNR(MetricTester):
"""Test class for `ComplexScaleInvariantSignalNoiseRatio` metric."""

atol = 1e-2

def test_c_si_snr_differentiability(self, preds, target, ref_metric, zero_mean):
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=ComplexScaleInvariantSignalNoiseRatio,
metric_functional=complex_scale_invariant_signal_noise_ratio,
metric_args={"zero_mean": zero_mean},
)

def test_c_si_sdr_half_cpu(self, preds, target, ref_metric, zero_mean):
"""Test dtype support of the metric on CPU."""
pytest.xfail("C-SI-SDR metric does not support cpu + half precision")

def test_c_si_sdr_half_gpu(self, preds, target, ref_metric, zero_mean):
"""Test dtype support of the metric on GPU."""
pytest.xfail("C-SI-SDR metric does not support gpu + half precision")


def test_on_real_audio():
"""Test that metric works as expected on real audio signals."""
rate, ref = wavfile.read(_SAMPLE_AUDIO_SPEECH)
rate, deg = wavfile.read(_SAMPLE_AUDIO_SPEECH_BAB_DB)
ref = torch.tensor(ref, dtype=torch.float32)
deg = torch.tensor(deg, dtype=torch.float32)
ref_stft = torch.stft(ref, n_fft=256, hop_length=128, return_complex=True)
deg_stft = torch.stft(deg, n_fft=256, hop_length=128, return_complex=True)

v = complex_scale_invariant_signal_noise_ratio(deg_stft, ref_stft, zero_mean=False)
assert torch.allclose(v, torch.tensor(0.03019072115421295, dtype=v.dtype), atol=1e-4), v
v = complex_scale_invariant_signal_noise_ratio(deg_stft, ref_stft, zero_mean=True)
assert torch.allclose(v, torch.tensor(0.030391741544008255, dtype=v.dtype), atol=1e-4), v


def test_error_on_incorrect_shape(metric_class=ComplexScaleInvariantSignalNoiseRatio):
"""Test that error is raised on incorrect shapes of input."""
metric = metric_class()
with pytest.raises(
RuntimeError,
match="Predictions and targets are expected to have the shape (..., frequency, time, 2)*",
):
metric(torch.randn(100), torch.randn(50))


def test_error_on_different_shape(metric_class=ComplexScaleInvariantSignalNoiseRatio):
"""Test that error is raised on different shapes of input."""
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape*"):
metric(torch.randn(129, 100, 2), torch.randn(129, 101, 2))
7 changes: 7 additions & 0 deletions tests/unittests/utilities/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchmetrics import MetricCollection
from torchmetrics.aggregation import MaxMetric, MeanMetric, MinMetric, SumMetric
from torchmetrics.audio import (
ComplexScaleInvariantSignalNoiseRatio,
ScaleInvariantSignalDistortionRatio,
ScaleInvariantSignalNoiseRatio,
ShortTimeObjectiveIntelligibility,
Expand Down Expand Up @@ -283,6 +284,12 @@
ScaleInvariantSignalDistortionRatio, _rand_input, _rand_input, id="scale_invariant_signal_distortion_ratio"
),
pytest.param(SignalNoiseRatio, _rand_input, _rand_input, id="signal_noise_ratio"),
pytest.param(
ComplexScaleInvariantSignalNoiseRatio,
lambda: torch.randn(10, 3, 5, 2),
lambda: torch.randn(10, 3, 5, 2),
id="complex scale invariant signal noise ratio",
),
pytest.param(ScaleInvariantSignalNoiseRatio, _rand_input, _rand_input, id="scale_invariant_signal_noise_ratio"),
pytest.param(
partial(ShortTimeObjectiveIntelligibility, fs=8000, extended=False),
Expand Down