Skip to content

Commit

Permalink
[AutoProcessor] Add Wav2Vec2WithLM & small fix (huggingface#14675)
Browse files Browse the repository at this point in the history
* [AutoProcessor] Add Wav2Vec2WithLM & small fix

* revert line removal

* Update src/transformers/__init__.py

* add test

* up

* up

* small fix
  • Loading branch information
patrickvonplaten authored Dec 8, 2021
1 parent 2294071 commit ee4fa2e
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 16 deletions.
5 changes: 3 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@
"Wav2Vec2Processor",
"Wav2Vec2Tokenizer",
],
"models.wav2vec2_with_lm": [],
"models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"],
"models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"],
"models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
Expand Down Expand Up @@ -474,7 +475,7 @@
]

if is_pyctcdecode_available():
_import_structure["models.wav2vec2"].append("Wav2Vec2ProcessorWithLM")
_import_structure["models.wav2vec2_with_lm"].append("Wav2Vec2ProcessorWithLM")
else:
from .utils import dummy_pyctcdecode_objects

Expand Down Expand Up @@ -2470,7 +2471,7 @@
from .utils.dummy_speech_objects import *

if is_pyctcdecode_available():
from .models.wav2vec2 import Wav2Vec2ProcessorWithLM
from .models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
else:
from .utils.dummy_pyctcdecode_objects import *

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
visual_bert,
vit,
wav2vec2,
wav2vec2_with_lm,
xlm,
xlm_prophetnet,
xlm_roberta,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
("speech_to_text_2", "Speech2Text2Processor"),
("trocr", "TrOCRProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2_with_lm", "Wav2Vec2ProcessorWithLM"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
]
)
Expand Down Expand Up @@ -145,6 +146,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs
}
model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs)
# strip to file name
model_files = [f.split("/")[-1] for f in model_files]

if FEATURE_EXTRACTOR_NAME in model_files:
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
if "processor_class" in config_dict:
Expand Down
7 changes: 1 addition & 6 deletions src/transformers/models/wav2vec2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_flax_available, is_pyctcdecode_available, is_tf_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available


_import_structure = {
Expand All @@ -27,8 +27,6 @@
"tokenization_wav2vec2": ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"],
}

if is_pyctcdecode_available():
_import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"]

if is_torch_available():
_import_structure["modeling_wav2vec2"] = [
Expand Down Expand Up @@ -64,9 +62,6 @@
from .processing_wav2vec2 import Wav2Vec2Processor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2Tokenizer

if is_pyctcdecode_available():
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM

if is_torch_available():
from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down
36 changes: 36 additions & 0 deletions src/transformers/models/wav2vec2_with_lm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_pyctcdecode_available


_import_structure = {}


if is_pyctcdecode_available():
_import_structure["processing_wav2vec2_with_lm"] = ["Wav2Vec2ProcessorWithLM"]


if TYPE_CHECKING:
if is_pyctcdecode_available():
from .processing_wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from ...feature_extraction_utils import FeatureExtractionMixin
from ...file_utils import ModelOutput, requires_backends
from ...tokenization_utils import PreTrainedTokenizer
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
from ..wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from ..wav2vec2.tokenization_wav2vec2 import Wav2Vec2CTCTokenizer


@dataclass
Expand Down Expand Up @@ -159,6 +159,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
if os.path.isdir(pretrained_model_name_or_path):
decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path)
else:
# BeamSearchDecoderCTC has no auto class
kwargs.pop("_from_auto", None)

decoder = BeamSearchDecoderCTC.load_from_hf_hub(pretrained_model_name_or_path, **kwargs)

# set language model attributes
Expand Down
3 changes: 2 additions & 1 deletion tests/fixtures/dummy_feature_extractor_config.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
"processor_class": "Wav2Vec2Processor"
}
17 changes: 14 additions & 3 deletions tests/test_processor_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,24 @@
import os
import tempfile
import unittest
from shutil import copyfile

from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME


SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_PROCESSOR_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
)
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")


class AutoFeatureExtractorTest(unittest.TestCase):
def test_processor_from_model_shortcut(self):
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(processor, Wav2Vec2Processor)

def test_processor_from_local_directory_from_config(self):
def test_processor_from_local_directory_from_repo(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model_config = Wav2Vec2Config()
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
Expand All @@ -44,3 +45,13 @@ def test_processor_from_local_directory_from_config(self):
processor = AutoProcessor.from_pretrained(tmpdirname)

self.assertIsInstance(processor, Wav2Vec2Processor)

def test_processor_from_local_directory_from_extractor_config(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# copy relevant files
copyfile(SAMPLE_PROCESSOR_CONFIG, os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME))
copyfile(SAMPLE_VOCAB, os.path.join(tmpdirname, "vocab.json"))

processor = AutoProcessor.from_pretrained(tmpdirname)

self.assertIsInstance(processor, Wav2Vec2Processor)
2 changes: 1 addition & 1 deletion tests/test_processor_wav2vec2_with_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

if is_pyctcdecode_available():
from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2 import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM


@require_pyctcdecode
Expand Down
6 changes: 5 additions & 1 deletion utils/check_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
# Catches a line if is_foo_available
_re_test_backend = re.compile(r"^\s*if\s+is\_[a-z_]*\_available\(\)")
# Catches a line _import_struct["bla"] = ["foo"]
_re_import_struct_equal_one = re.compile(r'^\s*_import_structure\["\S*"\]\ = "\[(\S*)\]"')
# Catches a line _import_struct["bla"].append("foo")
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
Expand Down Expand Up @@ -88,7 +90,9 @@ def parse_init(init_file):
# Until we unindent, add backend objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
line = lines[line_index]
if _re_import_struct_add_one.search(line) is not None:
if _re_import_struct_equal_one.search(line) is not None:
objects.append(_re_import_struct_equal_one.search(line).groups()[0])
elif _re_import_struct_add_one.search(line) is not None:
objects.append(_re_import_struct_add_one.search(line).groups()[0])
elif _re_import_struct_add_many.search(line) is not None:
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")
Expand Down

0 comments on commit ee4fa2e

Please sign in to comment.