diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5135d1a21855..8a14b1642a3a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -146,6 +146,13 @@ repos: - --fuzzy-match-generates-todo files: > \.cfg$|\.conf$|\.ini$|\.ldif$|\.properties$|\.readthedocs$|\.service$|\.tf$|Dockerfile.*$ + - repo: https://github.com/psf/black + rev: 22.12.0 + hooks: + - id: black + name: Run black (python formatter) + args: [--config=./pyproject.toml] + exclude: ^airflow/_vendor/|^airflow/contrib/ - repo: local hooks: - id: update-common-sql-api-stubs @@ -175,15 +182,8 @@ repos: additional_dependencies: ['ruff==0.0.226'] files: \.pyi?$ exclude: ^airflow/_vendor/ - - repo: https://github.com/psf/black - rev: 22.12.0 - hooks: - - id: black - name: Run black (python formatter) - args: [--config=./pyproject.toml] - exclude: ^airflow/_vendor/|^airflow/contrib/ - repo: https://github.com/asottile/blacken-docs - rev: v1.12.1 + rev: 1.13.0 hooks: - id: blacken-docs name: Run black on python code blocks in documentation files @@ -237,7 +237,7 @@ repos: files: ^chart/values\.schema\.json$|^chart/values_schema\.schema\.json$ pass_filenames: true - repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.9.0 + rev: v1.10.0 hooks: - id: rst-backticks name: Check if RST files use double backticks for code @@ -246,7 +246,7 @@ repos: name: Check if there are no deprecate log warn exclude: ^airflow/_vendor/ - repo: https://github.com/adrienverge/yamllint - rev: v1.28.0 + rev: v1.29.0 hooks: - id: yamllint name: Check YAML files with yamllint @@ -351,6 +351,7 @@ repos: language: python files: ^setup\.py$|^INSTALL$|^CONTRIBUTING\.rst$ pass_filenames: false + additional_dependencies: ['rich>=12.4.4'] - id: check-extras-order name: Check order of extras in Dockerfile entry: ./scripts/ci/pre_commit/pre_commit_check_order_dockerfile_extras.py diff --git a/airflow/providers/common/sql/operators/sql.pyi b/airflow/providers/common/sql/operators/sql.pyi index 956d9ffa0854..72d77a0e6ae7 100644 --- a/airflow/providers/common/sql/operators/sql.pyi +++ b/airflow/providers/common/sql/operators/sql.pyi @@ -31,11 +31,11 @@ Definition of the public interface for airflow.providers.common.sql.operators.sql isort:skip_file """ -from _typeshed import Incomplete +from _typeshed import Incomplete # noqa: F401 from airflow.models import BaseOperator, SkipMixin from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.utils.context import Context -from typing import Any, Callable, Iterable, Mapping, Sequence, SupportsAbs +from typing import Any, Callable, Iterable, Mapping, Sequence, SupportsAbs, Union def _parse_boolean(val: str) -> str | bool: ... def parse_boolean(val: str) -> str | bool: ... diff --git a/dev/deprecations/generate_deprecated_dicts.py b/dev/deprecations/generate_deprecated_dicts.py deleted file mode 100644 index b705fee48ba7..000000000000 --- a/dev/deprecations/generate_deprecated_dicts.py +++ /dev/null @@ -1,217 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import ast -import os -from collections import defaultdict -from functools import lru_cache -from pathlib import Path -from typing import NamedTuple - -from jinja2 import BaseLoader, Environment -from rich.console import Console - -if __name__ not in ("__main__", "__mp_main__"): - raise SystemExit( - "This file is intended to be executed as an executable program. You cannot use it as a module_path." - f"To run this script, run the ./{__file__} command [FILE] ..." - ) - -AIRFLOW_SOURCES_ROOT = Path(__file__).parents[2].resolve() -CONTRIB_DIR = AIRFLOW_SOURCES_ROOT / "airflow" / "contrib" - - -@lru_cache(maxsize=None) -def black_mode(): - from black import Mode, parse_pyproject_toml, target_version_option_callback - - config = parse_pyproject_toml(os.path.join(AIRFLOW_SOURCES_ROOT, "pyproject.toml")) - - target_versions = set( - target_version_option_callback(None, None, tuple(config.get("target_version", ()))), - ) - - return Mode( - target_versions=target_versions, - line_length=config.get("line_length", Mode.line_length), - is_pyi=bool(config.get("is_pyi", Mode.is_pyi)), - string_normalization=not bool(config.get("skip_string_normalization", not Mode.string_normalization)), - experimental_string_processing=bool( - config.get("experimental_string_processing", Mode.experimental_string_processing) - ), - ) - - -def black_format(content) -> str: - from black import format_str - - return format_str(content, mode=black_mode()) - - -class Import(NamedTuple): - module_path: str - name: str - alias: str - - -class ImportedClass(NamedTuple): - module_path: str - name: str - - -def get_imports(path: Path): - root = ast.parse(path.read_text()) - imports: dict[str, ImportedClass] = {} - for node in ast.iter_child_nodes(root): - if isinstance(node, ast.Import): - module_array: list[str] = [] - elif isinstance(node, ast.ImportFrom) and node.module: - module_array = node.module.split(".") - elif isinstance(node, ast.ClassDef): - for base in node.bases: - res = imports.get(base.id) # type: ignore[attr-defined] - if res: - yield Import(module_path=res.module_path, name=res.name, alias=node.name) - continue - else: - continue - for n in node.names: # type: ignore[attr-defined] - imported_as = n.asname if n.asname else n.name - module_path = ".".join(module_array) - imports[imported_as] = ImportedClass(module_path=module_path, name=n.name) - yield Import(module_path, n.name, imported_as) - - -DEPRECATED_CLASSES_TEMPLATE = """ -__deprecated_classes = { -{%- for module_path, package_imports in package_imports.items() %} - '{{module_path}}': { -{%- for import_item in package_imports %} - '{{import_item.alias}}': '{{import_item.module_path}}.{{import_item.name}}', -{%- endfor %} - }, -{%- endfor %} -} -""" - -DEPRECATED_MODULES = [ - "airflow/hooks/base_hook.py", - "airflow/hooks/dbapi_hook.py", - "airflow/hooks/docker_hook.py", - "airflow/hooks/druid_hook.py", - "airflow/hooks/hdfs_hook.py", - "airflow/hooks/hive_hooks.py", - "airflow/hooks/http_hook.py", - "airflow/hooks/jdbc_hook.py", - "airflow/hooks/mssql_hook.py", - "airflow/hooks/mysql_hook.py", - "airflow/hooks/oracle_hook.py", - "airflow/hooks/pig_hook.py", - "airflow/hooks/postgres_hook.py", - "airflow/hooks/presto_hook.py", - "airflow/hooks/S3_hook.py", - "airflow/hooks/samba_hook.py", - "airflow/hooks/slack_hook.py", - "airflow/hooks/sqlite_hook.py", - "airflow/hooks/webhdfs_hook.py", - "airflow/hooks/zendesk_hook.py", - "airflow/operators/bash_operator.py", - "airflow/operators/branch_operator.py", - "airflow/operators/check_operator.py", - "airflow/operators/dagrun_operator.py", - "airflow/operators/docker_operator.py", - "airflow/operators/druid_check_operator.py", - "airflow/operators/dummy.py", - "airflow/operators/dummy_operator.py", - "airflow/operators/email_operator.py", - "airflow/operators/gcs_to_s3.py", - "airflow/operators/google_api_to_s3_transfer.py", - "airflow/operators/hive_operator.py", - "airflow/operators/hive_stats_operator.py", - "airflow/operators/hive_to_druid.py", - "airflow/operators/hive_to_mysql.py", - "airflow/operators/hive_to_samba_operator.py", - "airflow/operators/http_operator.py", - "airflow/operators/jdbc_operator.py", - "airflow/operators/latest_only_operator.py", - "airflow/operators/mssql_operator.py", - "airflow/operators/mssql_to_hive.py", - "airflow/operators/mysql_operator.py", - "airflow/operators/mysql_to_hive.py", - "airflow/operators/oracle_operator.py", - "airflow/operators/papermill_operator.py", - "airflow/operators/pig_operator.py", - "airflow/operators/postgres_operator.py", - "airflow/operators/presto_check_operator.py", - "airflow/operators/presto_to_mysql.py", - "airflow/operators/python_operator.py", - "airflow/operators/redshift_to_s3_operator.py", - "airflow/operators/s3_file_transform_operator.py", - "airflow/operators/s3_to_hive_operator.py", - "airflow/operators/s3_to_redshift_operator.py", - "airflow/operators/slack_operator.py", - "airflow/operators/sql.py", - "airflow/operators/sql_branch_operator.py", - "airflow/operators/sqlite_operator.py", - "airflow/operators/subdag_operator.py", - "airflow/sensors/base_sensor_operator.py", - "airflow/sensors/date_time_sensor.py", - "airflow/sensors/external_task_sensor.py", - "airflow/sensors/hdfs_sensor.py", - "airflow/sensors/hive_partition_sensor.py", - "airflow/sensors/http_sensor.py", - "airflow/sensors/metastore_partition_sensor.py", - "airflow/sensors/named_hive_partition_sensor.py", - "airflow/sensors/s3_key_sensor.py", - "airflow/sensors/sql.py", - "airflow/sensors/sql_sensor.py", - "airflow/sensors/time_delta_sensor.py", - "airflow/sensors/web_hdfs_sensor.py", - "airflow/utils/log/cloudwatch_task_handler.py", - "airflow/utils/log/es_task_handler.py", - "airflow/utils/log/gcs_task_handler.py", - "airflow/utils/log/s3_task_handler.py", - "airflow/utils/log/stackdriver_task_handler.py", - "airflow/utils/log/wasb_task_handler.py", -] - -CONTRIB_FILES = (AIRFLOW_SOURCES_ROOT / "airflow" / "contrib").rglob("*.py") - - -if __name__ == "__main__": - console = Console(color_system="standard", width=300) - all_deprecated_imports: dict[str, dict[str, list[Import]]] = defaultdict(lambda: defaultdict(list)) - # delete = True - delete = False - # for file in DEPRECATED_MODULES: - for file in CONTRIB_FILES: - file_path = AIRFLOW_SOURCES_ROOT / file - if not file_path.exists() or file.name == "__init__.py": - continue - original_module = os.fspath(file_path.parent.relative_to(AIRFLOW_SOURCES_ROOT)).replace(os.sep, ".") - for _import in get_imports(file_path): - module_name = file_path.name[: -len(".py")] - if _import.name not in ["warnings", "RemovedInAirflow3Warning"]: - all_deprecated_imports[original_module][module_name].append(_import) - if delete: - file_path.unlink() - - for module_path, package_imports in all_deprecated_imports.items(): - console.print(f"[yellow]Import dictionary for {module_path}:\n") - template = Environment(loader=BaseLoader()).from_string(DEPRECATED_CLASSES_TEMPLATE) - print(black_format(template.render(package_imports=dict(sorted(package_imports.items()))))) diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py index 4e6b6962f6c2..1d4ec06e7574 100755 --- a/dev/provider_packages/prepare_provider_packages.py +++ b/dev/provider_packages/prepare_provider_packages.py @@ -46,6 +46,7 @@ import jsonschema import rich_click as click import semver as semver +from black import Mode, TargetVersion, format_str, parse_pyproject_toml from packaging.version import Version from rich.console import Console from rich.syntax import Syntax @@ -1387,29 +1388,16 @@ def update_commits_rst( @lru_cache(maxsize=None) -def black_mode(): - from black import Mode, parse_pyproject_toml, target_version_option_callback - +def black_mode() -> Mode: config = parse_pyproject_toml(os.path.join(AIRFLOW_SOURCES_ROOT_PATH, "pyproject.toml")) - - target_versions = set( - target_version_option_callback(None, None, tuple(config.get("target_version", ()))), - ) - + target_versions = {TargetVersion[val.upper()] for val in config.get("target_version", ())} return Mode( target_versions=target_versions, line_length=config.get("line_length", Mode.line_length), - is_pyi=bool(config.get("is_pyi", Mode.is_pyi)), - string_normalization=not bool(config.get("skip_string_normalization", not Mode.string_normalization)), - experimental_string_processing=bool( - config.get("experimental_string_processing", Mode.experimental_string_processing) - ), ) def black_format(content) -> str: - from black import format_str - return format_str(content, mode=black_mode()) diff --git a/scripts/ci/pre_commit/common_precommit_black_utils.py b/scripts/ci/pre_commit/common_precommit_black_utils.py new file mode 100644 index 000000000000..c9d0f7712239 --- /dev/null +++ b/scripts/ci/pre_commit/common_precommit_black_utils.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import sys +from functools import lru_cache +from pathlib import Path + +from black import Mode, TargetVersion, format_str, parse_pyproject_toml + +sys.path.insert(0, str(Path(__file__).parent.resolve())) # make sure common_precommit_utils is imported + +from common_precommit_utils import AIRFLOW_BREEZE_SOURCES_PATH # isort: skip # noqa E402 + + +@lru_cache(maxsize=None) +def black_mode(is_pyi: bool = Mode.is_pyi) -> Mode: + config = parse_pyproject_toml(os.fspath(AIRFLOW_BREEZE_SOURCES_PATH / "pyproject.toml")) + target_versions = {TargetVersion[val.upper()] for val in config.get("target_version", ())} + + return Mode( + target_versions=target_versions, + line_length=config.get("line_length", Mode.line_length), + is_pyi=is_pyi, + ) + + +def black_format(content: str, is_pyi: bool = Mode.is_pyi) -> str: + return format_str(content, mode=black_mode(is_pyi=is_pyi)) diff --git a/scripts/ci/pre_commit/common_precommit_utils.py b/scripts/ci/pre_commit/common_precommit_utils.py index aef6bc3dcea3..e7aa186c0f65 100644 --- a/scripts/ci/pre_commit/common_precommit_utils.py +++ b/scripts/ci/pre_commit/common_precommit_utils.py @@ -21,7 +21,8 @@ import re from pathlib import Path -AIRFLOW_SOURCES_ROOT = Path(__file__).parents[3].resolve() +AIRFLOW_SOURCES_ROOT_PATH = Path(__file__).parents[3].resolve() +AIRFLOW_BREEZE_SOURCES_PATH = AIRFLOW_SOURCES_ROOT_PATH / "dev" / "breeze" def filter_out_providers_on_non_main_branch(files: list[str]) -> list[str]: diff --git a/scripts/ci/pre_commit/pre_commit_check_pre_commit_hooks.py b/scripts/ci/pre_commit/pre_commit_check_pre_commit_hooks.py index d6e32a093716..1c98d5cb41a2 100755 --- a/scripts/ci/pre_commit/pre_commit_check_pre_commit_hooks.py +++ b/scripts/ci/pre_commit/pre_commit_check_pre_commit_hooks.py @@ -26,22 +26,24 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.resolve())) # make sure common_precommit_utils is imported +from common_precommit_utils import ( # isort: skip # noqa: E402 + AIRFLOW_BREEZE_SOURCES_PATH, + AIRFLOW_SOURCES_ROOT_PATH, + insert_documentation, +) +from common_precommit_black_utils import black_format # isort: skip # noqa E402 from collections import defaultdict # noqa: E402 -from functools import lru_cache # noqa: E402 from typing import Any # noqa: E402 import yaml # noqa: E402 -from common_precommit_utils import insert_documentation # noqa: E402 from rich.console import Console # noqa: E402 from tabulate import tabulate # noqa: E402 console = Console(width=400, color_system="standard") -AIRFLOW_SOURCES_PATH = Path(__file__).parents[3].resolve() -AIRFLOW_BREEZE_SOURCES_PATH = AIRFLOW_SOURCES_PATH / "dev" / "breeze" PRE_COMMIT_IDS_PATH = AIRFLOW_BREEZE_SOURCES_PATH / "src" / "airflow_breeze" / "pre_commit_ids.py" -PRE_COMMIT_YAML_FILE = AIRFLOW_SOURCES_PATH / ".pre-commit-config.yaml" +PRE_COMMIT_YAML_FILE = AIRFLOW_SOURCES_ROOT_PATH / ".pre-commit-config.yaml" def get_errors_and_hooks(content: Any, max_length: int) -> tuple[list[str], dict[str, list[str]], list[str]]: @@ -75,6 +77,22 @@ def get_errors_and_hooks(content: Any, max_length: int) -> tuple[list[str], dict return errors, hooks, image_hooks +def prepare_pre_commit_ids_py_file(pre_commit_ids): + PRE_COMMIT_IDS_PATH.write_text( + black_format( + content=render_template( + searchpath=AIRFLOW_BREEZE_SOURCES_PATH / "src" / "airflow_breeze", + template_name="pre_commit_ids", + context={"PRE_COMMIT_IDS": pre_commit_ids}, + extension=".py", + autoescape=False, + keep_trailing_newline=True, + ), + is_pyi=False, + ) + ) + + def render_template( searchpath: Path, template_name: str, @@ -107,46 +125,6 @@ def render_template( return content -@lru_cache(maxsize=None) -def black_mode(): - from black import Mode, parse_pyproject_toml, target_version_option_callback - - config = parse_pyproject_toml(AIRFLOW_BREEZE_SOURCES_PATH / "pyproject.toml") - - target_versions = set( - target_version_option_callback(None, None, tuple(config.get("target_version", ()))), - ) - - return Mode( - target_versions=target_versions, - line_length=config.get("line_length", Mode.line_length), - is_pyi=config.get("is_pyi", False), - string_normalization=not config.get("skip_string_normalization", False), - preview=config.get("preview", False), - ) - - -def black_format(content) -> str: - from black import format_str - - return format_str(content, mode=black_mode()) - - -def prepare_pre_commit_ids_py_file(pre_commit_ids): - PRE_COMMIT_IDS_PATH.write_text( - black_format( - render_template( - searchpath=AIRFLOW_BREEZE_SOURCES_PATH / "src" / "airflow_breeze", - template_name="pre_commit_ids", - context={"PRE_COMMIT_IDS": pre_commit_ids}, - extension=".py", - autoescape=False, - keep_trailing_newline=True, - ) - ) - ) - - def update_static_checks_array(hooks: dict[str, list[str]], image_hooks: list[str]): rows = [] hook_ids = list(hooks.keys()) @@ -159,7 +137,7 @@ def update_static_checks_array(hooks: dict[str, list[str]], image_hooks: list[st rows.append((hook_id, formatted_hook_description, " * " if hook_id in image_hooks else " ")) formatted_table = "\n" + tabulate(rows, tablefmt="grid", headers=("ID", "Description", "Image")) + "\n\n" insert_documentation( - file_path=AIRFLOW_SOURCES_PATH / "STATIC_CODE_CHECKS.rst", + file_path=AIRFLOW_SOURCES_ROOT_PATH / "STATIC_CODE_CHECKS.rst", content=formatted_table.splitlines(keepends=True), header=" .. BEGIN AUTO-GENERATED STATIC CHECK LIST", footer=" .. END AUTO-GENERATED STATIC CHECK LIST", diff --git a/scripts/ci/pre_commit/pre_commit_compile_www_assets.py b/scripts/ci/pre_commit/pre_commit_compile_www_assets.py index 27975f6b8c8e..4733a1460b88 100755 --- a/scripts/ci/pre_commit/pre_commit_compile_www_assets.py +++ b/scripts/ci/pre_commit/pre_commit_compile_www_assets.py @@ -23,7 +23,8 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.resolve())) # make sure common_precommit_utils is imported -from common_precommit_utils import get_directory_hash # isort: skip # noqa +from common_precommit_utils import get_directory_hash # isort: skip # noqa E402 +from common_precommit_black_utils import black_format # isort: skip # noqa E402 AIRFLOW_SOURCES_PATH = Path(__file__).parents[3].resolve() WWW_HASH_FILE = AIRFLOW_SOURCES_PATH / ".build" / "www" / "hash.txt" diff --git a/scripts/ci/pre_commit/pre_commit_insert_extras.py b/scripts/ci/pre_commit/pre_commit_insert_extras.py index fac926f611ca..3e08bd674d71 100755 --- a/scripts/ci/pre_commit/pre_commit_insert_extras.py +++ b/scripts/ci/pre_commit/pre_commit_insert_extras.py @@ -27,8 +27,8 @@ sys.path.insert(0, str(AIRFLOW_SOURCES_DIR)) # make sure setup is imported from Airflow # flake8: noqa: F401 -from common_precommit_utils import insert_documentation # isort: skip -from setup import EXTRAS_DEPENDENCIES # isort:skip +from common_precommit_utils import insert_documentation # isort: skip # noqa E402 +from setup import EXTRAS_DEPENDENCIES # isort:skip # noqa sys.path.append(str(AIRFLOW_SOURCES_DIR)) diff --git a/scripts/ci/pre_commit/pre_commit_local_yml_mounts.py b/scripts/ci/pre_commit/pre_commit_local_yml_mounts.py index 0f9f954959e8..6efba5a6aa08 100755 --- a/scripts/ci/pre_commit/pre_commit_local_yml_mounts.py +++ b/scripts/ci/pre_commit/pre_commit_local_yml_mounts.py @@ -20,18 +20,20 @@ import sys from pathlib import Path -AIRFLOW_SOURCES_DIR = Path(__file__).parents[3].resolve() - sys.path.insert(0, str(Path(__file__).parent.resolve())) # make sure common_precommit_utils is imported -sys.path.insert(0, str(AIRFLOW_SOURCES_DIR)) # make sure setup is imported from Airflow + +from common_precommit_utils import AIRFLOW_SOURCES_ROOT_PATH # isort: skip # noqa E402 + +sys.path.insert(0, str(AIRFLOW_SOURCES_ROOT_PATH)) # make sure setup is imported from Airflow sys.path.insert( - 0, str(AIRFLOW_SOURCES_DIR / "dev" / "breeze" / "src") + 0, str(AIRFLOW_SOURCES_ROOT_PATH / "dev" / "breeze" / "src") ) # make sure setup is imported from Airflow # flake8: noqa: F401 +from airflow_breeze.utils.docker_command_utils import VOLUMES_FOR_SELECTED_MOUNTS # isort: skip # noqa E402 -from common_precommit_utils import insert_documentation # isort: skip +from common_precommit_utils import insert_documentation # isort: skip # noqa E402 -sys.path.append(str(AIRFLOW_SOURCES_DIR)) +sys.path.append(str(AIRFLOW_SOURCES_ROOT_PATH)) MOUNTS_HEADER = ( " # START automatically generated volumes from " @@ -43,9 +45,7 @@ ) if __name__ == "__main__": - from airflow_breeze.utils.docker_command_utils import VOLUMES_FOR_SELECTED_MOUNTS - - local_mount_file_path = AIRFLOW_SOURCES_DIR / "scripts" / "ci" / "docker-compose" / "local.yml" + local_mount_file_path = AIRFLOW_SOURCES_ROOT_PATH / "scripts" / "ci" / "docker-compose" / "local.yml" PREFIX = " " volumes = [] for (src, dest) in VOLUMES_FOR_SELECTED_MOUNTS: diff --git a/scripts/ci/pre_commit/pre_commit_mypy.py b/scripts/ci/pre_commit/pre_commit_mypy.py index 2b99ab68dce9..d0a8cc50c294 100755 --- a/scripts/ci/pre_commit/pre_commit_mypy.py +++ b/scripts/ci/pre_commit/pre_commit_mypy.py @@ -37,11 +37,14 @@ from common_precommit_utils import filter_out_providers_on_non_main_branch sys.path.insert(0, str(AIRFLOW_SOURCES / "dev" / "breeze" / "src")) - from airflow_breeze.global_constants import MOUNT_SELECTED - from airflow_breeze.utils.console import get_console - from airflow_breeze.utils.docker_command_utils import get_extra_docker_flags - from airflow_breeze.utils.path_utils import create_mypy_volume_if_needed - from airflow_breeze.utils.run_utils import get_ci_image_for_pre_commits, run_command + from airflow_breeze.global_constants import MOUNT_SELECTED # isort: skip + from airflow_breeze.utils.console import get_console # isort: skip + from airflow_breeze.utils.docker_command_utils import get_extra_docker_flags # isort: skip + from airflow_breeze.utils.path_utils import create_mypy_volume_if_needed # isort: skip + from airflow_breeze.utils.run_utils import ( + get_ci_image_for_pre_commits, + run_command, + ) files_to_test = filter_out_providers_on_non_main_branch(sys.argv[1:]) if files_to_test == ["--namespace-packages"]: diff --git a/scripts/ci/pre_commit/pre_commit_update_common_sql_api_stubs.py b/scripts/ci/pre_commit/pre_commit_update_common_sql_api_stubs.py index cce59257cee7..80a7f081da84 100755 --- a/scripts/ci/pre_commit/pre_commit_update_common_sql_api_stubs.py +++ b/scripts/ci/pre_commit/pre_commit_update_common_sql_api_stubs.py @@ -23,7 +23,6 @@ import subprocess import sys import textwrap -from functools import lru_cache from pathlib import Path import jinja2 @@ -35,43 +34,20 @@ f"To execute this script, run ./{__file__} [FILE] ..." ) -AIRFLOW_SOURCES_ROOT = Path(__file__).parents[3].resolve() -PROVIDERS_ROOT = AIRFLOW_SOURCES_ROOT / "airflow" / "providers" +sys.path.insert(0, str(Path(__file__).parent.resolve())) # make sure common_precommit_utils is imported + +from common_precommit_utils import AIRFLOW_SOURCES_ROOT_PATH # isort: skip # noqa E402 +from common_precommit_black_utils import black_format # isort: skip # noqa E402 + +PROVIDERS_ROOT = AIRFLOW_SOURCES_ROOT_PATH / "airflow" / "providers" COMMON_SQL_ROOT = PROVIDERS_ROOT / "common" / "sql" -OUT_DIR = AIRFLOW_SOURCES_ROOT / "out" +OUT_DIR = AIRFLOW_SOURCES_ROOT_PATH / "out" COMMON_SQL_PACKAGE_PREFIX = "airflow.providers.common.sql." console = Console(width=400, color_system="standard") -@lru_cache(maxsize=None) -def black_mode(): - from black import Mode, parse_pyproject_toml, target_version_option_callback - - config = parse_pyproject_toml(os.path.join(AIRFLOW_SOURCES_ROOT, "pyproject.toml")) - - target_versions = set( - target_version_option_callback(None, None, tuple(config.get("target_version", ()))), - ) - - return Mode( - target_versions=target_versions, - line_length=config.get("line_length", Mode.line_length), - is_pyi=bool(config.get("is_pyi", Mode.is_pyi)), - string_normalization=not bool(config.get("skip_string_normalization", not Mode.string_normalization)), - experimental_string_processing=bool( - config.get("experimental_string_processing", Mode.experimental_string_processing) - ), - ) - - -def black_format(content) -> str: - from black import format_str - - return format_str(content, mode=black_mode()) - - class ConsoleDiff(difflib.Differ): def _dump(self, tag, x, lo, hi): """Generate comparison results for a same-tagged range.""" @@ -116,12 +92,31 @@ def summarize_changes(results: list[str]) -> tuple[int, int]: return removals, additions -def post_process_historically_publicised_methods(stub_file_path: Path, line: str, new_lines: list[str]): +def post_process_line(stub_file_path: Path, line: str, new_lines: list[str]) -> None: + """ + Post process line of the stub file. + + Stubgen is not a perfect tool for generating stub files, but it is good starting point. We have to + modify the stub files to make them more useful for us (as the approach of stubgen developers is not + very open to add more options or features that are not very generic). + + The patching that we currently perform: + * we add noqa to Incomplete imports from _typeshed (IntelliJ _typeshed does not like it) + * we add historically published methods + * fixes missing Union imports (see https://github.com/python/mypy/issues/12929) + + + :param stub_file_path: path of the file we process + :param line: line to post-process + :param new_lines: new_lines - this is where we add post-processed lines + """ if stub_file_path.relative_to(OUT_DIR) == Path("common") / "sql" / "operators" / "sql.pyi": - if line.strip().startswith("parse_boolean: Incomplete"): + stripped_line = line.strip() + if stripped_line.startswith("parse_boolean: Incomplete"): # Handle Special case - historically we allow _parse_boolean to be part of the public API, # and we handle it via parse_boolean = _parse_boolean which produces Incomplete entry in the - # stub - we replace the Incomplete method with both API methods that should be allowed + # stub - we replace the Incomplete method with both API methods that should be allowed. + # We also strip empty lines to let black figure out where to add them. # # We can remove those when we determine it is not risky for the community - when we determine # That most of the historically released providers have a way to easily update them, and they @@ -130,10 +125,16 @@ def post_process_historically_publicised_methods(stub_file_path: Path, line: str # provider to 2.*, the old providers (mainly google providers) might still use them. new_lines.append("def _parse_boolean(val: str) -> str | bool: ...") new_lines.append("def parse_boolean(val: str) -> str | bool: ...") - elif line.strip() == "class SQLExecuteQueryOperator(BaseSQLOperator):": + elif stripped_line == "class SQLExecuteQueryOperator(BaseSQLOperator):": # The "_raise_exception" method is really part of the public API and should not be removed new_lines.append(line) new_lines.append(" def _raise_exception(self, exception_string: str) -> Incomplete: ...") + elif stripped_line.startswith("from _typeshed import Incomplete"): + new_lines.append(line + " # noqa: F401") + elif stripped_line.startswith("from typing import") and "Union" not in line: + new_lines.append(line + ", Union") + elif stripped_line == "": + pass else: new_lines.append(line) else: @@ -141,31 +142,41 @@ def post_process_historically_publicised_methods(stub_file_path: Path, line: str def post_process_generated_stub_file( - module_name: str, stub_file_path: Path, lines: list[str], patch_historical_methods=False + module_name: str, stub_file_path: Path, lines: list[str], patch_generated_file=False ): """ - Post process the stub file - add the preamble and optionally patch historical methods. - Adding preamble always, makes sure that we can update the preamble and have it automatically updated - in generated files even if no API specification changes. + Post process the stub file: + * adding (or replacing) preamble (makes sure we can replace preamble with new one in old files) + * optionally patch the generated file :param module_name: name of the module of the file :param stub_file_path: path of the stub fil :param lines: lines that were read from the file (with stripped comments) - :param patch_historical_methods: whether we should patch historical methods + :param patch_generated_file: whether we should patch generated file :return: resulting lines of the file after post-processing """ template = jinja2.Template(PREAMBLE) new_lines = template.render(module_name=module_name).splitlines() for line in lines: - if patch_historical_methods: - post_process_historically_publicised_methods(stub_file_path, line, new_lines) + if patch_generated_file: + post_process_line(stub_file_path, line, new_lines) else: new_lines.append(line) return new_lines +def write_pyi_file(pyi_file_path: Path, content: str) -> None: + """ + Writes the content to the file. + + :param pyi_file_path: path of the file to write + :param content: content to write (will be properly formatted) + """ + pyi_file_path.write_text(black_format(content, is_pyi=True), encoding="utf-8") + + def read_pyi_file_content( - module_name: str, pyi_file_path: Path, patch_historical_methods=False + module_name: str, pyi_file_path: Path, patch_generated_files=False ) -> list[str] | None: """ Reads stub file content with post-processing and optionally patching historical methods. The comments @@ -176,7 +187,7 @@ def read_pyi_file_content( :param module_name: name of the module in question :param pyi_file_path: the path of the file to read - :param patch_historical_methods: whether the historical methods should be patched + :param patch_generated_files: whether the historical methods should be patched :return: list of lines of post-processed content or None if the file should be deleted. """ lines_no_comments = [ @@ -196,7 +207,7 @@ def read_pyi_file_content( console.print(f"[yellow]Skip {pyi_file_path} as it is an empty stub for __init__.py file") return None return post_process_generated_stub_file( - module_name, pyi_file_path, lines, patch_historical_methods=patch_historical_methods + module_name, pyi_file_path, lines, patch_generated_file=patch_generated_files ) @@ -209,63 +220,67 @@ def compare_stub_files(generated_stub_path: Path, force_override: bool) -> tuple """ _removals, _additions = 0, 0 rel_path = generated_stub_path.relative_to(OUT_DIR) - target_path = PROVIDERS_ROOT / rel_path + stub_file_target_path = PROVIDERS_ROOT / rel_path module_name = "airflow.providers." + os.fspath(rel_path.with_suffix("")).replace(os.path.sep, ".") generated_pyi_content = read_pyi_file_content( - module_name, generated_stub_path, patch_historical_methods=True + module_name, generated_stub_path, patch_generated_files=True ) if generated_pyi_content is None: os.unlink(generated_stub_path) - if target_path.exists(): + if stub_file_target_path.exists(): console.print( - f"[red]The {target_path} file is missing in generated files: but we are deleting it because" - " it is an empty __init__.pyi file." + f"[red]The {stub_file_target_path} file is missing in generated files: " + "but we are deleting it because it is an empty __init__.pyi file." ) if _force_override: console.print( - f"[yellow]The file {target_path} has been removed as changes are force-overridden" + f"[yellow]The file {stub_file_target_path} has been removed " + "as changes are force-overridden" ) - os.unlink(target_path) + os.unlink(stub_file_target_path) return 1, 0 else: console.print( f"[blue]The {generated_stub_path} file is an empty __init__.pyi file, we just ignore it." ) return 0, 0 - if not target_path.exists(): - console.print(f"[yellow]New file {target_path} has been missing. Treated as addition.") - target_path.write_text("\n".join(generated_pyi_content), encoding="utf-8") + if not stub_file_target_path.exists(): + console.print(f"[yellow]New file {stub_file_target_path} has been missing. Treated as addition.") + write_pyi_file(stub_file_target_path, "\n".join(generated_pyi_content) + "\n") return 0, 1 - target_pyi_content = read_pyi_file_content(module_name, target_path, patch_historical_methods=False) + target_pyi_content = read_pyi_file_content( + module_name, stub_file_target_path, patch_generated_files=False + ) if target_pyi_content is None: target_pyi_content = [] if generated_pyi_content != target_pyi_content: - console.print(f"[yellow]The {target_path} has changed.") + console.print(f"[yellow]The {stub_file_target_path} has changed.") diff = ConsoleDiff() comparison_results = list(diff.compare(target_pyi_content, generated_pyi_content)) _removals, _additions = summarize_changes(comparison_results) console.print( - f"[bright_blue]Summary of the generated changes in common.sql stub API file {target_path}:[/]\n" + "[bright_blue]Summary of the generated changes in common.sql " + f"stub API file {stub_file_target_path}:[/]\n" ) console.print(textwrap.indent("\n".join(comparison_results), " " * 4)) if _removals == 0 or force_override: - console.print(f"[yellow]The {target_path} has been updated\n") - console.print(f"[yellow]* additions: {additions}[/]") - console.print(f"[yellow]* removals: {removals}[/]") - target_path.write_text("\n".join(generated_pyi_content), encoding="utf-8") + console.print(f"[yellow]The {stub_file_target_path} has been updated\n") + console.print(f"[yellow]* additions: {total_additions}[/]") + console.print(f"[yellow]* removals: {total_removals}[/]") + write_pyi_file(stub_file_target_path, "\n".join(generated_pyi_content) + "\n") console.print( - f"\n[bright_blue]The {target_path} file has been updated automatically.[/]\n" + f"\n[bright_blue]The {stub_file_target_path} file has been updated automatically.[/]\n" "\n[yellow]Make sure to commit the changes.[/]" ) else: if force_override: - target_path.write_text("\n".join(generated_pyi_content), encoding="utf-8") + write_pyi_file(stub_file_target_path, "\n".join(generated_pyi_content) + "\n") console.print( - f"\n[bright_blue]The {target_path} file has been updated automatically.[/]\n" + f"\n[bright_blue]The {stub_file_target_path} file has been updated automatically.[/]\n" "\n[yellow]Make sure to commit the changes.[/]" ) else: - console.print(f"[green]OK. The {target_path} has not changed.") + console.print(f"[green]OK. The {stub_file_target_path} has not changed.") return _removals, _additions @@ -318,18 +333,20 @@ def compare_stub_files(generated_stub_path: Path, force_override: bool) -> tuple shutil.rmtree(OUT_DIR, ignore_errors=True) subprocess.run( - ["stubgen", *[os.fspath(path) for path in COMMON_SQL_ROOT.rglob("**/*.py")]], cwd=AIRFLOW_SOURCES_ROOT + ["stubgen", *[os.fspath(path) for path in COMMON_SQL_ROOT.rglob("**/*.py")]], + cwd=AIRFLOW_SOURCES_ROOT_PATH, ) - removals, additions = 0, 0 + total_removals, total_additions = 0, 0 _force_override = os.environ.get("UPDATE_COMMON_SQL_API") == "1" if _force_override: console.print("\n[yellow]The committed stub APIs are force-updated\n") + # reformat the generated stubs first for stub_path in OUT_DIR.rglob("**/*.pyi"): - stub_path.write_text(black_format(stub_path.read_text(encoding="utf-8")), encoding="utf-8") + write_pyi_file(stub_path, stub_path.read_text(encoding="utf-8")) for stub_path in OUT_DIR.rglob("**/*.pyi"): _new_removals, _new_additions = compare_stub_files(stub_path, force_override=_force_override) - removals += _new_removals - additions += _new_additions + total_removals += _new_removals + total_additions += _new_additions for target_path in COMMON_SQL_ROOT.rglob("*.pyi"): generated_path = OUT_DIR / target_path.relative_to(PROVIDERS_ROOT) if not generated_path.exists(): @@ -337,21 +354,21 @@ def compare_stub_files(generated_stub_path: Path, force_override: bool) -> tuple f"[red]The {target_path} file is missing in generated files:. " f"This is treated as breaking change." ) - removals += 1 + total_removals += 1 if _force_override: console.print( f"[yellow]The file {target_path} has been removed as changes are force-overridden" ) os.unlink(target_path) - if not removals and not additions: + if not total_removals and not total_additions: console.print("\n[green]All OK. The common.sql APIs did not change[/]") sys.exit(0) - if removals: + if total_removals: if not _force_override: console.print( f"\n[red]ERROR! As you can see above, there are changes in the common.sql stub API files.\n" - f"[red]* additions: {additions}[/]\n" - f"[red]* removals: {removals}[/]\n" + f"[red]* additions: {total_additions}[/]\n" + f"[red]* removals: {total_removals}[/]\n" ) console.print( "[bright_blue]Make sure to review the removals and changes for back-compatibility.[/]\n" @@ -366,15 +383,15 @@ def compare_stub_files(generated_stub_path: Path, force_override: bool) -> tuple else: console.print( f"\n[bright_blue]As you can see above, there are changes in the common.sql API:[/]\n\n" - f"[bright_blue]* additions: {additions}[/]\n" - f"[bright_blue]* removals: {removals}[/]\n" + f"[bright_blue]* additions: {total_additions}[/]\n" + f"[bright_blue]* removals: {total_removals}[/]\n" ) console.print("[yellow]You've set UPDATE_COMMON_SQL_API to 1 to update the API.[/]\n\n") console.print("[yellow]So the files were updated automatically.") else: console.print( f"\n[yellow]There are only additions in the API extracted from the common.sql code[/]\n\n" - f"[bright_blue]* additions: {additions}[/]\n" + f"[bright_blue]* additions: {total_additions}[/]\n" ) console.print("[bright_blue]So the files were updated automatically.") sys.exit(1)