Skip to content

Commit

Permalink
Merge pull request csteinmetz1#45 from csteinmetz1/tests
Browse files Browse the repository at this point in the history
Setting simple tests with Actions
  • Loading branch information
csteinmetz1 authored Jan 9, 2023
2 parents 6ae5cef + 70eef0b commit ffeedec
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 131 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: pyloudnorm

on: [push]

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
sudo apt-get install libsndfile1-dev
pip install .[all]
pip install pytest
- name: Test with pytest
run: |
pytest
File renamed without changes.
17 changes: 0 additions & 17 deletions tests/adv_test.py

This file was deleted.

19 changes: 0 additions & 19 deletions tests/extras_test.py

This file was deleted.

9 changes: 0 additions & 9 deletions tests/simple_test.py

This file was deleted.

29 changes: 0 additions & 29 deletions tests/stft_test.py

This file was deleted.

89 changes: 89 additions & 0 deletions tests/test_auraloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import torch
import auraloss


def test_mrstft():
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)

loss = auraloss.freq.MultiResolutionSTFTLoss()
res = loss(pred, target)
assert res is not None


def test_stft():
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)

loss = auraloss.freq.STFTLoss()
res = loss(pred, target)
assert res is not None


def test_stft_weights_a():
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)
# test difference weights
loss = auraloss.freq.STFTLoss(
w_log_mag=1.0,
w_lin_mag=0.0,
w_sc=1.0,
reduction="mean",
)
res = loss(pred, target)
assert res is not None


def test_stft_reduction():
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)
# test the reduction
loss = auraloss.freq.STFTLoss(
w_log_mag=1.0,
w_lin_mag=1.0,
w_sc=0.0,
reduction="none",
)
res = loss(pred, target)
print(res.shape)
assert len(res.shape) > 1


def test_sum_and_difference():
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)
loss = auraloss.freq.SumAndDifferenceSTFTLoss()
res = loss(pred, target)
assert res is not None


def test_melstft():
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)
# test MelSTFT
loss = auraloss.freq.MelSTFTLoss(44100)
res = loss(pred, target)
assert res is not None


def test_melstft_reduction():
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)
# test MelSTFT with no reduction
loss = auraloss.freq.MelSTFTLoss(44100, reduction="none")
res = loss(pred, target)
assert len(res) > 1


def test_multires_mel():
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)
sample_rate = 44100
loss = auraloss.freq.MultiResolutionSTFTLoss(
scale="mel",
n_bins=64,
sample_rate=sample_rate,
)
res = loss(pred, target)
assert res is not None
57 changes: 0 additions & 57 deletions tests/test_nn.py

This file was deleted.

0 comments on commit ffeedec

Please sign in to comment.