Skip to content

Commit 6ae5931

Browse files
authored
Merge pull request #13 from Spijkervet/pitch_tests
End-to-end PitchShift transform tests
2 parents dad94b1 + cf2cf41 commit 6ae5931

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
# What packages are required for this module to be executed?
2424
REQUIRED = ["numpy", "torch", "torchaudio", "julius", "wavaugment", "torch-pitch-shift"]
25-
TEST_REQUIRED = ["pytest", "black"]
25+
TEST_REQUIRED = ["pytest", "black", "librosa"]
2626

2727
# What packages are optional?
2828
EXTRAS = {

tests/test_augmentations.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import numpy as np
1+
import librosa
22
import torch
3-
import torchaudio
43
import pytest
4+
import numpy as np
55

66
from torchaudio_augmentations import (
77
Compose,
@@ -162,14 +162,45 @@ def test_pitch_shift_fast_ratios():
162162

163163
def test_pitch_shift_no_fast_ratios():
164164
with pytest.raises(ValueError):
165-
ps = PitchShift(
165+
_ = PitchShift(
166166
n_samples=num_samples,
167167
sample_rate=sample_rate,
168168
pitch_shift_min=4,
169169
pitch_shift_max=4,
170170
)
171171

172172

173+
def test_pitch_shift_transform_with_pitch_detection():
174+
"""To check semi-tone values, check: http://www.homepages.ucl.ac.uk/~sslyjjt/speech/semitone.html"""
175+
176+
source_frequency = 440
177+
max_semitone_shift = 4
178+
expected_frequency_shift = 554
179+
180+
num_channels = 1
181+
audio = generate_waveform(
182+
sample_rate, num_samples, num_channels, frequency=source_frequency
183+
)
184+
pitch_shift = PitchShift(
185+
n_samples=num_samples,
186+
sample_rate=sample_rate,
187+
pitch_shift_min=max_semitone_shift,
188+
pitch_shift_max=max_semitone_shift + 1,
189+
)
190+
191+
t_audio = pitch_shift(audio)
192+
librosa_audio = t_audio[0].numpy()
193+
f0_hz, _, _ = librosa.pyin(librosa_audio, fmin=10, fmax=1000)
194+
195+
# remove nan values:
196+
f0_hz = f0_hz[~np.isnan(f0_hz)]
197+
198+
detected_f0_hz = np.max(f0_hz)
199+
200+
# the detected frequency vs. expected frequency should not be smaller than 20Hz.
201+
assert abs(detected_f0_hz - expected_frequency_shift) < 20
202+
203+
173204
@pytest.mark.parametrize("num_channels", [1, 2])
174205
def test_reverb(num_channels):
175206
audio = generate_waveform(sample_rate, num_samples, num_channels)

tests/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44

55
def generate_waveform(
6-
sample_rate: int, num_samples: int, num_channels: int
6+
sample_rate: int,
7+
num_samples: int,
8+
num_channels: int,
9+
frequency: int = 440,
710
) -> torch.Tensor:
811

912
# Dividing x legnth value into three parts:- 1/10, 1/2, 4/10.
@@ -18,8 +21,7 @@ def generate_waveform(
1821
sustain = np.ones(sustain_length) * sustain_value
1922
attack_decay_sustain = np.concatenate((attack, decay, sustain))
2023

21-
freq = 440
22-
wavedata = np.sin(2 * np.pi * np.arange(num_samples) * freq / sample_rate)
24+
wavedata = np.sin(2 * np.pi * np.arange(num_samples) * frequency / sample_rate)
2325

2426
wavedata = wavedata * attack_decay_sustain
2527

0 commit comments

Comments
 (0)