diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index 9550b7c2a9ef90..fef8b529957d20 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -181,9 +181,13 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) - is_batched = bool( - isinstance(raw_speech, (list, tuple)) - and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) == 2 + if is_batched_numpy: + assert ( + len(raw_speech.shape) == 2 + ), f"Only mono-channel audio is supported for input to {self.__class__.__name__}" + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) # always return batch diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 1708dbf12512a4..a98a2f18818e15 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -820,9 +820,13 @@ def __call__( values, a list of numpy arrayr or a list of list of float values. """ - is_batched = bool( - isinstance(raw_speech, (list, tuple)) - and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) == 2 + if is_batched_numpy: + assert ( + len(raw_speech.shape) == 2 + ), f"Only mono-channel audio is supported for input to {self.__class__.__name__}" + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) # make sure input is in list format