diff --git a/.github/scripts/ensure_actions_will_cancel.py b/.github/scripts/ensure_actions_will_cancel.py index 14e90e2552061f..e161239a8c0e2c 100755 --- a/.github/scripts/ensure_actions_will_cancel.py +++ b/.github/scripts/ensure_actions_will_cancel.py @@ -11,12 +11,12 @@ WORKFLOWS = REPO_ROOT / ".github" / "workflows" -def concurrency_key(filename): +def concurrency_key(filename: Path) -> str: workflow_name = filename.with_suffix("").name.replace("_", "-") return f"{workflow_name}-${{{{ github.event.pull_request.number || github.sha }}}}" -def should_check(filename): +def should_check(filename: Path) -> bool: with open(filename, "r") as f: content = f.read() @@ -31,7 +31,7 @@ def should_check(filename): ) args = parser.parse_args() - files = WORKFLOWS.glob("*.yml") + files = list(WORKFLOWS.glob("*.yml")) errors_found = False files = [f for f in files if should_check(f)] diff --git a/.github/scripts/generate_pytorch_version.py b/.github/scripts/generate_pytorch_version.py index 42f26737184788..44ed5e5ba08d3a 100755 --- a/.github/scripts/generate_pytorch_version.py +++ b/.github/scripts/generate_pytorch_version.py @@ -16,12 +16,12 @@ class NoGitTagException(Exception): pass -def get_pytorch_root(): +def get_pytorch_root() -> Path: return Path(subprocess.check_output( ['git', 'rev-parse', '--show-toplevel'] ).decode('ascii').strip()) -def get_tag(): +def get_tag() -> str: root = get_pytorch_root() # We're on a tag am_on_tag = ( @@ -46,7 +46,7 @@ def get_tag(): tag = re.sub(TRAILING_RC_PATTERN, "", tag) return tag -def get_base_version(): +def get_base_version() -> str: root = get_pytorch_root() dirty_version = open(root / 'version.txt', 'r').read().strip() # Strips trailing a0 from version.txt, not too sure why it's there in the @@ -54,29 +54,34 @@ def get_base_version(): return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version) class PytorchVersion: - def __init__(self, gpu_arch_type, gpu_arch_version, no_build_suffix): + def __init__( + self, + gpu_arch_type: str, + gpu_arch_version: str, + no_build_suffix: bool, + ) -> None: self.gpu_arch_type = gpu_arch_type self.gpu_arch_version = gpu_arch_version self.no_build_suffix = no_build_suffix - def get_post_build_suffix(self): + def get_post_build_suffix(self) -> str: if self.gpu_arch_type == "cuda": return f"+cu{self.gpu_arch_version.replace('.', '')}" return f"+{self.gpu_arch_type}{self.gpu_arch_version}" - def get_release_version(self): + def get_release_version(self) -> str: if not get_tag(): raise NoGitTagException( "Not on a git tag, are you sure you want a release version?" ) return f"{get_tag()}{self.get_post_build_suffix()}" - def get_nightly_version(self): + def get_nightly_version(self) -> str: date_str = datetime.today().strftime('%Y%m%d') build_suffix = self.get_post_build_suffix() return f"{get_base_version()}.dev{date_str}{build_suffix}" -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="Generate pytorch version for binary builds" ) diff --git a/.github/scripts/lint_native_functions.py b/.github/scripts/lint_native_functions.py index 9b4ef4e183acd0..2e6d4e3e767570 100755 --- a/.github/scripts/lint_native_functions.py +++ b/.github/scripts/lint_native_functions.py @@ -14,19 +14,19 @@ the YAML, not to be prescriptive about it. ''' -import ruamel.yaml +import ruamel.yaml # type: ignore[import] import difflib import sys from pathlib import Path from io import StringIO -def fn(base): +def fn(base: str) -> str: return str(base / Path("aten/src/ATen/native/native_functions.yaml")) with open(Path(__file__).parent.parent.parent / fn('.'), "r") as f: contents = f.read() -yaml = ruamel.yaml.YAML() +yaml = ruamel.yaml.YAML() # type: ignore[attr-defined] yaml.preserve_quotes = True yaml.width = 1000 yaml.boolean_representation = ['False', 'True'] diff --git a/.github/scripts/run_torchbench.py b/.github/scripts/run_torchbench.py index 8baa6295cead8a..b3c7ee2be46054 100644 --- a/.github/scripts/run_torchbench.py +++ b/.github/scripts/run_torchbench.py @@ -31,7 +31,7 @@ timeout: 720 tests:""" -def gen_abtest_config(control: str, treatment: str, models: List[str]): +def gen_abtest_config(control: str, treatment: str, models: List[str]) -> str: d = {} d["control"] = control d["treatment"] = treatment @@ -43,7 +43,7 @@ def gen_abtest_config(control: str, treatment: str, models: List[str]): config = config + "\n" return config -def deploy_torchbench_config(output_dir: str, config: str): +def deploy_torchbench_config(output_dir: str, config: str) -> None: # Create test dir if needed pathlib.Path(output_dir).mkdir(exist_ok=True) # TorchBench config file name @@ -71,7 +71,7 @@ def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> List[str]: return [] return model_list -def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str): +def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str) -> None: # Copy system environment so that we will not override env = dict(os.environ) command = ["python", "bisection.py", "--work-dir", output_dir, diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 827aa29f4c94bb..7e3c87a1c80429 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -492,7 +492,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) "${TOOLS_PATH}/autograd/templates/python_linalg_functions.cpp" "${TOOLS_PATH}/autograd/templates/python_special_functions.cpp" "${TOOLS_PATH}/autograd/templates/variable_factories.h" - "${TOOLS_PATH}/autograd/templates/annotated_fn_args.py" + "${TOOLS_PATH}/autograd/templates/annotated_fn_args.py.in" "${TOOLS_PATH}/autograd/deprecated.yaml" "${TOOLS_PATH}/autograd/derivatives.yaml" "${TOOLS_PATH}/autograd/gen_autograd_functions.py" diff --git a/mypy-strict.ini b/mypy-strict.ini index 3ac5b6055f3bd0..cb8ef8f59c30ef 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -5,8 +5,6 @@ # this config file to be used to ENFORCE that people are using mypy on codegen # files. -# For now, only code_template.py and benchmark utils Timer are covered this way - [mypy] python_version = 3.6 plugins = mypy_plugins/check_mypy_version.py @@ -30,36 +28,14 @@ check_untyped_defs = True disallow_untyped_decorators = True no_implicit_optional = True warn_redundant_casts = True -warn_unused_ignores = True warn_return_any = True implicit_reexport = False strict_equality = True files = - .github/scripts/generate_binary_build_matrix.py, - .github/scripts/generate_ci_workflows.py, - .github/scripts/parse_ref.py, + .github, benchmarks/instruction_counts, - tools/actions_local_runner.py, - tools/autograd/*.py, - tools/clang_tidy.py, - tools/codegen, - tools/explicit_ci_jobs.py, - tools/extract_scripts.py, - tools/mypy_wrapper.py, - tools/print_test_stats.py, - tools/pyi, - tools/stats_utils, - tools/test_history.py, - tools/test/test_actions_local_runner.py, - tools/test/test_extract_scripts.py, - tools/test/test_mypy_wrapper.py, - tools/test/test_test_history.py, - tools/test/test_trailing_newlines.py, - tools/test/test_translate_annotations.py, - tools/trailing_newlines.py, - tools/translate_annotations.py, - tools/vscode_settings.py, + tools, torch/testing/_internal/framework_utils.py, torch/utils/_pytree.py, torch/utils/benchmark/utils/common.py, diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 870e1a8d0b710e..8cfecda82e3284 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -12,7 +12,7 @@ 'torch', 'utils'))) -from hipify import hipify_python +from hipify import hipify_python # type: ignore[import] parser = argparse.ArgumentParser(description='Top-level script for HIPifying, filling in most common parameters') parser.add_argument( @@ -115,7 +115,7 @@ ] # Check if the compiler is hip-clang. -def is_hip_clang(): +def is_hip_clang() -> bool: try: hip_path = os.getenv('HIP_PATH', '/opt/rocm/hip') return 'HIP_COMPILER=clang' in open(hip_path + '/lib/.hipInfo').read() diff --git a/tools/autograd/gen_annotated_fn_args.py b/tools/autograd/gen_annotated_fn_args.py index 34bcfb760e8f40..a38918171c5886 100644 --- a/tools/autograd/gen_annotated_fn_args.py +++ b/tools/autograd/gen_annotated_fn_args.py @@ -48,7 +48,7 @@ def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None: template_path = os.path.join(autograd_dir, 'templates') fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) - fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py', lambda: { + fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py.in', lambda: { 'annotated_args': textwrap.indent('\n'.join(annotated_args), ' '), }) diff --git a/tools/autograd/templates/annotated_fn_args.py b/tools/autograd/templates/annotated_fn_args.py.in similarity index 100% rename from tools/autograd/templates/annotated_fn_args.py rename to tools/autograd/templates/annotated_fn_args.py.in diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index a29a7e0557c3b2..d795770c884442 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -1,15 +1,16 @@ import os from glob import glob import shutil +from typing import Dict, Optional from .setup_helpers.env import IS_64BIT, IS_WINDOWS, check_negative_env_flag -from .setup_helpers.cmake import USE_NINJA +from .setup_helpers.cmake import USE_NINJA, CMake -from setuptools import distutils +from setuptools import distutils # type: ignore[import] -def _overlay_windows_vcvars(env): +def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]: vc_arch = 'x64' if IS_64BIT else 'x86' - vc_env = distutils._msvccompiler._get_vc_env(vc_arch) + vc_env: Dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch) # Keys in `_get_vc_env` are always lowercase. # We turn them into uppercase before overlaying vcvars # because OS environ keys are always uppercase on Windows. @@ -22,7 +23,7 @@ def _overlay_windows_vcvars(env): return vc_env -def _create_build_env(): +def _create_build_env() -> Dict[str, str]: # XXX - our cmake file sometimes looks at the system environment # and not cmake flags! # you should NEVER add something to this list. It is bad practice to @@ -44,7 +45,14 @@ def _create_build_env(): return my_env -def build_caffe2(version, cmake_python_library, build_python, rerun_cmake, cmake_only, cmake): +def build_caffe2( + version: Optional[str], + cmake_python_library: Optional[str], + build_python: bool, + rerun_cmake: bool, + cmake_only: bool, + cmake: CMake, +) -> None: my_env = _create_build_env() build_test = not check_negative_env_flag('BUILD_TEST') cmake.generate(version, diff --git a/tools/clang_format_all.py b/tools/clang_format_all.py index e09659333d803f..7792f15a77d126 100755 --- a/tools/clang_format_all.py +++ b/tools/clang_format_all.py @@ -12,7 +12,9 @@ import re import os import sys -from clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH +from typing import List, Set + +from .clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH # Allowlist of directories to check. All files that in that directory # (recursively) will be checked. @@ -28,7 +30,7 @@ CPP_FILE_REGEX = re.compile(".*\\.(h|cpp|cc|c|hpp)$") -def get_allowlisted_files(): +def get_allowlisted_files() -> Set[str]: """ Parse CLANG_FORMAT_ALLOWLIST and resolve all directories. Returns the set of allowlist cpp source files. @@ -42,7 +44,11 @@ def get_allowlisted_files(): return set(matches) -async def run_clang_format_on_file(filename, semaphore, verbose=False): +async def run_clang_format_on_file( + filename: str, + semaphore: asyncio.Semaphore, + verbose: bool = False, +) -> None: """ Run clang-format on the provided file. """ @@ -55,7 +61,11 @@ async def run_clang_format_on_file(filename, semaphore, verbose=False): print("Formatted {}".format(filename)) -async def file_clang_formatted_correctly(filename, semaphore, verbose=False): +async def file_clang_formatted_correctly( + filename: str, + semaphore: asyncio.Semaphore, + verbose: bool = False, +) -> bool: """ Checks if a file is formatted correctly and returns True if so. """ @@ -80,7 +90,11 @@ async def file_clang_formatted_correctly(filename, semaphore, verbose=False): return ok -async def run_clang_format(max_processes, diff=False, verbose=False): +async def run_clang_format( + max_processes: int, + diff: bool = False, + verbose: bool = False, +) -> bool: """ Run clang-format to all files in CLANG_FORMAT_ALLOWLIST that match CPP_FILE_REGEX. """ @@ -114,7 +128,7 @@ async def run_clang_format(max_processes, diff=False, verbose=False): return ok -def parse_args(args): +def parse_args(args: List[str]) -> argparse.Namespace: """ Parse and return command-line arguments. """ @@ -134,7 +148,7 @@ def parse_args(args): return parser.parse_args(args) -def main(args): +def main(args: List[str]) -> bool: # Parse arguments. options = parse_args(args) # Get clang-format and make sure it is the right binary and it is in the right place. diff --git a/tools/clang_format_utils.py b/tools/clang_format_utils.py index 36427ea149a127..a1f621ceb939ae 100644 --- a/tools/clang_format_utils.py +++ b/tools/clang_format_utils.py @@ -45,7 +45,7 @@ def compute_file_sha256(path: str) -> str: return hash.hexdigest() -def report_download_progress(chunk_number, chunk_size, file_size): +def report_download_progress(chunk_number: int, chunk_size: int, file_size: int) -> None: """ Pretty printer for file download progress. """ @@ -55,7 +55,7 @@ def report_download_progress(chunk_number, chunk_size, file_size): sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100))) -def download_clang_format(path): +def download_clang_format(path: str) -> bool: """ Downloads a clang-format binary appropriate for the host platform and stores it at the given location. """ @@ -81,7 +81,7 @@ def download_clang_format(path): return True -def get_and_check_clang_format(verbose=False): +def get_and_check_clang_format(verbose: bool = False) -> bool: """ Download a platform-appropriate clang-format binary if one doesn't already exist at the expected location and verify that it is the right binary by checking its SHA256 hash against the expected hash. diff --git a/tools/code_analyzer/gen_op_registration_allowlist.py b/tools/code_analyzer/gen_op_registration_allowlist.py index 04a58c8f522b9a..c4138c8212b015 100644 --- a/tools/code_analyzer/gen_op_registration_allowlist.py +++ b/tools/code_analyzer/gen_op_registration_allowlist.py @@ -12,14 +12,18 @@ import yaml from collections import defaultdict +from typing import Dict, List, Set -def canonical_name(opname): +def canonical_name(opname: str) -> str: # Skip the overload name part as it's not supported by code analyzer yet. return opname.split('.', 1)[0] -def load_op_dep_graph(fname): +DepGraph = Dict[str, Set[str]] + + +def load_op_dep_graph(fname: str) -> DepGraph: with open(fname, 'r') as stream: result = defaultdict(set) for op in yaml.safe_load(stream): @@ -27,10 +31,10 @@ def load_op_dep_graph(fname): for dep in op.get('depends', []): dep_name = canonical_name(dep['name']) result[op_name].add(dep_name) - return result + return dict(result) -def load_root_ops(fname): +def load_root_ops(fname: str) -> List[str]: result = [] with open(fname, 'r') as stream: for op in yaml.safe_load(stream): @@ -38,7 +42,11 @@ def load_root_ops(fname): return result -def gen_transitive_closure(dep_graph, root_ops, train=False): +def gen_transitive_closure( + dep_graph: DepGraph, + root_ops: List[str], + train: bool = False, +) -> List[str]: result = set(root_ops) queue = root_ops[:] @@ -64,7 +72,7 @@ def gen_transitive_closure(dep_graph, root_ops, train=False): return sorted(result) -def gen_transitive_closure_str(dep_graph, root_ops): +def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: List[str]) -> str: return ' '.join(gen_transitive_closure(dep_graph, root_ops)) diff --git a/tools/code_analyzer/op_deps_processor.py b/tools/code_analyzer/op_deps_processor.py index 6978ce75ec1789..9623dc673dc348 100644 --- a/tools/code_analyzer/op_deps_processor.py +++ b/tools/code_analyzer/op_deps_processor.py @@ -11,6 +11,7 @@ import argparse import yaml +from typing import Any, List from tools.codegen.code_template import CodeTemplate @@ -46,12 +47,12 @@ """) -def load_op_deps(fname): +def load_op_deps(fname: str) -> Any: with open(fname, 'r') as stream: return yaml.safe_load(stream) -def process_base_ops(graph, base_ops): +def process_base_ops(graph: Any, base_ops: List[str]) -> None: # remove base ops from all `depends` lists to compress the output graph for op in graph: op['depends'] = [ @@ -64,7 +65,13 @@ def process_base_ops(graph, base_ops): 'depends': [{'name': name} for name in base_ops]}) -def convert(fname, graph, output_template, op_template, op_dep_template): +def convert( + fname: str, + graph: Any, + output_template: CodeTemplate, + op_template: CodeTemplate, + op_dep_template: CodeTemplate, +) -> None: ops = [] for op in graph: op_name = op['name'] diff --git a/tools/code_coverage/package/oss/cov_json.py b/tools/code_coverage/package/oss/cov_json.py index 618075f97fd8bc..8b987d93a66a93 100644 --- a/tools/code_coverage/package/oss/cov_json.py +++ b/tools/code_coverage/package/oss/cov_json.py @@ -1,11 +1,11 @@ from ..tool import clang_coverage from ..util.setting import CompilerType, Option, TestList, TestPlatform from ..util.utils import check_compiler_type -from .init import detect_compiler_type +from .init import detect_compiler_type # type: ignore[attr-defined] from .run import clang_run, gcc_run -def get_json_report(test_list: TestList, options: Option): +def get_json_report(test_list: TestList, options: Option) -> None: cov_type = detect_compiler_type() check_compiler_type(cov_type) if cov_type == CompilerType.CLANG: diff --git a/tools/code_coverage/package/oss/init.py b/tools/code_coverage/package/oss/init.py index 33eceed125b900..beef653cb6ff3f 100644 --- a/tools/code_coverage/package/oss/init.py +++ b/tools/code_coverage/package/oss/init.py @@ -1,6 +1,6 @@ import argparse import os -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, cast from ..util.setting import ( JSON_FOLDER_BASE_DIR, @@ -129,7 +129,7 @@ def empty_list_if_none(arg_interested_folder: Optional[List[str]]) -> List[str]: return arg_interested_folder -def gcc_export_init(): +def gcc_export_init() -> None: remove_folder(JSON_FOLDER_BASE_DIR) create_folder(JSON_FOLDER_BASE_DIR) @@ -161,7 +161,7 @@ def print_init_info() -> None: print_log("pytorch folder: ", get_pytorch_folder()) print_log("cpp test binaries folder: ", get_oss_binary_folder(TestType.CPP)) print_log("python test scripts folder: ", get_oss_binary_folder(TestType.PY)) - print_log("compiler type: ", detect_compiler_type().value) + print_log("compiler type: ", cast(CompilerType, detect_compiler_type()).value) print_log( "llvm tool folder (only for clang, if you are using gcov please ignore it): ", get_llvm_tool_path(), diff --git a/tools/code_coverage/package/oss/utils.py b/tools/code_coverage/package/oss/utils.py index a285a14be37ba1..739aa6e3910c03 100644 --- a/tools/code_coverage/package/oss/utils.py +++ b/tools/code_coverage/package/oss/utils.py @@ -82,8 +82,7 @@ def get_gcda_files() -> List[str]: # TODO use glob # output = glob.glob(f"{folder_has_gcda}/**/*.gcda") output = subprocess.check_output(["find", folder_has_gcda, "-iname", "*.gcda"]) - output = output.decode("utf-8").split("\n") - return output + return output.decode("utf-8").split("\n") else: return [] diff --git a/tools/code_coverage/package/tool/clang_coverage.py b/tools/code_coverage/package/tool/clang_coverage.py index 6daca38b97146a..1d1ebff6ae1f96 100644 --- a/tools/code_coverage/package/tool/clang_coverage.py +++ b/tools/code_coverage/package/tool/clang_coverage.py @@ -148,7 +148,7 @@ def export(test_list: TestList, platform_type: TestPlatform) -> None: binary_file = "" shared_library_list = [] if platform_type == TestPlatform.FBCODE: - from caffe2.fb.code_coverage.tool.package.fbcode.utils import ( + from caffe2.fb.code_coverage.tool.package.fbcode.utils import ( # type: ignore[import] get_fbcode_binary_folder, ) diff --git a/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py b/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py index 7980b73fbe498b..17d7c18975ff94 100644 --- a/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py +++ b/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py @@ -10,11 +10,11 @@ class LlvmCoverageSegment(NamedTuple): is_gap_entry: Optional[int] @property - def has_coverage(self): + def has_coverage(self) -> bool: return self.segment_count > 0 @property - def is_executable(self): + def is_executable(self) -> bool: return self.has_count > 0 def get_coverage( diff --git a/tools/code_coverage/package/tool/print_report.py b/tools/code_coverage/package/tool/print_report.py index 98f026b22bb77b..749773f8e42333 100644 --- a/tools/code_coverage/package/tool/print_report.py +++ b/tools/code_coverage/package/tool/print_report.py @@ -1,20 +1,22 @@ import os import subprocess -from typing import IO, Dict, List, Set +from typing import IO, Dict, List, Set, Tuple from ..oss.utils import get_pytorch_folder from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType +CoverageItem = Tuple[str, float, int, int] -def key_by_percentage(x): + +def key_by_percentage(x: CoverageItem) -> float: return x[1] -def key_by_name(x): +def key_by_name(x: CoverageItem) -> str: return x[0] -def is_intrested_file(file_path: str, interested_folders: List[str]): +def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool: if "cuda" in file_path: return False if "aten/gen_aten" in file_path or "aten/aten_" in file_path: @@ -34,7 +36,7 @@ def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool: def print_test_by_type( - tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO + tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str] ) -> None: print("Tests " + type_name + " to collect coverage:", file=summary_file) @@ -49,7 +51,7 @@ def print_test_condition( tests_type: TestStatusType, interested_folders: List[str], coverage_only: List[str], - summary_file: IO, + summary_file: IO[str], summary_type: str, ) -> None: print_test_by_type(tests, tests_type["success"], "fully success", summary_file) @@ -91,14 +93,8 @@ def line_oriented_report( "LINE SUMMARY", ) for file_name in covered_lines: - if len(covered_lines[file_name]) == 0: - covered = {} - else: - covered = covered_lines[file_name] - if len(uncovered_lines[file_name]) == 0: - uncovered = {} - else: - uncovered = uncovered_lines[file_name] + covered = covered_lines[file_name] + uncovered = uncovered_lines[file_name] print( f"{file_name}\n covered lines: {sorted(covered)}\n unconvered lines:{sorted(uncovered)}", file=report_file, @@ -106,7 +102,7 @@ def line_oriented_report( def print_file_summary( - covered_summary: int, total_summary: int, summary_file: IO + covered_summary: int, total_summary: int, summary_file: IO[str] ) -> float: # print summary first try: @@ -124,10 +120,10 @@ def print_file_summary( def print_file_oriented_report( tests_type: TestStatusType, - coverage, + coverage: List[CoverageItem], covered_summary: int, total_summary: int, - summary_file: IO, + summary_file: IO[str], tests: TestList, interested_folders: List[str], coverage_only: List[str], @@ -178,7 +174,7 @@ def file_oriented_report( except ZeroDivisionError: percentage = 0 # store information in a list to be sorted - coverage.append([file_name, percentage, covered_count, total_count]) + coverage.append((file_name, percentage, covered_count, total_count)) # update summary covered_summary = covered_summary + covered_count total_summary = total_summary + total_count @@ -202,7 +198,7 @@ def get_html_ignored_pattern() -> List[str]: return ["/usr/*", "*anaconda3/*", "*third_party/*"] -def html_oriented_report(): +def html_oriented_report() -> None: # use lcov to generate the coverage report build_folder = os.path.join(get_pytorch_folder(), "build") coverage_info_file = os.path.join(SUMMARY_FOLDER_DIR, "coverage.info") diff --git a/tools/code_coverage/package/tool/summarize_jsons.py b/tools/code_coverage/package/tool/summarize_jsons.py index f53ed107384267..0a9dbd1d72fd25 100644 --- a/tools/code_coverage/package/tool/summarize_jsons.py +++ b/tools/code_coverage/package/tool/summarize_jsons.py @@ -55,7 +55,7 @@ def transform_file_name( def is_intrested_file( file_path: str, interested_folders: List[str], platform: TestPlatform -): +) -> bool: ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"] if any([pattern in file_path for pattern in ignored_patterns]): return False diff --git a/tools/code_coverage/package/tool/utils.py b/tools/code_coverage/package/tool/utils.py index 6ecb2f61bff642..7cad7c063f82b8 100644 --- a/tools/code_coverage/package/tool/utils.py +++ b/tools/code_coverage/package/tool/utils.py @@ -12,12 +12,12 @@ def run_cpp_test(binary_file: str) -> None: print_error(f"Binary failed to run: {binary_file}") -def get_tool_path_by_platform(platform: TestPlatform): +def get_tool_path_by_platform(platform: TestPlatform) -> str: if platform == TestPlatform.FBCODE: - from caffe2.fb.code_coverage.tool.package.fbcode.utils import get_llvm_tool_path + from caffe2.fb.code_coverage.tool.package.fbcode.utils import get_llvm_tool_path # type: ignore[import] - return get_llvm_tool_path() + return get_llvm_tool_path() # type: ignore[no-any-return] else: - from ..oss.utils import get_llvm_tool_path + from ..oss.utils import get_llvm_tool_path # type: ignore[no-redef] - return get_llvm_tool_path() + return get_llvm_tool_path() # type: ignore[no-any-return] diff --git a/tools/code_coverage/package/util/utils.py b/tools/code_coverage/package/util/utils.py index 8c3ab48f5a7fd5..06a6ba0a72c517 100644 --- a/tools/code_coverage/package/util/utils.py +++ b/tools/code_coverage/package/util/utils.py @@ -2,7 +2,7 @@ import shutil import sys import time -from typing import Any, Optional +from typing import Any, NoReturn, Optional from .setting import ( LOG_DIR, @@ -71,7 +71,7 @@ def convert_to_relative_path(whole_path: str, base_path: str) -> str: return whole_path[len(base_path) + 1 :] -def replace_extension(filename, ext): +def replace_extension(filename: str, ext: str) -> str: return filename[: filename.rfind(".")] + ext @@ -89,11 +89,11 @@ def get_raw_profiles_folder() -> str: def detect_compiler_type(platform: TestPlatform) -> CompilerType: if platform == TestPlatform.OSS: - from package.oss.utils import detect_compiler_type + from package.oss.utils import detect_compiler_type # type: ignore[misc] - cov_type = detect_compiler_type() + cov_type = detect_compiler_type() # type: ignore[call-arg] else: - from caffe2.fb.code_coverage.tool.package.fbcode.utils import ( + from caffe2.fb.code_coverage.tool.package.fbcode.utils import ( # type: ignore[import] detect_compiler_type, ) @@ -138,7 +138,7 @@ def check_test_type(test_type: str, target: str) -> None: ) -def raise_no_test_found_exception(cpp_binary_folder: str, python_binary_folder: str): +def raise_no_test_found_exception(cpp_binary_folder: str, python_binary_folder: str) -> NoReturn: raise RuntimeError( f"No cpp and python tests found in folder **{cpp_binary_folder} and **{python_binary_folder}**" ) diff --git a/tools/coverage_plugins_package/setup.py b/tools/coverage_plugins_package/setup.py index d519e6f69e4612..c93f6129258daf 100644 --- a/tools/coverage_plugins_package/setup.py +++ b/tools/coverage_plugins_package/setup.py @@ -1,4 +1,4 @@ -import setuptools +import setuptools # type: ignore[import] with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() diff --git a/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py b/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py index 3fdc61c828ef54..8dcd31397d2ad9 100644 --- a/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py +++ b/tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py @@ -8,9 +8,10 @@ marked as covered. ''' -from coverage import CoveragePlugin, CoverageData +from coverage import CoveragePlugin, CoverageData # type: ignore[import] from inspect import ismodule, isclass, ismethod, isfunction, iscode, getsourcefile, getsourcelines from time import time +from typing import Any # All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with # `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link: @@ -18,17 +19,17 @@ cov_data = CoverageData(basename=f'.coverage.jit.{time()}') -def is_not_builtin_class(obj): +def is_not_builtin_class(obj: Any) -> bool: return isclass(obj) and not type(obj).__module__ == 'builtins' -class JitPlugin(CoveragePlugin): +class JitPlugin(CoveragePlugin): # type: ignore[misc, no-any-unimported] ''' dynamic_context is an overridden function that gives us access to every frame run during the coverage process. We look for when the function being run is `should_drop`, as all functions that get passed into `should_drop` will be compiled and thus should be marked as covered. ''' - def dynamic_context(self, frame): + def dynamic_context(self, frame: Any) -> None: if frame.f_code.co_name == 'should_drop': obj = frame.f_locals['fn'] # The many conditions in the if statement below are based on the accepted arguments to getsourcefile. Based @@ -54,5 +55,5 @@ def dynamic_context(self, frame): cov_data.add_lines(line_data) super().dynamic_context(frame) -def coverage_init(reg, options): +def coverage_init(reg: Any, options: Any) -> None: reg.add_dynamic_context(JitPlugin()) diff --git a/tools/download_mnist.py b/tools/download_mnist.py index 45ae4881cb4588..f9609573ca1bf5 100644 --- a/tools/download_mnist.py +++ b/tools/download_mnist.py @@ -18,14 +18,18 @@ ] -def report_download_progress(chunk_number, chunk_size, file_size): +def report_download_progress( + chunk_number: int, + chunk_size: int, + file_size: int, +) -> None: if file_size != -1: percent = min(1, (chunk_number * chunk_size) / file_size) bar = '#' * int(64 * percent) sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100))) -def download(destination_path, resource, quiet): +def download(destination_path: str, resource: str, quiet: bool) -> None: if os.path.exists(destination_path): if not quiet: print('{} already exists, skipping ...'.format(destination_path)) @@ -48,7 +52,7 @@ def download(destination_path, resource, quiet): raise RuntimeError('Error downloading resource!') -def unzip(zipped_path, quiet): +def unzip(zipped_path: str, quiet: bool) -> None: unzipped_path = os.path.splitext(zipped_path)[0] if os.path.exists(unzipped_path): if not quiet: @@ -61,7 +65,7 @@ def unzip(zipped_path, quiet): print('Unzipped {} ...'.format(zipped_path)) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description='Download the MNIST dataset from the internet') parser.add_argument( diff --git a/tools/export_slow_tests.py b/tools/export_slow_tests.py index df8a2dd3697baa..b4a9af35b3015c 100644 --- a/tools/export_slow_tests.py +++ b/tools/export_slow_tests.py @@ -17,7 +17,7 @@ def get_test_case_times() -> Dict[str, float]: # an entry will be like ("test_doc_examples (__main__.TestTypeHints)" -> [values])) test_names_to_times: DefaultDict[str, List[float]] = defaultdict(list) for report in reports: - if report.get('format_version', 1) != 2: + if report.get('format_version', 1) != 2: # type: ignore[misc] raise RuntimeError("S3 format currently handled is version 2 only") v2report = cast(Version2Report, report) for test_file in v2report['files'].values(): @@ -46,7 +46,7 @@ def export_slow_tests(filename: str) -> None: file.write('\n') -def parse_args(): +def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description='Export a JSON of slow test cases in PyTorch unit test suite') parser.add_argument( @@ -61,7 +61,7 @@ def parse_args(): return parser.parse_args() -def main(): +def main() -> None: options = parse_args() export_slow_tests(options.filename) diff --git a/tools/fast_nvcc/fast_nvcc.py b/tools/fast_nvcc/fast_nvcc.py index 6412adf24fae32..308bd65e9d6588 100755 --- a/tools/fast_nvcc/fast_nvcc.py +++ b/tools/fast_nvcc/fast_nvcc.py @@ -14,7 +14,10 @@ import subprocess import sys import time +from typing import (Awaitable, DefaultDict, Dict, List, Match, Optional, Set, + cast) +from typing_extensions import TypedDict help_msg = '''fast_nvcc [OPTION]... -- [NVCC_ARG]... @@ -78,14 +81,14 @@ re_tmp = r'(? None: """ Warn the user about something regarding fast_nvcc. """ print(f'warning (fast_nvcc): {warning}', file=sys.stderr) -def warn_if_windows(): +def warn_if_windows() -> None: """ Warn the user that using fast_nvcc on Windows might not work. """ @@ -97,7 +100,7 @@ def warn_if_windows(): fast_nvcc_warn(url_vars) -def warn_if_tmpdir_flag(args): +def warn_if_tmpdir_flag(args: List[str]) -> None: """ Warn the user that using fast_nvcc with some flags might not work. """ @@ -121,11 +124,17 @@ def warn_if_tmpdir_flag(args): fast_nvcc_warn(f'{url_base}#{frag}') -def nvcc_dryrun_data(binary, args): +class DryunData(TypedDict): + env: Dict[str, str] + commands: List[str] + exit_code: int + + +def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData: """ Return parsed environment variables and commands from nvcc --dryrun. """ - result = subprocess.run( + result = subprocess.run( # type: ignore[call-overload] [binary, '--dryrun'] + args, capture_output=True, encoding='ascii', # this is just a guess @@ -148,7 +157,7 @@ def nvcc_dryrun_data(binary, args): return {'env': env, 'commands': commands, 'exit_code': result.returncode} -def warn_if_tmpdir_set(env): +def warn_if_tmpdir_set(env: Dict[str, str]) -> None: """ Warn the user that setting TMPDIR with fast_nvcc might not work. """ @@ -157,7 +166,7 @@ def warn_if_tmpdir_set(env): fast_nvcc_warn(url_vars) -def contains_non_executable(commands): +def contains_non_executable(commands: List[str]) -> bool: for command in commands: # This is to deal with special command dry-run result from NVCC such as: # ``` @@ -170,7 +179,7 @@ def contains_non_executable(commands): return False -def module_id_contents(command): +def module_id_contents(command: List[str]) -> str: """ Guess the contents of the .module_id file contained within command. """ @@ -187,7 +196,7 @@ def module_id_contents(command): return f'_{len(middle)}_{middle}_{suffix}' -def unique_module_id_files(commands): +def unique_module_id_files(commands: List[str]) -> List[str]: """ Give each command its own .module_id filename instead of sharing. """ @@ -196,7 +205,7 @@ def unique_module_id_files(commands): for i, line in enumerate(commands): arr = [] - def uniqueify(s): + def uniqueify(s: Match[str]) -> str: filename = re.sub(r'\-(\d+)', r'-\1-' + str(i), s.group(0)) arr.append(filename) return filename @@ -212,14 +221,19 @@ def uniqueify(s): return uniqueified -def make_rm_force(commands): +def make_rm_force(commands: List[str]) -> List[str]: """ Add --force to all rm commands. """ return [f'{c} --force' if c.startswith('rm ') else c for c in commands] -def print_verbose_output(*, env, commands, filename): +def print_verbose_output( + *, + env: Dict[str, str], + commands: List[List[str]], + filename: str, +) -> None: """ Human-readably write nvcc --dryrun data to stderr. """ @@ -234,21 +248,24 @@ def print_verbose_output(*, env, commands, filename): print(f'#{" "*len(prefix)}{part}', file=f) -def straight_line_dependencies(commands): +Graph = List[Set[int]] + + +def straight_line_dependencies(commands: List[str]) -> Graph: """ Return a straight-line dependency graph. """ return [({i - 1} if i > 0 else set()) for i in range(len(commands))] -def files_mentioned(command): +def files_mentioned(command: str) -> List[str]: """ Return fully-qualified names of all tmp files referenced by command. """ return [f'/tmp/{match.group(1)}' for match in re.finditer(re_tmp, command)] -def nvcc_data_dependencies(commands): +def nvcc_data_dependencies(commands: List[str]) -> Graph: """ Return a list of the set of dependencies for each command. """ @@ -261,8 +278,8 @@ def nvcc_data_dependencies(commands): # data dependency is sort of flipped, because the steps that use the # files generated by cicc need to wait for the fatbinary step to # finish first - tmp_files = {} - fatbins = collections.defaultdict(set) + tmp_files: Dict[str, int] = {} + fatbins: DefaultDict[int, Set[str]] = collections.defaultdict(set) graph = [] for i, line in enumerate(commands): deps = set() @@ -284,13 +301,13 @@ def nvcc_data_dependencies(commands): return graph -def is_weakly_connected(graph): +def is_weakly_connected(graph: Graph) -> bool: """ Return true iff graph is weakly connected. """ if not graph: return True - neighbors = [set() for _ in graph] + neighbors: List[Set[int]] = [set() for _ in graph] for node, predecessors in enumerate(graph): for pred in predecessors: neighbors[pred].add(node) @@ -307,7 +324,7 @@ def is_weakly_connected(graph): return len(found) == len(graph) -def warn_if_not_weakly_connected(graph): +def warn_if_not_weakly_connected(graph: Graph) -> None: """ Warn the user if the execution graph is not weakly connected. """ @@ -315,11 +332,16 @@ def warn_if_not_weakly_connected(graph): fast_nvcc_warn('execution graph is not (weakly) connected') -def print_dot_graph(*, commands, graph, filename): +def print_dot_graph( + *, + commands: List[List[str]], + graph: Graph, + filename: str, +) -> None: """ Print a DOT file displaying short versions of the commands in graph. """ - def name(k): + def name(k: int) -> str: return f'"{k} {os.path.basename(commands[k][0])}"' with open(filename, 'w') as f: print('digraph {', file=f) @@ -332,7 +354,24 @@ def name(k): print('}', file=f) -async def run_command(command, *, env, deps, gather_data, i, save): + +class Result(TypedDict, total=False): + exit_code: int + stdout: bytes + stderr: bytes + time: float + files: Dict[str, int] + + +async def run_command( + command: str, + *, + env: Dict[str, str], + deps: Set[Awaitable[Result]], + gather_data: bool, + i: int, + save: Optional[str], +) -> Result: """ Run the command with the given env after waiting for deps. """ @@ -350,8 +389,8 @@ async def run_command(command, *, env, deps, gather_data, i, save): stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await proc.communicate() - code = proc.returncode - results = {'exit_code': code, 'stdout': stdout, 'stderr': stderr} + code = cast(int, proc.returncode) + results: Result = {'exit_code': code, 'stdout': stdout, 'stderr': stderr} if gather_data: t2 = time.monotonic() results['time'] = t2 - t1 @@ -371,14 +410,21 @@ async def run_command(command, *, env, deps, gather_data, i, save): return results -async def run_graph(*, env, commands, graph, gather_data=False, save=None): +async def run_graph( + *, + env: Dict[str, str], + commands: List[str], + graph: Graph, + gather_data: bool = False, + save: Optional[str] = None, +) -> List[Result]: """ Return outputs/errors (and optionally time/file info) from commands. """ - tasks = [] + tasks: List[Awaitable[Result]] = [] for i, (command, indices) in enumerate(zip(commands, graph)): deps = {tasks[j] for j in indices} - tasks.append(asyncio.create_task(run_command( + tasks.append(asyncio.create_task(run_command( # type: ignore[attr-defined] command, env=env, deps=deps, @@ -389,7 +435,7 @@ async def run_graph(*, env, commands, graph, gather_data=False, save=None): return [await task for task in tasks] -def print_command_outputs(command_results): +def print_command_outputs(command_results: List[Result]) -> None: """ Print captured stdout and stderr from commands. """ @@ -398,11 +444,16 @@ def print_command_outputs(command_results): sys.stderr.write(result.get('stderr', b'').decode('ascii')) -def write_log_csv(command_parts, command_results, *, filename): +def write_log_csv( + command_parts: List[List[str]], + command_results: List[Result], + *, + filename: str, +) -> None: """ Write a CSV file of the times and /tmp file sizes from each command. """ - tmp_files = [] + tmp_files: List[str] = [] for result in command_results: tmp_files.extend(result.get('files', {}).keys()) with open(filename, 'w', newline='') as csvfile: @@ -415,7 +466,7 @@ def write_log_csv(command_parts, command_results, *, filename): writer.writerow({**row, **result.get('files', {})}) -def exit_code(results): +def exit_code(results: List[Result]) -> int: """ Aggregate individual exit codes into a single code. """ @@ -426,11 +477,18 @@ def exit_code(results): return 0 -def wrap_nvcc(args, config=default_config): +def wrap_nvcc( + args: List[str], + config: argparse.Namespace = default_config, +) -> int: return subprocess.call([config.nvcc] + args) -def fast_nvcc(args, *, config=default_config): +def fast_nvcc( + args: List[str], + *, + config: argparse.Namespace = default_config, +) -> int: """ Emulate the result of calling the given nvcc binary with args. @@ -465,7 +523,7 @@ def fast_nvcc(args, *, config=default_config): ) if config.sequential: graph = straight_line_dependencies(commands) - results = asyncio.run(run_graph( + results = asyncio.run(run_graph( # type: ignore[attr-defined] env=env, commands=commands, graph=graph, @@ -478,7 +536,7 @@ def fast_nvcc(args, *, config=default_config): return exit_code([dryrun_data] + results) -def our_arg(arg): +def our_arg(arg: str) -> bool: return arg != '--' diff --git a/tools/flake8_hook.py b/tools/flake8_hook.py index 633e5d8f032534..b9ebd5b4793123 100755 --- a/tools/flake8_hook.py +++ b/tools/flake8_hook.py @@ -2,7 +2,7 @@ import sys -from flake8.main import git +from flake8.main import git # type: ignore[import] if __name__ == '__main__': sys.exit( diff --git a/tools/gdb/pytorch-gdb.py b/tools/gdb/pytorch-gdb.py index 97d0ce3b5bbeb4..46cdcdec2de2b8 100644 --- a/tools/gdb/pytorch-gdb.py +++ b/tools/gdb/pytorch-gdb.py @@ -1,5 +1,6 @@ -import gdb +import gdb # type: ignore[import] import textwrap +from typing import Any class DisableBreakpoints: """ @@ -8,18 +9,18 @@ class DisableBreakpoints: commands """ - def __enter__(self): + def __enter__(self) -> None: self.disabled_breakpoints = [] for b in gdb.breakpoints(): if b.enabled: b.enabled = False self.disabled_breakpoints.append(b) - def __exit__(self, etype, evalue, tb): + def __exit__(self, etype: Any, evalue: Any, tb: Any) -> None: for b in self.disabled_breakpoints: b.enabled = True -class TensorRepr(gdb.Command): +class TensorRepr(gdb.Command): # type: ignore[misc, no-any-unimported] """ Print a human readable representation of the given at::Tensor. Usage: torch-tensor-repr EXP @@ -31,11 +32,11 @@ class TensorRepr(gdb.Command): """ __doc__ = textwrap.dedent(__doc__).strip() - def __init__(self): + def __init__(self) -> None: gdb.Command.__init__(self, 'torch-tensor-repr', gdb.COMMAND_USER, gdb.COMPLETE_EXPRESSION) - def invoke(self, args, from_tty): + def invoke(self, args: str, from_tty: bool) -> None: args = gdb.string_to_argv(args) if len(args) != 1: print('Usage: torch-tensor-repr EXP') diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index 2637e3b070fe05..61682c9c896344 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -2,7 +2,7 @@ import os import subprocess from pathlib import Path -from setuptools import distutils +from setuptools import distutils # type: ignore[import] from typing import Optional, Union def get_sha(pytorch_root: Union[str, Path]) -> str: diff --git a/tools/lite_interpreter/gen_selected_mobile_ops_header.py b/tools/lite_interpreter/gen_selected_mobile_ops_header.py index 569c19f837bdc0..bf28bf3c3a3f27 100644 --- a/tools/lite_interpreter/gen_selected_mobile_ops_header.py +++ b/tools/lite_interpreter/gen_selected_mobile_ops_header.py @@ -103,7 +103,7 @@ def write_selected_mobile_ops_with_all_dtypes( header_contents = "".join(body_parts) out_file.write(header_contents.encode("utf-8")) -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="Generate selected_mobile_ops.h for selective build." ) diff --git a/tools/nightly.py b/tools/nightly.py index fca49133550bfb..0b387e3b32dcf1 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -40,10 +40,10 @@ import subprocess from ast import literal_eval from argparse import ArgumentParser -from typing import Dict, Optional, Iterator +from typing import (Any, Callable, Dict, Generator, Iterable, Iterator, List, + Optional, Sequence, Set, Tuple, TypeVar, cast) - -LOGGER = None +LOGGER: Optional[logging.Logger] = None URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2" DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss" SHA1_RE = re.compile("([0-9a-fA-F]{40})") @@ -133,7 +133,7 @@ def logging_rotate() -> None: @contextlib.contextmanager -def logging_manager(*, debug: bool = False) -> Iterator[None]: +def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]: """Setup logging. If a failure starts here we won't be able to save the user in a reasonable way. @@ -179,7 +179,7 @@ def logging_manager(*, debug: bool = False) -> Iterator[None]: sys.exit(1) -def check_in_repo(): +def check_in_repo() -> Optional[str]: """Ensures that we are in the PyTorch repo.""" if not os.path.isfile("setup.py"): return "Not in root-level PyTorch repo, no setup.py found" @@ -187,12 +187,13 @@ def check_in_repo(): s = f.read() if "PyTorch" not in s: return "Not in PyTorch repo, 'PyTorch' not found in setup.py" + return None -def check_branch(subcommand, branch): +def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]: """Checks that the branch name can be checked out.""" if subcommand != "checkout": - return + return None # first make sure actual branch name was given if branch is None: return "Branch name to checkout must be supplied with '-b' option" @@ -203,36 +204,44 @@ def check_branch(subcommand, branch): return "Need to have clean working tree to checkout!\n\n" + p.stdout # next check that the branch name doesn't already exist cmd = ["git", "show-ref", "--verify", "--quiet", "refs/heads/" + branch] - p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) + p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) # type: ignore[assignment] if not p.returncode: return f"Branch {branch!r} already exists" + return None @contextlib.contextmanager -def timer(logger, prefix): +def timer(logger: logging.Logger, prefix: str) -> Iterator[None]: """Timed context manager""" start_time = time.time() yield logger.info(f"{prefix} took {time.time() - start_time:.3f} [s]") -def timed(prefix): +F = TypeVar('F', bound=Callable[..., Any]) + + +def timed(prefix: str) -> Callable[[F], F]: """Decorator for timing functions""" - def dec(f): + def dec(f: F) -> F: @functools.wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: global LOGGER - LOGGER.info(prefix) - with timer(LOGGER, prefix): + logger = cast(logging.Logger, LOGGER) + logger.info(prefix) + with timer(logger, prefix): return f(*args, **kwargs) - return wrapper + return cast(F, wrapper) return dec -def _make_channel_args(channels=("pytorch-nightly",), override_channels=False): +def _make_channel_args( + channels: Iterable[str] = ("pytorch-nightly",), + override_channels: bool = False, +) -> List[str]: args = [] for channel in channels: args.append("--channel") @@ -244,8 +253,11 @@ def _make_channel_args(channels=("pytorch-nightly",), override_channels=False): @timed("Solving conda environment") def conda_solve( - name=None, prefix=None, channels=("pytorch-nightly",), override_channels=False -): + name: Optional[str] = None, + prefix: Optional[str] = None, + channels: Iterable[str] = ("pytorch-nightly",), + override_channels: bool = False, +) -> Tuple[List[str], str, str, bool, List[str]]: """Performs the conda solve and splits the deps from the package.""" # compute what environment to use if prefix is not None: @@ -299,7 +311,7 @@ def conda_solve( @timed("Installing dependencies") -def deps_install(deps, existing_env, env_opts): +def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None: """Install dependencies to deps environment""" if not existing_env: # first remove previous pytorch-deps env @@ -312,7 +324,7 @@ def deps_install(deps, existing_env, env_opts): @timed("Installing pytorch nightly binaries") -def pytorch_install(url): +def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]: """"Install pytorch into a temporary directory""" pytdir = tempfile.TemporaryDirectory() cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url] @@ -320,7 +332,7 @@ def pytorch_install(url): return pytdir -def _site_packages(dirname, platform): +def _site_packages(dirname: str, platform: str) -> str: if platform.startswith("win"): template = os.path.join(dirname, "Lib", "site-packages") else: @@ -329,7 +341,7 @@ def _site_packages(dirname, platform): return spdir -def _ensure_commit(git_sha1): +def _ensure_commit(git_sha1: str) -> None: """Make sure that we actually have the commit locally""" cmd = ["git", "cat-file", "-e", git_sha1 + "^{commit}"] p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) @@ -341,7 +353,7 @@ def _ensure_commit(git_sha1): p = subprocess.run(cmd, check=True) -def _nightly_version(spdir): +def _nightly_version(spdir: str) -> str: # first get the git version from the installed module version_fname = os.path.join(spdir, "torch", "version.py") with open(version_fname) as f: @@ -371,7 +383,7 @@ def _nightly_version(spdir): @timed("Checking out nightly PyTorch") -def checkout_nightly_version(branch, spdir): +def checkout_nightly_version(branch: str, spdir: str) -> None: """Get's the nightly version and then checks it out.""" nightly_version = _nightly_version(spdir) cmd = ["git", "checkout", "-b", branch, nightly_version] @@ -379,40 +391,40 @@ def checkout_nightly_version(branch, spdir): @timed("Pulling nightly PyTorch") -def pull_nightly_version(spdir): +def pull_nightly_version(spdir: str) -> None: """Fetches the nightly version and then merges it .""" nightly_version = _nightly_version(spdir) cmd = ["git", "merge", nightly_version] p = subprocess.run(cmd, check=True) -def _get_listing_linux(source_dir): +def _get_listing_linux(source_dir: str) -> List[str]: listing = glob.glob(os.path.join(source_dir, "*.so")) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so"))) return listing -def _get_listing_osx(source_dir): +def _get_listing_osx(source_dir: str) -> List[str]: # oddly, these are .so files even on Mac listing = glob.glob(os.path.join(source_dir, "*.so")) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib"))) return listing -def _get_listing_win(source_dir): +def _get_listing_win(source_dir: str) -> List[str]: listing = glob.glob(os.path.join(source_dir, "*.pyd")) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib"))) listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll"))) return listing -def _glob_pyis(d): +def _glob_pyis(d: str) -> Set[str]: search = os.path.join(d, "**", "*.pyi") pyis = {os.path.relpath(p, d) for p in glob.iglob(search)} return pyis -def _find_missing_pyi(source_dir, target_dir): +def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]: source_pyis = _glob_pyis(source_dir) target_pyis = _glob_pyis(target_dir) missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)] @@ -420,7 +432,7 @@ def _find_missing_pyi(source_dir, target_dir): return missing_pyis -def _get_listing(source_dir, target_dir, platform): +def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]: if platform.startswith("linux"): listing = _get_listing_linux(source_dir) elif platform.startswith("osx"): @@ -437,7 +449,7 @@ def _get_listing(source_dir, target_dir, platform): return listing -def _remove_existing(trg, is_dir): +def _remove_existing(trg: str, is_dir: bool) -> None: if os.path.exists(trg): if is_dir: shutil.rmtree(trg) @@ -445,7 +457,13 @@ def _remove_existing(trg, is_dir): os.remove(trg) -def _move_single(src, source_dir, target_dir, mover, verb): +def _move_single( + src: str, + source_dir: str, + target_dir: str, + mover: Callable[[str, str], None], + verb: str, +) -> None: is_dir = os.path.isdir(src) relpath = os.path.relpath(src, source_dir) trg = os.path.join(target_dir, relpath) @@ -469,18 +487,18 @@ def _move_single(src, source_dir, target_dir, mover, verb): mover(src, trg) -def _copy_files(listing, source_dir, target_dir): +def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None: for src in listing: _move_single(src, source_dir, target_dir, shutil.copy2, "Copying") -def _link_files(listing, source_dir, target_dir): +def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None: for src in listing: _move_single(src, source_dir, target_dir, os.link, "Linking") @timed("Moving nightly files into repo") -def move_nightly_files(spdir, platform): +def move_nightly_files(spdir: str, platform: str) -> None: """Moves PyTorch files from temporary installed location to repo.""" # get file listing source_dir = os.path.join(spdir, "torch") @@ -496,7 +514,7 @@ def move_nightly_files(spdir, platform): _copy_files(listing, source_dir, target_dir) -def _available_envs(): +def _available_envs() -> Dict[str, str]: cmd = ["conda", "env", "list"] p = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) lines = p.stdout.splitlines() @@ -513,7 +531,7 @@ def _available_envs(): @timed("Writing pytorch-nightly.pth") -def write_pth(env_opts, platform): +def write_pth(env_opts: List[str], platform: str) -> None: """Writes Python path file for this dir.""" env_type, env_dir = env_opts if env_type == "--name": @@ -533,17 +551,16 @@ def write_pth(env_opts, platform): def install( - subcommand="checkout", - branch=None, - name=None, - prefix=None, - channels=("pytorch-nightly",), - override_channels=False, - logger=None, -): + *, + logger: logging.Logger, + subcommand: str = "checkout", + branch: Optional[str] = None, + name: Optional[str] = None, + prefix: Optional[str] = None, + channels: Iterable[str] = ("pytorch-nightly",), + override_channels: bool = False, +) -> None: """Development install of PyTorch""" - global LOGGER - logger = logger or LOGGER deps, pytorch, platform, existing_env, env_opts = conda_solve( name=name, prefix=prefix, channels=channels, override_channels=override_channels ) @@ -552,7 +569,7 @@ def install( pytdir = pytorch_install(pytorch) spdir = _site_packages(pytdir.name, platform) if subcommand == "checkout": - checkout_nightly_version(branch, spdir) + checkout_nightly_version(cast(str, branch), spdir) elif subcommand == "pull": pull_nightly_version(spdir) else: @@ -566,7 +583,7 @@ def install( ) -def make_parser(): +def make_parser() -> ArgumentParser: p = ArgumentParser("nightly") # subcommands subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute") @@ -627,7 +644,7 @@ def make_parser(): return p -def main(args=None): +def main(args: Optional[Sequence[str]] = None) -> None: """Main entry point""" global LOGGER p = make_parser() diff --git a/tools/render_junit.py b/tools/render_junit.py index 83aee69ec863bb..eac873d321cab2 100644 --- a/tools/render_junit.py +++ b/tools/render_junit.py @@ -16,8 +16,8 @@ except ImportError: print("rich not found, for color output use 'pip install rich'") -def parse_junit_reports(path_to_reports: str) -> List[TestCase]: - def parse_file(path: str) -> List[TestCase]: +def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore[no-any-unimported] + def parse_file(path: str) -> List[TestCase]: # type: ignore[no-any-unimported] try: return convert_junit_to_testcases(JUnitXml.fromfile(path)) except Exception as err: @@ -37,7 +37,7 @@ def parse_file(path: str) -> List[TestCase]: return ret_xml -def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]: +def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]: # type: ignore[no-any-unimported] testcases = [] for item in xml: if isinstance(item, TestSuite): @@ -46,7 +46,7 @@ def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase testcases.append(item) return testcases -def render_tests(testcases: List[TestCase]) -> None: +def render_tests(testcases: List[TestCase]) -> None: # type: ignore[no-any-unimported] num_passed = 0 num_skipped = 0 num_failed = 0 diff --git a/tools/setup_helpers/__init__.py b/tools/setup_helpers/__init__.py index 78da7f45cfc3e0..fa892dfb6e6f29 100644 --- a/tools/setup_helpers/__init__.py +++ b/tools/setup_helpers/__init__.py @@ -1,8 +1,9 @@ import os import sys +from typing import Optional -def which(thefile): +def which(thefile: str) -> Optional[str]: path = os.environ.get("PATH", os.defpath).split(os.pathsep) for d in path: fname = os.path.join(d, thefile) diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 9ce09c46ce7abd..d60dc36f13d0d2 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -8,14 +8,15 @@ from subprocess import check_call, check_output, CalledProcessError import sys import sysconfig -from setuptools import distutils +from setuptools import distutils # type: ignore[import] +from typing import IO, Any, Dict, List, Optional, Union from . import which from .env import (BUILD_DIR, IS_64BIT, IS_DARWIN, IS_WINDOWS, check_negative_env_flag) from .numpy_ import USE_NUMPY, NUMPY_INCLUDE_DIR -def _mkdir_p(d): +def _mkdir_p(d: str) -> None: try: os.makedirs(d) except OSError: @@ -28,7 +29,11 @@ def _mkdir_p(d): USE_NINJA = (not check_negative_env_flag('USE_NINJA') and which('ninja') is not None) -def convert_cmake_value_to_python_value(cmake_value, cmake_type): + +CMakeValue = Optional[Union[bool, str]] + + +def convert_cmake_value_to_python_value(cmake_value: str, cmake_type: str) -> CMakeValue: r"""Convert a CMake value in a string form to a Python value. Args: @@ -52,7 +57,7 @@ def convert_cmake_value_to_python_value(cmake_value, cmake_type): else: # Directly return the cmake_value. return cmake_value -def get_cmake_cache_variables_from_file(cmake_cache_file): +def get_cmake_cache_variables_from_file(cmake_cache_file: IO[str]) -> Dict[str, CMakeValue]: r"""Gets values in CMakeCache.txt into a dictionary. Args: @@ -93,12 +98,12 @@ def get_cmake_cache_variables_from_file(cmake_cache_file): class CMake: "Manages cmake." - def __init__(self, build_dir=BUILD_DIR): + def __init__(self, build_dir: str = BUILD_DIR) -> None: self._cmake_command = CMake._get_cmake_command() self.build_dir = build_dir @property - def _cmake_cache_file(self): + def _cmake_cache_file(self) -> str: r"""Returns the path to CMakeCache.txt. Returns: @@ -107,7 +112,7 @@ def _cmake_cache_file(self): return os.path.join(self.build_dir, 'CMakeCache.txt') @staticmethod - def _get_cmake_command(): + def _get_cmake_command() -> str: "Returns cmake command." cmake_command = 'cmake' @@ -124,7 +129,7 @@ def _get_cmake_command(): raise RuntimeError('no cmake or cmake3 with version >= 3.5.0 found') @staticmethod - def _get_version(cmd): + def _get_version(cmd: str) -> Any: "Returns cmake version." for line in check_output([cmd, '--version']).decode('utf-8').split('\n'): @@ -132,7 +137,7 @@ def _get_version(cmd): return distutils.version.LooseVersion(line.strip().split(' ')[2]) raise RuntimeError('no version found') - def run(self, args, env): + def run(self, args: List[str], env: Dict[str, str]) -> None: "Executes cmake with arguments and an environment." command = [self._cmake_command] + args @@ -146,13 +151,13 @@ def run(self, args, env): sys.exit(1) @staticmethod - def defines(args, **kwargs): + def defines(args: List[str], **kwargs: CMakeValue) -> None: "Adds definitions to a cmake argument list." for key, value in sorted(kwargs.items()): if value is not None: args.append('-D{}={}'.format(key, value)) - def get_cmake_cache_variables(self): + def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]: r"""Gets values in CMakeCache.txt into a dictionary. Returns: dict: A ``dict`` containing the value of cached CMake variables. @@ -160,7 +165,15 @@ def get_cmake_cache_variables(self): with open(self._cmake_cache_file) as f: return get_cmake_cache_variables_from_file(f) - def generate(self, version, cmake_python_library, build_python, build_test, my_env, rerun): + def generate( + self, + version: Optional[str], + cmake_python_library: Optional[str], + build_python: bool, + build_test: bool, + my_env: Dict[str, str], + rerun: bool, + ) -> None: "Runs cmake to generate native build files." if rerun and os.path.isfile(self._cmake_cache_file): @@ -215,7 +228,7 @@ def generate(self, version, cmake_python_library, build_python, build_test, my_e _mkdir_p(self.build_dir) # Store build options that are directly stored in environment variables - build_options = { + build_options: Dict[str, CMakeValue] = { # The default value cannot be easily obtained in CMakeLists.txt. We set it here. 'CMAKE_PREFIX_PATH': sysconfig.get_path('purelib') } @@ -335,7 +348,7 @@ def generate(self, version, cmake_python_library, build_python, build_test, my_e args.append(base_dir) self.run(args, env=my_env) - def build(self, my_env): + def build(self, my_env: Dict[str, str]) -> None: "Runs cmake to build binaries." from .env import build_type diff --git a/tools/setup_helpers/env.py b/tools/setup_helpers/env.py index f04f10cc287c55..d658acdb8d52df 100644 --- a/tools/setup_helpers/env.py +++ b/tools/setup_helpers/env.py @@ -3,6 +3,7 @@ import struct import sys from itertools import chain +from typing import Iterable, List, Optional, cast IS_WINDOWS = (platform.system() == 'Windows') @@ -17,19 +18,19 @@ BUILD_DIR = 'build' -def check_env_flag(name, default=''): +def check_env_flag(name: str, default: str = '') -> bool: return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] -def check_negative_env_flag(name, default=''): +def check_negative_env_flag(name: str, default: str = '') -> bool: return os.getenv(name, default).upper() in ['OFF', '0', 'NO', 'FALSE', 'N'] -def gather_paths(env_vars): +def gather_paths(env_vars: Iterable[str]) -> List[str]: return list(chain(*(os.getenv(v, '').split(os.pathsep) for v in env_vars))) -def lib_paths_from_base(base_path): +def lib_paths_from_base(base_path: str) -> List[str]: return [os.path.join(base_path, s) for s in ['lib/x64', 'lib', 'lib64']] @@ -49,7 +50,7 @@ class BuildType(object): """ - def __init__(self, cmake_build_type_env=None): + def __init__(self, cmake_build_type_env: Optional[str] = None) -> None: if cmake_build_type_env is not None: self.build_type_string = cmake_build_type_env return @@ -63,19 +64,19 @@ def __init__(self, cmake_build_type_env=None): # Normally it is anti-pattern to determine build type from CMAKE_BUILD_TYPE because it is not used for # multi-configuration build tools, such as Visual Studio and XCode. But since we always communicate with # CMake using CMAKE_BUILD_TYPE from our Python scripts, this is OK here. - self.build_type_string = cmake_cache_vars['CMAKE_BUILD_TYPE'] + self.build_type_string = cast(str, cmake_cache_vars['CMAKE_BUILD_TYPE']) else: self.build_type_string = os.environ.get('CMAKE_BUILD_TYPE', 'Release') - def is_debug(self): + def is_debug(self) -> bool: "Checks Debug build." return self.build_type_string == 'Debug' - def is_rel_with_deb_info(self): + def is_rel_with_deb_info(self) -> bool: "Checks RelWithDebInfo build." return self.build_type_string == 'RelWithDebInfo' - def is_release(self): + def is_release(self) -> bool: "Checks Release build." return self.build_type_string == 'Release' diff --git a/tools/setup_helpers/gen_version_header.py b/tools/setup_helpers/gen_version_header.py index 94ba264db83b34..963db1dad1f136 100644 --- a/tools/setup_helpers/gen_version_header.py +++ b/tools/setup_helpers/gen_version_header.py @@ -4,9 +4,12 @@ import argparse import os +from typing import Dict, Tuple, cast +Version = Tuple[int, int, int] -def parse_version(version: str) -> (int, int, int): + +def parse_version(version: str) -> Version: """ Parses a version string into (major, minor, patch) version numbers. @@ -24,10 +27,10 @@ def parse_version(version: str) -> (int, int, int): version_number_str = version[:i] break - return tuple([int(n) for n in version_number_str.split(".")]) + return cast(Version, tuple([int(n) for n in version_number_str.split(".")])) -def apply_replacements(replacements, text): +def apply_replacements(replacements: Dict[str, str], text: str) -> str: """ Applies the given replacements within the text. @@ -43,7 +46,7 @@ def apply_replacements(replacements, text): return text -def main(args): +def main(args: argparse.Namespace) -> None: with open(args.version_path) as f: version = f.read().strip() (major, minor, patch) = parse_version(version) diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index 8fb15be7565ff6..dcc5fb58f79798 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -2,12 +2,13 @@ import os import sys import yaml +from typing import Any, List, Optional, cast try: # use faster C loader if available from yaml import CSafeLoader as YamlLoader except ImportError: - from yaml import SafeLoader as YamlLoader + from yaml import SafeLoader as YamlLoader # type: ignore[misc] source_files = {'.py', '.cpp', '.h'} @@ -16,7 +17,7 @@ # TODO: This is a little inaccurate, because it will also pick # up setup_helper scripts which don't affect code generation -def all_generator_source(): +def all_generator_source() -> List[str]: r = [] for directory, _, filenames in os.walk('tools'): for f in filenames: @@ -26,15 +27,15 @@ def all_generator_source(): return sorted(r) -def generate_code(ninja_global=None, - declarations_path=None, - nn_path=None, - native_functions_path=None, - install_dir=None, - subset=None, - disable_autograd=False, - force_schema_registration=False, - operator_selector=None): +def generate_code(ninja_global: Optional[str] = None, + declarations_path: Optional[str] = None, + nn_path: Optional[str] = None, + native_functions_path: Optional[str] = None, + install_dir: Optional[str] = None, + subset: Optional[str] = None, + disable_autograd: bool = False, + force_schema_registration: bool = False, + operator_selector: Any = None) -> None: from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python from tools.autograd.gen_annotated_fn_args import gen_annotated from tools.codegen.selective_build.selector import SelectiveBuilder @@ -86,7 +87,7 @@ def generate_code(ninja_global=None, def get_selector_from_legacy_operator_selection_list( selected_op_list_path: str, -): +) -> Any: with open(selected_op_list_path, 'r') as f: # strip out the overload part # It's only for legacy config - do NOT copy this code! @@ -113,7 +114,10 @@ def get_selector_from_legacy_operator_selection_list( return selector -def get_selector(selected_op_list_path, operators_yaml_path): +def get_selector( + selected_op_list_path: Optional[str], + operators_yaml_path: Optional[str], +) -> Any: # cwrap depends on pyyaml, so we can't import it earlier root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, root) @@ -129,10 +133,10 @@ def get_selector(selected_op_list_path, operators_yaml_path): elif selected_op_list_path is not None: return get_selector_from_legacy_operator_selection_list(selected_op_list_path) else: - return SelectiveBuilder.from_yaml_path(operators_yaml_path) + return SelectiveBuilder.from_yaml_path(cast(str, operators_yaml_path)) -def main(): +def main() -> None: parser = argparse.ArgumentParser(description='Autogenerate code') parser.add_argument('--declarations-path') parser.add_argument('--native-functions-path') diff --git a/tools/shared/cwrap_common.py b/tools/shared/cwrap_common.py index d35d7d1bde3d44..01ff97aabd9ba2 100644 --- a/tools/shared/cwrap_common.py +++ b/tools/shared/cwrap_common.py @@ -2,8 +2,11 @@ # for now, I have put it in one place but right now is copied out of cwrap import copy +from typing import Any, Dict, Iterable, List, Union -def parse_arguments(args): +Arg = Dict[str, Any] + +def parse_arguments(args: List[Union[str, Arg]]) -> List[Arg]: new_args = [] for arg in args: # Simple arg declaration of form " " @@ -20,7 +23,10 @@ def parse_arguments(args): return new_args -def set_declaration_defaults(declaration): +Declaration = Dict[str, Any] + + +def set_declaration_defaults(declaration: Declaration) -> None: if 'schema_string' not in declaration: # This happens for legacy TH bindings like # _thnn_conv_depthwise2d_backward @@ -70,19 +76,26 @@ def set_declaration_defaults(declaration): # TODO(zach): added option to remove keyword handling for C++ which cannot # support it. +Option = Dict[str, Any] + -def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self): - def exclude_arg(arg): - return arg['type'] == 'CONSTANT' +def filter_unique_options( + options: Iterable[Option], + allow_kwarg: bool, + type_to_signature: Dict[str, str], + remove_self: bool, +) -> List[Option]: + def exclude_arg(arg: Arg) -> bool: + return arg['type'] == 'CONSTANT' # type: ignore[no-any-return] - def exclude_arg_with_self_check(arg): + def exclude_arg_with_self_check(arg: Arg) -> bool: return exclude_arg(arg) or (remove_self and arg['name'] == 'self') - def signature(option, kwarg_only_count): - if kwarg_only_count == 0: + def signature(option: Option, num_kwarg_only: int) -> str: + if num_kwarg_only == 0: kwarg_only_count = None else: - kwarg_only_count = -kwarg_only_count + kwarg_only_count = -num_kwarg_only arg_signature = '#'.join( type_to_signature.get(arg['type'], arg['type']) for arg in option['arguments'][:kwarg_only_count] @@ -111,40 +124,40 @@ def signature(option, kwarg_only_count): return unique -def sort_by_number_of_args(declaration, reverse=True): - def num_args(option): +def sort_by_number_of_args(declaration: Declaration, reverse: bool = True) -> None: + def num_args(option: Option) -> int: return len(option['arguments']) declaration['options'].sort(key=num_args, reverse=reverse) class Function(object): - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name - self.arguments = [] + self.arguments: List['Argument'] = [] - def add_argument(self, arg): + def add_argument(self, arg: 'Argument') -> None: assert isinstance(arg, Argument) self.arguments.append(arg) - def __repr__(self): + def __repr__(self) -> str: return self.name + '(' + ', '.join(a.__repr__() for a in self.arguments) + ')' class Argument(object): - def __init__(self, _type, name, is_optional): + def __init__(self, _type: str, name: str, is_optional: bool): self.type = _type self.name = name self.is_optional = is_optional - def __repr__(self): + def __repr__(self) -> str: return self.type + ' ' + self.name -def parse_header(path): +def parse_header(path: str) -> List[Function]: with open(path, 'r') as f: - lines = f.read().split('\n') + lines: Iterable[Any] = f.read().split('\n') # Remove empty lines and prebackend directives lines = filter(lambda l: l and not l.startswith('#'), lines) diff --git a/tools/shared/module_loader.py b/tools/shared/module_loader.py index 51c57aa161c931..7482047d4e8d9a 100644 --- a/tools/shared/module_loader.py +++ b/tools/shared/module_loader.py @@ -1,6 +1,11 @@ -def import_module(name, path): +from importlib.abc import Loader +from types import ModuleType +from typing import cast + + +def import_module(name: str, path: str) -> ModuleType: import importlib.util spec = importlib.util.spec_from_file_location(name, path) module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + cast(Loader, spec.loader).exec_module(module) return module diff --git a/tools/test/test_stats.py b/tools/test/test_stats.py index b85c38763cce08..b00fbda9d4ff11 100644 --- a/tools/test/test_stats.py +++ b/tools/test/test_stats.py @@ -1,13 +1,19 @@ # -*- coding: utf-8 -*- import unittest +from typing import Dict, List + from tools import print_test_stats +from tools.stats_utils.s3_stat_parser import (Commit, Report, ReportMetaMeta, + Status, Version1Case, + Version1Report, Version2Case, + Version2Report) -def fakehash(char): +def fakehash(char: str) -> str: return char * 40 -def dummy_meta_meta() -> print_test_stats.ReportMetaMeta: +def dummy_meta_meta() -> ReportMetaMeta: return { 'build_pr': '', 'build_tag': '', @@ -18,7 +24,14 @@ def dummy_meta_meta() -> print_test_stats.ReportMetaMeta: } -def makecase(name, seconds, *, errored=False, failed=False, skipped=False): +def makecase( + name: str, + seconds: float, + *, + errored: bool = False, + failed: bool = False, + skipped: bool = False, +) -> Version1Case: return { 'name': name, 'seconds': seconds, @@ -28,7 +41,7 @@ def makecase(name, seconds, *, errored=False, failed=False, skipped=False): } -def make_report_v1(tests) -> print_test_stats.Version1Report: +def make_report_v1(tests: Dict[str, List[Version1Case]]) -> Version1Report: suites = { suite_name: { 'total_seconds': sum(case['seconds'] for case in cases), @@ -37,20 +50,20 @@ def make_report_v1(tests) -> print_test_stats.Version1Report: for suite_name, cases in tests.items() } return { - **dummy_meta_meta(), + **dummy_meta_meta(), # type: ignore[misc] 'total_seconds': sum(s['total_seconds'] for s in suites.values()), 'suites': suites, } -def make_case_v2(seconds, status=None) -> print_test_stats.Version2Case: +def make_case_v2(seconds: float, status: Status = None) -> Version2Case: return { 'seconds': seconds, 'status': status, } -def make_report_v2(tests) -> print_test_stats.Version2Report: +def make_report_v2(tests: Dict[str, Dict[str, Dict[str, Version2Case]]]) -> Version2Report: files = {} for file_name, file_suites in tests.items(): suites = { @@ -65,7 +78,7 @@ def make_report_v2(tests) -> print_test_stats.Version2Report: 'total_seconds': sum(suite['total_seconds'] for suite in suites.values()), } return { - **dummy_meta_meta(), + **dummy_meta_meta(), # type: ignore[misc] 'format_version': 2, 'total_seconds': sum(s['total_seconds'] for s in files.values()), 'files': files, @@ -73,7 +86,7 @@ def make_report_v2(tests) -> print_test_stats.Version2Report: maxDiff = None class TestPrintTestStats(unittest.TestCase): - version1_report: print_test_stats.Version1Report = make_report_v1({ + version1_report: Version1Report = make_report_v1({ # input ordering of the suites is ignored 'Grault': [ # not printed: status same and time similar @@ -112,7 +125,7 @@ class TestPrintTestStats(unittest.TestCase): ], }) - version2_report: print_test_stats.Version2Report = make_report_v2( + version2_report: Version2Report = make_report_v2( { 'test_a': { 'Grault': { @@ -149,7 +162,7 @@ class TestPrintTestStats(unittest.TestCase): } }) - def test_simplify(self): + def test_simplify(self) -> None: self.assertEqual( { '': { @@ -222,10 +235,10 @@ def test_simplify(self): print_test_stats.simplify(self.version2_report), ) - def test_analysis(self): + def test_analysis(self) -> None: head_report = self.version1_report - base_reports = { + base_reports: Dict[Commit, List[Report]] = { # bbbb has no reports, so base is cccc instead fakehash('b'): [], fakehash('c'): [ @@ -391,7 +404,7 @@ class Qux: print_test_stats.anomalies(analysis), ) - def test_graph(self): + def test_graph(self) -> None: # HEAD is on master self.assertEqual( '''\ @@ -534,7 +547,7 @@ def test_graph(self): ) ) - def test_regression_info(self): + def test_regression_info(self) -> None: self.assertEqual( '''\ ----- Historic stats comparison result ------ @@ -588,7 +601,7 @@ def test_regression_info(self): ) ) - def test_regression_info_new_job(self): + def test_regression_info_new_job(self) -> None: self.assertEqual( '''\ ----- Historic stats comparison result ------ diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index c77e960ae65912..5e0a8de8b3e666 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -3,6 +3,7 @@ import inspect import sys import tempfile +from typing import Any, List, Optional, Tuple # this arbitrary-looking assortment of functionality is provided here # to have a central place for overrideable behavior. The motivating @@ -20,30 +21,33 @@ else: torch_parent = os.path.dirname(os.path.dirname(__file__)) -def get_file_path(*path_components): +def get_file_path(*path_components: str) -> str: return os.path.join(torch_parent, *path_components) -def get_file_path_2(*path_components): +def get_file_path_2(*path_components: str) -> str: return os.path.join(*path_components) -def get_writable_path(path): +def get_writable_path(path: str) -> str: if os.access(path, os.W_OK): return path return tempfile.mkdtemp(suffix=os.path.basename(path)) -def prepare_multiprocessing_environment(path): +def prepare_multiprocessing_environment(path: str) -> None: pass -def resolve_library_path(path): +def resolve_library_path(path: str) -> str: return os.path.realpath(path) -def get_source_lines_and_file(obj, error_msg=None): +def get_source_lines_and_file( + obj: Any, + error_msg: Optional[str] = None, +) -> Tuple[List[str], int, Optional[str]]: """ Wrapper around inspect.getsourcelines and inspect.getsourcefile.