Skip to content

Commit a355b68

Browse files
committed
Updated pitch detection backend to torch-pitch-shift
1 parent f3e0ad8 commit a355b68

File tree

6 files changed

+147
-47
lines changed

6 files changed

+147
-47
lines changed

setup.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@
1818
EMAIL = "janne.spijkervet@gmail.com"
1919
AUTHOR = "Janne Spijkervet"
2020
REQUIRES_PYTHON = ">=3.6.0"
21-
VERSION = "0.2.2"
21+
VERSION = "0.2.3"
2222

2323
# What packages are required for this module to be executed?
24-
REQUIRED = ["numpy", "torch", "torchaudio", "julius", "wavaugment"]
24+
REQUIRED = ["numpy", "torch", "torchaudio", "julius", "wavaugment", "torch-pitch-shift"]
25+
TEST_REQUIRED = ["pytest"]
2526

2627
# What packages are optional?
2728
EXTRAS = {
28-
'fancy feature': [''],
29+
"fancy feature": [""],
30+
"test": TEST_REQUIRED,
2931
}
3032

3133
# The rest you shouldn't have to touch too much :)

tests/test_augmentations.py

+71-11
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,37 @@ def test_random_resized_crop(num_channels):
3333
assert audio.shape[1] == num_samples
3434

3535

36+
@pytest.mark.parametrize(
37+
["batch_size", "num_channels"],
38+
[
39+
(1, 1),
40+
(4, 1),
41+
(16, 1),
42+
(1, 2),
43+
(4, 2),
44+
(16, 2),
45+
],
46+
)
47+
def test_random_resized_crop_batched(batch_size, num_channels):
48+
49+
num_samples = 22050 * 5
50+
audio = generate_waveform(sample_rate, num_samples, num_channels)
51+
audio = audio.repeat(batch_size, 1, 1)
52+
53+
transform = Compose([RandomResizedCrop(num_samples)])
54+
55+
audio = transform(audio)
56+
assert audio.shape[0] == batch_size
57+
assert audio.shape[1] == num_channels
58+
assert audio.shape[2] == num_samples
59+
60+
3661
@pytest.mark.parametrize("num_channels", [1, 2])
3762
def test_polarity(num_channels):
38-
audio = generate_waveform(sample_rate, num_samples,
39-
num_channels=num_channels)
40-
transform = Compose([PolarityInversion()],)
63+
audio = generate_waveform(sample_rate, num_samples, num_channels=num_channels)
64+
transform = Compose(
65+
[PolarityInversion()],
66+
)
4167

4268
t_audio = transform(audio)
4369
assert (t_audio == torch.neg(audio)).all()
@@ -47,7 +73,9 @@ def test_polarity(num_channels):
4773
@pytest.mark.parametrize("num_channels", [1, 2])
4874
def test_filter(num_channels):
4975
audio = generate_waveform(sample_rate, num_samples, num_channels)
50-
transform = Compose([HighLowPass(sample_rate=sample_rate)],)
76+
transform = Compose(
77+
[HighLowPass(sample_rate=sample_rate)],
78+
)
5179
t_audio = transform(audio)
5280
# torchaudio.save("tests/filter.wav", t_audio, sample_rate=sample_rate)
5381
assert t_audio.shape == audio.shape
@@ -56,7 +84,9 @@ def test_filter(num_channels):
5684
@pytest.mark.parametrize("num_channels", [1, 2])
5785
def test_delay(num_channels):
5886
audio = generate_waveform(sample_rate, num_samples, num_channels)
59-
transform = Compose([Delay(sample_rate=sample_rate)],)
87+
transform = Compose(
88+
[Delay(sample_rate=sample_rate)],
89+
)
6090

6191
t_audio = transform(audio)
6292
# torchaudio.save("tests/delay.wav", t_audio, sample_rate=sample_rate)
@@ -66,7 +96,9 @@ def test_delay(num_channels):
6696
@pytest.mark.parametrize("num_channels", [1, 2])
6797
def test_gain(num_channels):
6898
audio = generate_waveform(sample_rate, num_samples, num_channels)
69-
transform = Compose([Gain()],)
99+
transform = Compose(
100+
[Gain()],
101+
)
70102

71103
t_audio = transform(audio)
72104
# torchaudio.save("tests/gain.wav", t_audio, sample_rate=sample_rate)
@@ -76,7 +108,9 @@ def test_gain(num_channels):
76108
@pytest.mark.parametrize("num_channels", [1, 2])
77109
def test_noise(num_channels):
78110
audio = generate_waveform(sample_rate, num_samples, num_channels)
79-
transform = Compose([Noise(min_snr=0.5, max_snr=1)],)
111+
transform = Compose(
112+
[Noise(min_snr=0.5, max_snr=1)],
113+
)
80114

81115
t_audio = transform(audio)
82116
# torchaudio.save("tests/noise.wav", t_audio, sample_rate=sample_rate)
@@ -87,17 +121,41 @@ def test_noise(num_channels):
87121
def test_pitch(num_channels):
88122
audio = generate_waveform(sample_rate, num_samples, num_channels)
89123
transform = Compose(
90-
[PitchShift(n_samples=num_samples, sample_rate=sample_rate)],)
124+
[PitchShift(n_samples=num_samples, sample_rate=sample_rate)],
125+
)
91126

92127
t_audio = transform(audio)
93-
# torchaudio.save("tests/pitch.wav", t_audio, sample_rate=sample_rate)
128+
# torchaudio.save("tests/pitch.wav", audio, sample_rate=sample_rate)
129+
# torchaudio.save("tests/t_pitch.wav", t_audio, sample_rate=sample_rate)
94130
assert t_audio.shape == audio.shape
95131

96132

133+
def test_pitch_shift_fast_ratios():
134+
ps = PitchShift(
135+
n_samples=num_samples,
136+
sample_rate=sample_rate,
137+
pitch_shift_min=-5,
138+
pitch_shift_max=5,
139+
)
140+
assert len(ps.fast_shifts) == 20
141+
142+
143+
def test_pitch_shift_no_fast_ratios():
144+
with pytest.raises(ValueError):
145+
ps = PitchShift(
146+
n_samples=num_samples,
147+
sample_rate=sample_rate,
148+
pitch_shift_min=4,
149+
pitch_shift_max=4,
150+
)
151+
152+
97153
@pytest.mark.parametrize("num_channels", [1, 2])
98154
def test_reverb(num_channels):
99155
audio = generate_waveform(sample_rate, num_samples, num_channels)
100-
transform = Compose([Reverb(sample_rate=sample_rate)],)
156+
transform = Compose(
157+
[Reverb(sample_rate=sample_rate)],
158+
)
101159

102160
t_audio = transform(audio)
103161
# torchaudio.save("tests/reverb.wav", t_audio, sample_rate=sample_rate)
@@ -107,7 +165,9 @@ def test_reverb(num_channels):
107165
@pytest.mark.parametrize("num_channels", [1, 2])
108166
def test_reverse(num_channels):
109167
stereo_audio = generate_waveform(sample_rate, num_samples, num_channels)
110-
transform = Compose([Reverse()],)
168+
transform = Compose(
169+
[Reverse()],
170+
)
111171

112172
t_audio = transform(stereo_audio)
113173
# torchaudio.save("tests/reverse.wav", t_audio, sample_rate=sample_rate)
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,63 @@
11
import random
22
import torch
3-
import augment
3+
from torchaudio_augmentations.utils import (
4+
add_audio_batch_dimension,
5+
remove_audio_batch_dimension,
6+
tensor_has_valid_audio_batch_dimension,
7+
)
8+
from fractions import Fraction
9+
from typing import Optional
10+
from torch_pitch_shift import get_fast_shifts, pitch_shift, semitones_to_ratio
411

512

613
class PitchShift:
714
def __init__(
8-
self, n_samples, sample_rate, pitch_cents_min=-700, pitch_cents_max=700
15+
self,
16+
n_samples,
17+
sample_rate,
18+
pitch_shift_min: int = -7.0,
19+
pitch_shift_max: int = 7.0,
20+
bins_per_octave: Optional[int] = 12,
921
):
1022
self.n_samples = n_samples
1123
self.sample_rate = sample_rate
12-
self.pitch_cents_min = pitch_cents_min
13-
self.pitch_cents_max = pitch_cents_max
14-
self.src_info = {"rate": self.sample_rate}
15-
16-
def __call__(self, audio):
17-
n_steps = random.randint(self.pitch_cents_min, self.pitch_cents_max)
18-
effect_chain = augment.EffectChain().pitch(n_steps).rate(self.sample_rate)
19-
20-
num_channels = audio.shape[0]
21-
target_info = {
22-
"channels": num_channels,
23-
"length": self.n_samples,
24-
"rate": self.sample_rate,
25-
}
26-
y = effect_chain.apply(audio, src_info=self.src_info, target_info=target_info)
27-
28-
# sox might misbehave sometimes by giving nan/inf if sequences are too short (or silent)
29-
# and the effect chain includes eg `pitch`
30-
if torch.isnan(y).any() or torch.isinf(y).any():
31-
return audio.clone()
32-
33-
if y.shape[1] != audio.shape[1]:
34-
if y.shape[1] > audio.shape[1]:
35-
y = y[:, : audio.shape[1]]
36-
else:
37-
y0 = torch.zeros(1, audio.shape[1]).to(y.device)
38-
y0[:, : y.shape[1]] = y
39-
y = y0
24+
self.pitch_shift_min = pitch_shift_min
25+
self.pitch_shift_max = pitch_shift_max
26+
self.bins_per_octave = bins_per_octave
27+
28+
self._fast_shifts = get_fast_shifts(
29+
sample_rate,
30+
lambda x: x >= semitones_to_ratio(self.pitch_shift_min)
31+
and x <= semitones_to_ratio(self.pitch_shift_max)
32+
and x != 1,
33+
)
34+
35+
if len(self._fast_shifts) == 0:
36+
raise ValueError(
37+
f"Could not compute any fast pitch-shift ratios for the given sample rate and pitch shift range: {self.pitch_shift_min} - {self.pitch_shift_max} (semitones)"
38+
)
39+
40+
@property
41+
def fast_shifts(self):
42+
return self._fast_shifts
43+
44+
def draw_sample_uniform_from_fast_shifts(self) -> Fraction:
45+
return random.choice(self.fast_shifts)
46+
47+
def __call__(self, audio: torch.Tensor) -> torch.Tensor:
48+
is_batched = False
49+
if not tensor_has_valid_audio_batch_dimension(audio):
50+
audio = add_audio_batch_dimension(audio)
51+
is_batched = True
52+
53+
fast_shift = self.draw_sample_uniform_from_fast_shifts()
54+
y = pitch_shift(
55+
input=audio,
56+
shift=fast_shift,
57+
sample_rate=self.sample_rate,
58+
bins_per_octave=self.bins_per_octave,
59+
)
60+
61+
if is_batched:
62+
y = remove_audio_batch_dimension(y)
4063
return y

torchaudio_augmentations/augmentations/random_resized_crop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def __init__(self, n_samples):
88
self.n_samples = n_samples
99

1010
def forward(self, audio):
11-
max_samples = audio.shape[1]
11+
max_samples = audio.shape[-1]
1212
start_idx = random.randint(0, max_samples - self.n_samples)
13-
audio = audio[:, start_idx : start_idx + self.n_samples]
13+
audio = audio[..., start_idx : start_idx + self.n_samples]
1414
return audio

torchaudio_augmentations/augmentations/reverse.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ def __init__(self):
77
super().__init__()
88

99
def forward(self, audio):
10-
return torch.flip(audio, dims=[1])
10+
return torch.flip(audio, dims=[-1])

torchaudio_augmentations/utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
3+
4+
def tensor_has_valid_audio_batch_dimension(tensor: torch.Tensor) -> torch.Tensor:
5+
if tensor.ndim == 3:
6+
return True
7+
return False
8+
9+
10+
def add_audio_batch_dimension(tensor: torch.Tensor) -> torch.Tensor:
11+
return tensor.unsqueeze(dim=0)
12+
13+
14+
def remove_audio_batch_dimension(tensor: torch.Tensor) -> torch.Tensor:
15+
return tensor.squeeze(dim=0)

0 commit comments

Comments
 (0)