Skip to content

Commit

Permalink
SRMR: update note and change default parameters (#1872)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 3, 2023
1 parent 5cc05c3 commit ba34076
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added speech-to-reverberation modulation energy ratio (SRMR) metric ([#1792](https://github.com/Lightning-AI/torchmetrics/pull/1792))
- Added speech-to-reverberation modulation energy ratio (SRMR) metric ([#1792](https://github.com/Lightning-AI/torchmetrics/pull/1792), [#1872](https://github.com/Lightning-AI/torchmetrics/pull/1872))


- Added new global arg `compute_with_cache` to control caching behaviour after `compute` method ([#1754](https://github.com/Lightning-AI/torchmetrics/pull/1754))
Expand Down
8 changes: 7 additions & 1 deletion src/torchmetrics/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ class SpeechReverberationModulationEnergyRatio(Metric):
Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
and ``pip install git+https://github.com/detly/gammatone``.
.. note::
This implementation is experimental, and might not be consistent with the matlab
implementation `SRMRToolbox`_, especially the fast implementation.
The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have
a relatively small inconsistence.
Args:
fs: the sampling rate
n_cochlear_filters: Number of filters in the acoustic filterbank
Expand Down Expand Up @@ -93,7 +99,7 @@ def __init__(
n_cochlear_filters: int = 23,
low_freq: float = 125,
min_cf: float = 4,
max_cf: Optional[float] = 128,
max_cf: Optional[float] = None,
norm: bool = False,
fast: bool = False,
**kwargs: Any,
Expand Down
8 changes: 7 additions & 1 deletion src/torchmetrics/functional/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def speech_reverberation_modulation_energy_ratio(
n_cochlear_filters: int = 23,
low_freq: float = 125,
min_cf: float = 4,
max_cf: Optional[float] = 128,
max_cf: Optional[float] = None,
norm: bool = False,
fast: bool = False,
) -> Tensor:
Expand All @@ -210,6 +210,12 @@ def speech_reverberation_modulation_energy_ratio(
Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio``
and ``pip install git+https://github.com/detly/gammatone``.
.. note::
This implementation is experimental, and might not be consistent with the matlab
implementation `SRMRToolbox`_, especially the fast implementation.
The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have
a relatively small inconsistence.
Returns:
Tensor: srmr value, shape ``(...)``
Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/audio/test_srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _ref_metric_batch(preds: Tensor, target: Tensor, fs: int, fast: bool, norm:
preds = preds.detach().cpu().numpy()
score = []
for b in range(preds.shape[0]):
val, _ = srmrpy_srmr(preds[b, ...], fs=fs, fast=fast, norm=norm)
val, _ = srmrpy_srmr(preds[b, ...], fs=fs, fast=fast, norm=norm, max_cf=128 if not norm else 30)
score.append(val)
score = torch.tensor(score)
return score.reshape(*shape[:-1])
Expand Down

0 comments on commit ba34076

Please sign in to comment.