Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SRMR: update note and change default parameters #1872

Merged
merged 11 commits into from
Jul 3, 2023
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added speech-to-reverberation modulation energy ratio (SRMR) metric ([#1792](https://github.com/Lightning-AI/torchmetrics/pull/1792))
- Added speech-to-reverberation modulation energy ratio (SRMR) metric ([#1792](https://github.com/Lightning-AI/torchmetrics/pull/1792), [#1872](https://github.com/Lightning-AI/torchmetrics/pull/1872))


- Added new global arg `compute_with_cache` to control caching behaviour after `compute` method ([#1754](https://github.com/Lightning-AI/torchmetrics/pull/1754))
Expand Down
8 changes: 7 additions & 1 deletion src/torchmetrics/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ class SpeechReverberationModulationEnergyRatio(Metric):
Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
and ``pip install git+https://github.com/detly/gammatone``.

.. note::
This implementation is experimental, and might not be consistent with the matlab
implementation `SRMRToolbox`_, especially the fast implementation.
The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have
a relatively small inconsistence.

Args:
fs: the sampling rate
n_cochlear_filters: Number of filters in the acoustic filterbank
Expand Down Expand Up @@ -93,7 +99,7 @@ def __init__(
n_cochlear_filters: int = 23,
low_freq: float = 125,
min_cf: float = 4,
max_cf: Optional[float] = 128,
max_cf: Optional[float] = None,
norm: bool = False,
fast: bool = False,
**kwargs: Any,
Expand Down
8 changes: 7 additions & 1 deletion src/torchmetrics/functional/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def speech_reverberation_modulation_energy_ratio(
n_cochlear_filters: int = 23,
low_freq: float = 125,
min_cf: float = 4,
max_cf: Optional[float] = 128,
max_cf: Optional[float] = None,
norm: bool = False,
fast: bool = False,
) -> Tensor:
Expand All @@ -210,6 +210,12 @@ def speech_reverberation_modulation_energy_ratio(
Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
and ``pip install git+https://github.com/detly/gammatone``.

.. note::
This implementation is experimental, and might not be consistent with the matlab
implementation `SRMRToolbox`_, especially the fast implementation.
The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have
a relatively small inconsistence.

Returns:
Tensor: srmr value, shape ``(...)``

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/audio/test_srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _ref_metric_batch(preds: Tensor, target: Tensor, fs: int, fast: bool, norm:
preds = preds.detach().cpu().numpy()
score = []
for b in range(preds.shape[0]):
val, _ = srmrpy_srmr(preds[b, ...], fs=fs, fast=fast, norm=norm)
val, _ = srmrpy_srmr(preds[b, ...], fs=fs, fast=fast, norm=norm, max_cf=128 if not norm else 30)
score.append(val)
score = torch.tensor(score)
return score.reshape(*shape[:-1])
Expand Down