From d4762fd476ef2e6871c94fa7fadb785f9dbc24ca Mon Sep 17 00:00:00 2001 From: Leon Wu Date: Mon, 8 May 2023 20:33:07 -0700 Subject: [PATCH] Fix wav2vec2 is_batched check to include 2-D numpy arrays --- .../models/wav2vec2/feature_extraction_wav2vec2.py | 10 +++++++--- .../models/wav2vec2/tokenization_wav2vec2.py | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index 9550b7c2a9ef90..fef8b529957d20 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -181,9 +181,13 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) - is_batched = bool( - isinstance(raw_speech, (list, tuple)) - and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) == 2 + if is_batched_numpy: + assert ( + len(raw_speech.shape) == 2 + ), f"Only mono-channel audio is supported for input to {self.__class__.__name__}" + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) # always return batch diff --git a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py index 1708dbf12512a4..a98a2f18818e15 100644 --- a/src/transformers/models/wav2vec2/tokenization_wav2vec2.py +++ b/src/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -820,9 +820,13 @@ def __call__( values, a list of numpy arrayr or a list of list of float values. """ - is_batched = bool( - isinstance(raw_speech, (list, tuple)) - and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) == 2 + if is_batched_numpy: + assert ( + len(raw_speech.shape) == 2 + ), f"Only mono-channel audio is supported for input to {self.__class__.__name__}" + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) ) # make sure input is in list format