Skip to content
Merged
18 changes: 18 additions & 0 deletions src/transformers/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def spectrogram(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
dither: float = 0.0,
preemphasis: Optional[float] = None,
mel_filters: Optional[np.ndarray] = None,
mel_floor: float = 1e-10,
Expand Down Expand Up @@ -460,6 +461,12 @@ def spectrogram(
onesided (`bool`, *optional*, defaults to `True`):
If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 4.0 to add dithering with a normal distribution centered
around 0.0 with standard deviation 4.0, 0.0 means no dithering.
Comment on lines +465 to +467
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add the comment about "this can help for hard audio in ASR" 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, added the explanatory comments, thanks!

Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
values for signals with hard-zero sections, when VAD cutoff is present in the signal.
preemphasis (`float`, *optional*)
Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
Expand Down Expand Up @@ -540,6 +547,9 @@ def spectrogram(
for frame_idx in range(num_frames):
buffer[:frame_length] = waveform[timestep : timestep + frame_length]

if dither != 0.0:
buffer[:frame_length] += dither * np.random.randn(frame_length)

if remove_dc_offset:
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()

Expand Down Expand Up @@ -591,6 +601,7 @@ def spectrogram_batch(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
dither: float = 0.0,
preemphasis: Optional[float] = None,
mel_filters: Optional[np.ndarray] = None,
mel_floor: float = 1e-10,
Expand Down Expand Up @@ -653,6 +664,10 @@ def spectrogram_batch(
The padding strategy when `center` is `True`.
onesided (`bool`, *optional*, defaults to `True`):
If True, returns a one-sided spectrogram for real input signals.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 4.0 to add dithering with a normal distribution centered
around 0.0 with standard deviation 4.0, 0.0 means no dithering.
preemphasis (`float`, *optional*):
Applies a pre-emphasis filter to each frame.
mel_filters (`np.ndarray`, *optional*):
Expand Down Expand Up @@ -741,6 +756,9 @@ def spectrogram_batch(
timestep = frame_idx * hop_length
buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]

if dither != 0.0:
buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape)

if remove_dc_offset:
buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
Number of Mel-frequency bins.
padding_value (`float`, *optional*, defaults to 0.0):
The value that is used to fill the padding vectors.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 4.0 to add dithering with a normal distribution centered
around 0.0 with standard deviation 4.0 (assuming [-32k,+32k] range of kaldi waveform).
The value 0.0 means no dithering.
Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
values for signals with hard-zero sections, when VAD cutoff is present in the signal.
do_ceptral_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features.
normalize_means (`bool`, *optional*, defaults to `True`):
Expand All @@ -68,13 +75,15 @@ def __init__(
sampling_rate=16000,
num_mel_bins=80,
padding_value=0.0,
dither=0.0,
do_ceptral_normalize=True,
normalize_means=True,
normalize_vars=True,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.num_mel_bins = num_mel_bins
self.dither = dither
self.do_ceptral_normalize = do_ceptral_normalize
self.normalize_means = normalize_means
self.normalize_vars = normalize_vars
Expand Down Expand Up @@ -106,7 +115,12 @@ def _extract_fbank_features(
waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
if is_speech_available():
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
features = ta_kaldi.fbank(
waveform,
dither=self.dither,
num_mel_bins=self.num_mel_bins,
sample_frequency=self.sampling_rate,
)
features = features.numpy()
else:
waveform = np.squeeze(waveform)
Expand All @@ -118,6 +132,7 @@ def _extract_fbank_features(
fft_length=512,
power=2.0,
center=False,
dither=self.dither,
preemphasis=0.97,
mel_filters=self.mel_filters,
log_mel="log",
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
Size of the Fourier transform.
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
dither (`float`, *optional*, defaults to 0.0):
Adds dithering. In other words, adds a small Gaussian noise to each frame.
E.g. use 0.0001 to add dithering with a normal distribution centered
around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
The value 0.0 means no dithering.
Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
the high log_mel_fbank values for signals with hard-zero sections,
when VAD cutoff is present in the signal.
"""

model_input_names = ["input_features"]
Expand All @@ -69,6 +77,7 @@ def __init__(
chunk_length=30,
n_fft=400,
padding_value=0.0,
dither=0.0,
return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
**kwargs,
):
Expand All @@ -85,6 +94,7 @@ def __init__(
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.sampling_rate = sampling_rate
self.dither = dither
self.mel_filters = mel_filter_bank(
num_frequency_bins=1 + n_fft // 2,
num_mel_filters=feature_size,
Expand Down Expand Up @@ -114,6 +124,7 @@ def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> n
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
dither=self.dither,
mel_filters=self.mel_filters,
log_mel="log10",
)
Expand All @@ -132,6 +143,12 @@ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu")
waveform = torch.from_numpy(waveform).to(device, torch.float32)
window = torch.hann_window(self.n_fft, device=device)

# Note: it would be better to dither the chunked waveform,
# so overlapping signal does not get the same dithering.
# But, chunking is happening inside pytorch, so it is here.
if self.dither != 0.0:
waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)

stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,40 @@ def test_call(self):
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))

def test_dither(self):
np.random.seed(42) # seed the dithering randn()

# Tests that features with and without little dithering are similar, but not the same
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's set the seed here, to ensure reproducibility.

dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
dict_no_dither["dither"] = 0.0

dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
dict_dither["dither"] = 1.0

feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
feature_extractor_dither = self.feature_extraction_class(**dict_dither)

# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]

# compute features
input_features_no_dither = feature_extractor_no_dither(
np_speech_inputs, padding=True, return_tensors="np"
).input_features
input_features_dither = feature_extractor_dither(
np_speech_inputs, padding=True, return_tensors="np"
).input_features

# test there is a difference between features (there's added noise to input signal)
diff = input_features_dither - input_features_no_dither

# features are not identical
self.assertTrue(np.abs(diff).mean() > 1e-5)
# features are not too different
self.assertTrue(np.abs(diff).mean() <= 1e-3)
self.assertTrue(np.abs(diff).max() <= 1e-2)
Comment on lines +178 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!


def test_cepstral_mean_and_variance_normalization(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
Expand Down
34 changes: 34 additions & 0 deletions tests/models/whisper/test_feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,40 @@ def test_call(self):
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))

def test_dither(self):
np.random.seed(42) # seed the dithering randn()

# Tests that features with and without little dithering are similar, but not the same
dict_no_dither = self.feat_extract_tester.prepare_feat_extract_dict()
dict_no_dither["dither"] = 0.0

dict_dither = self.feat_extract_tester.prepare_feat_extract_dict()
dict_dither["dither"] = 0.00003 # approx. 1/32k

feature_extractor_no_dither = self.feature_extraction_class(**dict_no_dither)
feature_extractor_dither = self.feature_extraction_class(**dict_dither)

# create three inputs of length 800, 1000, and 1200
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]

# compute features
input_features_no_dither = feature_extractor_no_dither(
np_speech_inputs, padding=True, return_tensors="np"
).input_features
input_features_dither = feature_extractor_dither(
np_speech_inputs, padding=True, return_tensors="np"
).input_features

# test there is a difference between features (there's added noise to input signal)
diff = input_features_dither - input_features_no_dither

# features are not identical
self.assertTrue(np.abs(diff).mean() > 1e-6)
# features are not too different
self.assertTrue(np.abs(diff).mean() <= 1e-4)
self.assertTrue(np.abs(diff).max() <= 1e-3)

@require_torch
def test_double_precision_pad(self):
import torch
Expand Down