diff --git a/utils/check_copies.py b/utils/check_copies.py index 0352b6419e3098..563f88a5ec130a 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -40,6 +40,7 @@ import glob import os import re +from typing import List, Optional, Tuple import black from doc_builder.style_doc import style_docstrings_in_code @@ -125,14 +126,22 @@ transformers_module = direct_transformers_import(TRANSFORMERS_PATH) -def _should_continue(line, indent): +def _should_continue(line: str, indent: str) -> bool: # Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a # function definition return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None -def find_code_in_transformers(object_name): - """Find and return the code source code of `object_name`.""" +def find_code_in_transformers(object_name: str) -> str: + """ + Find and return the source code of an object. + + Args: + object_name (`str`): The name of the object we want the source code of. + + Returns: + `str`: The source code of the object. + """ parts = object_name.split(".") i = 0 @@ -181,7 +190,16 @@ def find_code_in_transformers(object_name): _re_fill_pattern = re.compile(r"]*>") -def get_indent(code): +def get_indent(code: str) -> str: + """ + Find the indent in the first non empty line in a code sample. + + Args: + code (`str`): The code to inspect. + + Returns: + `str`: The indent looked at (as string). + """ lines = code.split("\n") idx = 0 while idx < len(lines) and len(lines[idx]) == 0: @@ -191,9 +209,15 @@ def get_indent(code): return "" -def blackify(code): +def blackify(code: str) -> str: """ - Applies the black part of our `make style` command to `code`. + Applies the black part of our `make style` command to some code. + + Args: + code (`str`): The code to format. + + Returns: + `str`: The formatted code. """ has_indent = len(get_indent(code)) > 0 if has_indent: @@ -204,14 +228,22 @@ def blackify(code): return result[len("class Bla:\n") :] if has_indent else result -def check_codes_match(observed_code, theoretical_code): +def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int]: """ - Checks if the code in `observed_code` and `theoretical_code` match with the exception of the class/function name. - Returns the index of the first line where there is a difference (if any) and `None` if the codes match. + Checks if two version of a code match with the exception of the class/function name. + + Args: + observed_code (`str`): The code found. + theoretical_code (`str`): The code to match. + + Returns: + `Optional[int]`: The index of the first line where there is a difference (if any) and `None` if the codes + match. """ observed_code_header = observed_code.split("\n")[0] theoretical_code_header = theoretical_code.split("\n")[0] + # Catch the function/class name: it is expected that those do not match. _re_class_match = re.compile(r"class\s+([^\(:]+)(?:\(|:)") _re_func_match = re.compile(r"def\s+([^\(]+)\(") for re_pattern in [_re_class_match, _re_func_match]: @@ -220,6 +252,7 @@ def check_codes_match(observed_code, theoretical_code): theoretical_name = re_pattern.search(theoretical_code_header).groups()[0] theoretical_code_header = theoretical_code_header.replace(theoretical_name, observed_obj_name) + # Find the first diff. Line 0 is special since we need to compare with the function/class names ignored. diff_index = 0 if theoretical_code_header != observed_code_header: return 0 @@ -231,11 +264,19 @@ def check_codes_match(observed_code, theoretical_code): diff_index += 1 -def is_copy_consistent(filename, overwrite=False): +def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[Tuple[str, int]]]: """ - Check if the code commented as a copy in `filename` matches the original. + Check if the code commented as a copy in a file matches the original. - Return the differences or overwrites the content depending on `overwrite`. + Args: + filename (`str`): + The name of the file to check. + overwrite (`bool`, *optional*, defaults to `False`): + Whether or not to overwrite the copies when they don't match. + + Returns: + `Optional[List[Tuple[str, int]]]`: If `overwrite=False`, returns the list of differences as tuples `(str, int)` + with the name of the object having a diff and the line number where theere is the first diff. """ with open(filename, "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() @@ -308,8 +349,12 @@ def is_copy_consistent(filename, overwrite=False): def check_copies(overwrite: bool = False): """ - Check every file is copy-consistent with the original and maybe `overwrite` content when it is not. Also check the - model list in the main README and other READMEs/index.md are consistent. + Check every file is copy-consistent with the original. Also check the model list in the main README and other + READMEs/index.md are consistent. + + Args: + overwrite (`bool`, *optional*, defaults to `False`): + Whether or not to overwrite the copies when they don't match. """ all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True) diffs = [] @@ -328,8 +373,11 @@ def check_copies(overwrite: bool = False): def check_full_copies(overwrite: bool = False): """ - Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent and maybe - `overwrite` to fix issues. + Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent. + + Args: + overwrite (`bool`, *optional*, defaults to `False`): + Whether or not to overwrite the copies when they don't match. """ diffs = [] for target, source in FULL_COPIES.items(): @@ -354,8 +402,18 @@ def check_full_copies(overwrite: bool = False): ) -def get_model_list(filename, start_prompt, end_prompt): - """Extracts the model list from a README, between `start_prompt` and `end_prompt`.""" +def get_model_list(filename: str, start_prompt: str, end_prompt: str) -> str: + """ + Extracts the model list from a README. + + Args: + filename (`str`): The name of the README file to check. + start_prompt (`str`): The string to look for that introduces the model list. + end_prompt (`str`): The string to look for that ends the model list. + + Returns: + `str`: The model list. + """ with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() # Find the start of the list. @@ -368,6 +426,7 @@ def get_model_list(filename, start_prompt, end_prompt): current_line = "" end_index = start_index + # Keep going until the end of the list. while not lines[end_index].startswith(end_prompt): if lines[end_index].startswith("1."): if len(current_line) > 1: @@ -382,7 +441,7 @@ def get_model_list(filename, start_prompt, end_prompt): return "".join(result) -def convert_to_localized_md(model_list, localized_model_list, format_str): +def convert_to_localized_md(model_list: str, localized_model_list: str, format_str: str) -> Tuple[bool, str]: """ Compare the model list from the main README to the one in a localized README. @@ -458,19 +517,33 @@ def _rep(match): return readmes_match, "\n".join((x[1] for x in sorted_index)) + "\n" -def convert_readme_to_index(model_list): +def convert_readme_to_index(model_list: str) -> str: """ - Converts the model list of the README to the index.md format. + Converts the model list of the README to the index.md format (adapting links to the doc to relative links). + + Args: + model_list (`str`): The model list of the main README. + + Returns: + `str`: The model list in the format for the index. """ # We need to replce both link to the main doc and stable doc (the order of the next two instructions is important). model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "") return model_list.replace("https://huggingface.co/docs/transformers/", "") -def _find_text_in_file(filename, start_prompt, end_prompt): +def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> Tuple[str, int, int, List[str]]: """ - Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty - lines. + Find the text in a file between two prompts. + + Args: + filename (`str`): The name of the file to look into. + start_prompt (`str`): The string to look for that introduces the content looked for. + end_prompt (`str`): The string to look for that ends the content looked for. + + Returns: + Tuple[str, int, int, List[str]]: The content between the two prompts, the index of the start line in the + original file, the index of the end line in the original file and the list of lines of that file. """ with open(filename, "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() @@ -493,9 +566,13 @@ def _find_text_in_file(filename, start_prompt, end_prompt): return "".join(lines[start_index:end_index]), start_index, end_index, lines -def check_model_list_copy(overwrite=False, max_per_line=119): +def check_model_list_copy(overwrite: bool = False): """ Check the model lists in the README is consistent with the ones in the other READMES and also with `index.nmd`. + + Args: + overwrite (`bool`, *optional*, defaults to `False`): + Whether or not to overwrite the copies when they don't match. """ # Fix potential doc links in the README with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f: @@ -526,6 +603,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119): end_prompt=LOCALIZED_READMES["README.md"]["end_prompt"], ) + # Buld the converted Markdown. converted_md_lists = [] for filename, value in LOCALIZED_READMES.items(): _start_prompt = value["start_prompt"] @@ -537,6 +615,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119): converted_md_lists.append((filename, readmes_match, converted_md_list, _start_prompt, _end_prompt)) + # Build the converted index and compare it. converted_md_list = convert_readme_to_index(md_list) if converted_md_list != index_list: if overwrite: @@ -548,6 +627,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119): "`make fix-copies` to fix this." ) + # Compare the converted Markdowns for converted_md_list in converted_md_lists: filename, readmes_match, converted_md, _start_prompt, _end_prompt = converted_md_list @@ -606,10 +686,13 @@ def check_model_list_copy(overwrite=False, max_per_line=119): ) -def check_readme(overwrite=False): +def check_readme(overwrite: bool = False): """ - Check if the main README contains all the models in the library or not. If `overwrite`, will add an entry for the - missing models using `README_TEMPLATE`. + Check if the main README contains all the models in the library or not. + + Args: + overwrite (`bool`, *optional*, defaults to `False`): + Whether or not to add an entry for the missing models using `README_TEMPLATE`. """ info = LOCALIZED_READMES["README.md"] models, start_index, end_index, lines = _find_text_in_file( diff --git a/utils/check_doc_toc.py b/utils/check_doc_toc.py index 83c6be4795362b..ccbff5e0b648ee 100644 --- a/utils/check_doc_toc.py +++ b/utils/check_doc_toc.py @@ -34,6 +34,7 @@ import argparse from collections import defaultdict +from typing import List import yaml @@ -41,7 +42,7 @@ PATH_TO_TOC = "docs/source/en/_toctree.yml" -def clean_model_doc_toc(model_doc): +def clean_model_doc_toc(model_doc: List[dict]) -> List[dict]: """ Cleans a section of the table of content of the model documentation (one specific modality) by removing duplicates and sorting models alphabetically. @@ -77,7 +78,7 @@ def clean_model_doc_toc(model_doc): return sorted(new_doc, key=lambda s: s["title"].lower()) -def check_model_doc(overwrite=False): +def check_model_doc(overwrite: bool = False): """ Check that the content of the table of content in `_toctree.yml` is clean (no duplicates and sorted for the model API doc) and potentially auto-cleans it. diff --git a/utils/check_doctest_list.py b/utils/check_doctest_list.py index 3815a2bda0bab2..ee751bc279c4d8 100644 --- a/utils/check_doctest_list.py +++ b/utils/check_doctest_list.py @@ -40,7 +40,16 @@ DOCTEST_FILE_PATHS = ["documentation_tests.txt", "slow_documentation_tests.txt"] -def clean_doctest_list(doctest_file, overwrite=False): +def clean_doctest_list(doctest_file: str, overwrite: bool = False): + """ + Cleans the doctest in a given file. + + Args: + doctest_file (`str`): + The path to the doctest file to check or clean. + overwrite (`bool`, *optional*, defaults to `False`): + Whether or not to fix problems. If `False`, will error when the file is not clean. + """ non_existent_paths = [] all_paths = [] with open(doctest_file, "r", encoding="utf-8") as f: diff --git a/utils/check_dummies.py b/utils/check_dummies.py index 39869e87fb67e7..a3ab6ebfa77b92 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -12,10 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +This script is responsible for making sure the dummies in utils/dummies_xxx.py are up to date with the main init. + +Why dummies? This is to make sure that a user can always import all objects from `transformers`, even if they don't +have the necessary extra libs installed. Those objects will then raise helpful error message whenever the user tries +to access one of their methods. + +Usage (from the root of the repo): + +Check that the dummy files are up to date (used in `make repo-consistency`): + +```bash +python utils/check_dummies.py +``` +Update the dummy files if needed (used in `make fix-copies`): + +```bash +python utils/check_dummies.py --fix_and_overwrite +``` +""" import argparse import os import re +from typing import Dict, List, Optional # All paths are set with the intent you should run this script from the root of the repo with the command @@ -26,13 +47,16 @@ _re_backend = re.compile(r"is\_([a-z_]*)_available()") # Matches from xxx import bla _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") +# Matches if not is_xxx_available() _re_test_backend = re.compile(r"^\s+if\s+not\s+\(?is\_[a-z_]*\_available\(\)") +# Template for the dummy objects. DUMMY_CONSTANT = """ {0} = None """ + DUMMY_CLASS = """ class {0}(metaclass=DummyObject): _backends = {1} @@ -48,8 +72,18 @@ def {0}(*args, **kwargs): """ -def find_backend(line): - """Find one (or multiple) backend in a code line of the init.""" +def find_backend(line: str) -> Optional[str]: + """ + Find one (or multiple) backend in a code line of the init. + + Args: + line (`str`): A code line in an init file. + + Returns: + Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line + contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so + `xxx_and_yyy` for instance). + """ if _re_test_backend.search(line) is None: return None backends = [b[0] for b in _re_backend.findall(line)] @@ -57,8 +91,13 @@ def find_backend(line): return "_and_".join(backends) -def read_init(): - """Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects.""" +def read_init() -> Dict[str, List[str]]: + """ + Read the init and extract backend-specific objects. + + Returns: + Dict[str, List[str]]: A dictionary mapping backend name to the list of object names requiring that backend. + """ with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() @@ -83,8 +122,10 @@ def read_init(): line = lines[line_index] single_line_import_search = _re_single_line_import.search(line) if single_line_import_search is not None: + # Single-line imports objects.extend(single_line_import_search.groups()[0].split(", ")) elif line.startswith(" " * 12): + # Multiple-line imports (with 3 indent level) objects.append(line[12:-2]) line_index += 1 @@ -95,8 +136,17 @@ def read_init(): return backend_specific_objects -def create_dummy_object(name, backend_name): - """Create the code for the dummy object corresponding to `name`.""" +def create_dummy_object(name: str, backend_name: str) -> str: + """ + Create the code for a dummy object. + + Args: + name (`str`): The name of the object. + backend_name (`str`): The name of the backend required for that object. + + Returns: + `str`: The code of the dummy object. + """ if name.isupper(): return DUMMY_CONSTANT.format(name) elif name.islower(): @@ -105,11 +155,21 @@ def create_dummy_object(name, backend_name): return DUMMY_CLASS.format(name, backend_name) -def create_dummy_files(backend_specific_objects=None): - """Create the content of the dummy files.""" +def create_dummy_files(backend_specific_objects: Optional[Dict[str, List[str]]] = None) -> Dict[str, str]: + """ + Create the content of the dummy files. + + Args: + backend_specific_objects (`Dict[str, List[str]]`, *optional*): + The mapping backend name to list of backend-specific objects. If not passed, will be obtained by calling + `read_init()`. + + Returns: + `Dict[str, str]`: A dictionary mapping backend name to code of the corresponding backend file. + """ if backend_specific_objects is None: backend_specific_objects = read_init() - # For special correspondence backend to module name as used in the function requires_modulename + dummy_files = {} for backend, objects in backend_specific_objects.items(): @@ -122,10 +182,17 @@ def create_dummy_files(backend_specific_objects=None): return dummy_files -def check_dummies(overwrite=False): - """Check if the dummy files are up to date and maybe `overwrite` with the right content.""" +def check_dummies(overwrite: bool = False): + """ + Check if the dummy files are up to date and maybe `overwrite` with the right content. + + Args: + overwrite (`bool`, *optional*, default to `False`): + Whether or not to overwrite the content of the dummy files. Will raise an error if they are not up to date + when `overwrite=False`. + """ dummy_files = create_dummy_files() - # For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py + # For special correspondence backend name to shortcut as used in utils/dummy_xxx_objects.py short_names = {"torch": "pt"} # Locate actual dummy modules and read their content. @@ -143,6 +210,7 @@ def check_dummies(overwrite=False): else: actual_dummies[backend] = "" + # Compare actual with what they should be. for backend in dummy_files.keys(): if dummy_files[backend] != actual_dummies[backend]: if overwrite: diff --git a/utils/check_inits.py b/utils/check_inits.py index 12b61223e42a85..43361adbf8f553 100644 --- a/utils/check_inits.py +++ b/utils/check_inits.py @@ -12,13 +12,37 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Utility that checks the custom inits of Transformers are well-defined: Transformers uses init files that delay the +import of an object to when it's actually needed. This is to avoid the main init importing all models, which would +make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with +delayed imports have two halves: one definining a dictionary `_import_structure` which maps modules to the name of the +objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. The goal of this +script is to check the objects defined in both halves are the same. + +This also checks the main init properly references all submodules, even if it doesn't import anything from them: every +submodule should be defined as a key of `_import_structure`, with an empty list as value potentially, or the submodule +won't be importable. + +Use from the root of the repo with: + +```bash +python utils/check_inits.py +``` + +for a check that will error in case of inconsistencies (used by `make repo-consistency`). + +There is no auto-fix possible here sadly :-( +""" import collections import os import re from pathlib import Path +from typing import Dict, List, Optional, Tuple +# Path is set with the intent you should run this script from the root of the repo. PATH_TO_TRANSFORMERS = "src/transformers" @@ -46,8 +70,18 @@ _re_else = re.compile(r"^\s*else:") -def find_backend(line): - """Find one (or multiple) backend in a code line of the init.""" +def find_backend(line: str) -> Optional[str]: + """ + Find one (or multiple) backend in a code line of the init. + + Args: + line (`str`): A code line of the main init. + + Returns: + Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line + contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so + `xxx_and_yyy` for instance). + """ if _re_test_backend.search(line) is None: return None backends = [b[0] for b in _re_backend.findall(line)] @@ -55,14 +89,23 @@ def find_backend(line): return "_and_".join(backends) -def parse_init(init_file): +def parse_init(init_file) -> Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]: """ - Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects - defined + Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects + defined. + + Args: + init_file (`str`): Path to the init file to inspect. + + Returns: + `Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]`: A tuple of two dictionaries mapping backends to list of + imported objects, one for the `_import_structure` part of the init and one for the `TYPE_CHECKING` part of the + init. Returns `None` if the init is not a custom init. """ with open(init_file, "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() + # Get the to `_import_structure` definition. line_index = 0 while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"): line_index += 1 @@ -91,7 +134,9 @@ def parse_init(init_file): objects.append(line[9:-3]) line_index += 1 + # Those are stored with the key "none". import_dict_objects = {"none": objects} + # Let's continue with backend-specific objects in _import_structure while not lines[line_index].startswith("if TYPE_CHECKING"): # If the line is an if not is_backend_available, we grab all objects associated. @@ -151,6 +196,7 @@ def parse_init(init_file): line_index += 1 type_hint_objects = {"none": objects} + # Let's continue with backend-specific objects while line_index < len(lines): # If the line is an if is_backend_available, we grab all objects associated. @@ -186,19 +232,33 @@ def parse_init(init_file): return import_dict_objects, type_hint_objects -def analyze_results(import_dict_objects, type_hint_objects): +def analyze_results(import_dict_objects: Dict[str, List[str]], type_hint_objects: Dict[str, List[str]]) -> List[str]: """ Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init. + + Args: + import_dict_objects (`Dict[str, List[str]]`): + A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to + list of imported objects. + type_hint_objects (`Dict[str, List[str]]`): + A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to + list of imported objects. + + Returns: + `List[str]`: The list of errors corresponding to mismatches. """ def find_duplicates(seq): return [k for k, v in collections.Counter(seq).items() if v > 1] + # If one backend is missing from the other part of the init, error early. if list(import_dict_objects.keys()) != list(type_hint_objects.keys()): return ["Both sides of the init do not have the same backends!"] errors = [] + # Find all errors. for key in import_dict_objects.keys(): + # Duplicate imports in any half. duplicate_imports = find_duplicates(import_dict_objects[key]) if duplicate_imports: errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}") @@ -206,6 +266,7 @@ def find_duplicates(seq): if duplicate_type_hints: errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}") + # Missing imports in either part of the init. if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])): name = "base imports" if key == "none" else f"{key} backend" errors.append(f"Differences for {name}:") @@ -237,7 +298,7 @@ def check_all_inits(): raise ValueError("\n\n".join(failures)) -def get_transformers_submodules(): +def get_transformers_submodules() -> List[str]: """ Returns the list of Transformers submodules. """ @@ -272,6 +333,9 @@ def get_transformers_submodules(): def check_submodules(): + """ + Check all submodules of Transformers are properly registered in the main init. Error otherwise. + """ # This is to make sure the transformers module imported is the one in the repo. from transformers.utils import direct_transformers_import diff --git a/utils/check_repo.py b/utils/check_repo.py index 7af69519c68fea..8be8469465d56f 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -12,15 +12,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +Utility that performs several consistency checks on the repo. This includes: +- checking all models are properly defined in the __init__ of models/ +- checking all models are in the main __init__ +- checking all models are properly tested +- checking all object in the main __init__ are documented +- checking all models are in at least one auto class +- checking all the auto mapping are properly defined (no typos, importable) +- checking the list of deprecated models is up to date + +Use from the root of the repo with (as used in `make repo-consistency`): + +```bash +python utils/check_repo.py +``` + +It has no auto-fix mode. +""" import inspect import os import re import sys +import types import warnings from collections import OrderedDict from difflib import get_close_matches from pathlib import Path +from typing import List, Tuple from transformers import is_flax_available, is_tf_available, is_torch_available from transformers.models.auto import get_values @@ -60,91 +79,25 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested "InstructBlipQFormerModel", # Building part of bigger (tested) model. - "NllbMoeDecoder", - "NllbMoeEncoder", "UMT5EncoderModel", # Building part of bigger (tested) model. - "LlamaDecoder", # Building part of bigger (tested) model. "Blip2QFormerModel", # Building part of bigger (tested) model. - "DetaEncoder", # Building part of bigger (tested) model. - "DetaDecoder", # Building part of bigger (tested) model. "ErnieMForInformationExtraction", - "GraphormerEncoder", # Building part of bigger (tested) model. "GraphormerDecoderHead", # Building part of bigger (tested) model. - "CLIPSegDecoder", # Building part of bigger (tested) model. - "TableTransformerEncoder", # Building part of bigger (tested) model. - "TableTransformerDecoder", # Building part of bigger (tested) model. - "TimeSeriesTransformerEncoder", # Building part of bigger (tested) model. - "TimeSeriesTransformerDecoder", # Building part of bigger (tested) model. - "InformerEncoder", # Building part of bigger (tested) model. - "InformerDecoder", # Building part of bigger (tested) model. - "AutoformerEncoder", # Building part of bigger (tested) model. - "AutoformerDecoder", # Building part of bigger (tested) model. "JukeboxVQVAE", # Building part of bigger (tested) model. "JukeboxPrior", # Building part of bigger (tested) model. - "DeformableDetrEncoder", # Building part of bigger (tested) model. - "DeformableDetrDecoder", # Building part of bigger (tested) model. - "OPTDecoder", # Building part of bigger (tested) model. - "FlaxWhisperDecoder", # Building part of bigger (tested) model. - "FlaxWhisperEncoder", # Building part of bigger (tested) model. - "WhisperDecoder", # Building part of bigger (tested) model. - "WhisperEncoder", # Building part of bigger (tested) model. "DecisionTransformerGPT2Model", # Building part of bigger (tested) model. "SegformerDecodeHead", # Building part of bigger (tested) model. - "PLBartEncoder", # Building part of bigger (tested) model. - "PLBartDecoder", # Building part of bigger (tested) model. - "PLBartDecoderWrapper", # Building part of bigger (tested) model. - "BigBirdPegasusEncoder", # Building part of bigger (tested) model. - "BigBirdPegasusDecoder", # Building part of bigger (tested) model. - "BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model. - "DetrEncoder", # Building part of bigger (tested) model. - "DetrDecoder", # Building part of bigger (tested) model. - "DetrDecoderWrapper", # Building part of bigger (tested) model. - "ConditionalDetrEncoder", # Building part of bigger (tested) model. - "ConditionalDetrDecoder", # Building part of bigger (tested) model. - "M2M100Encoder", # Building part of bigger (tested) model. - "M2M100Decoder", # Building part of bigger (tested) model. - "MCTCTEncoder", # Building part of bigger (tested) model. "MgpstrModel", # Building part of bigger (tested) model. - "Speech2TextEncoder", # Building part of bigger (tested) model. - "Speech2TextDecoder", # Building part of bigger (tested) model. - "LEDEncoder", # Building part of bigger (tested) model. - "LEDDecoder", # Building part of bigger (tested) model. - "BartDecoderWrapper", # Building part of bigger (tested) model. - "BartEncoder", # Building part of bigger (tested) model. "BertLMHeadModel", # Needs to be setup as decoder. - "BlenderbotSmallEncoder", # Building part of bigger (tested) model. - "BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model. - "BlenderbotEncoder", # Building part of bigger (tested) model. - "BlenderbotDecoderWrapper", # Building part of bigger (tested) model. - "MBartEncoder", # Building part of bigger (tested) model. - "MBartDecoderWrapper", # Building part of bigger (tested) model. "MegatronBertLMHeadModel", # Building part of bigger (tested) model. - "MegatronBertEncoder", # Building part of bigger (tested) model. - "MegatronBertDecoder", # Building part of bigger (tested) model. - "MegatronBertDecoderWrapper", # Building part of bigger (tested) model. - "MusicgenDecoder", # Building part of bigger (tested) model. - "MvpDecoderWrapper", # Building part of bigger (tested) model. - "MvpEncoder", # Building part of bigger (tested) model. - "PegasusEncoder", # Building part of bigger (tested) model. - "PegasusDecoderWrapper", # Building part of bigger (tested) model. - "PegasusXEncoder", # Building part of bigger (tested) model. - "PegasusXDecoder", # Building part of bigger (tested) model. - "PegasusXDecoderWrapper", # Building part of bigger (tested) model. - "DPREncoder", # Building part of bigger (tested) model. - "ProphetNetDecoderWrapper", # Building part of bigger (tested) model. "RealmBertModel", # Building part of bigger (tested) model. "RealmReader", # Not regular model. "RealmScorer", # Not regular model. "RealmForOpenQA", # Not regular model. "ReformerForMaskedLM", # Needs to be setup as decoder. - "Speech2Text2DecoderWrapper", # Building part of bigger (tested) model. - "TFDPREncoder", # Building part of bigger (tested) model. "TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?) "TFRobertaForMultipleChoice", # TODO: fix "TFRobertaPreLayerNormForMultipleChoice", # TODO: fix - "TrOCRDecoderWrapper", # Building part of bigger (tested) model. - "TFWhisperEncoder", # Building part of bigger (tested) model. - "TFWhisperDecoder", # Building part of bigger (tested) model. "SeparableConv1D", # Building part of bigger (tested) model. "FlaxBartForCausalLM", # Building part of bigger (tested) model. "FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM. @@ -155,18 +108,6 @@ "TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models "BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model. "BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model. - "SpeechT5Decoder", # Building part of bigger (tested) model. - "SpeechT5DecoderWithoutPrenet", # Building part of bigger (tested) model. - "SpeechT5DecoderWithSpeechPrenet", # Building part of bigger (tested) model. - "SpeechT5DecoderWithTextPrenet", # Building part of bigger (tested) model. - "SpeechT5Encoder", # Building part of bigger (tested) model. - "SpeechT5EncoderWithoutPrenet", # Building part of bigger (tested) model. - "SpeechT5EncoderWithSpeechPrenet", # Building part of bigger (tested) model. - "SpeechT5EncoderWithTextPrenet", # Building part of bigger (tested) model. - "SpeechT5SpeechDecoder", # Building part of bigger (tested) model. - "SpeechT5SpeechEncoder", # Building part of bigger (tested) model. - "SpeechT5TextDecoder", # Building part of bigger (tested) model. - "SpeechT5TextEncoder", # Building part of bigger (tested) model. "BarkCausalModel", # Building part of bigger (tested) model. "BarkModel", # Does not have a forward signature - generation tested with integration tests ] @@ -236,12 +177,6 @@ "AutoformerForPrediction", "JukeboxVQVAE", "JukeboxPrior", - "PegasusXEncoder", - "PegasusXDecoder", - "PegasusXDecoderWrapper", - "PegasusXEncoder", - "PegasusXDecoder", - "PegasusXDecoderWrapper", "SamModel", "DPTForDepthEstimation", "DecisionTransformerGPT2Model", @@ -250,17 +185,11 @@ "ViltForImageAndTextRetrieval", "ViltForTokenClassification", "ViltForMaskedLM", - "XGLMEncoder", - "XGLMDecoder", - "XGLMDecoderWrapper", "PerceiverForMultimodalAutoencoding", "PerceiverForOpticalFlow", "SegformerDecodeHead", "TFSegformerDecodeHead", "FlaxBeitForMaskedImageModeling", - "PLBartEncoder", - "PLBartDecoder", - "PLBartDecoderWrapper", "BeitForMaskedImageModeling", "ChineseCLIPTextModel", "ChineseCLIPVisionModel", @@ -347,7 +276,7 @@ ] # DO NOT edit this list! -# (The corresponding pytorch objects should never be in the main `__init__`, but it's too late to remove) +# (The corresponding pytorch objects should never have been in the main `__init__`, but it's too late to remove) OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [ "FlaxBertLayer", "FlaxBigBirdLayer", @@ -361,8 +290,7 @@ "TFViTMAELayer", ] -# Update this list for models that have multiple model types for the same -# model doc +# Update this list for models that have multiple model types for the same model doc. MODEL_TYPE_TO_DOC_MAPPING = OrderedDict( [ ("data2vec-text", "data2vec"), @@ -378,6 +306,10 @@ def check_missing_backends(): + """ + Checks if all backends are installed (otherwise the check of this script is incomplete). Will error in the CI if + that's not the case but only throw a warning for users running this. + """ missing_backends = [] if not is_torch_available(): missing_backends.append("PyTorch") @@ -402,7 +334,9 @@ def check_missing_backends(): def check_model_list(): - """Check the model list inside the transformers library.""" + """ + Checks the model listed as subfolders of `models` match the models available in `transformers.models`. + """ # Get the models from the directory structure of `src/transformers/models/` models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models") _models = [] @@ -413,7 +347,7 @@ def check_model_list(): if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir): _models.append(model) - # Get the models from the directory structure of `src/transformers/models/` + # Get the models in the submodule `transformers.models` models = [model for model in dir(transformers.models) if not model.startswith("__")] missing_models = sorted(set(_models).difference(models)) @@ -425,8 +359,8 @@ def check_model_list(): # If some modeling modules should be ignored for all checks, they should be added in the nested list # _ignore_modules of this function. -def get_model_modules(): - """Get the model modules inside the transformers library.""" +def get_model_modules() -> List[str]: + """Get all the model modules inside the transformers library (except deprecated models).""" _ignore_modules = [ "modeling_auto", "modeling_encoder_decoder", @@ -454,21 +388,32 @@ def get_model_modules(): ] modules = [] for model in dir(transformers.models): - if model == "deprecated": - continue # There are some magic dunder attributes in the dir, we ignore them - if not model.startswith("__"): - model_module = getattr(transformers.models, model) - for submodule in dir(model_module): - if submodule.startswith("modeling") and submodule not in _ignore_modules: - modeling_module = getattr(model_module, submodule) - if inspect.ismodule(modeling_module): - modules.append(modeling_module) + if model == "deprecated" or model.startswith("__"): + continue + + model_module = getattr(transformers.models, model) + for submodule in dir(model_module): + if submodule.startswith("modeling") and submodule not in _ignore_modules: + modeling_module = getattr(model_module, submodule) + if inspect.ismodule(modeling_module): + modules.append(modeling_module) return modules -def get_models(module, include_pretrained=False): - """Get the objects in module that are models.""" +def get_models(module: types.ModuleType, include_pretrained: bool = False) -> List[Tuple[str, type]]: + """ + Get the objects in a module that are models. + + Args: + module (`types.ModuleType`): + The module from which we are extracting models. + include_pretrained (`bool`, *optional*, defaults to `False`): + Whether or not to include the `PreTrainedModel` subclass (like `BertPreTrainedModel`) or not. + + Returns: + List[Tuple[str, type]]: List of models as tuples (class name, actual class). + """ models = [] model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel) for attr_name in dir(module): @@ -480,12 +425,10 @@ def get_models(module, include_pretrained=False): return models -def is_a_private_model(model): - """Returns True if the model should not be in the main init.""" - if model in PRIVATE_MODELS: - return True - - # Wrapper, Encoder and Decoder are all privates +def is_building_block(model: str) -> bool: + """ + Returns `True` if a model is a building block part of a bigger model. + """ if model.endswith("Wrapper"): return True if model.endswith("Encoder"): @@ -494,7 +437,13 @@ def is_a_private_model(model): return True if model.endswith("Prenet"): return True - return False + + +def is_a_private_model(model: str) -> bool: + """Returns `True` if the model should not be in the main init.""" + if model in PRIVATE_MODELS: + return True + return is_building_block(model) def check_models_are_in_init(): @@ -514,11 +463,14 @@ def check_models_are_in_init(): # If some test_modeling files should be ignored when checking models are all tested, they should be added in the # nested list _ignore_files of this function. -def get_model_test_files(): - """Get the model test files. +def get_model_test_files() -> List[str]: + """ + Get the model test files. - The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be - considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files. + Returns: + `List[str]`: The list of test files. The returned files will NOT contain the `tests` (i.e. `PATH_TO_TESTS` + defined in this script). They will be considered as paths relative to `tests`. A caller has to use + `os.path.join(PATH_TO_TESTS, ...)` to access the files. """ _ignore_files = [ @@ -531,7 +483,6 @@ def get_model_test_files(): "test_modeling_tf_encoder_decoder", ] test_files = [] - # Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models` model_test_root = os.path.join(PATH_TO_TESTS, "models") model_test_dirs = [] for x in os.listdir(model_test_root): @@ -553,9 +504,17 @@ def get_model_test_files(): # This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class # for the all_model_classes variable. -def find_tested_models(test_file): - """Parse the content of test_file to detect what's in all_model_classes""" - # This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class +def find_tested_models(test_file: str) -> List[str]: + """ + Parse the content of test_file to detect what's in `all_model_classes`. This detects the models that inherit from + the common test class. + + Args: + test_file (`str`): The path to the test file to check + + Returns: + `List[str]`: The list of models tested in that file. + """ with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f: content = f.read() all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content) @@ -571,8 +530,25 @@ def find_tested_models(test_file): return model_tested -def check_models_are_tested(module, test_file): - """Check models defined in module are tested in test_file.""" +def should_be_tested(model_name: str) -> bool: + """ + Whether or not a model should be tested. + """ + if model_name in IGNORE_NON_TESTED: + return False + return not is_building_block(model_name) + + +def check_models_are_tested(module: types.ModuleType, test_file: str) -> List[str]: + """Check models defined in a module are all tested in a given file. + + Args: + module (`types.ModuleType`): The module in which we get the models. + test_file (`str`): The path to the file where the module is tested. + + Returns: + `List[str]`: The list of error messages corresponding to models not tested. + """ # XxxPreTrainedModel are not tested defined_models = get_models(module) tested_models = find_tested_models(test_file) @@ -586,7 +562,7 @@ def check_models_are_tested(module, test_file): ] failures = [] for model_name, _ in defined_models: - if model_name not in tested_models and model_name not in IGNORE_NON_TESTED: + if model_name not in tested_models and should_be_tested(model_name): failures.append( f"{model_name} is defined in {module.__name__} but is not tested in " + f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file." @@ -602,6 +578,7 @@ def check_all_models_are_tested(): test_files = get_model_test_files() failures = [] for module in modules: + # Matches a module to its test file. test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file] if len(test_file) == 0: failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.") @@ -616,7 +593,7 @@ def check_all_models_are_tested(): raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) -def get_all_auto_configured_models(): +def get_all_auto_configured_models() -> List[str]: """Return the list of all models in at least one auto class.""" result = set() # To avoid duplicates we concatenate all model classes in a set. if is_torch_available(): @@ -634,8 +611,8 @@ def get_all_auto_configured_models(): return list(result) -def ignore_unautoclassed(model_name): - """Rules to determine if `name` should be in an auto class.""" +def ignore_unautoclassed(model_name: str) -> bool: + """Rules to determine if a model should be in an auto class.""" # Special white list if model_name in IGNORE_NON_AUTO_CONFIGURED: return True @@ -645,8 +622,19 @@ def ignore_unautoclassed(model_name): return False -def check_models_are_auto_configured(module, all_auto_models): - """Check models defined in module are each in an auto class.""" +def check_models_are_auto_configured(module: types.ModuleType, all_auto_models: List[str]) -> List[str]: + """ + Check models defined in module are each in an auto class. + + Args: + module (`types.ModuleType`): + The module in which we get the models. + all_auto_models (`List[str]`): + The list of all models in an auto class (as obtained with `get_all_auto_configured_models()`). + + Returns: + `List[str]`: The list of error messages corresponding to models not tested. + """ defined_models = get_models(module) failures = [] for model_name, _ in defined_models: @@ -661,6 +649,7 @@ def check_models_are_auto_configured(module, all_auto_models): def check_all_models_are_auto_configured(): """Check all models are each in an auto class.""" + # This is where we need to check we have all backends or the check is incomplete. check_missing_backends() modules = get_model_modules() all_auto_models = get_all_auto_configured_models() @@ -675,6 +664,7 @@ def check_all_models_are_auto_configured(): def check_all_auto_object_names_being_defined(): """Check all names defined in auto (name) mappings exist in the library.""" + # This is where we need to check we have all backends or the check is incomplete. check_missing_backends() failures = [] @@ -695,7 +685,7 @@ def check_all_auto_object_names_being_defined(): mappings_to_check.update({name: getattr(module, name) for name in mapping_names}) for name, mapping in mappings_to_check.items(): - for model_type, class_names in mapping.items(): + for _, class_names in mapping.items(): if not isinstance(class_names, tuple): class_names = (class_names,) for class_name in class_names: @@ -716,6 +706,7 @@ def check_all_auto_object_names_being_defined(): def check_all_auto_mapping_names_in_config_mapping_names(): """Check all keys defined in auto mappings (mappings of names) appear in `CONFIG_MAPPING_NAMES`.""" + # This is where we need to check we have all backends or the check is incomplete. check_missing_backends() failures = [] @@ -736,7 +727,7 @@ def check_all_auto_mapping_names_in_config_mapping_names(): mappings_to_check.update({name: getattr(module, name) for name in mapping_names}) for name, mapping in mappings_to_check.items(): - for model_type, class_names in mapping.items(): + for model_type in mapping: if model_type not in CONFIG_MAPPING_NAMES: failures.append( f"`{model_type}` appears in the mapping `{name}` but it is not defined in the keys of " @@ -747,7 +738,8 @@ def check_all_auto_mapping_names_in_config_mapping_names(): def check_all_auto_mappings_importable(): - """Check all auto mappings could be imported.""" + """Check all auto mappings can be imported.""" + # This is where we need to check we have all backends or the check is incomplete. check_missing_backends() failures = [] @@ -761,7 +753,7 @@ def check_all_auto_mappings_importable(): mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")] mappings_to_check.update({name: getattr(module, name) for name in mapping_names}) - for name, _ in mappings_to_check.items(): + for name in mappings_to_check: name = name.replace("_MAPPING_NAMES", "_MAPPING") if not hasattr(transformers, name): failures.append(f"`{name}`") @@ -770,44 +762,46 @@ def check_all_auto_mappings_importable(): def check_objects_being_equally_in_main_init(): - """Check if an object is in the main __init__ if its counterpart in PyTorch is.""" + """ + Check if a (TensorFlow or Flax) object is in the main __init__ iif its counterpart in PyTorch is. + """ attrs = dir(transformers) failures = [] for attr in attrs: obj = getattr(transformers, attr) - if hasattr(obj, "__module__"): - module_path = obj.__module__ - if "models.deprecated" in module_path: - continue - module_name = module_path.split(".")[-1] - module_dir = ".".join(module_path.split(".")[:-1]) - if ( - module_name.startswith("modeling_") - and not module_name.startswith("modeling_tf_") - and not module_name.startswith("modeling_flax_") - ): - parent_module = sys.modules[module_dir] - - frameworks = [] - if is_tf_available(): - frameworks.append("TF") - if is_flax_available(): - frameworks.append("Flax") - - for framework in frameworks: - other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_") - if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"): - other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_") - other_module = getattr(parent_module, other_module_name) - if hasattr(other_module, f"{framework}{attr}"): - if not hasattr(transformers, f"{framework}{attr}"): - if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK: - failures.append(f"{framework}{attr}") - if hasattr(other_module, f"{framework}_{attr}"): - if not hasattr(transformers, f"{framework}_{attr}"): - if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK: - failures.append(f"{framework}_{attr}") + if not hasattr(obj, "__module__") or "models.deprecated" in obj.__module__: + continue + + module_path = obj.__module__ + module_name = module_path.split(".")[-1] + module_dir = ".".join(module_path.split(".")[:-1]) + if ( + module_name.startswith("modeling_") + and not module_name.startswith("modeling_tf_") + and not module_name.startswith("modeling_flax_") + ): + parent_module = sys.modules[module_dir] + + frameworks = [] + if is_tf_available(): + frameworks.append("TF") + if is_flax_available(): + frameworks.append("Flax") + + for framework in frameworks: + other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_") + if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"): + other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_") + other_module = getattr(parent_module, other_module_name) + if hasattr(other_module, f"{framework}{attr}"): + if not hasattr(transformers, f"{framework}{attr}"): + if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK: + failures.append(f"{framework}{attr}") + if hasattr(other_module, f"{framework}_{attr}"): + if not hasattr(transformers, f"{framework}_{attr}"): + if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK: + failures.append(f"{framework}_{attr}") if len(failures) > 0: raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures)) @@ -815,8 +809,16 @@ def check_objects_being_equally_in_main_init(): _re_decorator = re.compile(r"^\s*@(\S+)\s+$") -def check_decorator_order(filename): - """Check that in the test file `filename` the slow decorator is always last.""" +def check_decorator_order(filename: str) -> List[int]: + """ + Check that in a given test file, the slow decorator is always last. + + Args: + filename (`str`): The path to a test file to check. + + Returns: + `List[int]`: The list of failures as a list of indices where there are problems. + """ with open(filename, "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() decorator_before = None @@ -849,8 +851,13 @@ def check_all_decorator_order(): ) -def find_all_documented_objects(): - """Parse the content of all doc files to detect which classes and functions it documents""" +def find_all_documented_objects() -> List[str]: + """ + Parse the content of all doc files to detect which classes and functions it documents. + + Returns: + `List[str]`: The list of all object names being documented. + """ documented_obj = [] for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"): with open(doc_file, "r", encoding="utf-8", newline="\n") as f: @@ -959,8 +966,8 @@ def find_all_documented_objects(): ] -def ignore_undocumented(name): - """Rules to determine if `name` should be undocumented.""" +def ignore_undocumented(name: str) -> bool: + """Rules to determine if `name` should be undocumented (returns `True` if it should not be documented).""" # NOT DOCUMENTED ON PURPOSE. # Constants uppercase are not documented. if name.isupper(): @@ -1047,7 +1054,7 @@ def check_model_type_doc_match(): _re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE) -def is_rst_docstring(docstring): +def is_rst_docstring(docstring: str) -> True: """ Returns `True` if `docstring` is written in rst. """ @@ -1061,7 +1068,7 @@ def is_rst_docstring(docstring): def check_docstrings_are_in_md(): - """Check all docstrings are in md""" + """Check all docstrings are written in md and nor rst.""" files_with_rst = [] for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"): with open(file, encoding="utf-8") as f: @@ -1084,6 +1091,9 @@ def check_docstrings_are_in_md(): def check_deprecated_constant_is_up_to_date(): + """ + Check if the constant `DEPRECATED_MODELS` in `models/auto/configuration_auto.py` is up to date. + """ deprecated_folder = os.path.join(PATH_TO_TRANSFORMERS, "models", "deprecated") deprecated_models = [m for m in os.listdir(deprecated_folder) if not m.startswith("_")]