Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Source-Aggregated Signal-to-Distortion Ratio (SA-SDR) #1882

Merged
merged 36 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
fc1e239
+SA-SDR
quancs May 24, 2023
07526f2
+sasdr
quancs Jun 9, 2023
aa05e7a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2023
1cc28c2
add SA-SDR functional
quancs Jul 3, 2023
b25476b
update
quancs Jul 4, 2023
21a23f1
Merge branch 'sa-sdr' of https://github.com/quancs/torchmetrics into …
quancs Jul 4, 2023
b767f8e
Merge branch 'master' into sa-sdr
quancs Jul 3, 2023
c142517
update
quancs Jul 4, 2023
dd20ee5
add SourceAggregatedSignalDistortionRatio
quancs Jul 4, 2023
3f243f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 4, 2023
d839d7b
Merge branch 'master' into sa-sdr
SkafteNicki Jul 4, 2023
5dd3608
update
quancs Jul 5, 2023
3c750a8
Merge branch 'master' into sa-sdr
quancs Jul 5, 2023
ae6f864
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2023
b4aeb81
update
quancs Jul 5, 2023
b1211e5
fix
quancs Jul 5, 2023
660df0d
Merge branch 'sa-sdr' of https://github.com/quancs/torchmetrics into …
quancs Jul 5, 2023
aa943f0
fix
quancs Jul 5, 2023
ef26d3e
fix
quancs Jul 6, 2023
e35c9fa
Update CHANGELOG.md
quancs Jul 6, 2023
4b48b58
Update src/torchmetrics/audio/sdr.py
quancs Jul 6, 2023
95c7fc7
Update src/torchmetrics/audio/sdr.py
quancs Jul 6, 2023
5e1bc2d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2023
acbe5a2
fix
quancs Jul 6, 2023
73bc564
Merge branch 'master' into sa-sdr
Borda Jul 6, 2023
512cdee
fix
quancs Jul 6, 2023
a5e28b6
fix
quancs Jul 6, 2023
d5a12f5
fix unrelated error
SkafteNicki Jul 9, 2023
1b204de
Merge branch 'master' into sa-sdr
SkafteNicki Jul 10, 2023
d4f47c6
Update src/torchmetrics/audio/sdr.py
quancs Jul 11, 2023
d41287e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2023
5650c8f
add unit
quancs Jul 11, 2023
7ed428c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2023
60785f9
NUM_SAMPLES
Borda Jul 11, 2023
5eecebc
Merge branch 'master' into sa-sdr
mergify[bot] Jul 11, 2023
a9c3b9f
Merge branch 'master' into sa-sdr
SkafteNicki Jul 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added source aggregated signal-to-distortion ratio (SA-SDR) metric ([#1882](https://github.com/Lightning-AI/torchmetrics/pull/1882)


### Changed
Expand Down
23 changes: 23 additions & 0 deletions docs/source/audio/source_aggregated_signal_distortion_ratio.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Source Aggregated Signal-to-Distortion Ratio (SA-SDR)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

#####################################################
Source Aggregated Signal-to-Distortion Ratio (SA-SDR)
#####################################################

Module Interface
________________

.. autoclass:: torchmetrics.audio.sdr.SourceAggregatedSignalDistortionRatio
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.audio.sdr.source_aggregated_signal_distortion_ratio
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
.. _sdr ref2: https://arxiv.org/abs/2110.06440
.. _Scale-invariant signal-to-distortion ratio: https://arxiv.org/abs/1811.02508
.. _Scale-invariant signal-to-noise ratio: https://arxiv.org/abs/1711.00541
.. _Source-aggregated signal-to-distortion ratio: https://arxiv.org/abs/2110.15581
.. _Complex scale-invariant signal-to-noise ratio: https://arxiv.org/abs/2011.09162
.. _Signal-to-noise ratio: https://arxiv.org/abs/1811.02508
.. _Speech-to-Reverberation Modulation Energy Ratio: https://ieeexplore.ieee.org/document/5547575
Expand Down
7 changes: 6 additions & 1 deletion src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.audio.pit import PermutationInvariantTraining
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio
from torchmetrics.audio.sdr import (
ScaleInvariantSignalDistortionRatio,
SignalDistortionRatio,
SourceAggregatedSignalDistortionRatio,
)
from torchmetrics.audio.snr import (
ComplexScaleInvariantSignalNoiseRatio,
ScaleInvariantSignalNoiseRatio,
Expand All @@ -30,6 +34,7 @@
"PermutationInvariantTraining",
"ScaleInvariantSignalDistortionRatio",
"SignalDistortionRatio",
"SourceAggregatedSignalDistortionRatio",
"ScaleInvariantSignalNoiseRatio",
"SignalNoiseRatio",
"ComplexScaleInvariantSignalNoiseRatio",
Expand Down
131 changes: 129 additions & 2 deletions src/torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,23 @@

from torch import Tensor, tensor

from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
from torchmetrics.functional.audio.sdr import (
scale_invariant_signal_distortion_ratio,
signal_distortion_ratio,
source_aggregated_signal_distortion_ratio,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

__doctest_requires__ = {"SignalDistortionRatio": ["fast_bss_eval"]}

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SignalDistortionRatio.plot", "ScaleInvariantSignalDistortionRatio.plot"]
__doctest_skip__ = [
"SignalDistortionRatio.plot",
"ScaleInvariantSignalDistortionRatio.plot",
"SourceAggregatedSignalDistortionRatio.plot",
]


class SignalDistortionRatio(Metric):
Expand Down Expand Up @@ -265,3 +273,122 @@ def plot(
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)


class SourceAggregatedSignalDistortionRatio(Metric):
r"""`Source-aggregated signal-to-distortion ratio`_ (SA-SDR).

The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where
one-speaker and multiple-speaker scenes coexist.

As input to ``forward`` and ``update`` the metric accepts the following input

- ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(..., spk, time)``
- ``target`` (:class:`~torch.Tensor`): float tensor with shape ``(..., spk, time)``

As output of `forward` and `compute` the metric returns the following output

- ``sa_sdr`` (:class:`~torch.Tensor`): float scalar tensor with average SA-SDR value over samples

Args:
preds: float tensor with shape ``(..., spk, time)``
target: float tensor with shape ``(..., spk, time)``
scale_invariant: if True, scale the targets of different speakers with the same alpha
zero_mean: If to zero mean target and preds or not
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(2, 8000) # [..., spk, time]
>>> target = torch.randn(2, 8000)
>>> sasdr = SourceAggregatedSignalDistortionRatio()
>>> sasdr(preds, target)
tensor(-41.6579)
>>> # use with pit
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> pit = PermutationInvariantTraining(source_aggregated_signal_distortion_ratio,
... mode="permutation-wise", eval_func="max")
>>> pit(preds, target)
tensor(-41.2790)
"""

msum: Tensor
mnum: Tensor
full_state_update: bool = False
is_differentiable: bool = True
higher_is_better: bool = True
plot_lower_bound: Optional[float] = None
plot_upper_bound: Optional[float] = None

def __init__(
self,
scale_invariant: bool = True,
zero_mean: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

if not isinstance(scale_invariant, bool):
raise ValueError(f"Expected argument `scale_invarint` to be a bool, but got {scale_invariant}")
self.scale_invariant = scale_invariant
quancs marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(zero_mean, bool):
raise ValueError(f"Expected argument `zero_mean` to be a bool, but got {zero_mean}")
self.zero_mean = zero_mean
quancs marked this conversation as resolved.
Show resolved Hide resolved

self.add_state("msum", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("mnum", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
mbatch = source_aggregated_signal_distortion_ratio(preds, target, self.scale_invariant, self.zero_mean)

self.msum += mbatch.sum()
self.mnum += mbatch.numel()

def compute(self) -> Tensor:
"""Compute metric."""
return self.msum / self.mnum

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> metric = SourceAggregatedSignalDistortionRatio()
>>> metric.update(torch.rand(2,8000), torch.rand(2,8000))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import SourceAggregatedSignalDistortionRatio
>>> metric = SourceAggregatedSignalDistortionRatio()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(2,8000), torch.rand(2,8000)))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
10 changes: 6 additions & 4 deletions src/torchmetrics/classification/group_fairness.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,23 +291,25 @@ def plot(
.. plot::
:scale: 75

>>> from torch import rand, randint
>>> import torch
>>> _ = torch.manual_seed(42)
>>> # Example plotting a single value
>>> from torchmetrics.classification import BinaryFairness
>>> metric = BinaryFairness(2)
>>> metric.update(rand(20), randint(2,(20,)), randint(2,(20,)))
>>> metric.update(torch.rand(20), torch.randint(2,(20,)), torch.randint(2,(20,)))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> from torch import rand, randint, ones
>>> import torch
>>> _ = torch.manual_seed(42)
>>> # Example plotting multiple values
>>> from torchmetrics.classification import BinaryFairness
>>> metric = BinaryFairness(2)
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(rand(20), randint(2,(20,)), ones(20).long()))
... values.append(metric(torch.rand(20), torch.randint(2,(20,)), torch.ones(20).long()))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
7 changes: 6 additions & 1 deletion src/torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate
from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
from torchmetrics.functional.audio.sdr import (
scale_invariant_signal_distortion_ratio,
signal_distortion_ratio,
source_aggregated_signal_distortion_ratio,
)
from torchmetrics.functional.audio.snr import (
complex_scale_invariant_signal_noise_ratio,
scale_invariant_signal_noise_ratio,
Expand All @@ -30,6 +34,7 @@
"permutation_invariant_training",
"pit_permutate",
"scale_invariant_signal_distortion_ratio",
"source_aggregated_signal_distortion_ratio",
"signal_distortion_ratio",
"scale_invariant_signal_noise_ratio",
"signal_noise_ratio",
Expand Down
65 changes: 65 additions & 0 deletions src/torchmetrics/functional/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,68 @@ def scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_

val = (torch.sum(target_scaled**2, dim=-1) + eps) / (torch.sum(noise**2, dim=-1) + eps)
return 10 * torch.log10(val)


def source_aggregated_signal_distortion_ratio(
preds: Tensor,
target: Tensor,
scale_invariant: bool = True,
zero_mean: bool = False,
) -> Tensor:
"""`Source-aggregated signal-to-distortion ratio`_ (SA-SDR).

The SA-SDR is proposed to provide a stable gradient for meeting style source separation, where
one-speaker and multiple-speaker scenes coexist.

Args:
preds: float tensor with shape ``(..., spk, time)``
target: float tensor with shape ``(..., spk, time)``
scale_invariant: if True, scale the targets of different speakers with the same alpha
zero_mean: If to zero mean target and preds or not

Returns:
SA-SDR with shape ``(...)``

Example:
>>> import torch
>>> from torchmetrics.functional.audio import source_aggregated_signal_distortion_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(2, 8000) # [..., spk, time]
>>> target = torch.randn(2, 8000)
>>> source_aggregated_signal_distortion_ratio(preds, target)
tensor(-41.6579)
>>> # use with permutation_invariant_training
>>> from torchmetrics.functional.audio import permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = permutation_invariant_training(preds, target,
... source_aggregated_signal_distortion_ratio, mode="permutation-wise")
>>> best_metric
tensor([-37.9511, -41.9124, -42.7369, -42.5155])
>>> best_perm
tensor([[1, 0],
[1, 0],
[0, 1],
[1, 0]])
"""
_check_same_shape(preds, target)
if preds.ndim < 2:
raise RuntimeError(f"The preds and target should have the shape (..., spk, time), but {preds.shape} found")

eps = torch.finfo(preds.dtype).eps

if zero_mean:
target = target - torch.mean(target, dim=-1, keepdim=True)
preds = preds - torch.mean(preds, dim=-1, keepdim=True)

if scale_invariant:
# scale the targets of different speakers with the same alpha (shape [..., 1, 1])
alpha = ((preds * target).sum(dim=-1, keepdim=True).sum(dim=-2, keepdim=True) + eps) / (
(target**2).sum(dim=-1, keepdim=True).sum(dim=-2, keepdim=True) + eps
)
target = alpha * target

distortion = target - preds

val = ((target**2).sum(dim=-1).sum(dim=-1) + eps) / ((distortion**2).sum(dim=-1).sum(dim=-1) + eps)
return 10 * torch.log10(val)
Loading