From 24e80df43e2f6630380595006ed732d1b3fd38db Mon Sep 17 00:00:00 2001 From: LWprogramming Date: Tue, 23 May 2023 11:37:35 -0700 Subject: [PATCH] is_batched fix for remaining 2-D numpy arrays (#23309) * Fix is_batched code to allow 2-D numpy arrays for audio * Tests * Fix typo * Incorporate comments from PR #23223 --- ...eature_extraction_audio_spectrogram_transformer.py | 11 +++++++---- .../models/clap/feature_extraction_clap.py | 11 +++++++---- .../models/mctct/feature_extraction_mctct.py | 11 +++++++---- .../feature_extraction_speech_to_text.py | 11 +++++++---- .../models/speecht5/feature_extraction_speecht5.py | 11 +++++++---- .../models/tvlt/feature_extraction_tvlt.py | 11 +++++++---- .../models/whisper/feature_extraction_whisper.py | 11 +++++++---- ...eature_extraction_audio_spectrogram_transformer.py | 8 ++++++++ tests/models/clap/test_feature_extraction_clap.py | 8 ++++++++ tests/models/mctct/test_feature_extraction_mctct.py | 8 ++++++++ .../test_feature_extraction_speech_to_text.py | 8 ++++++++ .../speecht5/test_feature_extraction_speecht5.py | 8 ++++++++ tests/models/tvlt/test_feature_extraction_tvlt.py | 9 +++++++++ .../models/whisper/test_feature_extraction_whisper.py | 8 ++++++++ 14 files changed, 106 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py index deda2fc7781b28..786548fd2336e9 100644 --- a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py @@ -135,7 +135,8 @@ def __call__( Args: raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. sampling_rate (`int`, *optional*): The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass `sampling_rate` at the forward call to prevent silent errors. @@ -160,9 +161,11 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) - is_batched = bool( - isinstance(raw_speech, (list, tuple)) - and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) if is_batched: diff --git a/src/transformers/models/clap/feature_extraction_clap.py b/src/transformers/models/clap/feature_extraction_clap.py index 6edd739fa16d60..d33307ffbd22fc 100644 --- a/src/transformers/models/clap/feature_extraction_clap.py +++ b/src/transformers/models/clap/feature_extraction_clap.py @@ -272,7 +272,8 @@ def __call__( Args: raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. truncation (`str`, *optional*): Truncation pattern for long audio inputs. Two patterns are available: - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and @@ -312,9 +313,11 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) - is_batched = bool( - isinstance(raw_speech, (list, tuple)) - and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) if is_batched: diff --git a/src/transformers/models/mctct/feature_extraction_mctct.py b/src/transformers/models/mctct/feature_extraction_mctct.py index 467e654244b993..9e9e276c168ca1 100644 --- a/src/transformers/models/mctct/feature_extraction_mctct.py +++ b/src/transformers/models/mctct/feature_extraction_mctct.py @@ -180,7 +180,8 @@ def __call__( Args: raw_speech (`torch.Tensor`, `np.ndarray`, `List[float]`, `List[torch.Tensor]`, `List[np.ndarray]`, `List[List[float]]`): The sequence or batch of sequences to be padded. Each sequence can be a tensor, a numpy array, a list - of float values, a list of tensors, a list of numpy arrays or a list of list of float values. + of float values, a list of tensors, a list of numpy arrays or a list of list of float values. Must be + mono channel audio, not stereo, i.e. single float per timestep. padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: @@ -231,9 +232,11 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) - is_batched = bool( - isinstance(raw_speech, (list, tuple)) - and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) if is_batched: diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py index a5e6b0d4004264..81f2ea4e99be22 100644 --- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py +++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py @@ -141,7 +141,8 @@ def __call__( Args: raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: @@ -200,9 +201,11 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) - is_batched = bool( - isinstance(raw_speech, (list, tuple)) - and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) if is_batched: diff --git a/src/transformers/models/speecht5/feature_extraction_speecht5.py b/src/transformers/models/speecht5/feature_extraction_speecht5.py index 5fe6ca39765c1f..dd5ff4c8a1afae 100644 --- a/src/transformers/models/speecht5/feature_extraction_speecht5.py +++ b/src/transformers/models/speecht5/feature_extraction_speecht5.py @@ -201,7 +201,8 @@ def __call__( Args: audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, *optional*): The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. This outputs waveform features. + values, a list of numpy arrays or a list of list of float values. This outputs waveform features. Must + be mono channel audio, not stereo, i.e. single float per timestep. audio_target (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, *optional*): The sequence or batch of sequences to be processed as targets. Each sequence can be a numpy array, a list of float values, a list of numpy arrays or a list of list of float values. This outputs log-mel @@ -307,9 +308,11 @@ def _process_audio( return_tensors: Optional[Union[str, TensorType]] = None, **kwargs, ) -> BatchFeature: - is_batched = bool( - isinstance(speech, (list, tuple)) - and (isinstance(speech[0], np.ndarray) or isinstance(speech[0], (tuple, list))) + is_batched_numpy = isinstance(speech, np.ndarray) and len(speech.shape) > 1 + if is_batched_numpy and len(speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(speech, (list, tuple)) and (isinstance(speech[0], (np.ndarray, tuple, list))) ) if is_batched: diff --git a/src/transformers/models/tvlt/feature_extraction_tvlt.py b/src/transformers/models/tvlt/feature_extraction_tvlt.py index 6d919550cf558d..d5beba76bd6986 100644 --- a/src/transformers/models/tvlt/feature_extraction_tvlt.py +++ b/src/transformers/models/tvlt/feature_extraction_tvlt.py @@ -129,7 +129,8 @@ def __call__( Args: raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors instead of list of python integers. Acceptable values are: - `'pt'`: Return PyTorch `torch.Tensor` objects. @@ -176,9 +177,11 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) - is_batched = bool( - isinstance(raw_speech, (list, tuple)) - and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) if is_batched: raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index e0b772216205fe..70eb8bd94e7676 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -152,7 +152,8 @@ def __call__( Args: raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. truncation (`bool`, *optional*, default to `True`): Activates truncation to cut input sequences longer than *max_length* to *max_length*. pad_to_multiple_of (`int`, *optional*, defaults to None): @@ -203,9 +204,11 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) - is_batched = bool( - isinstance(raw_speech, (list, tuple)) - and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) if is_batched: diff --git a/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py b/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py index a7a81dceb15314..69a1bddc825080 100644 --- a/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py +++ b/tests/models/audio_spectrogram_transformer/test_feature_extraction_audio_spectrogram_transformer.py @@ -125,6 +125,14 @@ def test_call(self): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_sequences_1 = feat_extract(speech_inputs, return_tensors="np").input_values + encoded_sequences_2 = feat_extract(np_speech_inputs, return_tensors="np").input_values + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + @require_torch def test_double_precision_pad(self): import torch diff --git a/tests/models/clap/test_feature_extraction_clap.py b/tests/models/clap/test_feature_extraction_clap.py index 733dd66681c473..c49d045ba87407 100644 --- a/tests/models/clap/test_feature_extraction_clap.py +++ b/tests/models/clap/test_feature_extraction_clap.py @@ -139,6 +139,14 @@ def test_call(self): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features + encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + def test_double_precision_pad(self): import torch diff --git a/tests/models/mctct/test_feature_extraction_mctct.py b/tests/models/mctct/test_feature_extraction_mctct.py index cab2911fdd40f4..f3d8f0fea940e9 100644 --- a/tests/models/mctct/test_feature_extraction_mctct.py +++ b/tests/models/mctct/test_feature_extraction_mctct.py @@ -134,6 +134,14 @@ def test_call(self): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features + encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + def test_cepstral_mean_and_variance_normalization(self): feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)] diff --git a/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py b/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py index aedd445e5d6393..293b33fde80e3a 100644 --- a/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py +++ b/tests/models/speech_to_text/test_feature_extraction_speech_to_text.py @@ -136,6 +136,14 @@ def test_call(self): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features + encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + def test_cepstral_mean_and_variance_normalization(self): feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict()) speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)] diff --git a/tests/models/speecht5/test_feature_extraction_speecht5.py b/tests/models/speecht5/test_feature_extraction_speecht5.py index 11ed50de4bbc3a..a09bf7f8ae58d2 100644 --- a/tests/models/speecht5/test_feature_extraction_speecht5.py +++ b/tests/models/speecht5/test_feature_extraction_speecht5.py @@ -275,6 +275,14 @@ def test_call_target(self): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_values + encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_values + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + def test_batch_feature_target(self): speech_inputs = self.feat_extract_tester.prepare_inputs_for_target() feat_extract = self.feature_extraction_class(**self.feat_extract_dict) diff --git a/tests/models/tvlt/test_feature_extraction_tvlt.py b/tests/models/tvlt/test_feature_extraction_tvlt.py index a76f3c9dca08f7..051708a306981f 100644 --- a/tests/models/tvlt/test_feature_extraction_tvlt.py +++ b/tests/models/tvlt/test_feature_extraction_tvlt.py @@ -189,6 +189,15 @@ def test_call(self): self.assertTrue(encoded_audios.shape[-2] <= feature_extractor.spectrogram_length) self.assertTrue(encoded_audios.shape[-3] == feature_extractor.num_channels) + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_audios = feature_extractor(np_speech_inputs, return_tensors="np", sampling_rate=44100).audio_values + self.assertTrue(encoded_audios.ndim == 4) + self.assertTrue(encoded_audios.shape[-1] == feature_extractor.feature_size) + self.assertTrue(encoded_audios.shape[-2] <= feature_extractor.spectrogram_length) + self.assertTrue(encoded_audios.shape[-3] == feature_extractor.num_channels) + def _load_datasamples(self, num_samples): ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech diff --git a/tests/models/whisper/test_feature_extraction_whisper.py b/tests/models/whisper/test_feature_extraction_whisper.py index 31ea28b9ad628c..90cbfc21c04f35 100644 --- a/tests/models/whisper/test_feature_extraction_whisper.py +++ b/tests/models/whisper/test_feature_extraction_whisper.py @@ -173,6 +173,14 @@ def test_call(self): for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + # Test 2-D numpy arrays are batched. + speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)] + np_speech_inputs = np.asarray(speech_inputs) + encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features + encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features + for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2): + self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3)) + # Test truncation required speech_inputs = [floats_list((1, x))[0] for x in range(200, (feature_extractor.n_samples + 500), 200)] np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]