From 76c502b0d6cbdabb6262b1600b5f8c06bfac8ba5 Mon Sep 17 00:00:00 2001 From: philgzl Date: Wed, 30 Oct 2024 21:36:48 +0100 Subject: [PATCH] Add NISQA metric (#2792) Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 + ...on_intrusive_speech_quality_assessment.rst | 21 + docs/source/links.rst | 1 + src/torchmetrics/audio/__init__.py | 6 + src/torchmetrics/audio/nisqa.py | 152 +++++++ src/torchmetrics/functional/audio/__init__.py | 6 + src/torchmetrics/functional/audio/nisqa.py | 396 ++++++++++++++++++ tests/unittests/audio/test_nisqa.py | 202 +++++++++ 8 files changed, 787 insertions(+) create mode 100644 docs/source/audio/non_intrusive_speech_quality_assessment.rst create mode 100644 src/torchmetrics/audio/nisqa.py create mode 100644 src/torchmetrics/functional/audio/nisqa.py create mode 100644 tests/unittests/audio/test_nisqa.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5add85f280f..196076fdbe7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442)) +- Added a new audio metric `NISQA` ([#2792](https://github.com/PyTorchLightning/metrics/pull/2792)) + + - Added `Dice` metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725)) diff --git a/docs/source/audio/non_intrusive_speech_quality_assessment.rst b/docs/source/audio/non_intrusive_speech_quality_assessment.rst new file mode 100644 index 00000000000..0e4d3b8bf2f --- /dev/null +++ b/docs/source/audio/non_intrusive_speech_quality_assessment.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Non-Intrusive Speech Quality Assessment (NISQA v2.0) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg + :tags: Audio + +.. include:: ../links.rst + +#################################################### +Non-Intrusive Speech Quality Assessment (NISQA v2.0) +#################################################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.audio.nisqa.NonIntrusiveSpeechQualityAssessment + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.audio.nisqa.non_intrusive_speech_quality_assessment diff --git a/docs/source/links.rst b/docs/source/links.rst index b7a4f63565e..ed01989b3cc 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -111,6 +111,7 @@ .. _Perceptual Evaluation of Speech Quality: https://en.wikipedia.org/wiki/Perceptual_Evaluation_of_Speech_Quality .. _pesq package: https://github.com/ludlows/python-pesq .. _Deep Noise Suppression performance evaluation based on Mean Opinion Score: https://arxiv.org/abs/2010.15258 +.. _Non-Intrusive Speech Quality Assessment: https://arxiv.org/abs/2104.09494 .. _Cees Taal's website: http://www.ceestaal.nl/code/ .. _pystoi package: https://github.com/mpariente/pystoi .. _stoi ref1: https://ieeexplore.ieee.org/abstract/document/5495701 diff --git a/src/torchmetrics/audio/__init__.py b/src/torchmetrics/audio/__init__.py index 14b987a7113..24ff9e737e8 100644 --- a/src/torchmetrics/audio/__init__.py +++ b/src/torchmetrics/audio/__init__.py @@ -28,6 +28,7 @@ _ONNXRUNTIME_AVAILABLE, _PESQ_AVAILABLE, _PYSTOI_AVAILABLE, + _REQUESTS_AVAILABLE, _SCIPI_AVAILABLE, _TORCHAUDIO_AVAILABLE, ) @@ -68,3 +69,8 @@ from torchmetrics.audio.dnsmos import DeepNoiseSuppressionMeanOpinionScore __all__ += ["DeepNoiseSuppressionMeanOpinionScore"] + +if _LIBROSA_AVAILABLE and _REQUESTS_AVAILABLE: + from torchmetrics.audio.nisqa import NonIntrusiveSpeechQualityAssessment + + __all__ += ["NonIntrusiveSpeechQualityAssessment"] diff --git a/src/torchmetrics/audio/nisqa.py b/src/torchmetrics/audio/nisqa.py new file mode 100644 index 00000000000..d722b192bb5 --- /dev/null +++ b/src/torchmetrics/audio/nisqa.py @@ -0,0 +1,152 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Sequence, Union + +from torch import Tensor, tensor + +from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import ( + _LIBROSA_AVAILABLE, + _MATPLOTLIB_AVAILABLE, + _REQUESTS_AVAILABLE, +) +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +__doctest_requires__ = {"NonIntrusiveSpeechQualityAssessment": ["librosa", "requests"]} + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["NonIntrusiveSpeechQualityAssessment.plot"] + + +class NonIntrusiveSpeechQualityAssessment(Metric): + """`Non-Intrusive Speech Quality Assessment`_ (NISQA v2.0) [1], [2]. + + As input to ``forward`` and ``update`` the metric accepts the following input + + - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)`` + + As output of ``forward`` and ``compute`` the metric returns the following output + + - ``nisqa`` (:class:`~torch.Tensor`): float tensor reduced across the batch with shape ``(5,)`` corresponding to + overall MOS, noisiness, discontinuity, coloration and loudness in that order + + .. note:: Using this metric requires you to have ``librosa`` and ``requests`` installed. Install as + ``pip install librosa requests``. + + .. note:: The ``forward`` and ``compute`` methods in this class return values reduced across the batch. To obtain + values for each sample, you may use the functional counterpart + :func:`~torchmetrics.functional.audio.nisqa.non_intrusive_speech_quality_assessment`. + + Args: + fs: sampling frequency of input + + Raises: + ModuleNotFoundError: + If ``librosa`` or ``requests`` are not installed + + Example: + >>> import torch + >>> from torchmetrics.audio import NonIntrusiveSpeechQualityAssessment + >>> _ = torch.manual_seed(42) + >>> preds = torch.randn(16000) + >>> nisqa = NonIntrusiveSpeechQualityAssessment(16000) + >>> nisqa(preds) + tensor([1.0433, 1.9545, 2.6087, 1.3460, 1.7117]) + + References: + - [1] G. Mittag and S. Möller, "Non-intrusive speech quality assessment for super-wideband speech communication + networks", in Proc. ICASSP, 2019. + - [2] G. Mittag, B. Naderi, A. Chehadi and S. Möller, "NISQA: A deep CNN-self-attention model for + multidimensional speech quality prediction with crowdsourced datasets", in Proc. INTERSPEECH, 2021. + + """ + + sum_nisqa: Tensor + total: Tensor + full_state_update: bool = False + is_differentiable: bool = False + higher_is_better: bool = True + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 5.0 + + def __init__(self, fs: int, **kwargs: Any) -> None: + super().__init__(**kwargs) + if not _LIBROSA_AVAILABLE or not _REQUESTS_AVAILABLE: + raise ModuleNotFoundError( + "NISQA metric requires that librosa and requests are installed. " + "Install as `pip install librosa requests`." + ) + if not isinstance(fs, int) or fs <= 0: + raise ValueError(f"Argument `fs` expected to be a positive integer, but got {fs}") + self.fs = fs + + self.add_state("sum_nisqa", default=tensor([0.0, 0.0, 0.0, 0.0, 0.0]), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor) -> None: + """Update state with predictions.""" + nisqa_batch = non_intrusive_speech_quality_assessment( + preds, + self.fs, + ).to(self.sum_nisqa.device) + + nisqa_batch = nisqa_batch.reshape(-1, 5) + self.sum_nisqa += nisqa_batch.sum(dim=0) + self.total += nisqa_batch.shape[0] + + def compute(self) -> Tensor: + """Compute metric.""" + return self.sum_nisqa / self.total + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling ``metric.forward`` or ``metric.compute`` or a list of these + results. If no value is provided, will automatically call ``metric.compute`` and plot that result. + ax: A matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If ``matplotlib`` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio import NonIntrusiveSpeechQualityAssessment + >>> metric = NonIntrusiveSpeechQualityAssessment(16000) + >>> metric.update(torch.randn(16000)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio import NonIntrusiveSpeechQualityAssessment + >>> metric = NonIntrusiveSpeechQualityAssessment(16000) + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(torch.randn(16000))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/functional/audio/__init__.py b/src/torchmetrics/functional/audio/__init__.py index ac228cf671a..edd9e9283e0 100644 --- a/src/torchmetrics/functional/audio/__init__.py +++ b/src/torchmetrics/functional/audio/__init__.py @@ -28,6 +28,7 @@ _ONNXRUNTIME_AVAILABLE, _PESQ_AVAILABLE, _PYSTOI_AVAILABLE, + _REQUESTS_AVAILABLE, _SCIPI_AVAILABLE, _TORCHAUDIO_AVAILABLE, ) @@ -69,3 +70,8 @@ from torchmetrics.functional.audio.dnsmos import deep_noise_suppression_mean_opinion_score __all__ += ["deep_noise_suppression_mean_opinion_score"] + +if _LIBROSA_AVAILABLE and _REQUESTS_AVAILABLE: + from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment + + __all__ += ["non_intrusive_speech_quality_assessment"] diff --git a/src/torchmetrics/functional/audio/nisqa.py b/src/torchmetrics/functional/audio/nisqa.py new file mode 100644 index 00000000000..6696f93965d --- /dev/null +++ b/src/torchmetrics/functional/audio/nisqa.py @@ -0,0 +1,396 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Code related main NISQA model definition are under the following copyright + +# Copyright (c) 2021 Gabriel Mittag, Quality and Usability Lab + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import copy +import math +import os +import warnings +from functools import lru_cache +from typing import Any, Dict, Tuple + +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.functional import adaptive_max_pool2d, relu, softmax +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from torchmetrics.utilities import rank_zero_info +from torchmetrics.utilities.imports import _LIBROSA_AVAILABLE, _REQUESTS_AVAILABLE + +if _LIBROSA_AVAILABLE and _REQUESTS_AVAILABLE: + import librosa + import requests +else: + librosa, requests = None, None # type:ignore + +__doctest_requires__ = {("non_intrusive_speech_quality_assessment",): ["librosa", "requests"]} + +NISQA_DIR = "~/.torchmetrics/NISQA" + + +def non_intrusive_speech_quality_assessment(preds: Tensor, fs: int) -> Tensor: + """`Non-Intrusive Speech Quality Assessment`_ (NISQA v2.0) [1], [2]. + + .. note:: Using this metric requires you to have ``librosa`` and ``requests`` installed. Install as + ``pip install librosa requests``. + + Args: + preds: float tensor with shape ``(...,time)`` + fs: sampling frequency of input + + Returns: + Float tensor with shape ``(...,5)`` corresponding to overall MOS, noisiness, discontinuity, coloration and + loudness in that order + + Raises: + ModuleNotFoundError: + If ``librosa`` or ``requests`` are not installed + RuntimeError: + If the input is too short, causing the number of mel spectrogram windows to be zero + RuntimeError: + If the input is too long, causing the number of mel spectrogram windows to exceed the maximum allowed + + Example: + >>> import torch + >>> from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment + >>> _ = torch.manual_seed(42) + >>> preds = torch.randn(16000) + >>> non_intrusive_speech_quality_assessment(preds, 16000) + tensor([1.0433, 1.9545, 2.6087, 1.3460, 1.7117]) + + References: + - [1] G. Mittag and S. Möller, "Non-intrusive speech quality assessment for super-wideband speech communication + networks", in Proc. ICASSP, 2019. + - [2] G. Mittag, B. Naderi, A. Chehadi and S. Möller, "NISQA: A deep CNN-self-attention model for + multidimensional speech quality prediction with crowdsourced datasets", in Proc. INTERSPEECH, 2021. + + """ + if not _LIBROSA_AVAILABLE or not _REQUESTS_AVAILABLE: + raise ModuleNotFoundError( + "NISQA metric requires that librosa and requests are installed. Install as `pip install librosa requests`." + ) + model, args = _load_nisqa_model() + if not isinstance(fs, int) or fs <= 0: + raise ValueError(f"Argument `fs` expected to be a positive integer, but got {fs}") + model.eval() + x = preds.reshape(-1, preds.shape[-1]) + x = _get_librosa_melspec(x.cpu().numpy(), fs, args) + x, n_wins = _segment_specs(torch.from_numpy(x), args) + with torch.no_grad(): + x = model(x, n_wins.expand(x.shape[0])) + # ["mos_pred", "noi_pred", "dis_pred", "col_pred", "loud_pred"] + # the dimensions are always listed in the papers as MOS, noisiness, coloration, discontinuity and loudness + # but based on original code the actual model output order is MOS, noisiness, discontinuity, coloration, loudness + return x.reshape(preds.shape[:-1] + (5,)) + + +@lru_cache +def _load_nisqa_model() -> Tuple[nn.Module, Dict[str, Any]]: + """Load NISQA model and its parameters. + + Returns: + Tuple ``(model,args)`` where ``model`` is the NISQA model and ``args`` is a dictionary with all its parameters + + """ + model_path = os.path.expanduser(os.path.join(NISQA_DIR, "nisqa.tar")) + if not os.path.exists(model_path): + _download_weights() + checkpoint = torch.load(model_path, map_location="cpu", weights_only=True) + args = checkpoint["args"] + model = _NISQADIM(args) + model.load_state_dict(checkpoint["model_state_dict"], strict=True) + return model, args + + +def _download_weights() -> None: + """Download NISQA model weights.""" + url = "https://github.com/gabrielmittag/NISQA/raw/refs/heads/master/weights/nisqa.tar" + nisqa_dir = os.path.expanduser(NISQA_DIR) + os.makedirs(nisqa_dir, exist_ok=True) + saveto = os.path.join(nisqa_dir, "nisqa.tar") + if os.path.exists(saveto): + return + rank_zero_info(f"downloading {url} to {saveto}") + myfile = requests.get(url) + with open(saveto, "wb") as f: + f.write(myfile.content) + + +class _NISQADIM(nn.Module): + # main NISQA model definition + # ported from https://github.com/gabrielmittag/NISQA + # Copyright (c) 2021 Gabriel Mittag, Quality and Usability Lab + # MIT License + def __init__(self, args: Dict[str, Any]) -> None: + super().__init__() + self.cnn = _Framewise(args) + self.time_dependency = _TimeDependency(args) + pool = _Pooling(args) + self.pool_layers = _get_clones(pool, 5) + + def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: + x = self.cnn(x, n_wins) + x, n_wins = self.time_dependency(x, n_wins) + out = [mod(x, n_wins) for mod in self.pool_layers] + return torch.cat(out, dim=1) + + +class _Framewise(nn.Module): + # part of NISQA model definition + def __init__(self, args: Dict[str, Any]) -> None: + super().__init__() + self.model = _AdaptCNN(args) + + def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: + x_packed = pack_padded_sequence(x, n_wins, batch_first=True, enforce_sorted=False) + x = self.model(x_packed.data.unsqueeze(1)) + x = x_packed._replace(data=x) + x, _ = pad_packed_sequence(x, batch_first=True, padding_value=0.0, total_length=int(n_wins.max())) + return x + + +class _AdaptCNN(nn.Module): + # part of NISQA model definition + def __init__(self, args: Dict[str, Any]) -> None: + super().__init__() + self.pool_1 = args["cnn_pool_1"] + self.pool_2 = args["cnn_pool_2"] + self.pool_3 = args["cnn_pool_3"] + self.dropout = nn.Dropout2d(p=args["cnn_dropout"]) + cnn_pad = (1, 0) if args["cnn_kernel_size"][0] == 1 else (1, 1) + self.conv1 = nn.Conv2d(1, args["cnn_c_out_1"], args["cnn_kernel_size"], padding=cnn_pad) + self.bn1 = nn.BatchNorm2d(self.conv1.out_channels) + self.conv2 = nn.Conv2d(self.conv1.out_channels, args["cnn_c_out_2"], args["cnn_kernel_size"], padding=cnn_pad) + self.bn2 = nn.BatchNorm2d(self.conv2.out_channels) + self.conv3 = nn.Conv2d(self.conv2.out_channels, args["cnn_c_out_3"], args["cnn_kernel_size"], padding=cnn_pad) + self.bn3 = nn.BatchNorm2d(self.conv3.out_channels) + self.conv4 = nn.Conv2d(self.conv3.out_channels, args["cnn_c_out_3"], args["cnn_kernel_size"], padding=cnn_pad) + self.bn4 = nn.BatchNorm2d(self.conv4.out_channels) + self.conv5 = nn.Conv2d(self.conv4.out_channels, args["cnn_c_out_3"], args["cnn_kernel_size"], padding=cnn_pad) + self.bn5 = nn.BatchNorm2d(self.conv5.out_channels) + self.conv6 = nn.Conv2d( + self.conv5.out_channels, + args["cnn_c_out_3"], + (args["cnn_kernel_size"][0], args["cnn_pool_3"][1]), + padding=(1, 0), + ) + self.bn6 = nn.BatchNorm2d(self.conv6.out_channels) + + def forward(self, x: Tensor) -> Tensor: + x = relu(self.bn1(self.conv1(x))) + x = adaptive_max_pool2d(x, output_size=(self.pool_1)) + x = relu(self.bn2(self.conv2(x))) + x = adaptive_max_pool2d(x, output_size=(self.pool_2)) + x = self.dropout(x) + x = relu(self.bn3(self.conv3(x))) + x = self.dropout(x) + x = relu(self.bn4(self.conv4(x))) + x = adaptive_max_pool2d(x, output_size=(self.pool_3)) + x = self.dropout(x) + x = relu(self.bn5(self.conv5(x))) + x = self.dropout(x) + x = relu(self.bn6(self.conv6(x))) + return x.view(-1, self.conv6.out_channels * self.pool_3[0]) + + +class _TimeDependency(nn.Module): + # part of NISQA model definition + def __init__(self, args: Dict[str, Any]) -> None: + super().__init__() + self.model = _SelfAttention(args) + + def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: + return self.model(x, n_wins) + + +class _SelfAttention(nn.Module): + # part of NISQA model definition + def __init__(self, args: Dict[str, Any]) -> None: + super().__init__() + encoder_layer = _SelfAttentionLayer(args) + self.norm1 = nn.LayerNorm(args["td_sa_d_model"]) + self.linear = nn.Linear(args["cnn_c_out_3"] * args["cnn_pool_3"][0], args["td_sa_d_model"]) + self.layers = _get_clones(encoder_layer, args["td_sa_num_layers"]) + self._reset_parameters() + + def _reset_parameters(self) -> None: + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src: Tensor, n_wins: Tensor) -> Tuple[Tensor, Tensor]: + src = self.linear(src) + output = src.transpose(1, 0) + output = self.norm1(output) + for mod in self.layers: + output, n_wins = mod(output, n_wins) + return output.transpose(1, 0), n_wins + + +class _SelfAttentionLayer(nn.Module): + # part of NISQA model definition + def __init__(self, args: Dict[str, Any]) -> None: + super().__init__() + self.self_attn = nn.MultiheadAttention(args["td_sa_d_model"], args["td_sa_nhead"], args["td_sa_dropout"]) + self.linear1 = nn.Linear(args["td_sa_d_model"], args["td_sa_h"]) + self.dropout = nn.Dropout(args["td_sa_dropout"]) + self.linear2 = nn.Linear(args["td_sa_h"], args["td_sa_d_model"]) + self.norm1 = nn.LayerNorm(args["td_sa_d_model"]) + self.norm2 = nn.LayerNorm(args["td_sa_d_model"]) + self.dropout1 = nn.Dropout(args["td_sa_dropout"]) + self.dropout2 = nn.Dropout(args["td_sa_dropout"]) + self.activation = relu + + def forward(self, src: Tensor, n_wins: Tensor) -> Tuple[Tensor, Tensor]: + mask = torch.arange(src.shape[0])[None, :] < n_wins[:, None] + src2 = self.self_attn(src, src, src, key_padding_mask=~mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src, n_wins + + +class _Pooling(nn.Module): + # part of NISQA model definition + def __init__(self, args: Dict[str, Any]) -> None: + super().__init__() + self.model = _PoolAttFF(args) + + def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: + return self.model(x, n_wins) + + +class _PoolAttFF(torch.nn.Module): + # part of NISQA model definition + def __init__(self, args: Dict[str, Any]) -> None: + super().__init__() + self.linear1 = nn.Linear(args["td_sa_d_model"], args["pool_att_h"]) + self.linear2 = nn.Linear(args["pool_att_h"], 1) + self.linear3 = nn.Linear(args["td_sa_d_model"], 1) + self.activation = relu + self.dropout = nn.Dropout(args["pool_att_dropout"]) + + def forward(self, x: Tensor, n_wins: Tensor) -> Tensor: + att = self.linear2(self.dropout(self.activation(self.linear1(x)))) + att = att.transpose(2, 1) + mask = torch.arange(att.shape[2])[None, :] < n_wins[:, None] + att[~mask.unsqueeze(1)] = float("-inf") + att = softmax(att, dim=2) + x = torch.bmm(att, x) + x = x.squeeze(1) + return self.linear3(x) + + +def _get_librosa_melspec(y: np.ndarray, sr: int, args: Dict[str, Any]) -> np.ndarray: + """Compute mel spectrogram from waveform using librosa. + + Args: + y: waveform with shape ``(batch_size,time)`` + sr: sampling rate + args: dictionary with all NISQA parameters + + Returns: + Mel spectrogram with shape ``(batch_size,n_mels,n_frames)`` + + """ + hop_length = int(sr * args["ms_hop_length"]) + win_length = int(sr * args["ms_win_length"]) + with warnings.catch_warnings(): + # ignore empty mel filter warning since this is expected when input signal is not fullband + # see https://github.com/gabrielmittag/NISQA/issues/6#issuecomment-838157571 + warnings.filterwarnings("ignore", message="Empty filters detected in mel frequency basis") + melspec = librosa.feature.melspectrogram( + y=y, + sr=sr, + S=None, + n_fft=args["ms_n_fft"], + hop_length=hop_length, + win_length=win_length, + window="hann", + center=True, + pad_mode="reflect", + power=1.0, + n_mels=args["ms_n_mels"], + fmin=0.0, + fmax=args["ms_fmax"], + htk=False, + norm="slaney", + ) + # batch processing of librosa.core.amplitude_to_db is not equivalent to individual processing due to top_db being + # relative to max value + # so process individually and then stack + return np.stack([librosa.amplitude_to_db(m, ref=1.0, amin=1e-4, top_db=80.0) for m in melspec]) + + +def _segment_specs(x: Tensor, args: Dict[str, Any]) -> Tuple[Tensor, Tensor]: + """Segment mel spectrogram into overlapping windows. + + Args: + x: mel spectrogram with shape ``(batch_size,n_mels,n_frames)`` + args: dictionary with all NISQA parameters + + Returns: + Tuple ``(x_padded,n_wins)```, where ``x_padded`` is the segmented mel spectrogram with shape + ``(batch_size,max_length,n_mels,seg_length)`` where the second dimension is the number of windows and was + padded to ``max_length``, and ``n_wins`` is the number of windows and is 0-dimensional + + """ + seg_length = args["ms_seg_length"] + seg_hop = args["ms_seg_hop_length"] + max_length = args["ms_max_segments"] + n_wins = x.shape[2] - (seg_length - 1) + if n_wins < 1: + raise RuntimeError("Input signal is too short.") + idx1 = torch.arange(seg_length) + idx2 = torch.arange(n_wins) + idx3 = idx1.unsqueeze(0) + idx2.unsqueeze(1) + x = x.transpose(2, 1)[:, idx3, :].transpose(3, 2) + x = x[:, ::seg_hop] + n_wins = math.ceil(n_wins / seg_hop) + if max_length < n_wins: + raise RuntimeError("Maximum number of mel spectrogram windows exceeded. Use shorter audio.") + x_padded = torch.zeros((x.shape[0], max_length, x.shape[2], x.shape[3])) + x_padded[:, :n_wins] = x + return x_padded, torch.tensor(n_wins) + + +def _get_clones(module: nn.Module, n: int) -> nn.ModuleList: + """Create ``n`` copies of a module.""" + return nn.ModuleList([copy.deepcopy(module) for i in range(n)]) diff --git a/tests/unittests/audio/test_nisqa.py b/tests/unittests/audio/test_nisqa.py new file mode 100644 index 00000000000..a36ad0db1c6 --- /dev/null +++ b/tests/unittests/audio/test_nisqa.py @@ -0,0 +1,202 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Any, Dict, Tuple + +import pytest +import torch +from torch import Tensor +from torchmetrics.audio.nisqa import NonIntrusiveSpeechQualityAssessment +from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment + +from unittests._helpers.testers import MetricTester + +# reference values below were calculated using the method described in https://github.com/gabrielmittag/NISQA/blob/master/README.md +inputs = [ + { + "preds": torch.rand(2, 2, 16000, generator=torch.Generator().manual_seed(42)), # uniform noise + "fs": 16000, + "reference": torch.tensor([ + [ + [0.8105150461, 1.8459059000, 2.4780223370, 1.0402423143, 1.5687377453], + [0.8629049063, 1.7767801285, 2.3915612698, 1.0460783243, 1.6212222576], + ], + [ + [0.8608418703, 1.9113740921, 2.5213730335, 1.0900889635, 1.6314117908], + [0.8071692586, 1.7834275961, 2.4235677719, 1.0236976147, 1.5617829561], + ], + ]), + }, + { + "preds": torch.rand(2, 2, 48000, generator=torch.Generator().manual_seed(42)), # uniform noise + "fs": 48000, + "reference": torch.tensor([ + [ + [0.7670641541, 1.1634330750, 2.6056811810, 1.4002652168, 1.5218108892], + [0.7974857688, 1.1845922470, 2.6476621628, 1.4282002449, 1.5324314833], + ], + [ + [0.8114687800, 1.1764185429, 2.6281285286, 1.4396891594, 1.5460423231], + [0.6779640913, 1.1818346977, 2.5106279850, 1.2842310667, 1.4014176130], + ], + ]), + }, + { + "preds": torch.stack([ + torch.stack([ + torch.sin(2 * 3.14159 * 440 / 16000 * torch.arange(16000)), # 1 s 440 Hz tone @ 16 kHz + torch.sin(2 * 3.14159 * 1000 / 16000 * torch.arange(16000)), # 1 s 1000 Hz tone @ 16 kHz + ]), + torch.stack([ + torch.sign(torch.sin(2 * 3.14159 * 200 / 16000 * torch.arange(16000))), # 1 s 200 Hz square @ 16 kHz + (1 + 2 * 200 / 16000 * torch.arange(16000)) % 2 - 1, # 1 s 200 Hz sawtooth @ 16 kHz + ]), + ]), + "fs": 16000, + "reference": torch.tensor([ + [ + [1.1243989468, 2.1237702370, 3.6184809208, 1.2584471703, 1.8518198729], + [1.2761806250, 1.8802671432, 3.3731021881, 1.2554246187, 1.6879540682], + ], + [ + [0.9259074330, 2.7644648552, 3.1585879326, 1.4163932800, 1.5672523975], + [0.8493731022, 2.6398222446, 3.0776870251, 1.1348335743, 1.6034533978], + ], + ]), + }, + { + "preds": torch.stack([ + torch.stack([ + torch.sin(2 * 3.14159 * 440 / 48000 * torch.arange(48000)), # 1 s 440 Hz tone @ 48 kHz + torch.sin(2 * 3.14159 * 1000 / 48000 * torch.arange(48000)), # 1 s 1000 Hz tone @ 48 kHz + ]), + torch.stack([ + torch.sign(torch.sin(2 * 3.14159 * 200 / 48000 * torch.arange(48000))), # 1 s 200 Hz square @ 48 kHz + (1 + 2 * 200 / 48000 * torch.arange(48000)) % 2 - 1, # 1 s 200 Hz sawtooth @ 48 kHz + ]), + ]), + "fs": 48000, + "reference": torch.tensor([ + [ + [1.1263639927, 2.1246092319, 3.6191856861, 1.2572505474, 1.8531025648], + [1.2741736174, 1.8896869421, 3.3755991459, 1.2591584921, 1.6720581055], + ], + [ + [0.8731431961, 1.6447117329, 2.8125579357, 1.6197175980, 1.2627843618], + [1.2543514967, 2.0644433498, 3.1744530201, 1.8767380714, 1.9447042942], + ], + ]), + }, +] + + +def _reference_metric_batch(preds, target, mean): + def _reference_metric(preds): + for pred, ref in zip(*[ + [x for i in inputs for x in i[which].reshape(-1, i[which].shape[-1])] for which in ["preds", "reference"] + ]): + if torch.equal(preds, pred): + return ref + raise NotImplementedError + + out = torch.stack([_reference_metric(pred) for pred in preds.reshape(-1, preds.shape[-1])]) + return out.mean(dim=0) if mean else out.reshape(*preds.shape[:-1], 5) + + +def _nisqa_cheat(preds, target, **kwargs: Dict[str, Any]): + # cheat the MetricTester as non_intrusive_speech_quality_assessment does not need a target + return non_intrusive_speech_quality_assessment(preds, **kwargs) + + +class _NISQACheat(NonIntrusiveSpeechQualityAssessment): + # cheat the MetricTester as NonIntrusiveSpeechQualityAssessment does not need a target + def update(self, preds: Tensor, target: Tensor) -> None: + super().update(preds=preds) + + +@pytest.mark.parametrize("preds, fs, reference", [(i["preds"], i["fs"], i["reference"]) for i in inputs]) +class TestNISQA(MetricTester): + """Test class for `NonIntrusiveSpeechQualityAssessment` metric.""" + + atol = 1e-4 + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_nisqa(self, preds: Tensor, reference: Tensor, fs: int, ddp: bool, device=None): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp, + preds=preds, + target=preds, + metric_class=_NISQACheat, + reference_metric=partial(_reference_metric_batch, mean=True), + metric_args={"fs": fs}, + ) + + def test_nisqa_functional(self, preds: Tensor, reference: Tensor, fs: int, device="cpu"): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=preds, + metric_functional=_nisqa_cheat, + reference_metric=partial(_reference_metric_batch, mean=False), + metric_args={"fs": fs}, + ) + + +@pytest.mark.parametrize("shape", [(3000,), (2, 3000), (1, 2, 3000), (2, 3, 1, 3000)]) +def test_shape(shape: Tuple[int]): + """Test output shape.""" + preds = torch.rand(*shape) + out = non_intrusive_speech_quality_assessment(preds, 16000) + assert out.shape == (*shape[:-1], 5) + metric = NonIntrusiveSpeechQualityAssessment(16000) + out = metric(preds) + assert out.shape == (5,) + + +def test_batched_vs_unbatched(): + """Test batched versus unbatched processing.""" + preds = torch.rand(2, 2, 16000, generator=torch.Generator().manual_seed(42)) + out_batched = non_intrusive_speech_quality_assessment(preds, 16000) + out_unbatched = torch.stack([ + non_intrusive_speech_quality_assessment(x, 16000) for x in preds.reshape(-1, 16000) + ]).reshape(2, 2, 5) + assert torch.allclose(out_batched, out_unbatched) + + +def test_error_on_short_input(): + """Test error on short input.""" + preds = torch.rand(3000) + non_intrusive_speech_quality_assessment(preds, 16000) + with pytest.raises(RuntimeError, match="Input signal is too short."): + non_intrusive_speech_quality_assessment(preds, 48000) + preds = torch.rand(2000) + with pytest.raises(RuntimeError, match="Input signal is too short."): + non_intrusive_speech_quality_assessment(preds, 16000) + with pytest.raises(RuntimeError, match="Input signal is too short."): + non_intrusive_speech_quality_assessment(preds, 48000) + + +def test_error_on_long_input(): + """Test error on long input.""" + preds = torch.rand(834240) + with pytest.raises(RuntimeError, match="Maximum number of mel spectrogram windows exceeded. Use shorter audio."): + non_intrusive_speech_quality_assessment(preds, 16000) + non_intrusive_speech_quality_assessment(preds, 48000) + preds = torch.rand(2502720) + with pytest.raises(RuntimeError, match="Maximum number of mel spectrogram windows exceeded. Use shorter audio."): + non_intrusive_speech_quality_assessment(preds, 16000) + with pytest.raises(RuntimeError, match="Maximum number of mel spectrogram windows exceeded. Use shorter audio."): + non_intrusive_speech_quality_assessment(preds, 48000)