|
1 | 1 | import random
|
2 | 2 | import torch
|
3 |
| -from fractions import Fraction |
4 |
| -from typing import Optional |
5 |
| -from torch_pitch_shift import get_fast_shifts, pitch_shift, semitones_to_ratio |
6 |
| -from torchaudio_augmentations.utils import ( |
7 |
| - add_audio_batch_dimension, |
8 |
| - remove_audio_batch_dimension, |
9 |
| - tensor_has_valid_audio_batch_dimension, |
10 |
| -) |
| 3 | +import augment |
11 | 4 |
|
12 | 5 |
|
13 |
| -class PitchShift(torch.nn.Module): |
| 6 | +class PitchShift: |
14 | 7 | def __init__(
|
15 |
| - self, |
16 |
| - n_samples, |
17 |
| - sample_rate, |
18 |
| - pitch_shift_min: int = -7.0, |
19 |
| - pitch_shift_max: int = 7.0, |
20 |
| - bins_per_octave: Optional[int] = 12, |
| 8 | + self, n_samples, sample_rate, pitch_shift_min=-7.0, pitch_shift_max=7.0 |
21 | 9 | ):
|
22 | 10 | self.n_samples = n_samples
|
23 | 11 | self.sample_rate = sample_rate
|
24 |
| - self.pitch_shift_min = pitch_shift_min |
25 |
| - self.pitch_shift_max = pitch_shift_max |
26 |
| - self.bins_per_octave = bins_per_octave |
27 |
| - |
28 |
| - self._fast_shifts = get_fast_shifts( |
29 |
| - sample_rate, |
30 |
| - lambda x: x >= semitones_to_ratio(self.pitch_shift_min) |
31 |
| - and x <= semitones_to_ratio(self.pitch_shift_max) |
32 |
| - and x != 1, |
33 |
| - ) |
34 |
| - |
35 |
| - if len(self._fast_shifts) == 0: |
36 |
| - raise ValueError( |
37 |
| - f"Could not compute any fast pitch-shift ratios for the given sample rate and pitch shift range: {self.pitch_shift_min} - {self.pitch_shift_max} (semitones)" |
38 |
| - ) |
39 |
| - |
40 |
| - @property |
41 |
| - def fast_shifts(self): |
42 |
| - return self._fast_shifts |
43 |
| - |
44 |
| - def draw_sample_uniform_from_fast_shifts(self) -> Fraction: |
45 |
| - return random.choice(self.fast_shifts) |
46 |
| - |
47 |
| - def __call__(self, audio: torch.Tensor) -> torch.Tensor: |
48 |
| - is_batched = False |
49 |
| - if not tensor_has_valid_audio_batch_dimension(audio): |
50 |
| - audio = add_audio_batch_dimension(audio) |
51 |
| - is_batched = True |
52 |
| - |
53 |
| - fast_shift = self.draw_sample_uniform_from_fast_shifts() |
54 |
| - y = pitch_shift( |
55 |
| - input=audio, |
56 |
| - shift=fast_shift, |
57 |
| - sample_rate=self.sample_rate, |
58 |
| - bins_per_octave=self.bins_per_octave, |
59 |
| - ) |
60 |
| - |
61 |
| - if is_batched: |
62 |
| - y = remove_audio_batch_dimension(y) |
| 12 | + self.pitch_shift_cents_min = int(pitch_shift_min * 100) |
| 13 | + self.pitch_shift_cents_max = int(pitch_shift_max * 100) |
| 14 | + self.src_info = {"rate": self.sample_rate} |
| 15 | + |
| 16 | + def process(self, x): |
| 17 | + n_steps = random.randint(self.pitch_shift_cents_min, self.pitch_shift_cents_max) |
| 18 | + effect_chain = augment.EffectChain().pitch(n_steps).rate(self.sample_rate) |
| 19 | + num_channels = x.shape[0] |
| 20 | + target_info = { |
| 21 | + "channels": num_channels, |
| 22 | + "length": self.n_samples, |
| 23 | + "rate": self.sample_rate, |
| 24 | + } |
| 25 | + y = effect_chain.apply(x, src_info=self.src_info, target_info=target_info) |
| 26 | + |
| 27 | + # sox might misbehave sometimes by giving nan/inf if sequences are too short (or silent) |
| 28 | + # and the effect chain includes eg `pitch` |
| 29 | + if torch.isnan(y).any() or torch.isinf(y).any(): |
| 30 | + return x.clone() |
| 31 | + |
| 32 | + if y.shape[1] != x.shape[1]: |
| 33 | + if y.shape[1] > x.shape[1]: |
| 34 | + y = y[:, : x.shape[1]] |
| 35 | + else: |
| 36 | + y0 = torch.zeros(num_channels, x.shape[1]).to(y.device) |
| 37 | + y0[:, : y.shape[1]] = y |
| 38 | + y = y0 |
63 | 39 | return y
|
| 40 | + |
| 41 | + def __call__(self, audio): |
| 42 | + if audio.ndim == 3: |
| 43 | + for b in range(audio.shape[0]): |
| 44 | + audio[b] = self.process(audio[b]) |
| 45 | + return audio |
| 46 | + else: |
| 47 | + return self.process(audio) |
0 commit comments