From 5e80b79e037bb2d5e6cc2c10f6e51614bdf3d310 Mon Sep 17 00:00:00 2001 From: Merel Theisen <49397448+merelcht@users.noreply.github.com> Date: Thu, 11 Jan 2024 11:51:07 +0000 Subject: [PATCH] Revisit `mypy` setup (#3485) * Move mypy setup from pre-commit * Fix no implicit optionals * Fix Optional[...] must have exactly one type argument * Fix mypy type incompatibility errors * Address comment about order of type | None --------- Signed-off-by: Merel Theisen --- .pre-commit-config.yaml | 44 -------------------- Makefile | 1 + docs/source/conf.py | 1 + docs/source/hooks/examples.md | 2 +- kedro/config/abstract_config.py | 4 +- kedro/config/omegaconf_config.py | 32 ++++++++------- kedro/framework/cli/micropkg.py | 32 ++++++++------- kedro/framework/cli/starters.py | 10 ++--- kedro/framework/cli/utils.py | 4 +- kedro/framework/context/context.py | 4 +- kedro/framework/session/session.py | 34 ++++++++-------- kedro/io/cached_dataset.py | 6 +-- kedro/io/core.py | 24 ++++++----- kedro/io/data_catalog.py | 23 ++++++----- kedro/io/lambda_dataset.py | 6 +-- kedro/io/memory_dataset.py | 9 +++-- kedro/io/shared_memory_dataset.py | 2 +- kedro/ipython/__init__.py | 12 +++--- kedro/pipeline/modular_pipeline.py | 6 +-- kedro/pipeline/node.py | 22 +++++----- kedro/pipeline/pipeline.py | 60 ++++++++++++++-------------- kedro/runner/parallel_runner.py | 10 ++--- kedro/runner/runner.py | 20 +++++----- kedro/runner/sequential_runner.py | 2 +- kedro/runner/thread_runner.py | 4 +- pyproject.toml | 11 +++++ tests/framework/cli/test_starters.py | 4 +- 27 files changed, 186 insertions(+), 203 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 96895deafb..51a14935f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,27 +32,6 @@ repos: - id: requirements-txt-fixer # Sorts entries in requirements.txt exclude: "^kedro/templates/|^features/steps/test_starter/" - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.961 - hooks: - - id: mypy - args: [--allow-redefinition, --ignore-missing-imports] - exclude: | - (?x)( - ^kedro/templates/| - ^docs/| - ^features/steps/test_starter/ - ) - additional_dependencies: - - types-cachetools - - types-filelock - - types-PyYAML - - types-redis - - types-requests - - types-setuptools - - types-toml - - attrs - - repo: https://github.com/asottile/blacken-docs rev: v1.12.1 hooks: @@ -69,29 +48,8 @@ repos: pass_filenames: false entry: lint-imports - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.961 - hooks: - - id: mypy - args: [--allow-redefinition, --ignore-missing-imports] - exclude: | - (?x)( - ^kedro/templates/| - ^docs/| - ^features/steps/test_starter/ - ) - additional_dependencies: - - types-cachetools - - types-filelock - - types-PyYAML - - types-redis - - types-requests - - types-setuptools - - types-toml - - attrs - repo: local hooks: - # Slow lintint - id: secret_scan name: "Secret scan" language: system @@ -104,5 +62,3 @@ repos: types: [file, python] exclude: ^kedro/templates/|^tests/|^features/steps/test_starter entry: bandit -ll - -# Manual only diff --git a/Makefile b/Makefile index fbd3df38c0..6618d20657 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,7 @@ clean: lint: pre-commit run -a --hook-stage manual $(hook) + mypy kedro test: pytest --numprocesses 4 --dist loadfile diff --git a/docs/source/conf.py b/docs/source/conf.py index ccb19d8faa..b16ed019a9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -157,6 +157,7 @@ "D.get(k,d), also set D[k]=d if k not in D", "D[k] if k in D, else d. d defaults to None.", "None. Update D from mapping/iterable E and F.", + "Patterns", ), "py:data": ( "typing.Any", diff --git a/docs/source/hooks/examples.md b/docs/source/hooks/examples.md index 54e584d89c..3c8220effe 100644 --- a/docs/source/hooks/examples.md +++ b/docs/source/hooks/examples.md @@ -368,7 +368,7 @@ class NodeInputReplacementHook: @hook_impl def before_node_run( self, node: Node, catalog: DataCatalog - ) -> Optional[Dict[str, Any]]: + ) -> dict[str, Any] | None: """Replace `first_input` for `my_node`""" if node.name == "my_node": # return the string filepath to the `first_input` dataset diff --git a/kedro/config/abstract_config.py b/kedro/config/abstract_config.py index 776ec6c836..ae9be039dd 100644 --- a/kedro/config/abstract_config.py +++ b/kedro/config/abstract_config.py @@ -17,8 +17,8 @@ class AbstractConfigLoader(UserDict): def __init__( self, conf_source: str, - env: str = None, - runtime_params: dict[str, Any] = None, + env: str | None = None, + runtime_params: dict[str, Any] | None = None, **kwargs, ): super().__init__() diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index f8ec549b38..a6fd1b24de 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -6,11 +6,12 @@ import io import logging import mimetypes +import typing from pathlib import Path from typing import Any, Callable, Iterable import fsspec -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from omegaconf.errors import InterpolationResolutionError, UnsupportedInterpolationType from omegaconf.resolvers import oc from yaml.parser import ParserError @@ -76,14 +77,14 @@ class OmegaConfigLoader(AbstractConfigLoader): def __init__( # noqa: PLR0913 self, conf_source: str, - env: str = None, - runtime_params: dict[str, Any] = None, + env: str | None = None, + runtime_params: dict[str, Any] | None = None, *, - config_patterns: dict[str, list[str]] = None, - base_env: str = None, - default_run_env: str = None, - custom_resolvers: dict[str, Callable] = None, - merge_strategy: dict[str, str] = None, + config_patterns: dict[str, list[str]] | None = None, + base_env: str | None = None, + default_run_env: str | None = None, + custom_resolvers: dict[str, Callable] | None = None, + merge_strategy: dict[str, str] | None = None, ): """Instantiates a ``OmegaConfigLoader``. @@ -251,6 +252,7 @@ def __repr__(self): # pragma: no cover f"config_patterns={self.config_patterns})" ) + @typing.no_type_check def load_and_merge_dir_config( # noqa: PLR0913 self, conf_path: str, @@ -431,7 +433,7 @@ def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]): raise ValueError(f"{dup_str}") @staticmethod - def _resolve_environment_variables(config: dict[str, Any]) -> None: + def _resolve_environment_variables(config: DictConfig) -> None: """Use the ``oc.env`` resolver to read environment variables and replace them in-place, clearing the resolver after the operation is complete if it was not registered beforehand. @@ -466,16 +468,16 @@ def _soft_merge(config, env_config): # Soft merge the two env dirs. The chosen env will override base if keys clash. return OmegaConf.to_container(OmegaConf.merge(config, env_config)) - def _is_hidden(self, path: str): + def _is_hidden(self, path_str: str): """Check if path contains any hidden directory or is a hidden file""" - path = Path(path) + path = Path(path_str) conf_path = Path(self.conf_source).resolve().as_posix() if self._protocol == "file": path = path.resolve() - path = path.as_posix() - if path.startswith(conf_path): - path = path.replace(conf_path, "") - parts = path.split(self._fs.sep) # filesystem specific separator + posix_path = path.as_posix() + if posix_path.startswith(conf_path): + posix_path = posix_path.replace(conf_path, "") + parts = posix_path.split(self._fs.sep) # filesystem specific separator HIDDEN = "." # Check if any component (folder or file) starts with a dot (.) return any(part.startswith(HIDDEN) for part in parts) diff --git a/kedro/framework/cli/micropkg.py b/kedro/framework/cli/micropkg.py index 6002f8f9d7..32c00b2323 100644 --- a/kedro/framework/cli/micropkg.py +++ b/kedro/framework/cli/micropkg.py @@ -193,10 +193,10 @@ def pull_package( # noqa: PLR0913 def _pull_package( # noqa: PLR0913 package_path: str, metadata: ProjectMetadata, - env: str = None, - alias: str = None, - destination: str = None, - fs_args: str = None, + env: str | None = None, + alias: str | None = None, + destination: str | None = None, + fs_args: str | None = None, ): with tempfile.TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir).resolve() @@ -473,7 +473,7 @@ def _refactor_code_for_unpacking( # noqa: PLR0913 """ def _move_package_with_conflicting_name( - target: Path, original_name: str, desired_name: str = None + target: Path, original_name: str, desired_name: str | None = None ) -> Path: _rename_package(project, original_name, "tmp_name") full_path = _create_nested_package(project, target) @@ -524,9 +524,9 @@ def _install_files( # noqa: PLR0913, too-many-locals project_metadata: ProjectMetadata, package_name: str, source_path: Path, - env: str = None, - alias: str = None, - destination: str = None, + env: str | None = None, + alias: str | None = None, + destination: str | None = None, ): env = env or "base" @@ -605,9 +605,9 @@ def _get_default_version(metadata: ProjectMetadata, micropkg_module_path: str) - def _package_micropkg( micropkg_module_path: str, metadata: ProjectMetadata, - alias: str = None, - destination: str = None, - env: str = None, + alias: str | None = None, + destination: str | None = None, + env: str | None = None, ) -> Path: micropkg_name = micropkg_module_path.split(".")[-1] package_dir = metadata.source_dir / metadata.package_name @@ -635,12 +635,14 @@ def _package_micropkg( # Check that micropkg directory exists and not empty _validate_dir(package_source) - destination = Path(destination) if destination else metadata.project_path / "dist" + package_destination = ( + Path(destination) if destination else metadata.project_path / "dist" + ) version = _get_default_version(metadata, micropkg_module_path) _generate_sdist_file( micropkg_name=micropkg_name, - destination=destination.resolve(), + destination=package_destination.resolve(), source_paths=source_paths, version=version, metadata=metadata, @@ -650,7 +652,7 @@ def _package_micropkg( _clean_pycache(package_dir) _clean_pycache(metadata.project_path) - return destination + return package_destination def _validate_dir(path: Path) -> None: @@ -826,7 +828,7 @@ def _generate_sdist_file( # noqa: PLR0913,too-many-locals source_paths: tuple[Path, Path, list[tuple[Path, str]]], version: str, metadata: ProjectMetadata, - alias: str = None, + alias: str | None = None, ) -> None: package_name = alias or micropkg_name package_source, tests_source, conf_source = source_paths diff --git a/kedro/framework/cli/starters.py b/kedro/framework/cli/starters.py index 20e0916775..2d1d31aad3 100644 --- a/kedro/framework/cli/starters.py +++ b/kedro/framework/cli/starters.py @@ -105,7 +105,7 @@ class KedroStarterSpec: # noqa: too-few-public-methods for starter_spec in _OFFICIAL_STARTER_SPECS: starter_spec.origin = "kedro" -_OFFICIAL_STARTER_SPECS = {spec.alias: spec for spec in _OFFICIAL_STARTER_SPECS} +_OFFICIAL_STARTER_SPECS_DICT = {spec.alias: spec for spec in _OFFICIAL_STARTER_SPECS} TOOLS_SHORTNAME_TO_NUMBER = { "lint": "1", @@ -308,7 +308,7 @@ def new( # noqa: PLR0913 @starter.command("list") -def list_starters(): +def list_starters() -> None: """List all official project starters available.""" starters_dict = _get_starters_dict() @@ -360,7 +360,7 @@ def _get_cookiecutter_dir( f" Specified tag {checkout}. The following tags are available: " + ", ".join(_get_available_tags(template_path)) ) - official_starters = sorted(_OFFICIAL_STARTER_SPECS) + official_starters = sorted(_OFFICIAL_STARTER_SPECS_DICT) raise KedroCliError( f"{error_message}. The aliases for the official Kedro starters are: \n" f"{yaml.safe_dump(official_starters, sort_keys=False)}" @@ -438,7 +438,7 @@ def _get_starters_dict() -> dict[str, KedroStarterSpec]: ), } """ - starter_specs = _OFFICIAL_STARTER_SPECS + starter_specs = _OFFICIAL_STARTER_SPECS_DICT for starter_entry_point in _get_entry_points(name="starters"): origin = starter_entry_point.module.split(".")[0] @@ -777,7 +777,7 @@ def _validate_selection(tools: list[str]): sys.exit(1) -def _parse_tools_input(tools_str: None | str): +def _parse_tools_input(tools_str: str | None): """Parse the tools input string. Args: diff --git a/kedro/framework/cli/utils.py b/kedro/framework/cli/utils.py index 6de7497af7..a23b9bf3cd 100644 --- a/kedro/framework/cli/utils.py +++ b/kedro/framework/cli/utils.py @@ -159,9 +159,9 @@ def _merge_same_name_collections(groups: Sequence[click.MultiCommand]): named_groups: defaultdict[str, list[click.MultiCommand]] = defaultdict(list) helps: defaultdict[str, list] = defaultdict(list) for group in groups: - named_groups[group.name].append(group) + named_groups[group.name].append(group) # type: ignore if group.help: - helps[group.name].append(group.help) + helps[group.name].append(group.help) # type: ignore return [ click.CommandCollection( diff --git a/kedro/framework/context/context.py b/kedro/framework/context/context.py index 919895bc34..c0991b2bea 100644 --- a/kedro/framework/context/context.py +++ b/kedro/framework/context/context.py @@ -209,8 +209,8 @@ def params(self) -> dict[str, Any]: def _get_catalog( self, - save_version: str = None, - load_versions: dict[str, str] = None, + save_version: str | None = None, + load_versions: dict[str, str] | None = None, ) -> DataCatalog: """A hook for changing the creation of a DataCatalog instance. diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index da908c89a1..24def9e27f 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -30,17 +30,17 @@ def _describe_git(project_path: Path) -> dict[str, dict[str, Any]]: - project_path = str(project_path) + path = str(project_path) try: res = subprocess.check_output( ["git", "rev-parse", "--short", "HEAD"], - cwd=project_path, + cwd=path, stderr=subprocess.STDOUT, ) git_data: dict[str, Any] = {"commit_sha": res.decode().strip()} git_status_res = subprocess.check_output( ["git", "status", "--short"], - cwd=project_path, + cwd=path, stderr=subprocess.STDOUT, ) git_data["dirty"] = bool(git_status_res.decode().strip()) @@ -48,7 +48,7 @@ def _describe_git(project_path: Path) -> dict[str, dict[str, Any]]: # `subprocess.check_output()` raises `NotADirectoryError` on Windows except Exception: # noqa: broad-except logger = logging.getLogger(__name__) - logger.debug("Unable to git describe %s", project_path) + logger.debug("Unable to git describe %s", path) logger.debug(traceback.format_exc()) return {} @@ -100,7 +100,7 @@ class KedroSession: def __init__( # noqa: PLR0913 self, session_id: str, - package_name: str = None, + package_name: str | None = None, project_path: Path | str | None = None, save_on_close: bool = False, conf_source: str | None = None, @@ -126,8 +126,8 @@ def create( # noqa: PLR0913 cls, project_path: Path | str | None = None, save_on_close: bool = True, - env: str = None, - extra_params: dict[str, Any] = None, + env: str | None = None, + extra_params: dict[str, Any] | None = None, conf_source: str | None = None, ) -> KedroSession: """Create a new instance of ``KedroSession`` with the session data. @@ -272,16 +272,16 @@ def __exit__(self, exc_type, exc_value, tb_): def run( # noqa: PLR0913,too-many-locals self, - pipeline_name: str = None, - tags: Iterable[str] = None, - runner: AbstractRunner = None, - node_names: Iterable[str] = None, - from_nodes: Iterable[str] = None, - to_nodes: Iterable[str] = None, - from_inputs: Iterable[str] = None, - to_outputs: Iterable[str] = None, - load_versions: dict[str, str] = None, - namespace: str = None, + pipeline_name: str | None = None, + tags: Iterable[str] | None = None, + runner: AbstractRunner | None = None, + node_names: Iterable[str] | None = None, + from_nodes: Iterable[str] | None = None, + to_nodes: Iterable[str] | None = None, + from_inputs: Iterable[str] | None = None, + to_outputs: Iterable[str] | None = None, + load_versions: dict[str, str] | None = None, + namespace: str | None = None, ) -> dict[str, Any]: """Runs the pipeline with a specified runner. diff --git a/kedro/io/cached_dataset.py b/kedro/io/cached_dataset.py index e0935c8100..24dcc7d96e 100644 --- a/kedro/io/cached_dataset.py +++ b/kedro/io/cached_dataset.py @@ -37,9 +37,9 @@ class as shown above. def __init__( self, dataset: AbstractDataset | dict, - version: Version = None, - copy_mode: str = None, - metadata: dict[str, Any] = None, + version: Version | None = None, + copy_mode: str | None = None, + metadata: dict[str, Any] | None = None, ): """Creates a new instance of ``CachedDataset`` pointing to the provided Python object. diff --git a/kedro/io/core.py b/kedro/io/core.py index 9071bf4730..7335c0b28a 100644 --- a/kedro/io/core.py +++ b/kedro/io/core.py @@ -118,8 +118,8 @@ def from_config( cls: type, name: str, config: dict[str, Any], - load_version: str = None, - save_version: str = None, + load_version: str | None = None, + save_version: str | None = None, ) -> AbstractDataset: """Create a data set instance using the configuration provided. @@ -351,7 +351,9 @@ class Version(namedtuple("Version", ["load", "save"])): def parse_dataset_definition( - config: dict[str, Any], load_version: str = None, save_version: str = None + config: dict[str, Any], + load_version: str | None = None, + save_version: str | None = None, ) -> tuple[type[AbstractDataset], dict[str, Any]]: """Parse and instantiate a dataset class using the configuration provided. @@ -451,8 +453,8 @@ def _load_obj(class_path: str) -> Any | None: return class_obj -def _local_exists(filepath: str) -> bool: # SKIP_IF_NO_SPARK - filepath = Path(filepath) +def _local_exists(local_filepath: str) -> bool: # SKIP_IF_NO_SPARK + filepath = Path(local_filepath) return filepath.exists() or any(par.is_file() for par in filepath.parents) @@ -506,8 +508,8 @@ def __init__( self, filepath: PurePosixPath, version: Version | None, - exists_function: Callable[[str], bool] = None, - glob_function: Callable[[str], list[str]] = None, + exists_function: Callable[[str], bool] | None = None, + glob_function: Callable[[str], list[str]] | None = None, ): """Creates a new instance of ``AbstractVersionedDataset``. @@ -700,7 +702,7 @@ def _parse_filepath(filepath: str) -> dict[str, str]: def get_protocol_and_path( - filepath: str | os.PathLike, version: Version = None + filepath: str | os.PathLike, version: Version | None = None ) -> tuple[str, str]: """Parses filepath on protocol and path. @@ -732,17 +734,17 @@ def get_protocol_and_path( return protocol, path -def get_filepath_str(path: PurePath, protocol: str) -> str: +def get_filepath_str(raw_path: PurePath, protocol: str) -> str: """Returns filepath. Returns full filepath (with protocol) if protocol is HTTP(s). Args: - path: filepath without protocol. + raw_path: filepath without protocol. protocol: protocol. Returns: Filepath string. """ - path = path.as_posix() + path = raw_path.as_posix() if protocol in HTTP_PROTOCOLS: path = "".join((protocol, PROTOCOL_DELIMITER, path)) return path diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index 57e911be05..c36938a9ed 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -141,11 +141,11 @@ class DataCatalog: def __init__( # noqa: PLR0913 self, - datasets: dict[str, AbstractDataset] = None, - feed_dict: dict[str, Any] = None, - dataset_patterns: Patterns = None, - load_versions: dict[str, str] = None, - save_version: str = None, + datasets: dict[str, AbstractDataset] | None = None, + feed_dict: dict[str, Any] | None = None, + dataset_patterns: Patterns | None = None, + load_versions: dict[str, str] | None = None, + save_version: str | None = None, ) -> None: """``DataCatalog`` stores instances of ``AbstractDataset`` implementations to provide ``load`` and ``save`` capabilities from @@ -204,9 +204,9 @@ def _logger(self): def from_config( cls, catalog: dict[str, dict[str, Any]] | None, - credentials: dict[str, dict[str, Any]] = None, - load_versions: dict[str, str] = None, - save_version: str = None, + credentials: dict[str, dict[str, Any]] | None = None, + load_versions: dict[str, str] | None = None, + save_version: str | None = None, ) -> DataCatalog: """Create a ``DataCatalog`` instance from configuration. This is a factory method used to provide developers with a way to instantiate @@ -366,7 +366,10 @@ def _specificity(pattern: str) -> int: return len(result) def _get_dataset( - self, dataset_name: str, version: Version = None, suggest: bool = True + self, + dataset_name: str, + version: Version | None = None, + suggest: bool = True, ) -> AbstractDataset: matched_pattern = self._match_pattern(self._dataset_patterns, dataset_name) if dataset_name not in self._datasets and matched_pattern: @@ -448,7 +451,7 @@ def _resolve_config( ) from exc return config - def load(self, name: str, version: str = None) -> Any: + def load(self, name: str, version: str | None = None) -> Any: """Loads a registered data set. Args: diff --git a/kedro/io/lambda_dataset.py b/kedro/io/lambda_dataset.py index f3012aa1ad..54b4f531c0 100644 --- a/kedro/io/lambda_dataset.py +++ b/kedro/io/lambda_dataset.py @@ -80,9 +80,9 @@ def __init__( # noqa: PLR0913 self, load: Callable[[], Any] | None, save: Callable[[Any], None] | None, - exists: Callable[[], bool] = None, - release: Callable[[], None] = None, - metadata: dict[str, Any] = None, + exists: Callable[[], bool] | None = None, + release: Callable[[], None] | None = None, + metadata: dict[str, Any] | None = None, ): """Creates a new instance of ``LambdaDataset`` with references to the required input/output data set methods. diff --git a/kedro/io/memory_dataset.py b/kedro/io/memory_dataset.py index 5b1075fdb0..7c696bdfc9 100644 --- a/kedro/io/memory_dataset.py +++ b/kedro/io/memory_dataset.py @@ -35,7 +35,10 @@ class MemoryDataset(AbstractDataset): """ def __init__( - self, data: Any = _EMPTY, copy_mode: str = None, metadata: dict[str, Any] = None + self, + data: Any = _EMPTY, + copy_mode: str | None = None, + metadata: dict[str, Any] | None = None, ): """Creates a new instance of ``MemoryDataset`` pointing to the provided Python object. @@ -93,11 +96,11 @@ def _infer_copy_mode(data: Any) -> str: try: import pandas as pd except ImportError: # pragma: no cover - pd = None # pragma: no cover + pd = None # type: ignore # pragma: no cover try: import numpy as np except ImportError: # pragma: no cover - np = None # pragma: no cover + np = None # type: ignore # pragma: no cover if pd and isinstance(data, pd.DataFrame) or np and isinstance(data, np.ndarray): copy_mode = "copy" diff --git a/kedro/io/shared_memory_dataset.py b/kedro/io/shared_memory_dataset.py index b27bb8bc99..2fa952ff65 100644 --- a/kedro/io/shared_memory_dataset.py +++ b/kedro/io/shared_memory_dataset.py @@ -10,7 +10,7 @@ class SharedMemoryDataset(AbstractDataset): """``SharedMemoryDataset`` is a wrapper class for a shared MemoryDataset in SyncManager.""" - def __init__(self, manager: SyncManager = None): + def __init__(self, manager: SyncManager | None = None): """Creates a new instance of ``SharedMemoryDataset``, and creates shared MemoryDataset attribute. diff --git a/kedro/ipython/__init__.py b/kedro/ipython/__init__.py index 661371cf9b..f814817991 100644 --- a/kedro/ipython/__init__.py +++ b/kedro/ipython/__init__.py @@ -67,7 +67,9 @@ def load_ipython_extension(ipython): ) @argument("--conf-source", type=str, default=None, help=CONF_SOURCE_HELP) def magic_reload_kedro( - line: str, local_ns: dict[str, Any] = None, conf_source: str = None + line: str, + local_ns: dict[str, Any] | None = None, + conf_source: str | None = None, ): """ The `%reload_kedro` IPython line magic. @@ -79,11 +81,11 @@ def magic_reload_kedro( def reload_kedro( - path: str = None, - env: str = None, - extra_params: dict[str, Any] = None, + path: str | None = None, + env: str | None = None, + extra_params: dict[str, Any] | None = None, local_namespace: dict[str, Any] | None = None, - conf_source: str = None, + conf_source: str | None = None, ) -> None: # pragma: no cover """Function that underlies the %reload_kedro Line magic. This should not be imported or run directly but instead invoked through %reload_kedro.""" diff --git a/kedro/pipeline/modular_pipeline.py b/kedro/pipeline/modular_pipeline.py index 13d5a0e6b0..1d583daca0 100644 --- a/kedro/pipeline/modular_pipeline.py +++ b/kedro/pipeline/modular_pipeline.py @@ -157,7 +157,7 @@ def pipeline( # noqa: PLR0913 outputs: str | set[str] | dict[str, str] | None = None, parameters: str | set[str] | dict[str, str] | None = None, tags: str | Iterable[str] | None = None, - namespace: str = None, + namespace: str | None = None, ) -> Pipeline: r"""Create a ``Pipeline`` from a collection of nodes and/or ``Pipeline``\s. @@ -259,8 +259,8 @@ def _rename(name: str): return name def _process_dataset_names( - datasets: None | str | list[str] | dict[str, str] - ) -> None | str | list[str] | dict[str, str]: + datasets: str | list[str] | dict[str, str] | None + ) -> str | list[str] | dict[str, str] | None: if datasets is None: return None if isinstance(datasets, str): diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index 440f402bb4..231a05399a 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -22,13 +22,13 @@ class Node: def __init__( # noqa: PLR0913 self, func: Callable, - inputs: None | str | list[str] | dict[str, str], - outputs: None | str | list[str] | dict[str, str], + inputs: str | list[str] | dict[str, str] | None, + outputs: str | list[str] | dict[str, str] | None, *, - name: str = None, + name: str | None = None, tags: str | Iterable[str] | None = None, confirms: str | list[str] | None = None, - namespace: str = None, + namespace: str | None = None, ): """Create a node in the pipeline by providing a function to be called along with variable names for inputs and/or outputs. @@ -306,7 +306,7 @@ def confirms(self) -> list[str]: """ return _to_list(self._confirms) - def run(self, inputs: dict[str, Any] = None) -> dict[str, Any]: + def run(self, inputs: dict[str, Any] | None = None) -> dict[str, Any]: """Run this node using the provided inputs and return its results in a dictionary. @@ -509,7 +509,7 @@ def _validate_inputs_dif_than_outputs(self): ) @staticmethod - def _process_inputs_for_bind(inputs: None | str | list[str] | dict[str, str]): + def _process_inputs_for_bind(inputs: str | list[str] | dict[str, str] | None): # Safeguard that we do not mutate list inputs inputs = copy.copy(inputs) args: list[str] = [] @@ -532,13 +532,13 @@ def _node_error_message(msg) -> str: def node( # noqa: PLR0913 func: Callable, - inputs: None | str | list[str] | dict[str, str], - outputs: None | str | list[str] | dict[str, str], + inputs: str | list[str] | dict[str, str] | None, + outputs: str | list[str] | dict[str, str] | None, *, - name: str = None, + name: str | None = None, tags: str | Iterable[str] | None = None, confirms: str | list[str] | None = None, - namespace: str = None, + namespace: str | None = None, ) -> Node: """Create a node in the pipeline by providing a function to be called along with variable names for inputs and/or outputs. @@ -614,7 +614,7 @@ def _dict_inputs_to_list(func: Callable[[Any], Any], inputs: dict[str, str]): return [*sig.args, *sig.kwargs.values()] -def _to_list(element: None | str | Iterable[str] | dict[str, str]) -> list[str]: +def _to_list(element: str | Iterable[str] | dict[str, str] | None) -> list[str]: """Make a list out of node inputs/outputs. Returns: diff --git a/kedro/pipeline/pipeline.py b/kedro/pipeline/pipeline.py index 840e446f9d..c9794fb3b4 100644 --- a/kedro/pipeline/pipeline.py +++ b/kedro/pipeline/pipeline.py @@ -9,7 +9,7 @@ import json from collections import Counter, defaultdict from itertools import chain -from typing import Iterable +from typing import Any, Iterable from toposort import CircularDependencyError as ToposortCircleError from toposort import toposort @@ -134,36 +134,36 @@ def __init__( "'nodes' argument of 'Pipeline' is None. It must be an " "iterable of nodes and/or pipelines instead." ) - nodes = list(nodes) # in case it's a generator - _validate_duplicate_nodes(nodes) + nodes_list = list(nodes) # in case it's a generator + _validate_duplicate_nodes(nodes_list) - nodes = list( + nodes_chain = list( chain.from_iterable( - [[n] if isinstance(n, Node) else n.nodes for n in nodes] + [[n] if isinstance(n, Node) else n.nodes for n in nodes_list] ) ) - _validate_transcoded_inputs_outputs(nodes) + _validate_transcoded_inputs_outputs(nodes_chain) _tags = set(_to_list(tags)) - nodes = [n.tag(_tags) for n in nodes] + tagged_nodes = [n.tag(_tags) for n in nodes_chain] - self._nodes_by_name = {node.name: node for node in nodes} - _validate_unique_outputs(nodes) - _validate_unique_confirms(nodes) + self._nodes_by_name = {node.name: node for node in tagged_nodes} + _validate_unique_outputs(tagged_nodes) + _validate_unique_confirms(tagged_nodes) # input -> nodes with input self._nodes_by_input: dict[str, set[Node]] = defaultdict(set) - for node in nodes: + for node in tagged_nodes: for input_ in node.inputs: self._nodes_by_input[_strip_transcoding(input_)].add(node) # output -> node with output self._nodes_by_output: dict[str, Node] = {} - for node in nodes: + for node in tagged_nodes: for output in node.outputs: self._nodes_by_output[_strip_transcoding(output)] = node - self._nodes = nodes + self._nodes = tagged_nodes self._topo_sorted_nodes = _topologically_sorted(self.node_dependencies) def __repr__(self): # pragma: no cover @@ -675,19 +675,19 @@ def only_nodes_with_tags(self, *tags: str) -> Pipeline: nodes of the current one such that only nodes containing *any* of the tags provided are being copied. """ - tags = set(tags) - nodes = [node for node in self.nodes if tags & node.tags] + unique_tags = set(tags) + nodes = [node for node in self.nodes if unique_tags & node.tags] return Pipeline(nodes) def filter( # noqa: PLR0913 self, - tags: Iterable[str] = None, - from_nodes: Iterable[str] = None, - to_nodes: Iterable[str] = None, - node_names: Iterable[str] = None, - from_inputs: Iterable[str] = None, - to_outputs: Iterable[str] = None, - node_namespace: str = None, + tags: Iterable[str] | None = None, + from_nodes: Iterable[str] | None = None, + to_nodes: Iterable[str] | None = None, + node_names: Iterable[str] | None = None, + from_inputs: Iterable[str] | None = None, + to_outputs: Iterable[str] | None = None, + node_namespace: str | None = None, ) -> Pipeline: """Creates a new ``Pipeline`` object with the nodes that meet all of the specified filtering conditions. @@ -733,7 +733,7 @@ def filter( # noqa: PLR0913 """ # Use [node_namespace] so only_nodes_with_namespace can follow the same # *filter_args pattern as the other filtering methods, which all take iterables. - node_namespace = [node_namespace] if node_namespace else None + node_namespace_iterable = [node_namespace] if node_namespace else None filter_methods = { self.only_nodes_with_tags: tags, @@ -742,7 +742,7 @@ def filter( # noqa: PLR0913 self.only_nodes: node_names, self.from_inputs: from_inputs, self.to_outputs: to_outputs, - self.only_nodes_with_namespace: node_namespace, + self.only_nodes_with_namespace: node_namespace_iterable, } subset_pipelines = { @@ -805,7 +805,7 @@ def _validate_duplicate_nodes(nodes_or_pipes: Iterable[Node | Pipeline]): seen_nodes: set[str] = set() duplicates: dict[Pipeline | None, set[str]] = defaultdict(set) - def _check_node(node_: Node, pipeline_: Pipeline = None): + def _check_node(node_: Node, pipeline_: Pipeline | None = None): name = node_.name if name in seen_nodes: duplicates[pipeline_].add(name) @@ -837,8 +837,8 @@ def _check_node(node_: Node, pipeline_: Pipeline = None): def _validate_unique_outputs(nodes: list[Node]) -> None: - outputs = chain.from_iterable(node.outputs for node in nodes) - outputs = map(_strip_transcoding, outputs) + outputs_chain = chain.from_iterable(node.outputs for node in nodes) + outputs = map(_strip_transcoding, outputs_chain) duplicates = [key for key, value in Counter(outputs).items() if value > 1] if duplicates: raise OutputNotUniqueError( @@ -848,8 +848,8 @@ def _validate_unique_outputs(nodes: list[Node]) -> None: def _validate_unique_confirms(nodes: list[Node]) -> None: - confirms = chain.from_iterable(node.confirms for node in nodes) - confirms = map(_strip_transcoding, confirms) + confirms_chain = chain.from_iterable(node.confirms for node in nodes) + confirms = map(_strip_transcoding, confirms_chain) duplicates = [key for key, value in Counter(confirms).items() if value > 1] if duplicates: raise ConfirmNotUniqueError( @@ -898,7 +898,7 @@ def _topologically_sorted(node_dependencies) -> list[list[Node]]: executed on the second step, etc. """ - def _circle_error_message(error_data: dict[str, str]) -> str: + def _circle_error_message(error_data: dict[Any, set]) -> str: """Error messages provided by the toposort library will refer to indices that are used as an intermediate step. This method can be used to replace that message with diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 9d4081f121..2a2a5c3b03 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -57,9 +57,9 @@ def _run_node_synchronization( # noqa: PLR0913 node: Node, catalog: DataCatalog, is_async: bool = False, - session_id: str = None, - package_name: str = None, - logging_config: dict[str, Any] = None, + session_id: str | None = None, + package_name: str | None = None, + logging_config: dict[str, Any] | None = None, ) -> Node: """Run a single `Node` with inputs from and outputs to the `catalog`. @@ -99,7 +99,7 @@ class ParallelRunner(AbstractRunner): def __init__( self, - max_workers: int = None, + max_workers: int | None = None, is_async: bool = False, extra_dataset_patterns: dict[str, dict[str, Any]] | None = None, ): @@ -237,7 +237,7 @@ def _run( # noqa: too-many-locals,useless-suppression pipeline: Pipeline, catalog: DataCatalog, hook_manager: PluginManager, - session_id: str = None, + session_id: str | None = None, ) -> None: """The abstract interface for running pipelines. diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 0aa9f07e13..1afa8dde92 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -56,8 +56,8 @@ def run( self, pipeline: Pipeline, catalog: DataCatalog, - hook_manager: PluginManager = None, - session_id: str = None, + hook_manager: PluginManager | None = None, + session_id: str | None = None, ) -> dict[str, Any]: """Run the ``Pipeline`` using the datasets provided by ``catalog`` and save results back to the same objects. @@ -78,7 +78,7 @@ def run( """ - hook_manager = hook_manager or _NullPluginManager() + hook_or_null_manager = hook_manager or _NullPluginManager() catalog = catalog.shallow_copy() # Check which datasets used in the pipeline are in the catalog or match @@ -114,7 +114,7 @@ def run( self._logger.info( "Asynchronous mode is enabled for loading and saving data" ) - self._run(pipeline, catalog, hook_manager, session_id) + self._run(pipeline, catalog, hook_or_null_manager, session_id) # type: ignore self._logger.info("Pipeline execution completed successfully.") @@ -163,7 +163,7 @@ def _run( pipeline: Pipeline, catalog: DataCatalog, hook_manager: PluginManager, - session_id: str = None, + session_id: str | None = None, ) -> None: """The abstract interface for running pipelines, assuming that the inputs have already been checked and normalized by run(). @@ -298,7 +298,7 @@ def run_node( catalog: DataCatalog, hook_manager: PluginManager, is_async: bool = False, - session_id: str = None, + session_id: str | None = None, ) -> Node: """Run a single `Node` with inputs from and outputs to the `catalog`. @@ -342,7 +342,7 @@ def _collect_inputs_from_hook( # noqa: PLR0913 inputs: dict[str, Any], is_async: bool, hook_manager: PluginManager, - session_id: str = None, + session_id: str | None = None, ) -> dict[str, Any]: inputs = inputs.copy() # shallow copy to prevent in-place modification by the hook hook_response = hook_manager.hook.before_node_run( @@ -375,7 +375,7 @@ def _call_node_run( # noqa: PLR0913 inputs: dict[str, Any], is_async: bool, hook_manager: PluginManager, - session_id: str = None, + session_id: str | None = None, ) -> dict[str, Any]: try: outputs = node.run(inputs) @@ -404,7 +404,7 @@ def _run_node_sequential( node: Node, catalog: DataCatalog, hook_manager: PluginManager, - session_id: str = None, + session_id: str | None = None, ) -> Node: inputs = {} @@ -451,7 +451,7 @@ def _run_node_async( node: Node, catalog: DataCatalog, hook_manager: PluginManager, - session_id: str = None, + session_id: str | None = None, ) -> Node: def _synchronous_dataset_load(dataset_name: str): """Minimal wrapper to ensure Hooks are run synchronously diff --git a/kedro/runner/sequential_runner.py b/kedro/runner/sequential_runner.py index 3270df5122..5e14592dd8 100644 --- a/kedro/runner/sequential_runner.py +++ b/kedro/runner/sequential_runner.py @@ -47,7 +47,7 @@ def _run( pipeline: Pipeline, catalog: DataCatalog, hook_manager: PluginManager, - session_id: str = None, + session_id: str | None = None, ) -> None: """The method implementing sequential pipeline running. diff --git a/kedro/runner/thread_runner.py b/kedro/runner/thread_runner.py index 2b2cb01ecb..577863f8b3 100644 --- a/kedro/runner/thread_runner.py +++ b/kedro/runner/thread_runner.py @@ -26,7 +26,7 @@ class ThreadRunner(AbstractRunner): def __init__( self, - max_workers: int = None, + max_workers: int | None = None, is_async: bool = False, extra_dataset_patterns: dict[str, dict[str, Any]] | None = None, ): @@ -86,7 +86,7 @@ def _run( # noqa: too-many-locals,useless-suppression pipeline: Pipeline, catalog: DataCatalog, hook_manager: PluginManager, - session_id: str = None, + session_id: str | None = None, ) -> None: """The abstract interface for running pipelines. diff --git a/pyproject.toml b/pyproject.toml index e2707f37db..b217f6a8e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ test = [ "kedro-datasets", "moto==1.3.7; python_version < '3.10'", "moto==4.1.12; python_version >= '3.10'", + "mypy~=1.0", "pandas~=2.0", "pip-tools>=6.5", "pre-commit>=2.9.2, <4.0", # The hook `mypy` requires pre-commit version 2.9.2. @@ -81,6 +82,12 @@ test = [ "s3fs>=2021.4, <2024.1", # Upper bound set arbitrarily, to be reassessed in early 2024 "semver", "trufflehog~=2.1", + # mypy related dependencies + "pandas-stubs", + "types-PyYAML", + "types-cachetools", + "types-toml", + "types-toposort" ] docs = [ "docutils<0.18", @@ -233,3 +240,7 @@ known-first-party = ["kedro"] [tool.ruff.per-file-ignores] "{tests,docs}/*" = ["PLR2004","PLR0913"] "{tests,docs,tools,static,features,docs}/*" = ["T201"] # Check print statement for kedro/ only + +[tool.mypy] +ignore_missing_imports = true +exclude = ['^kedro/templates/', '^docs/', '^features/steps/test_starter/'] diff --git a/tests/framework/cli/test_starters.py b/tests/framework/cli/test_starters.py index f64b7ba245..b9cff3fac5 100644 --- a/tests/framework/cli/test_starters.py +++ b/tests/framework/cli/test_starters.py @@ -13,7 +13,7 @@ from kedro import __version__ as version from kedro.framework.cli.starters import ( - _OFFICIAL_STARTER_SPECS, + _OFFICIAL_STARTER_SPECS_DICT, TEMPLATE_PATH, KedroStarterSpec, _convert_tool_names_to_numbers, @@ -251,7 +251,7 @@ def test_starter_list(fake_kedro_cli): result = CliRunner().invoke(fake_kedro_cli, ["starter", "list"]) assert result.exit_code == 0, result.output - for alias in _OFFICIAL_STARTER_SPECS: + for alias in _OFFICIAL_STARTER_SPECS_DICT: assert alias in result.output