-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PoC for a ProcessorMixin class #15549
Changes from 1 commit
af76d4d
0272f5e
f5c85af
df997ea
e878a54
89ea8dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# coding=utf-8 | ||
# Copyright 2022 The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Processing saving/loading class for common processors. | ||
""" | ||
|
||
import importlib.util | ||
from pathlib import Path | ||
|
||
|
||
# Comment to write | ||
spec = importlib.util.spec_from_file_location( | ||
"transformers", Path(__file__).parent / "__init__.py", submodule_search_locations=[Path(__file__).parent] | ||
) | ||
transformers_module = spec.loader.load_module() | ||
|
||
|
||
AUTO_TO_BASE_CLASS_MAPPING = { | ||
"AutoTokenizer": "PreTrainedTokenizerBase", | ||
"AutoFeatureExtractor": "FeatureExtractionMixin", | ||
} | ||
|
||
|
||
class ProcessorMixin: | ||
""" | ||
This is a mixin used to provide saving/loading functionality for all processor classes. | ||
""" | ||
|
||
attributes = ["feature_extractor", "tokenizer"] | ||
# Names need to be attr_class for attr in attributes | ||
feature_extractor_class = None | ||
tokenizer_class = None | ||
|
||
# args have to match the attributes class attribute | ||
def __init__(self, *args): | ||
if len(args) != len(self.attributes): | ||
raise ValueError( | ||
f"This processor requires {len(self.attributes)} arguments: {', '.join(self.attributes)}. Got " | ||
f"{len(args)} arguments instead." | ||
) | ||
|
||
# Check each arg is of the proper class (this will also catch a user initializing in the wrong order) | ||
for arg, attribute_name in zip(args, self.attributes): | ||
class_name = getattr(self, f"{attribute_name}_class") | ||
# Nothing is every going to be an instance of "AutoXxx", in that case we check the base class. | ||
sgugger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name) | ||
if isinstance(class_name, tuple): | ||
proper_class = tuple(getattr(transformers_module, n) for n in class_name if n is not None) | ||
else: | ||
proper_class = getattr(transformers_module, class_name) | ||
|
||
if not isinstance(arg, proper_class): | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great error message |
||
f"Received a {type(arg)} for argument {attribute_name}, but a {class_name} was expected." | ||
sgugger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
setattr(self, attribute_name, arg) | ||
|
||
def __repr__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cool! |
||
attributes_repr = [f"- {name}: {repr(getattr(self, name))}" for name in self.attributes] | ||
attributes_repr = "\n".join(attributes_repr) | ||
return f"{self.__class__.__name__}:\n{attributes_repr}" | ||
|
||
def save_pretrained(self, save_directory): | ||
""" | ||
Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it | ||
can be reloaded using the [`~ProcessorMixin.from_pretrained`] method. | ||
|
||
<Tip> | ||
|
||
This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and | ||
[`~tokenization_utils_base.PreTrainedTokenizer.save_pretrained`]. Please refer to the docstrings of the methods | ||
above for more information. | ||
|
||
</Tip> | ||
|
||
Args: | ||
save_directory (`str` or `os.PathLike`): | ||
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will | ||
be created if it does not exist). | ||
""" | ||
for attribute_name in self.attributes: | ||
attribute = getattr(self, attribute_name) | ||
# Include the processor class in the attribute config so this processor can then be reloaded with the | ||
# `AutoProcessor` API. | ||
if hasattr(attribute, "_set_processor_class"): | ||
attribute._set_processor_class(self.__class__.__name__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a test to make sure that every tokenizer & feature extractor has this function in a follow up PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is defined by the base classes ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yeah true - this makes sense! |
||
attribute.save_pretrained(save_directory) | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | ||
r""" | ||
Instantiate a processor associated with a pretrained model. | ||
|
||
<Tip> | ||
|
||
This class method is simply calling the feature extractor | ||
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and the tokenizer | ||
[`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] methods. Please refer to the docstrings of the | ||
methods above for more information. | ||
|
||
</Tip> | ||
|
||
Args: | ||
pretrained_model_name_or_path (`str` or `os.PathLike`): | ||
This can be either: | ||
|
||
- a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on | ||
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or | ||
namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. | ||
- a path to a *directory* containing a feature extractor file saved using the | ||
[`~SequenceFeatureExtractor.save_pretrained`] method, e.g., `./my_model_directory/`. | ||
- a path or url to a saved feature extractor JSON *file*, e.g., | ||
`./my_model_directory/preprocessor_config.json`. | ||
**kwargs | ||
Additional keyword arguments passed along to both | ||
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and | ||
[`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`]. | ||
""" | ||
args = [] | ||
for attribute_name in cls.attributes: | ||
class_name = getattr(cls, f"{attribute_name}_class") | ||
if isinstance(class_name, tuple): | ||
classes = tuple(getattr(transformers_module, n) if n is not None else None for n in class_name) | ||
use_fast = kwargs.get("use_fast", True) | ||
if use_fast and classes[1] is not None: | ||
attribute_class = classes[1] | ||
else: | ||
attribute_class = classes[0] | ||
else: | ||
attribute_class = getattr(transformers_module, class_name) | ||
|
||
args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs)) | ||
return cls(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you find the need to duplicate the
transformers
module in memory. This code executes it again, and instists on how it is loaded, for no obvious reason.I do not recognize what the difference is to
import transformers as transformers_module
can you explain?