Skip to content

Commit 633d456

Browse files
committed
Fixed slow pitch transforms
1 parent e7b379b commit 633d456

File tree

2 files changed

+38
-74
lines changed

2 files changed

+38
-74
lines changed

tests/test_augmentations.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -206,26 +206,6 @@ def test_pitch(batch_size, num_channels):
206206
assert t_audio.shape == audio.shape
207207

208208

209-
def test_pitch_shift_fast_ratios():
210-
ps = PitchShift(
211-
n_samples=num_samples,
212-
sample_rate=sample_rate,
213-
pitch_shift_min=-5,
214-
pitch_shift_max=5,
215-
)
216-
assert len(ps.fast_shifts) == 20
217-
218-
219-
def test_pitch_shift_no_fast_ratios():
220-
with pytest.raises(ValueError):
221-
_ = PitchShift(
222-
n_samples=num_samples,
223-
sample_rate=sample_rate,
224-
pitch_shift_min=4,
225-
pitch_shift_max=4,
226-
)
227-
228-
229209
def test_pitch_shift_transform_with_pitch_detection():
230210
"""To check semi-tone values, check: http://www.homepages.ucl.ac.uk/~sslyjjt/speech/semitone.html"""
231211

Lines changed: 38 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,47 @@
11
import random
22
import torch
3-
from fractions import Fraction
4-
from typing import Optional
5-
from torch_pitch_shift import get_fast_shifts, pitch_shift, semitones_to_ratio
6-
from torchaudio_augmentations.utils import (
7-
add_audio_batch_dimension,
8-
remove_audio_batch_dimension,
9-
tensor_has_valid_audio_batch_dimension,
10-
)
3+
import augment
114

125

13-
class PitchShift(torch.nn.Module):
6+
class PitchShift:
147
def __init__(
15-
self,
16-
n_samples,
17-
sample_rate,
18-
pitch_shift_min: int = -7.0,
19-
pitch_shift_max: int = 7.0,
20-
bins_per_octave: Optional[int] = 12,
8+
self, n_samples, sample_rate, pitch_shift_min=-7.0, pitch_shift_max=7.0
219
):
2210
self.n_samples = n_samples
2311
self.sample_rate = sample_rate
24-
self.pitch_shift_min = pitch_shift_min
25-
self.pitch_shift_max = pitch_shift_max
26-
self.bins_per_octave = bins_per_octave
27-
28-
self._fast_shifts = get_fast_shifts(
29-
sample_rate,
30-
lambda x: x >= semitones_to_ratio(self.pitch_shift_min)
31-
and x <= semitones_to_ratio(self.pitch_shift_max)
32-
and x != 1,
33-
)
34-
35-
if len(self._fast_shifts) == 0:
36-
raise ValueError(
37-
f"Could not compute any fast pitch-shift ratios for the given sample rate and pitch shift range: {self.pitch_shift_min} - {self.pitch_shift_max} (semitones)"
38-
)
39-
40-
@property
41-
def fast_shifts(self):
42-
return self._fast_shifts
43-
44-
def draw_sample_uniform_from_fast_shifts(self) -> Fraction:
45-
return random.choice(self.fast_shifts)
46-
47-
def __call__(self, audio: torch.Tensor) -> torch.Tensor:
48-
is_batched = False
49-
if not tensor_has_valid_audio_batch_dimension(audio):
50-
audio = add_audio_batch_dimension(audio)
51-
is_batched = True
52-
53-
fast_shift = self.draw_sample_uniform_from_fast_shifts()
54-
y = pitch_shift(
55-
input=audio,
56-
shift=fast_shift,
57-
sample_rate=self.sample_rate,
58-
bins_per_octave=self.bins_per_octave,
59-
)
60-
61-
if is_batched:
62-
y = remove_audio_batch_dimension(y)
12+
self.pitch_shift_cents_min = int(pitch_shift_min * 100)
13+
self.pitch_shift_cents_max = int(pitch_shift_max * 100)
14+
self.src_info = {"rate": self.sample_rate}
15+
16+
def process(self, x):
17+
n_steps = random.randint(self.pitch_shift_cents_min, self.pitch_shift_cents_max)
18+
effect_chain = augment.EffectChain().pitch(n_steps).rate(self.sample_rate)
19+
num_channels = x.shape[0]
20+
target_info = {
21+
"channels": num_channels,
22+
"length": self.n_samples,
23+
"rate": self.sample_rate,
24+
}
25+
y = effect_chain.apply(x, src_info=self.src_info, target_info=target_info)
26+
27+
# sox might misbehave sometimes by giving nan/inf if sequences are too short (or silent)
28+
# and the effect chain includes eg `pitch`
29+
if torch.isnan(y).any() or torch.isinf(y).any():
30+
return x.clone()
31+
32+
if y.shape[1] != x.shape[1]:
33+
if y.shape[1] > x.shape[1]:
34+
y = y[:, : x.shape[1]]
35+
else:
36+
y0 = torch.zeros(num_channels, x.shape[1]).to(y.device)
37+
y0[:, : y.shape[1]] = y
38+
y = y0
6339
return y
40+
41+
def __call__(self, audio):
42+
if audio.ndim == 3:
43+
for b in range(audio.shape[0]):
44+
audio[b] = self.process(audio[b])
45+
return audio
46+
else:
47+
return self.process(audio)

0 commit comments

Comments
 (0)