Skip to content

Commit

Permalink
Doc checks (#25408)
Browse files Browse the repository at this point in the history
* Document check_dummies

* Type hints and doc in other files

* Document check inits

* Add documentation to

* Address review comments
  • Loading branch information
sgugger authored Aug 10, 2023
1 parent b14d464 commit 16edf4d
Show file tree
Hide file tree
Showing 6 changed files with 462 additions and 227 deletions.
139 changes: 111 additions & 28 deletions utils/check_copies.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import glob
import os
import re
from typing import List, Optional, Tuple

import black
from doc_builder.style_doc import style_docstrings_in_code
Expand Down Expand Up @@ -125,14 +126,22 @@
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)


def _should_continue(line, indent):
def _should_continue(line: str, indent: str) -> bool:
# Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
# function definition
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None


def find_code_in_transformers(object_name):
"""Find and return the code source code of `object_name`."""
def find_code_in_transformers(object_name: str) -> str:
"""
Find and return the source code of an object.
Args:
object_name (`str`): The name of the object we want the source code of.
Returns:
`str`: The source code of the object.
"""
parts = object_name.split(".")
i = 0

Expand Down Expand Up @@ -181,7 +190,16 @@ def find_code_in_transformers(object_name):
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")


def get_indent(code):
def get_indent(code: str) -> str:
"""
Find the indent in the first non empty line in a code sample.
Args:
code (`str`): The code to inspect.
Returns:
`str`: The indent looked at (as string).
"""
lines = code.split("\n")
idx = 0
while idx < len(lines) and len(lines[idx]) == 0:
Expand All @@ -191,9 +209,15 @@ def get_indent(code):
return ""


def blackify(code):
def blackify(code: str) -> str:
"""
Applies the black part of our `make style` command to `code`.
Applies the black part of our `make style` command to some code.
Args:
code (`str`): The code to format.
Returns:
`str`: The formatted code.
"""
has_indent = len(get_indent(code)) > 0
if has_indent:
Expand All @@ -204,14 +228,22 @@ def blackify(code):
return result[len("class Bla:\n") :] if has_indent else result


def check_codes_match(observed_code, theoretical_code):
def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int]:
"""
Checks if the code in `observed_code` and `theoretical_code` match with the exception of the class/function name.
Returns the index of the first line where there is a difference (if any) and `None` if the codes match.
Checks if two version of a code match with the exception of the class/function name.
Args:
observed_code (`str`): The code found.
theoretical_code (`str`): The code to match.
Returns:
`Optional[int]`: The index of the first line where there is a difference (if any) and `None` if the codes
match.
"""
observed_code_header = observed_code.split("\n")[0]
theoretical_code_header = theoretical_code.split("\n")[0]

# Catch the function/class name: it is expected that those do not match.
_re_class_match = re.compile(r"class\s+([^\(:]+)(?:\(|:)")
_re_func_match = re.compile(r"def\s+([^\(]+)\(")
for re_pattern in [_re_class_match, _re_func_match]:
Expand All @@ -220,6 +252,7 @@ def check_codes_match(observed_code, theoretical_code):
theoretical_name = re_pattern.search(theoretical_code_header).groups()[0]
theoretical_code_header = theoretical_code_header.replace(theoretical_name, observed_obj_name)

# Find the first diff. Line 0 is special since we need to compare with the function/class names ignored.
diff_index = 0
if theoretical_code_header != observed_code_header:
return 0
Expand All @@ -231,11 +264,19 @@ def check_codes_match(observed_code, theoretical_code):
diff_index += 1


def is_copy_consistent(filename, overwrite=False):
def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[Tuple[str, int]]]:
"""
Check if the code commented as a copy in `filename` matches the original.
Check if the code commented as a copy in a file matches the original.
Return the differences or overwrites the content depending on `overwrite`.
Args:
filename (`str`):
The name of the file to check.
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to overwrite the copies when they don't match.
Returns:
`Optional[List[Tuple[str, int]]]`: If `overwrite=False`, returns the list of differences as tuples `(str, int)`
with the name of the object having a diff and the line number where theere is the first diff.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
Expand Down Expand Up @@ -308,8 +349,12 @@ def is_copy_consistent(filename, overwrite=False):

def check_copies(overwrite: bool = False):
"""
Check every file is copy-consistent with the original and maybe `overwrite` content when it is not. Also check the
model list in the main README and other READMEs/index.md are consistent.
Check every file is copy-consistent with the original. Also check the model list in the main README and other
READMEs/index.md are consistent.
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to overwrite the copies when they don't match.
"""
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
diffs = []
Expand All @@ -328,8 +373,11 @@ def check_copies(overwrite: bool = False):

def check_full_copies(overwrite: bool = False):
"""
Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent and maybe
`overwrite` to fix issues.
Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent.
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to overwrite the copies when they don't match.
"""
diffs = []
for target, source in FULL_COPIES.items():
Expand All @@ -354,8 +402,18 @@ def check_full_copies(overwrite: bool = False):
)


def get_model_list(filename, start_prompt, end_prompt):
"""Extracts the model list from a README, between `start_prompt` and `end_prompt`."""
def get_model_list(filename: str, start_prompt: str, end_prompt: str) -> str:
"""
Extracts the model list from a README.
Args:
filename (`str`): The name of the README file to check.
start_prompt (`str`): The string to look for that introduces the model list.
end_prompt (`str`): The string to look for that ends the model list.
Returns:
`str`: The model list.
"""
with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start of the list.
Expand All @@ -368,6 +426,7 @@ def get_model_list(filename, start_prompt, end_prompt):
current_line = ""
end_index = start_index

# Keep going until the end of the list.
while not lines[end_index].startswith(end_prompt):
if lines[end_index].startswith("1."):
if len(current_line) > 1:
Expand All @@ -382,7 +441,7 @@ def get_model_list(filename, start_prompt, end_prompt):
return "".join(result)


def convert_to_localized_md(model_list, localized_model_list, format_str):
def convert_to_localized_md(model_list: str, localized_model_list: str, format_str: str) -> Tuple[bool, str]:
"""
Compare the model list from the main README to the one in a localized README.
Expand Down Expand Up @@ -458,19 +517,33 @@ def _rep(match):
return readmes_match, "\n".join((x[1] for x in sorted_index)) + "\n"


def convert_readme_to_index(model_list):
def convert_readme_to_index(model_list: str) -> str:
"""
Converts the model list of the README to the index.md format.
Converts the model list of the README to the index.md format (adapting links to the doc to relative links).
Args:
model_list (`str`): The model list of the main README.
Returns:
`str`: The model list in the format for the index.
"""
# We need to replce both link to the main doc and stable doc (the order of the next two instructions is important).
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
return model_list.replace("https://huggingface.co/docs/transformers/", "")


def _find_text_in_file(filename, start_prompt, end_prompt):
def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> Tuple[str, int, int, List[str]]:
"""
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
lines.
Find the text in a file between two prompts.
Args:
filename (`str`): The name of the file to look into.
start_prompt (`str`): The string to look for that introduces the content looked for.
end_prompt (`str`): The string to look for that ends the content looked for.
Returns:
Tuple[str, int, int, List[str]]: The content between the two prompts, the index of the start line in the
original file, the index of the end line in the original file and the list of lines of that file.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
Expand All @@ -493,9 +566,13 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
return "".join(lines[start_index:end_index]), start_index, end_index, lines


def check_model_list_copy(overwrite=False, max_per_line=119):
def check_model_list_copy(overwrite: bool = False):
"""
Check the model lists in the README is consistent with the ones in the other READMES and also with `index.nmd`.
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to overwrite the copies when they don't match.
"""
# Fix potential doc links in the README
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
Expand Down Expand Up @@ -526,6 +603,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
end_prompt=LOCALIZED_READMES["README.md"]["end_prompt"],
)

# Buld the converted Markdown.
converted_md_lists = []
for filename, value in LOCALIZED_READMES.items():
_start_prompt = value["start_prompt"]
Expand All @@ -537,6 +615,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):

converted_md_lists.append((filename, readmes_match, converted_md_list, _start_prompt, _end_prompt))

# Build the converted index and compare it.
converted_md_list = convert_readme_to_index(md_list)
if converted_md_list != index_list:
if overwrite:
Expand All @@ -548,6 +627,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
"`make fix-copies` to fix this."
)

# Compare the converted Markdowns
for converted_md_list in converted_md_lists:
filename, readmes_match, converted_md, _start_prompt, _end_prompt = converted_md_list

Expand Down Expand Up @@ -606,10 +686,13 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
)


def check_readme(overwrite=False):
def check_readme(overwrite: bool = False):
"""
Check if the main README contains all the models in the library or not. If `overwrite`, will add an entry for the
missing models using `README_TEMPLATE`.
Check if the main README contains all the models in the library or not.
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to add an entry for the missing models using `README_TEMPLATE`.
"""
info = LOCALIZED_READMES["README.md"]
models, start_index, end_index, lines = _find_text_in_file(
Expand Down
5 changes: 3 additions & 2 deletions utils/check_doc_toc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@

import argparse
from collections import defaultdict
from typing import List

import yaml


PATH_TO_TOC = "docs/source/en/_toctree.yml"


def clean_model_doc_toc(model_doc):
def clean_model_doc_toc(model_doc: List[dict]) -> List[dict]:
"""
Cleans a section of the table of content of the model documentation (one specific modality) by removing duplicates
and sorting models alphabetically.
Expand Down Expand Up @@ -77,7 +78,7 @@ def clean_model_doc_toc(model_doc):
return sorted(new_doc, key=lambda s: s["title"].lower())


def check_model_doc(overwrite=False):
def check_model_doc(overwrite: bool = False):
"""
Check that the content of the table of content in `_toctree.yml` is clean (no duplicates and sorted for the model
API doc) and potentially auto-cleans it.
Expand Down
11 changes: 10 additions & 1 deletion utils/check_doctest_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,16 @@
DOCTEST_FILE_PATHS = ["documentation_tests.txt", "slow_documentation_tests.txt"]


def clean_doctest_list(doctest_file, overwrite=False):
def clean_doctest_list(doctest_file: str, overwrite: bool = False):
"""
Cleans the doctest in a given file.
Args:
doctest_file (`str`):
The path to the doctest file to check or clean.
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to fix problems. If `False`, will error when the file is not clean.
"""
non_existent_paths = []
all_paths = []
with open(doctest_file, "r", encoding="utf-8") as f:
Expand Down
Loading

0 comments on commit 16edf4d

Please sign in to comment.