Skip to content

Commit 08ab956

Browse files
authored
Merge pull request #108 from PaddlePaddle/new_augmentation
add 3 augmentor class
2 parents 1c6cefc + 123d1a3 commit 08ab956

File tree

7 files changed

+148
-11
lines changed

7 files changed

+148
-11
lines changed

deep_speech_2/data_utils/audio.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import io
88
import soundfile
9-
import scikits.samplerate
9+
import resampy
1010
from scipy import signal
1111
import random
1212
import copy
@@ -308,7 +308,7 @@ def normalize_online_bayesian(self,
308308
prior_mean_squared = 10.**(prior_db / 10.)
309309
prior_sum_of_squares = prior_mean_squared * prior_samples
310310
cumsum_of_squares = np.cumsum(self.samples**2)
311-
sample_count = np.arange(len(self.num_samples)) + 1
311+
sample_count = np.arange(self.num_samples) + 1
312312
if startup_sample_idx > 0:
313313
cumsum_of_squares[:startup_sample_idx] = \
314314
cumsum_of_squares[startup_sample_idx]
@@ -321,21 +321,19 @@ def normalize_online_bayesian(self,
321321
gain_db = target_db - rms_estimate_db
322322
self.gain_db(gain_db)
323323

324-
def resample(self, target_sample_rate, quality='sinc_medium'):
324+
def resample(self, target_sample_rate, filter='kaiser_best'):
325325
"""Resample the audio to a target sample rate.
326326
327327
Note that this is an in-place transformation.
328328
329329
:param target_sample_rate: Target sample rate.
330330
:type target_sample_rate: int
331-
:param quality: One of {'sinc_fastest', 'sinc_medium', 'sinc_best'}.
332-
Sets resampling speed/quality tradeoff.
333-
See http://www.mega-nerd.com/SRC/api_misc.html#Converters
334-
:type quality: str
331+
:param filter: The resampling filter to use one of {'kaiser_best',
332+
'kaiser_fast'}.
333+
:type filter: str
335334
"""
336-
resample_ratio = target_sample_rate / self._sample_rate
337-
self._samples = scikits.samplerate.resample(
338-
self._samples, r=resample_ratio, type=quality)
335+
self._samples = resampy.resample(
336+
self.samples, self.sample_rate, target_sample_rate, filter=filter)
339337
self._sample_rate = target_sample_rate
340338

341339
def pad_silence(self, duration, sides='both'):

deep_speech_2/data_utils/augmentor/augmentation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
import random
88
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
99
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
10+
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
11+
from data_utils.augmentor.resample import ResampleAugmentor
12+
from data_utils.augmentor.online_bayesian_normalization import \
13+
OnlineBayesianNormalizationAugmentor
1014

1115

1216
class AugmentationPipeline(object):
@@ -79,5 +83,11 @@ def _get_augmentor(self, augmentor_type, params):
7983
return VolumePerturbAugmentor(self._rng, **params)
8084
elif augmentor_type == "shift":
8185
return ShiftPerturbAugmentor(self._rng, **params)
86+
elif augmentor_type == "speed":
87+
return SpeedPerturbAugmentor(self._rng, **params)
88+
elif augmentor_type == "resample":
89+
return ResampleAugmentor(self._rng, **params)
90+
elif augmentor_type == "bayesian_normal":
91+
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
8292
else:
8393
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Contain the online bayesian normalization augmentation model."""
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
from data_utils.augmentor.base import AugmentorBase
7+
8+
9+
class OnlineBayesianNormalizationAugmentor(AugmentorBase):
10+
"""Augmentation model for adding online bayesian normalization.
11+
12+
:param rng: Random generator object.
13+
:type rng: random.Random
14+
:param target_db: Target RMS value in decibels.
15+
:type target_db: float
16+
:param prior_db: Prior RMS estimate in decibels.
17+
:type prior_db: float
18+
:param prior_samples: Prior strength in number of samples.
19+
:type prior_samples: int
20+
:param startup_delay: Default 0.0s. If provided, this function will
21+
accrue statistics for the first startup_delay
22+
seconds before applying online normalization.
23+
:type starup_delay: float.
24+
"""
25+
26+
def __init__(self,
27+
rng,
28+
target_db,
29+
prior_db,
30+
prior_samples,
31+
startup_delay=0.0):
32+
self._target_db = target_db
33+
self._prior_db = prior_db
34+
self._prior_samples = prior_samples
35+
self._rng = rng
36+
self._startup_delay = startup_delay
37+
38+
def transform_audio(self, audio_segment):
39+
"""Normalizes the input audio using the online Bayesian approach.
40+
41+
Note that this is an in-place transformation.
42+
43+
:param audio_segment: Audio segment to add effects to.
44+
:type audio_segment: AudioSegment|SpeechSegment
45+
"""
46+
audio_segment.normalize_online_bayesian(self._target_db, self._prior_db,
47+
self._prior_samples,
48+
self._startup_delay)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Contain the resample augmentation model."""
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
from data_utils.augmentor.base import AugmentorBase
7+
8+
9+
class ResampleAugmentor(AugmentorBase):
10+
"""Augmentation model for resampling.
11+
12+
See more info here:
13+
https://ccrma.stanford.edu/~jos/resample/index.html
14+
15+
:param rng: Random generator object.
16+
:type rng: random.Random
17+
:param new_sample_rate: New sample rate in Hz.
18+
:type new_sample_rate: int
19+
"""
20+
21+
def __init__(self, rng, new_sample_rate):
22+
self._new_sample_rate = new_sample_rate
23+
self._rng = rng
24+
25+
def transform_audio(self, audio_segment):
26+
"""Resamples the input audio to a target sample rate.
27+
28+
Note that this is an in-place transformation.
29+
30+
:param audio: Audio segment to add effects to.
31+
:type audio: AudioSegment|SpeechSegment
32+
"""
33+
audio_segment.resample(self._new_sample_rate)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Contain the speech perturbation augmentation model."""
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
from data_utils.augmentor.base import AugmentorBase
7+
8+
9+
class SpeedPerturbAugmentor(AugmentorBase):
10+
"""Augmentation model for adding speed perturbation.
11+
12+
See reference paper here:
13+
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
14+
15+
:param rng: Random generator object.
16+
:type rng: random.Random
17+
:param min_speed_rate: Lower bound of new speed rate to sample and should
18+
not be smaller than 0.9.
19+
:type min_speed_rate: float
20+
:param max_speed_rate: Upper bound of new speed rate to sample and should
21+
not be larger than 1.1.
22+
:type max_speed_rate: float
23+
"""
24+
25+
def __init__(self, rng, min_speed_rate, max_speed_rate):
26+
if min_speed_rate < 0.9:
27+
raise ValueError(
28+
"Sampling speed below 0.9 can cause unnatural effects")
29+
if max_speed_rate > 1.1:
30+
raise ValueError(
31+
"Sampling speed above 1.1 can cause unnatural effects")
32+
self._min_speed_rate = min_speed_rate
33+
self._max_speed_rate = max_speed_rate
34+
self._rng = rng
35+
36+
def transform_audio(self, audio_segment):
37+
"""Sample a new speed rate from the given range and
38+
changes the speed of the given audio clip.
39+
40+
Note that this is an in-place transformation.
41+
42+
:param audio_segment: Audio segment to add effects to.
43+
:type audio_segment: AudioSegment|SpeechSegment
44+
"""
45+
sampled_speed = self._rng.uniform(self._min_speed_rate,
46+
self._max_speed_rate)
47+
audio_segment.change_speed(sampled_speed)

deep_speech_2/data_utils/augmentor/volume_perturb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ def transform_audio(self, audio_segment):
3737
:type audio_segment: AudioSegmenet|SpeechSegment
3838
"""
3939
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
40-
audio_segment.apply_gain(gain)
40+
audio_segment.gain_db(gain)

deep_speech_2/requirements.txt

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
wget==3.2
22
scipy==0.13.1
3+
resampy==0.1.5

0 commit comments

Comments
 (0)