Skip to content

Commit

Permalink
Merge branch 'master' into working_branch
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Apr 17, 2023
2 parents cddf276 + 527afbe commit 63b3acf
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ coverage:
# https://codecov.readme.io/v1.0/docs/commit-status
project:
default:
informational: true
target: 95% # specify the target coverage for each commit status
threshold: 30% # allow this little decrease on project
# https://github.com/codecov/support/wiki/Filtering-Branches
Expand All @@ -28,6 +29,7 @@ coverage:
# https://github.com/codecov/support/wiki/Patch-Status
patch:
default:
informational: true
threshold: 50% # allow this much decrease on patch
changes: false

Expand Down
10 changes: 5 additions & 5 deletions tests/unittests/audio/test_sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
Input = namedtuple("Input", ["preds", "target"])

inputs_1spk = Input(
preds=torch.rand(4, 1, 1, 1000),
target=torch.rand(4, 1, 1, 1000),
preds=torch.rand(2, 1, 1, 500),
target=torch.rand(2, 1, 1, 500),
)
inputs_2spk = Input(
preds=torch.rand(4, 1, 2, 1000),
target=torch.rand(4, 1, 2, 1000),
preds=torch.rand(2, 1, 2, 500),
target=torch.rand(2, 1, 2, 500),
)


Expand All @@ -52,7 +52,7 @@ def _sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool
for b in range(preds.shape[0]):
sdr_val_np, _, _, _ = bss_eval_sources(target[b], preds[b], compute_permutation)
mss.append(sdr_val_np)
return torch.tensor(mss)
return torch.tensor(np.array(mss))


def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
Expand Down
8 changes: 2 additions & 6 deletions tests/unittests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,16 @@

from torchmetrics.audio import SignalNoiseRatio
from torchmetrics.functional.audio import signal_noise_ratio
from unittests import NUM_BATCHES
from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester

seed_all(42)

TIME = 25

Input = namedtuple("Input", ["preds", "target"])

BATCH_SIZE = 2
inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME),
preds=torch.rand(2, 1, 1, 25),
target=torch.rand(2, 1, 1, 25),
)


Expand Down

0 comments on commit 63b3acf

Please sign in to comment.