diff --git a/tests/test_auraloss.py b/tests/test_auraloss.py index 9377f72..0569755 100644 --- a/tests/test_auraloss.py +++ b/tests/test_auraloss.py @@ -126,7 +126,6 @@ def test_multires_mel(): assert res is not None - def test_perceptual_multires_mel(): target = torch.rand(8, 2, 44100) pred = torch.rand(8, 2, 44100) @@ -143,6 +142,7 @@ def test_perceptual_multires_mel(): res = loss(pred, target) assert res is not None + def test_stft_l2(): N = 32 n = torch.arange(N) @@ -196,4 +196,3 @@ def test_multires_l2(): expected_loss = ((N // 2) ** 2) / (N // 2 + 1) torch.testing.assert_close(res, torch.tensor(expected_loss), rtol=1e-3, atol=1e-3) -