Skip to content

Commit

Permalink
Strictly type everything in .github and tools (pytorch#59117)
Browse files Browse the repository at this point in the history
Summary:
This PR greatly simplifies `mypy-strict.ini` by strictly typing everything in `.github` and `tools`, rather than picking and choosing only specific files in those two dirs. It also removes `warn_unused_ignores` from `mypy-strict.ini`, for reasons described in pytorch#56402 (comment): basically, that setting makes life more difficult depending on what libraries you have installed locally vs in CI (e.g. `ruamel`).

Pull Request resolved: pytorch#59117

Test Plan:
```
flake8
mypy --config mypy-strict.ini
```

Reviewed By: malfet

Differential Revision: D28765386

Pulled By: samestep

fbshipit-source-id: 3e744e301c7a464f8a2a2428fcdbad534e231f2e
  • Loading branch information
samestep authored and facebook-github-bot committed Jun 7, 2021
1 parent 6ff001c commit 737d920
Show file tree
Hide file tree
Showing 43 changed files with 463 additions and 312 deletions.
6 changes: 3 additions & 3 deletions .github/scripts/ensure_actions_will_cancel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
WORKFLOWS = REPO_ROOT / ".github" / "workflows"


def concurrency_key(filename):
def concurrency_key(filename: Path) -> str:
workflow_name = filename.with_suffix("").name.replace("_", "-")
return f"{workflow_name}-${{{{ github.event.pull_request.number || github.sha }}}}"


def should_check(filename):
def should_check(filename: Path) -> bool:
with open(filename, "r") as f:
content = f.read()

Expand All @@ -31,7 +31,7 @@ def should_check(filename):
)
args = parser.parse_args()

files = WORKFLOWS.glob("*.yml")
files = list(WORKFLOWS.glob("*.yml"))

errors_found = False
files = [f for f in files if should_check(f)]
Expand Down
21 changes: 13 additions & 8 deletions .github/scripts/generate_pytorch_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
class NoGitTagException(Exception):
pass

def get_pytorch_root():
def get_pytorch_root() -> Path:
return Path(subprocess.check_output(
['git', 'rev-parse', '--show-toplevel']
).decode('ascii').strip())

def get_tag():
def get_tag() -> str:
root = get_pytorch_root()
# We're on a tag
am_on_tag = (
Expand All @@ -46,37 +46,42 @@ def get_tag():
tag = re.sub(TRAILING_RC_PATTERN, "", tag)
return tag

def get_base_version():
def get_base_version() -> str:
root = get_pytorch_root()
dirty_version = open(root / 'version.txt', 'r').read().strip()
# Strips trailing a0 from version.txt, not too sure why it's there in the
# first place
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)

class PytorchVersion:
def __init__(self, gpu_arch_type, gpu_arch_version, no_build_suffix):
def __init__(
self,
gpu_arch_type: str,
gpu_arch_version: str,
no_build_suffix: bool,
) -> None:
self.gpu_arch_type = gpu_arch_type
self.gpu_arch_version = gpu_arch_version
self.no_build_suffix = no_build_suffix

def get_post_build_suffix(self):
def get_post_build_suffix(self) -> str:
if self.gpu_arch_type == "cuda":
return f"+cu{self.gpu_arch_version.replace('.', '')}"
return f"+{self.gpu_arch_type}{self.gpu_arch_version}"

def get_release_version(self):
def get_release_version(self) -> str:
if not get_tag():
raise NoGitTagException(
"Not on a git tag, are you sure you want a release version?"
)
return f"{get_tag()}{self.get_post_build_suffix()}"

def get_nightly_version(self):
def get_nightly_version(self) -> str:
date_str = datetime.today().strftime('%Y%m%d')
build_suffix = self.get_post_build_suffix()
return f"{get_base_version()}.dev{date_str}{build_suffix}"

def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate pytorch version for binary builds"
)
Expand Down
6 changes: 3 additions & 3 deletions .github/scripts/lint_native_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@
the YAML, not to be prescriptive about it.
'''

import ruamel.yaml
import ruamel.yaml # type: ignore[import]
import difflib
import sys
from pathlib import Path
from io import StringIO

def fn(base):
def fn(base: str) -> str:
return str(base / Path("aten/src/ATen/native/native_functions.yaml"))

with open(Path(__file__).parent.parent.parent / fn('.'), "r") as f:
contents = f.read()

yaml = ruamel.yaml.YAML()
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
yaml.preserve_quotes = True
yaml.width = 1000
yaml.boolean_representation = ['False', 'True']
Expand Down
6 changes: 3 additions & 3 deletions .github/scripts/run_torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
timeout: 720
tests:"""

def gen_abtest_config(control: str, treatment: str, models: List[str]):
def gen_abtest_config(control: str, treatment: str, models: List[str]) -> str:
d = {}
d["control"] = control
d["treatment"] = treatment
Expand All @@ -43,7 +43,7 @@ def gen_abtest_config(control: str, treatment: str, models: List[str]):
config = config + "\n"
return config

def deploy_torchbench_config(output_dir: str, config: str):
def deploy_torchbench_config(output_dir: str, config: str) -> None:
# Create test dir if needed
pathlib.Path(output_dir).mkdir(exist_ok=True)
# TorchBench config file name
Expand Down Expand Up @@ -71,7 +71,7 @@ def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> List[str]:
return []
return model_list

def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str):
def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str) -> None:
# Copy system environment so that we will not override
env = dict(os.environ)
command = ["python", "bisection.py", "--work-dir", output_dir,
Expand Down
2 changes: 1 addition & 1 deletion caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TOOLS_PATH}/autograd/templates/python_linalg_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_special_functions.cpp"
"${TOOLS_PATH}/autograd/templates/variable_factories.h"
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py"
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py.in"
"${TOOLS_PATH}/autograd/deprecated.yaml"
"${TOOLS_PATH}/autograd/derivatives.yaml"
"${TOOLS_PATH}/autograd/gen_autograd_functions.py"
Expand Down
28 changes: 2 additions & 26 deletions mypy-strict.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
# this config file to be used to ENFORCE that people are using mypy on codegen
# files.

# For now, only code_template.py and benchmark utils Timer are covered this way

[mypy]
python_version = 3.6
plugins = mypy_plugins/check_mypy_version.py
Expand All @@ -30,36 +28,14 @@ check_untyped_defs = True
disallow_untyped_decorators = True
no_implicit_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_return_any = True
implicit_reexport = False
strict_equality = True

files =
.github/scripts/generate_binary_build_matrix.py,
.github/scripts/generate_ci_workflows.py,
.github/scripts/parse_ref.py,
.github,
benchmarks/instruction_counts,
tools/actions_local_runner.py,
tools/autograd/*.py,
tools/clang_tidy.py,
tools/codegen,
tools/explicit_ci_jobs.py,
tools/extract_scripts.py,
tools/mypy_wrapper.py,
tools/print_test_stats.py,
tools/pyi,
tools/stats_utils,
tools/test_history.py,
tools/test/test_actions_local_runner.py,
tools/test/test_extract_scripts.py,
tools/test/test_mypy_wrapper.py,
tools/test/test_test_history.py,
tools/test/test_trailing_newlines.py,
tools/test/test_translate_annotations.py,
tools/trailing_newlines.py,
tools/translate_annotations.py,
tools/vscode_settings.py,
tools,
torch/testing/_internal/framework_utils.py,
torch/utils/_pytree.py,
torch/utils/benchmark/utils/common.py,
Expand Down
4 changes: 2 additions & 2 deletions tools/amd_build/build_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
'torch',
'utils')))

from hipify import hipify_python
from hipify import hipify_python # type: ignore[import]

parser = argparse.ArgumentParser(description='Top-level script for HIPifying, filling in most common parameters')
parser.add_argument(
Expand Down Expand Up @@ -115,7 +115,7 @@
]

# Check if the compiler is hip-clang.
def is_hip_clang():
def is_hip_clang() -> bool:
try:
hip_path = os.getenv('HIP_PATH', '/opt/rocm/hip')
return 'HIP_COMPILER=clang' in open(hip_path + '/lib/.hipInfo').read()
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_annotated_fn_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:

template_path = os.path.join(autograd_dir, 'templates')
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py', lambda: {
fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py.in', lambda: {
'annotated_args': textwrap.indent('\n'.join(annotated_args), ' '),
})

Expand Down
20 changes: 14 additions & 6 deletions tools/build_pytorch_libs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import os
from glob import glob
import shutil
from typing import Dict, Optional

from .setup_helpers.env import IS_64BIT, IS_WINDOWS, check_negative_env_flag
from .setup_helpers.cmake import USE_NINJA
from .setup_helpers.cmake import USE_NINJA, CMake

from setuptools import distutils
from setuptools import distutils # type: ignore[import]

def _overlay_windows_vcvars(env):
def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
vc_arch = 'x64' if IS_64BIT else 'x86'
vc_env = distutils._msvccompiler._get_vc_env(vc_arch)
vc_env: Dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
# Keys in `_get_vc_env` are always lowercase.
# We turn them into uppercase before overlaying vcvars
# because OS environ keys are always uppercase on Windows.
Expand All @@ -22,7 +23,7 @@ def _overlay_windows_vcvars(env):
return vc_env


def _create_build_env():
def _create_build_env() -> Dict[str, str]:
# XXX - our cmake file sometimes looks at the system environment
# and not cmake flags!
# you should NEVER add something to this list. It is bad practice to
Expand All @@ -44,7 +45,14 @@ def _create_build_env():
return my_env


def build_caffe2(version, cmake_python_library, build_python, rerun_cmake, cmake_only, cmake):
def build_caffe2(
version: Optional[str],
cmake_python_library: Optional[str],
build_python: bool,
rerun_cmake: bool,
cmake_only: bool,
cmake: CMake,
) -> None:
my_env = _create_build_env()
build_test = not check_negative_env_flag('BUILD_TEST')
cmake.generate(version,
Expand Down
28 changes: 21 additions & 7 deletions tools/clang_format_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import re
import os
import sys
from clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH
from typing import List, Set

from .clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH

# Allowlist of directories to check. All files that in that directory
# (recursively) will be checked.
Expand All @@ -28,7 +30,7 @@
CPP_FILE_REGEX = re.compile(".*\\.(h|cpp|cc|c|hpp)$")


def get_allowlisted_files():
def get_allowlisted_files() -> Set[str]:
"""
Parse CLANG_FORMAT_ALLOWLIST and resolve all directories.
Returns the set of allowlist cpp source files.
Expand All @@ -42,7 +44,11 @@ def get_allowlisted_files():
return set(matches)


async def run_clang_format_on_file(filename, semaphore, verbose=False):
async def run_clang_format_on_file(
filename: str,
semaphore: asyncio.Semaphore,
verbose: bool = False,
) -> None:
"""
Run clang-format on the provided file.
"""
Expand All @@ -55,7 +61,11 @@ async def run_clang_format_on_file(filename, semaphore, verbose=False):
print("Formatted {}".format(filename))


async def file_clang_formatted_correctly(filename, semaphore, verbose=False):
async def file_clang_formatted_correctly(
filename: str,
semaphore: asyncio.Semaphore,
verbose: bool = False,
) -> bool:
"""
Checks if a file is formatted correctly and returns True if so.
"""
Expand All @@ -80,7 +90,11 @@ async def file_clang_formatted_correctly(filename, semaphore, verbose=False):
return ok


async def run_clang_format(max_processes, diff=False, verbose=False):
async def run_clang_format(
max_processes: int,
diff: bool = False,
verbose: bool = False,
) -> bool:
"""
Run clang-format to all files in CLANG_FORMAT_ALLOWLIST that match CPP_FILE_REGEX.
"""
Expand Down Expand Up @@ -114,7 +128,7 @@ async def run_clang_format(max_processes, diff=False, verbose=False):

return ok

def parse_args(args):
def parse_args(args: List[str]) -> argparse.Namespace:
"""
Parse and return command-line arguments.
"""
Expand All @@ -134,7 +148,7 @@ def parse_args(args):
return parser.parse_args(args)


def main(args):
def main(args: List[str]) -> bool:
# Parse arguments.
options = parse_args(args)
# Get clang-format and make sure it is the right binary and it is in the right place.
Expand Down
6 changes: 3 additions & 3 deletions tools/clang_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def compute_file_sha256(path: str) -> str:
return hash.hexdigest()


def report_download_progress(chunk_number, chunk_size, file_size):
def report_download_progress(chunk_number: int, chunk_size: int, file_size: int) -> None:
"""
Pretty printer for file download progress.
"""
Expand All @@ -55,7 +55,7 @@ def report_download_progress(chunk_number, chunk_size, file_size):
sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100)))


def download_clang_format(path):
def download_clang_format(path: str) -> bool:
"""
Downloads a clang-format binary appropriate for the host platform and stores it at the given location.
"""
Expand All @@ -81,7 +81,7 @@ def download_clang_format(path):
return True


def get_and_check_clang_format(verbose=False):
def get_and_check_clang_format(verbose: bool = False) -> bool:
"""
Download a platform-appropriate clang-format binary if one doesn't already exist at the expected location and verify
that it is the right binary by checking its SHA256 hash against the expected hash.
Expand Down
Loading

0 comments on commit 737d920

Please sign in to comment.