Skip to content

Commit 1a81d77

Browse files
Add dithering to the Speech2TextFeatureExtractor API. (#34638)
* Add dithering to the `Speech2TextFeatureExtractor` API. - in kaldi : https://github.com/kaldi-asr/kaldi/blob/4a8b7f673275597fef8a15b160124bd0985b59bd/src/feat/feature-window.cc#L145 - with dithering without a seed, the features become non-deterministic due to small Gaussian noise added to the audio (i.e. 2 runs lead to little different outputs) * update the PR - add dithering also for WhisperFeatureExtractor - not adding to Wav2Vec2FeatureExtractor (no FBANK computation) * add unit-tests for dithering, fix docstrings * ruff * utils/check_copies.py --fix_and_overwrite * update code, add seed to unit-test * adding explanation of dithering
1 parent 9f51dc2 commit 1a81d77

File tree

5 files changed

+119
-1
lines changed

5 files changed

+119
-1
lines changed

src/transformers/audio_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def spectrogram(
390390
center: bool = True,
391391
pad_mode: str = "reflect",
392392
onesided: bool = True,
393+
dither: float = 0.0,
393394
preemphasis: Optional[float] = None,
394395
mel_filters: Optional[np.ndarray] = None,
395396
mel_floor: float = 1e-10,
@@ -460,6 +461,12 @@ def spectrogram(
460461
onesided (`bool`, *optional*, defaults to `True`):
461462
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
462463
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
464+
dither (`float`, *optional*, defaults to 0.0):
465+
Adds dithering. In other words, adds a small Gaussian noise to each frame.
466+
E.g. use 4.0 to add dithering with a normal distribution centered
467+
around 0.0 with standard deviation 4.0, 0.0 means no dithering.
468+
Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
469+
values for signals with hard-zero sections, when VAD cutoff is present in the signal.
463470
preemphasis (`float`, *optional*)
464471
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
465472
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
@@ -540,6 +547,9 @@ def spectrogram(
540547
for frame_idx in range(num_frames):
541548
buffer[:frame_length] = waveform[timestep : timestep + frame_length]
542549

550+
if dither != 0.0:
551+
buffer[:frame_length] += dither * np.random.randn(frame_length)
552+
543553
if remove_dc_offset:
544554
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
545555

@@ -591,6 +601,7 @@ def spectrogram_batch(
591601
center: bool = True,
592602
pad_mode: str = "reflect",
593603
onesided: bool = True,
604+
dither: float = 0.0,
594605
preemphasis: Optional[float] = None,
595606
mel_filters: Optional[np.ndarray] = None,
596607
mel_floor: float = 1e-10,
@@ -653,6 +664,10 @@ def spectrogram_batch(
653664
The padding strategy when `center` is `True`.
654665
onesided (`bool`, *optional*, defaults to `True`):
655666
If True, returns a one-sided spectrogram for real input signals.
667+
dither (`float`, *optional*, defaults to 0.0):
668+
Adds dithering. In other words, adds a small Gaussian noise to each frame.
669+
E.g. use 4.0 to add dithering with a normal distribution centered
670+
around 0.0 with standard deviation 4.0, 0.0 means no dithering.
656671
preemphasis (`float`, *optional*):
657672
Applies a pre-emphasis filter to each frame.
658673
mel_filters (`np.ndarray`, *optional*):
@@ -741,6 +756,9 @@ def spectrogram_batch(
741756
timestep = frame_idx * hop_length
742757
buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
743758

759+
if dither != 0.0:
760+
buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape)
761+
744762
if remove_dc_offset:
745763
buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
746764

src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
5252
Number of Mel-frequency bins.
5353
padding_value (`float`, *optional*, defaults to 0.0):
5454
The value that is used to fill the padding vectors.
55+
dither (`float`, *optional*, defaults to 0.0):
56+
Adds dithering. In other words, adds a small Gaussian noise to each frame.
57+
E.g. use 4.0 to add dithering with a normal distribution centered
58+
around 0.0 with standard deviation 4.0 (assuming [-32k,+32k] range of kaldi waveform).
59+
The value 0.0 means no dithering.
60+
Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
61+
values for signals with hard-zero sections, when VAD cutoff is present in the signal.
5562
do_ceptral_normalize (`bool`, *optional*, defaults to `True`):
5663
Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features.
5764
normalize_means (`bool`, *optional*, defaults to `True`):
@@ -68,13 +75,15 @@ def __init__(
6875
sampling_rate=16000,
6976
num_mel_bins=80,
7077
padding_value=0.0,
78+
dither=0.0,
7179
do_ceptral_normalize=True,
7280
normalize_means=True,
7381
normalize_vars=True,
7482
**kwargs,
7583
):
7684
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
7785
self.num_mel_bins = num_mel_bins
86+
self.dither = dither
7887
self.do_ceptral_normalize = do_ceptral_normalize
7988
self.normalize_means = normalize_means
8089
self.normalize_vars = normalize_vars
@@ -106,7 +115,12 @@ def _extract_fbank_features(
106115
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
107116
if is_speech_available():
108117
waveform = torch.from_numpy(waveform).unsqueeze(0)
109-
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
118+
features = ta_kaldi.fbank(
119+
waveform,
120+
dither=self.dither,
121+
num_mel_bins=self.num_mel_bins,
122+
sample_frequency=self.sampling_rate,
123+
)
110124
features = features.numpy()
111125
else:
112126
waveform = np.squeeze(waveform)
@@ -118,6 +132,7 @@ def _extract_fbank_features(
118132
fft_length=512,
119133
power=2.0,
120134
center=False,
135+
dither=self.dither,
121136
preemphasis=0.97,
122137
mel_filters=self.mel_filters,
123138
log_mel="log",

src/transformers/models/whisper/feature_extraction_whisper.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
5757
Size of the Fourier transform.
5858
padding_value (`float`, *optional*, defaults to 0.0):
5959
Padding value used to pad the audio. Should correspond to silences.
60+
dither (`float`, *optional*, defaults to 0.0):
61+
Adds dithering. In other words, adds a small Gaussian noise to each frame.
62+
E.g. use 0.0001 to add dithering with a normal distribution centered
63+
around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
64+
The value 0.0 means no dithering.
65+
Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
66+
the high log_mel_fbank values for signals with hard-zero sections,
67+
when VAD cutoff is present in the signal.
6068
"""
6169

6270
model_input_names = ["input_features"]
@@ -69,6 +77,7 @@ def __init__(
6977
chunk_length=30,
7078
n_fft=400,
7179
padding_value=0.0,
80+
dither=0.0,
7281
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
7382
**kwargs,
7483
):
@@ -85,6 +94,7 @@ def __init__(
8594
self.n_samples = chunk_length * sampling_rate
8695
self.nb_max_frames = self.n_samples // hop_length
8796
self.sampling_rate = sampling_rate
97+
self.dither = dither
8898
self.mel_filters = mel_filter_bank(
8999
num_frequency_bins=1 + n_fft // 2,
90100
num_mel_filters=feature_size,
@@ -114,6 +124,7 @@ def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> n
114124
frame_length=self.n_fft,
115125
hop_length=self.hop_length,
116126
power=2.0,
127+
dither=self.dither,
117128
mel_filters=self.mel_filters,
118129
log_mel="log10",
119130
)
@@ -132,6 +143,12 @@ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu")
132143
waveform = torch.from_numpy(waveform).to(device, torch.float32)
133144
window = torch.hann_window(self.n_fft, device=device)
134145

146+
# Note: it would be better to dither the chunked waveform,
147+
# so overlapping signal does not get the same dithering.
148+
# But, chunking is happening inside pytorch, so it is here.
149+
if self.dither != 0.0:
150+
waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)
151+
135152
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
136153
magnitudes = stft[..., :-1].abs() ** 2
137154

tests/models/speech_to_text/test_feature_extraction_speech_to_text.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,40 @@ def test_call(self):
144144
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
145145
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
146146

147+
def test_dither(self):
148+
np.random.seed(42) # seed the dithering randn()
149+
150+
# Tests that features with and without little dithering are similar, but not the same
151+
dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
152+
dict_no_dither["dither"] = 0.0
153+
154+
dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
155+
dict_dither["dither"] = 1.0
156+
157+
feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
158+
feature_extractor_dither = self.feature_extraction_class(**dict_dither)
159+
160+
# create three inputs of length 800, 1000, and 1200
161+
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
162+
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
163+
164+
# compute features
165+
input_features_no_dither = feature_extractor_no_dither(
166+
np_speech_inputs, padding=True, return_tensors="np"
167+
).input_features
168+
input_features_dither = feature_extractor_dither(
169+
np_speech_inputs, padding=True, return_tensors="np"
170+
).input_features
171+
172+
# test there is a difference between features (there's added noise to input signal)
173+
diff = input_features_dither - input_features_no_dither
174+
175+
# features are not identical
176+
self.assertTrue(np.abs(diff).mean() > 1e-5)
177+
# features are not too different
178+
self.assertTrue(np.abs(diff).mean() <= 1e-3)
179+
self.assertTrue(np.abs(diff).max() <= 1e-2)
180+
147181
def test_cepstral_mean_and_variance_normalization(self):
148182
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
149183
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]

tests/models/whisper/test_feature_extraction_whisper.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,40 @@ def test_call(self):
200200
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
201201
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
202202

203+
def test_dither(self):
204+
np.random.seed(42) # seed the dithering randn()
205+
206+
# Tests that features with and without little dithering are similar, but not the same
207+
dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
208+
dict_no_dither["dither"] = 0.0
209+
210+
dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
211+
dict_dither["dither"] = 0.00003 # approx. 1/32k
212+
213+
feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
214+
feature_extractor_dither = self.feature_extraction_class(**dict_dither)
215+
216+
# create three inputs of length 800, 1000, and 1200
217+
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
218+
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
219+
220+
# compute features
221+
input_features_no_dither = feature_extractor_no_dither(
222+
np_speech_inputs, padding=True, return_tensors="np"
223+
).input_features
224+
input_features_dither = feature_extractor_dither(
225+
np_speech_inputs, padding=True, return_tensors="np"
226+
).input_features
227+
228+
# test there is a difference between features (there's added noise to input signal)
229+
diff = input_features_dither - input_features_no_dither
230+
231+
# features are not identical
232+
self.assertTrue(np.abs(diff).mean() > 1e-6)
233+
# features are not too different
234+
self.assertTrue(np.abs(diff).mean() <= 1e-4)
235+
self.assertTrue(np.abs(diff).max() <= 1e-3)
236+
203237
@require_torch
204238
def test_double_precision_pad(self):
205239
import torch

0 commit comments

Comments
 (0)