Skip to content

Commit

Permalink
[Wav2Vec2ProcessorWithLM] Fix auto processor with lm (huggingface#15683)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten authored and Steven committed Feb 18, 2022
1 parent fe8df53 commit 5840911
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
else:
# BeamSearchDecoderCTC has no auto class
kwargs.pop("_from_auto", None)
# snapshot_download has no `trust_remote_code` flag
kwargs.pop("trust_remote_code", None)

# make sure that only relevant filenames are downloaded
language_model_filenames = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*")
Expand Down
20 changes: 20 additions & 0 deletions tests/test_processor_wav2vec2_with_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np

from transformers import AutoProcessor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, is_pyctcdecode_available
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
Expand Down Expand Up @@ -330,3 +331,22 @@ def test_decoder_local_files(self):

# test that both decoder form hub and local files in cache are the same
self.assertListEqual(local_decoder_files, expected_decoder_files)

def test_processor_from_auto_processor(self):
processor_wav2vec2 = Wav2Vec2ProcessorWithLM.from_pretrained("hf-internal-testing/processor_with_lm")
processor_auto = AutoProcessor.from_pretrained("hf-internal-testing/processor_with_lm")

raw_speech = floats_list((3, 1000))

input_wav2vec2 = processor_wav2vec2(raw_speech, return_tensors="np")
input_auto = processor_auto(raw_speech, return_tensors="np")

for key in input_wav2vec2.keys():
self.assertAlmostEqual(input_wav2vec2[key].sum(), input_auto[key].sum(), delta=1e-2)

logits = self._get_dummy_logits()

decoded_wav2vec2 = processor_wav2vec2.batch_decode(logits)
decoded_auto = processor_auto.batch_decode(logits)

self.assertListEqual(decoded_wav2vec2.text, decoded_auto.text)

0 comments on commit 5840911

Please sign in to comment.