diff --git a/Makefile b/Makefile index 6618d20657..fe0022ef23 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ clean: lint: pre-commit run -a --hook-stage manual $(hook) - mypy kedro + mypy kedro --strict --allow-any-generics test: pytest --numprocesses 4 --dist loadfile diff --git a/features/steps/sh_run.py b/features/steps/sh_run.py index 8f54245b54..c925c02797 100644 --- a/features/steps/sh_run.py +++ b/features/steps/sh_run.py @@ -82,7 +82,7 @@ def __init__(self, cmd: list[str], **kwargs) -> None: **kwargs: keyword arguments such as env and cwd """ - super().__init__( # type: ignore + super().__init__( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs ) diff --git a/kedro/config/abstract_config.py b/kedro/config/abstract_config.py index ae9be039dd..60b75fcba8 100644 --- a/kedro/config/abstract_config.py +++ b/kedro/config/abstract_config.py @@ -19,7 +19,7 @@ def __init__( conf_source: str, env: str | None = None, runtime_params: dict[str, Any] | None = None, - **kwargs, + **kwargs: Any, ): super().__init__() self.conf_source = conf_source diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index a6fd1b24de..229d6a4060 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -145,13 +145,13 @@ def __init__( # noqa: PLR0913 except MissingConfigException: self._globals = {} - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: if key == "globals": # Update the cached value at self._globals since it is used by the globals resolver self._globals = value super().__setitem__(key, value) - def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912 + def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912 """Get configuration files by key, load and merge them, and return them in the form of a config dictionary. @@ -175,7 +175,7 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912 self._register_runtime_params_resolver() if key in self: - return super().__getitem__(key) + return super().__getitem__(key) # type: ignore[no-any-return] if key not in self.config_patterns: raise KeyError( @@ -196,7 +196,7 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912 else: base_path = str(Path(self._fs.ls("", detail=False)[-1]) / self.base_env) try: - base_config = self.load_and_merge_dir_config( + base_config = self.load_and_merge_dir_config( # type: ignore[no-untyped-call] base_path, patterns, key, processed_files, read_environment_variables ) except UnsupportedInterpolationType as exc: @@ -216,7 +216,7 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912 else: env_path = str(Path(self._fs.ls("", detail=False)[-1]) / run_env) try: - env_config = self.load_and_merge_dir_config( + env_config = self.load_and_merge_dir_config( # type: ignore[no-untyped-call] env_path, patterns, key, processed_files, read_environment_variables ) except UnsupportedInterpolationType as exc: @@ -244,9 +244,9 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912 f" the glob pattern(s): {[*self.config_patterns[key]]}" ) - return resulting_config + return resulting_config # type: ignore[no-any-return] - def __repr__(self): # pragma: no cover + def __repr__(self) -> str: # pragma: no cover return ( f"OmegaConfigLoader(conf_source={self.conf_source}, env={self.env}, " f"config_patterns={self.config_patterns})" @@ -312,8 +312,8 @@ def load_and_merge_dir_config( # noqa: PLR0913 self._resolve_environment_variables(config) config_per_file[config_filepath] = config except (ParserError, ScannerError) as exc: - line = exc.problem_mark.line # type: ignore - cursor = exc.problem_mark.column # type: ignore + line = exc.problem_mark.line + cursor = exc.problem_mark.column raise ParserError( f"Invalid YAML or JSON file {Path(conf_path, config_filepath.name).as_posix()}," f" unable to read line {line}, position {cursor}." @@ -342,7 +342,7 @@ def load_and_merge_dir_config( # noqa: PLR0913 if not k.startswith("_") } - def _is_valid_config_path(self, path): + def _is_valid_config_path(self, path: Path) -> bool: """Check if given path is a file path and file type is yaml or json.""" posix_path = path.as_posix() return self._fs.isfile(str(posix_path)) and path.suffix in [ @@ -351,7 +351,7 @@ def _is_valid_config_path(self, path): ".json", ] - def _register_globals_resolver(self): + def _register_globals_resolver(self) -> None: """Register the globals resolver""" OmegaConf.register_new_resolver( "globals", @@ -359,14 +359,14 @@ def _register_globals_resolver(self): replace=True, ) - def _register_runtime_params_resolver(self): + def _register_runtime_params_resolver(self) -> None: OmegaConf.register_new_resolver( "runtime_params", self._get_runtime_value, replace=True, ) - def _get_globals_value(self, variable, default_value=_NO_VALUE): + def _get_globals_value(self, variable: str, default_value: Any = _NO_VALUE) -> Any: """Return the globals values to the resolver""" if variable.startswith("_"): raise InterpolationResolutionError( @@ -383,7 +383,7 @@ def _get_globals_value(self, variable, default_value=_NO_VALUE): f"Globals key '{variable}' not found and no default value provided." ) - def _get_runtime_value(self, variable, default_value=_NO_VALUE): + def _get_runtime_value(self, variable: str, default_value: Any = _NO_VALUE) -> Any: """Return the runtime params values to the resolver""" runtime_oc = OmegaConf.create(self.runtime_params) interpolated_value = OmegaConf.select( @@ -397,7 +397,7 @@ def _get_runtime_value(self, variable, default_value=_NO_VALUE): ) @staticmethod - def _register_new_resolvers(resolvers: dict[str, Callable]): + def _register_new_resolvers(resolvers: dict[str, Callable]) -> None: """Register custom resolvers""" for name, resolver in resolvers.items(): if not OmegaConf.has_resolver(name): @@ -406,7 +406,7 @@ def _register_new_resolvers(resolvers: dict[str, Callable]): OmegaConf.register_new_resolver(name=name, resolver=resolver) @staticmethod - def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]): + def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]) -> None: duplicates = [] filepaths = list(seen_files_to_keys.keys()) @@ -449,7 +449,9 @@ def _resolve_environment_variables(config: DictConfig) -> None: OmegaConf.resolve(config) @staticmethod - def _destructive_merge(config, env_config, env_path): + def _destructive_merge( + config: dict[str, Any], env_config: dict[str, Any], env_path: str + ) -> dict[str, Any]: # Destructively merge the two env dirs. The chosen env will override base. common_keys = config.keys() & env_config.keys() if common_keys: @@ -464,11 +466,11 @@ def _destructive_merge(config, env_config, env_path): return config @staticmethod - def _soft_merge(config, env_config): + def _soft_merge(config: dict[str, Any], env_config: dict[str, Any]) -> Any: # Soft merge the two env dirs. The chosen env will override base if keys clash. return OmegaConf.to_container(OmegaConf.merge(config, env_config)) - def _is_hidden(self, path_str: str): + def _is_hidden(self, path_str: str) -> bool: """Check if path contains any hidden directory or is a hidden file""" path = Path(path_str) conf_path = Path(self.conf_source).resolve().as_posix() diff --git a/kedro/framework/cli/catalog.py b/kedro/framework/cli/catalog.py index 933e7b9616..3208ba0de1 100644 --- a/kedro/framework/cli/catalog.py +++ b/kedro/framework/cli/catalog.py @@ -1,7 +1,11 @@ """A collection of CLI commands for working with Kedro catalog.""" +from __future__ import annotations + import copy from collections import defaultdict from itertools import chain +from pathlib import Path +from typing import Any import click import yaml @@ -11,21 +15,22 @@ from kedro.framework.project import pipelines, settings from kedro.framework.session import KedroSession from kedro.framework.startup import ProjectMetadata +from kedro.io import AbstractDataset -def _create_session(package_name: str, **kwargs): +def _create_session(package_name: str, **kwargs: Any) -> KedroSession: kwargs.setdefault("save_on_close", False) return KedroSession.create(**kwargs) # noqa: missing-function-docstring @click.group(name="Kedro") -def catalog_cli(): # pragma: no cover +def catalog_cli() -> None: # pragma: no cover pass @catalog_cli.group() -def catalog(): +def catalog() -> None: """Commands for working with catalog.""" @@ -42,7 +47,7 @@ def catalog(): callback=split_string, ) @click.pass_obj -def list_datasets(metadata: ProjectMetadata, pipeline, env): +def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None: """Show datasets per type.""" title = "Datasets in '{}' pipeline" not_mentioned = "Datasets not mentioned in pipeline" @@ -111,11 +116,13 @@ def list_datasets(metadata: ProjectMetadata, pipeline, env): secho(yaml.dump(result)) -def _map_type_to_datasets(datasets, datasets_meta): +def _map_type_to_datasets( + datasets: set[str], datasets_meta: dict[str, AbstractDataset] +) -> dict: """Build dictionary with a dataset type as a key and list of datasets of the specific type as a value. """ - mapping = defaultdict(list) + mapping = defaultdict(list) # type: ignore[var-annotated] for dataset in datasets: is_param = dataset.startswith("params:") or dataset == "parameters" if not is_param: @@ -136,7 +143,7 @@ def _map_type_to_datasets(datasets, datasets_meta): help="Name of a pipeline.", ) @click.pass_obj -def create_catalog(metadata: ProjectMetadata, pipeline_name, env): +def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> None: """Create Data Catalog YAML configuration with missing datasets. Add ``MemoryDataset`` datasets to Data Catalog YAML configuration @@ -185,7 +192,7 @@ def create_catalog(metadata: ProjectMetadata, pipeline_name, env): click.echo("All datasets are already configured.") -def _add_missing_datasets_to_catalog(missing_ds, catalog_path): +def _add_missing_datasets_to_catalog(missing_ds: list[str], catalog_path: Path) -> None: if catalog_path.is_file(): catalog_config = yaml.safe_load(catalog_path.read_text()) or {} else: @@ -204,7 +211,7 @@ def _add_missing_datasets_to_catalog(missing_ds, catalog_path): @catalog.command("rank") @env_option @click.pass_obj -def rank_catalog_factories(metadata: ProjectMetadata, env): +def rank_catalog_factories(metadata: ProjectMetadata, env: str) -> None: """List all dataset factories in the catalog, ranked by priority by which they are matched.""" session = _create_session(metadata.package_name, env=env) context = session.load_context() @@ -219,7 +226,7 @@ def rank_catalog_factories(metadata: ProjectMetadata, env): @catalog.command("resolve") @env_option @click.pass_obj -def resolve_patterns(metadata: ProjectMetadata, env): +def resolve_patterns(metadata: ProjectMetadata, env: str) -> None: """Resolve catalog factories against pipeline datasets. Note that this command is runner agnostic and thus won't take into account any default dataset creation defined in the runner.""" @@ -268,5 +275,5 @@ def resolve_patterns(metadata: ProjectMetadata, env): secho(yaml.dump(explicit_datasets)) -def _trim_filepath(project_path: str, file_path: str): +def _trim_filepath(project_path: str, file_path: str) -> str: return file_path.replace(project_path, "", 1) diff --git a/kedro/framework/cli/cli.py b/kedro/framework/cli/cli.py index 304fb6b4bc..c65fe70556 100644 --- a/kedro/framework/cli/cli.py +++ b/kedro/framework/cli/cli.py @@ -2,11 +2,13 @@ This module implements commands available from the kedro CLI. """ +from __future__ import annotations + import importlib import sys from collections import defaultdict from pathlib import Path -from typing import Sequence +from typing import Any, Sequence import click @@ -42,7 +44,7 @@ @click.group(context_settings=CONTEXT_SETTINGS, name="Kedro") @click.version_option(version, "--version", "-V", help="Show version and exit") -def cli(): # pragma: no cover +def cli() -> None: # pragma: no cover """Kedro is a CLI for creating and using Kedro projects. For more information, type ``kedro info``. @@ -51,7 +53,7 @@ def cli(): # pragma: no cover @cli.command() -def info(): +def info() -> None: """Get more information about kedro.""" click.secho(LOGO, fg="green") click.echo( @@ -104,12 +106,12 @@ def __init__(self, project_path: Path): def main( self, - args=None, - prog_name=None, - complete_var=None, - standalone_mode=True, - **extra, - ): + args: Any | None = None, + prog_name: Any | None = None, + complete_var: Any | None = None, + standalone_mode: bool = True, + **extra: Any, + ) -> Any: if self._metadata: extra.update(obj=self._metadata) @@ -182,13 +184,13 @@ def project_groups(self) -> Sequence[click.MultiCommand]: raise KedroCliError( f"Cannot load commands from {self._metadata.package_name}.cli" ) - user_defined = project_cli.cli # type: ignore + user_defined = project_cli.cli # return built-in commands, plugin commands and user defined commands # (overriding happens as follows built-in < plugins < cli.py) return [*built_in, *plugins, user_defined] -def main(): # pragma: no cover +def main() -> None: # pragma: no cover """Main entry point. Look for a ``cli.py``, and, if found, add its commands to `kedro`'s before invoking the CLI. """ diff --git a/kedro/framework/cli/hooks/manager.py b/kedro/framework/cli/hooks/manager.py index 714f8f780e..8b07a7b746 100644 --- a/kedro/framework/cli/hooks/manager.py +++ b/kedro/framework/cli/hooks/manager.py @@ -14,7 +14,7 @@ _CLI_PLUGIN_HOOKS = "kedro.cli_hooks" -def get_cli_hook_manager(): +def get_cli_hook_manager() -> PluginManager: """Create or return the global _hook_manager singleton instance.""" global _cli_hook_manager # noqa: PLW0603 if _cli_hook_manager is None: diff --git a/kedro/framework/cli/hooks/specs.py b/kedro/framework/cli/hooks/specs.py index cc8c23a9f2..f772c6de57 100644 --- a/kedro/framework/cli/hooks/specs.py +++ b/kedro/framework/cli/hooks/specs.py @@ -17,7 +17,7 @@ def before_command_run( self, project_metadata: ProjectMetadata, command_args: list[str], - ): + ) -> None: """Hooks to be invoked before a CLI command runs. It receives the ``project_metadata`` as well as all command line arguments that were used, including the command @@ -32,7 +32,7 @@ def before_command_run( @cli_hook_spec def after_command_run( self, project_metadata: ProjectMetadata, command_args: list[str], exit_code: int - ): + ) -> None: """Hooks to be invoked after a CLI command runs. It receives the ``project_metadata`` as well as all command line arguments that were used, including the command diff --git a/kedro/framework/cli/jupyter.py b/kedro/framework/cli/jupyter.py index d2facef34b..3c9418fbb6 100644 --- a/kedro/framework/cli/jupyter.py +++ b/kedro/framework/cli/jupyter.py @@ -7,6 +7,7 @@ import os import shutil from pathlib import Path +from typing import Any import click @@ -30,19 +31,19 @@ class JupyterCommandGroup(click.Group): """A custom class for ordering the `kedro jupyter` command groups""" - def list_commands(self, ctx): + def list_commands(self, ctx: click.Context) -> list[str]: """List commands according to a custom order""" return ["setup", "notebook", "lab", "convert"] # noqa: missing-function-docstring @click.group(name="Kedro") -def jupyter_cli(): # pragma: no cover +def jupyter_cli() -> None: # pragma: no cover pass @jupyter_cli.group(cls=JupyterCommandGroup) -def jupyter(): +def jupyter() -> None: """Open Jupyter Notebook / Lab with project specific variables loaded, or convert notebooks into Kedro code. """ @@ -50,7 +51,7 @@ def jupyter(): @forward_command(jupyter, "setup", forward_help=True) @click.pass_obj # this will pass the metadata as first argument -def setup(metadata: ProjectMetadata, args, **kwargs): # noqa: unused-argument +def setup(metadata: ProjectMetadata, /, args: Any, **kwargs: Any) -> None: # noqa: unused-argument """Initialise the Jupyter Kernel for a kedro project.""" _check_module_importable("ipykernel") validate_settings() @@ -65,10 +66,11 @@ def setup(metadata: ProjectMetadata, args, **kwargs): # noqa: unused-argument @click.pass_obj # this will pass the metadata as first argument def jupyter_notebook( metadata: ProjectMetadata, - env, - args, - **kwargs, -): # noqa: unused-argument + /, + env: str, + args: Any, + **kwargs: Any, +) -> None: # noqa: unused-argument """Open Jupyter Notebook with project specific variables loaded.""" _check_module_importable("notebook") validate_settings() @@ -91,10 +93,11 @@ def jupyter_notebook( @click.pass_obj # this will pass the metadata as first argument def jupyter_lab( metadata: ProjectMetadata, - env, - args, - **kwargs, -): # noqa: unused-argument + /, + env: str, + args: Any, + **kwargs: Any, +) -> None: # noqa: unused-argument """Open Jupyter Lab with project specific variables loaded.""" _check_module_importable("jupyterlab") validate_settings() diff --git a/kedro/framework/cli/micropkg.py b/kedro/framework/cli/micropkg.py index 32c00b2323..1d4c4fcdf8 100644 --- a/kedro/framework/cli/micropkg.py +++ b/kedro/framework/cli/micropkg.py @@ -14,6 +14,7 @@ from typing import Any, Iterable, Iterator import click +from importlib_metadata import PackageMetadata from omegaconf import OmegaConf from packaging.requirements import InvalidRequirement, Requirement from packaging.utils import canonicalize_name @@ -106,7 +107,7 @@ def __eq__(self, other: Any) -> bool: ) -def _check_module_path(ctx, param, value): # noqa: unused-argument +def _check_module_path(ctx: click.core.Context, param: Any, value: str) -> str: # noqa: unused-argument if value and not re.match(r"^[\w.]+$", value): message = ( "The micro-package location you provided is not a valid Python module path" @@ -117,12 +118,12 @@ def _check_module_path(ctx, param, value): # noqa: unused-argument # noqa: missing-function-docstring @click.group(name="Kedro") -def micropkg_cli(): # pragma: no cover +def micropkg_cli() -> None: # pragma: no cover pass @micropkg_cli.group() -def micropkg(): +def micropkg() -> None: """Commands for working with micro-packages.""" @@ -157,13 +158,14 @@ def micropkg(): @click.pass_obj # this will pass the metadata as first argument def pull_package( # noqa: PLR0913 metadata: ProjectMetadata, - package_path, - env, - alias, - destination, - fs_args, - all_flag, - **kwargs, + /, + package_path: str, + env: str, + alias: str, + destination: str, + fs_args: str, + all_flag: str, + **kwargs: Any, ) -> None: """Pull and unpack a modular pipeline and other micro-packages in your project.""" if not package_path and not all_flag: @@ -197,7 +199,7 @@ def _pull_package( # noqa: PLR0913 alias: str | None = None, destination: str | None = None, fs_args: str | None = None, -): +) -> None: with tempfile.TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir).resolve() _unpack_sdist(package_path, temp_dir_path, fs_args) @@ -330,13 +332,14 @@ def _package_micropkgs_from_manifest(metadata: ProjectMetadata) -> None: @click.pass_obj # this will pass the metadata as first argument def package_micropkg( # noqa: PLR0913 metadata: ProjectMetadata, - module_path, - env, - alias, - destination, - all_flag, - **kwargs, -): + /, + module_path: str, + env: str, + alias: str, + destination: str, + all_flag: str, + **kwargs: Any, +) -> None: """Package up a modular pipeline or micro-package as a Python source distribution.""" if not module_path and not all_flag: click.secho( @@ -361,7 +364,7 @@ def package_micropkg( # noqa: PLR0913 click.secho(message, fg="green") -def _get_fsspec_filesystem(location: str, fs_args: str | None): +def _get_fsspec_filesystem(location: str, fs_args: str | None) -> Any: # noqa: import-outside-toplevel import fsspec @@ -380,13 +383,13 @@ def _get_fsspec_filesystem(location: str, fs_args: str | None): return None -def _is_within_directory(directory, target): +def _is_within_directory(directory: Path, target: Path) -> bool: abs_directory = directory.resolve() abs_target = target.resolve() return abs_directory in abs_target.parents -def safe_extract(tar, path): +def safe_extract(tar: tarfile.TarFile, path: Path) -> None: for member in tar.getmembers(): member_path = path / member.name if not _is_within_directory(path, member_path): @@ -428,7 +431,7 @@ def _unpack_sdist(location: str, destination: Path, fs_args: str | None) -> None safe_extract(fs_file, destination) -def _rename_files(conf_source: Path, old_name: str, new_name: str): +def _rename_files(conf_source: Path, old_name: str, new_name: str) -> None: config_files_to_rename = ( each for each in conf_source.rglob("*") @@ -527,7 +530,7 @@ def _install_files( # noqa: PLR0913, too-many-locals env: str | None = None, alias: str | None = None, destination: str | None = None, -): +) -> None: env = env or "base" package_source, test_source, conf_source = _get_package_artifacts( @@ -589,7 +592,7 @@ def _get_default_version(metadata: ProjectMetadata, micropkg_module_path: str) - micropkg_module = import_module( f"{metadata.package_name}.{micropkg_module_path}" ) - return micropkg_module.__version__ # type: ignore + return micropkg_module.__version__ # type: ignore[no-any-return] except (AttributeError, ModuleNotFoundError): logger.warning( "Micropackage version not found in '%s.%s', will take the top-level one in '%s'", @@ -599,7 +602,7 @@ def _get_default_version(metadata: ProjectMetadata, micropkg_module_path: str) - ) # if micropkg version doesn't exist, take the project one project_module = import_module(f"{metadata.package_name}") - return project_module.__version__ # type: ignore + return project_module.__version__ # type: ignore[no-any-return] def _package_micropkg( @@ -662,7 +665,7 @@ def _validate_dir(path: Path) -> None: raise KedroCliError(f"'{path}' is an empty directory.") -def _get_sdist_name(name, version): +def _get_sdist_name(name: str, version: str) -> str: return f"{name}-{version}.tar.gz" @@ -672,7 +675,7 @@ def _sync_path_list(source: list[tuple[Path, str]], target: Path) -> None: _sync_dirs(source_path, target_with_suffix) -def _drop_comment(line): +def _drop_comment(line: str) -> str: # https://github.com/pypa/setuptools/blob/b545fc7/\ # pkg_resources/_vendor/jaraco/text/__init__.py#L554-L566 return line.partition(" #")[0] @@ -779,7 +782,9 @@ def _refactor_code_for_package( |__ test.py """ - def _move_package_with_conflicting_name(target: Path, conflicting_name: str): + def _move_package_with_conflicting_name( + target: Path, conflicting_name: str + ) -> None: tmp_name = "tmp_name" tmp_module = target.parent / tmp_name _rename_package(project, target.as_posix(), tmp_name) @@ -886,7 +891,7 @@ def _generate_sdist_file( # noqa: PLR0913,too-many-locals ) -def _generate_manifest_file(output_dir: Path): +def _generate_manifest_file(output_dir: Path) -> None: manifest_file = output_dir / "MANIFEST.in" manifest_file.write_text( """ @@ -966,7 +971,7 @@ def _append_package_reqs( ) -def _get_all_library_reqs(metadata): +def _get_all_library_reqs(metadata: PackageMetadata) -> list[str]: """Get all library requirements from metadata, leaving markers intact.""" # See https://discuss.python.org/t/\ # programmatically-getting-non-optional-requirements-of-current-directory/26963/2 diff --git a/kedro/framework/cli/pipeline.py b/kedro/framework/cli/pipeline.py index 89e8d281c5..6b09ad0e2c 100644 --- a/kedro/framework/cli/pipeline.py +++ b/kedro/framework/cli/pipeline.py @@ -5,7 +5,7 @@ import shutil from pathlib import Path from textwrap import indent -from typing import NamedTuple +from typing import Any, NamedTuple import click @@ -41,7 +41,7 @@ class PipelineArtifacts(NamedTuple): pipeline_conf: Path -def _assert_pkg_name_ok(pkg_name: str): +def _assert_pkg_name_ok(pkg_name: str) -> None: """Check that python package name is in line with PEP8 requirements. Args: @@ -65,7 +65,7 @@ def _assert_pkg_name_ok(pkg_name: str): raise KedroCliError(message) -def _check_pipeline_name(ctx, param, value): # noqa: unused-argument +def _check_pipeline_name(ctx: click.Context, param: Any, value: str) -> str: # noqa: unused-argument if value: _assert_pkg_name_ok(value) return value @@ -73,12 +73,12 @@ def _check_pipeline_name(ctx, param, value): # noqa: unused-argument # noqa: missing-function-docstring @click.group(name="Kedro") -def pipeline_cli(): # pragma: no cover +def pipeline_cli() -> None: # pragma: no cover pass @pipeline_cli.group() -def pipeline(): +def pipeline() -> None: """Commands for working with pipelines.""" @@ -99,8 +99,14 @@ def pipeline(): @env_option(help="Environment to create pipeline configuration in. Defaults to `base`.") @click.pass_obj # this will pass the metadata as first argument def create_pipeline( - metadata: ProjectMetadata, name, template_path, skip_config, env, **kwargs -): # noqa: unused-argument + metadata: ProjectMetadata, + /, + name: str, + template_path: Path, + skip_config: bool, + env: str, + **kwargs: Any, +) -> None: # noqa: unused-argument """Create a new modular pipeline by providing a name.""" package_dir = metadata.source_dir / metadata.package_name conf_source = settings.CONF_SOURCE @@ -140,7 +146,9 @@ def create_pipeline( "-y", "--yes", is_flag=True, help="Confirm deletion of pipeline non-interactively." ) @click.pass_obj # this will pass the metadata as first argument -def delete_pipeline(metadata: ProjectMetadata, name, env, yes, **kwargs): # noqa: unused-argument +def delete_pipeline( + metadata: ProjectMetadata, /, name: str, env: str, yes: bool, **kwargs: Any +) -> None: # noqa: unused-argument """Delete a modular pipeline by providing a name.""" package_dir = metadata.source_dir / metadata.package_name conf_source = settings.CONF_SOURCE @@ -195,7 +203,7 @@ def delete_pipeline(metadata: ProjectMetadata, name, env, yes, **kwargs): # noq ) -def _echo_deletion_warning(message: str, **paths: list[Path]): +def _echo_deletion_warning(message: str, **paths: list[Path]) -> None: paths = {key: values for key, values in paths.items() if values} if paths: @@ -216,7 +224,7 @@ def _create_pipeline(name: str, template_path: Path, output_dir: Path) -> Path: click.echo(f"Creating the pipeline '{name}': ", nl=False) try: - result_path = cookiecutter( + cookiecutter_result = cookiecutter( str(template_path), output_dir=str(output_dir), no_input=True, @@ -228,7 +236,7 @@ def _create_pipeline(name: str, template_path: Path, output_dir: Path) -> Path: raise KedroCliError(f"{cls.__module__}.{cls.__qualname__}: {exc}") from exc click.secho("OK", fg="green") - result_path = Path(result_path) + result_path = Path(cookiecutter_result) message = indent(f"Location: '{result_path.resolve()}'", " " * 2) click.secho(message, bold=True) @@ -237,7 +245,9 @@ def _create_pipeline(name: str, template_path: Path, output_dir: Path) -> Path: return result_path -def _sync_dirs(source: Path, target: Path, prefix: str = "", overwrite: bool = False): +def _sync_dirs( + source: Path, target: Path, prefix: str = "", overwrite: bool = False +) -> None: """Recursively copies `source` directory (or file) into `target` directory without overwriting any existing files/directories in the target using the following rules: @@ -313,7 +323,9 @@ def _get_artifacts_to_package( return artifacts -def _copy_pipeline_tests(pipeline_name: str, result_path: Path, package_dir: Path): +def _copy_pipeline_tests( + pipeline_name: str, result_path: Path, package_dir: Path +) -> None: tests_source = result_path / "tests" tests_target = package_dir.parent / "tests" / "pipelines" / pipeline_name try: @@ -324,7 +336,7 @@ def _copy_pipeline_tests(pipeline_name: str, result_path: Path, package_dir: Pat def _copy_pipeline_configs( result_path: Path, conf_path: Path, skip_config: bool, env: str -): +) -> None: config_source = result_path / "config" try: if not skip_config: @@ -334,7 +346,7 @@ def _copy_pipeline_configs( shutil.rmtree(config_source) -def _delete_artifacts(*artifacts: Path): +def _delete_artifacts(*artifacts: Path) -> None: for artifact in artifacts: click.echo(f"Deleting '{artifact}': ", nl=False) try: diff --git a/kedro/framework/cli/project.py b/kedro/framework/cli/project.py index 5a15a761a2..a38758c767 100644 --- a/kedro/framework/cli/project.py +++ b/kedro/framework/cli/project.py @@ -1,8 +1,10 @@ """A collection of CLI commands for working with Kedro project.""" +from __future__ import annotations import os import sys from pathlib import Path +from typing import Any import click @@ -61,14 +63,14 @@ # noqa: missing-function-docstring @click.group(name="Kedro") -def project_group(): # pragma: no cover +def project_group() -> None: # pragma: no cover pass @forward_command(project_group, forward_help=True) @env_option @click.pass_obj # this will pass the metadata as first argument -def ipython(metadata: ProjectMetadata, env, args, **kwargs): # noqa: unused-argument +def ipython(metadata: ProjectMetadata, /, env: str, args: Any, **kwargs: Any) -> None: # noqa: unused-argument """Open IPython with project specific variables loaded.""" _check_module_importable("IPython") @@ -79,7 +81,7 @@ def ipython(metadata: ProjectMetadata, env, args, **kwargs): # noqa: unused-arg @project_group.command() @click.pass_obj # this will pass the metadata as first argument -def package(metadata: ProjectMetadata): +def package(metadata: ProjectMetadata) -> None: """Package the project as a Python wheel.""" # Even if the user decides for the older setup.py on purpose, # pyproject.toml is needed for Kedro metadata @@ -196,35 +198,35 @@ def package(metadata: ProjectMetadata): callback=_split_params, ) def run( # noqa: PLR0913,unused-argument,too-many-locals - tags, - env, - runner, - is_async, - node_names, - to_nodes, - from_nodes, - from_inputs, - to_outputs, - load_versions, - pipeline, - config, - conf_source, - params, - namespace, -): + tags: str, + env: str, + runner: str, + is_async: bool, + node_names: str, + to_nodes: str, + from_nodes: str, + from_inputs: str, + to_outputs: str, + load_versions: dict[str, str] | None, + pipeline: str, + config: str, + conf_source: str, + params: dict[str, Any], + namespace: str, +) -> None: """Run the pipeline.""" - runner = load_obj(runner or "SequentialRunner", "kedro.runner") - tags = tuple(tags) - node_names = tuple(node_names) + runner_obj = load_obj(runner or "SequentialRunner", "kedro.runner") + tuple_tags = tuple(tags) + tuple_node_names = tuple(node_names) with KedroSession.create( env=env, conf_source=conf_source, extra_params=params ) as session: session.run( - tags=tags, - runner=runner(is_async=is_async), - node_names=node_names, + tags=tuple_tags, + runner=runner_obj(is_async=is_async), + node_names=tuple_node_names, from_nodes=from_nodes, to_nodes=to_nodes, from_inputs=from_inputs, diff --git a/kedro/framework/cli/registry.py b/kedro/framework/cli/registry.py index c57b3551b3..ca82681e04 100644 --- a/kedro/framework/cli/registry.py +++ b/kedro/framework/cli/registry.py @@ -1,4 +1,6 @@ """A collection of CLI commands for working with registered Kedro pipelines.""" +from typing import Any + import click import yaml @@ -9,17 +11,17 @@ # noqa: missing-function-docstring @click.group(name="Kedro") -def registry_cli(): # pragma: no cover +def registry_cli() -> None: # pragma: no cover pass @registry_cli.group() -def registry(): +def registry() -> None: """Commands for working with registered pipelines.""" @registry.command("list") -def list_registered_pipelines(): +def list_registered_pipelines() -> None: """List all pipelines defined in your pipeline_registry.py file.""" click.echo(yaml.dump(sorted(pipelines))) @@ -27,7 +29,9 @@ def list_registered_pipelines(): @command_with_verbosity(registry, "describe") @click.argument("name", nargs=1, default="__default__") @click.pass_obj -def describe_registered_pipeline(metadata: ProjectMetadata, name, **kwargs): # noqa: unused-argument, protected-access +def describe_registered_pipeline( + metadata: ProjectMetadata, /, name: str, **kwargs: Any +) -> None: # noqa: unused-argument, protected-access """Describe a registered pipeline by providing a pipeline name. Defaults to the `__default__` pipeline. """ diff --git a/kedro/framework/cli/starters.py b/kedro/framework/cli/starters.py index 7fbe4f5ad6..e0ff762589 100644 --- a/kedro/framework/cli/starters.py +++ b/kedro/framework/cli/starters.py @@ -19,6 +19,7 @@ import click import yaml from attrs import define, field +from importlib_metadata import EntryPoints import kedro from kedro import __version__ as version @@ -148,7 +149,7 @@ def _validate_flag_inputs(flag_inputs: dict[str, Any]) -> None: ) -def _validate_regex(pattern_name, text): +def _validate_regex(pattern_name: str, text: str) -> None: VALIDATION_PATTERNS = { "yes_no": { "regex": r"(?i)^\s*(y|yes|n|no)\s*$", @@ -179,11 +180,11 @@ def _validate_regex(pattern_name, text): sys.exit(1) -def _parse_yes_no_to_bool(value): +def _parse_yes_no_to_bool(value: str) -> Any: return value.strip().lower() in ["y", "yes"] if value is not None else None -def _validate_selected_tools(selected_tools): +def _validate_selected_tools(selected_tools: str | None) -> None: valid_tools = list(TOOLS_SHORTNAME_TO_NUMBER) + ["all", "none"] if selected_tools is not None: @@ -206,7 +207,7 @@ def _validate_selected_tools(selected_tools): def _print_selection_and_prompt_info( - selected_tools: str | None, example_pipeline: bool | None, interactive: bool + selected_tools: str | None, example_pipeline: str | None, interactive: bool ) -> None: # Confirm tools selection if selected_tools is not None: @@ -240,12 +241,12 @@ def _print_selection_and_prompt_info( # noqa: missing-function-docstring @click.group(context_settings=CONTEXT_SETTINGS, name="Kedro") -def create_cli(): # pragma: no cover +def create_cli() -> None: # pragma: no cover pass @create_cli.group() -def starter(): +def starter() -> None: """Commands for working with project starters.""" @@ -264,15 +265,15 @@ def starter(): @click.option("--name", "-n", "project_name", help=NAME_ARG_HELP) @click.option("--example", "-e", "example_pipeline", help=EXAMPLE_ARG_HELP) def new( # noqa: PLR0913 - config_path, - starter_alias, - selected_tools, - project_name, - checkout, - directory, - example_pipeline, # This will be True or False - **kwargs, -): + config_path: str, + starter_alias: str, + selected_tools: str, + project_name: str, + checkout: str, + directory: str, + example_pipeline: str, # This will be True or False + **kwargs: Any, +) -> None: """Create a new kedro project.""" flag_inputs = { "config": config_path, @@ -295,7 +296,7 @@ def new( # noqa: PLR0913 template_path = spec.template_path # "directory" is an optional key for starters from plugins, so if the key is # not present we will use "None". - directory = spec.directory + directory = spec.directory # type: ignore[assignment] checkout = checkout or version elif starter_alias is not None: template_path = starter_alias @@ -325,13 +326,13 @@ def new( # noqa: PLR0913 # but it causes an issue with readonly files on windows # see: https://bugs.python.org/issue26660. # So on error, we will attempt to clear the readonly bits and re-attempt the cleanup - shutil.rmtree(tmpdir, onerror=_remove_readonly) + shutil.rmtree(tmpdir, onerror=_remove_readonly) # type: ignore[arg-type] # Obtain config, either from a file or from interactive user prompts. extra_context = _get_extra_context( prompts_required=prompts_required, config_path=config_path, - cookiecutter_context=cookiecutter_context, # type: ignore + cookiecutter_context=cookiecutter_context, selected_tools=selected_tools, project_name=project_name, example_pipeline=example_pipeline, @@ -355,7 +356,7 @@ def new( # noqa: PLR0913 extra_context.get("tools"), extra_context.get("example_pipeline"), interactive_flow, - ) # type: ignore + ) @starter.command("list") @@ -373,7 +374,7 @@ def list_starters() -> None: # ensure kedro starters are listed first sorted_starters_dict = dict( - sorted(sorted_starters_dict.items(), key=lambda x: x == "kedro") + sorted(sorted_starters_dict.items(), key=lambda x: x == "kedro") # type: ignore[comparison-overlap] ) for origin, starters_spec in sorted_starters_dict.items(): @@ -425,7 +426,7 @@ def _get_prompts_required_and_clear_from_CLI_provided( selected_tools: str, project_name: str, example_pipeline: str, -) -> dict[str, Any]: +) -> Any: """Finds the information a user must supply according to prompts.yml, and clear it from what has already been provided via the CLI(validate it before)""" prompts_yml = cookiecutter_dir / "prompts.yml" @@ -493,7 +494,7 @@ def _get_starters_dict() -> dict[str, KedroStarterSpec]: for starter_entry_point in _get_entry_points(name="starters"): origin = starter_entry_point.module.split(".")[0] - specs = _safe_load_entry_point(starter_entry_point) or [] + specs: EntryPoints | list = _safe_load_entry_point(starter_entry_point) or [] for spec in specs: if not isinstance(spec, KedroStarterSpec): click.secho( @@ -517,7 +518,7 @@ def _get_starters_dict() -> dict[str, KedroStarterSpec]: def _get_extra_context( # noqa: PLR0913 prompts_required: dict, config_path: str, - cookiecutter_context: OrderedDict, + cookiecutter_context: OrderedDict | None, selected_tools: str | None, project_name: str | None, example_pipeline: str | None, @@ -556,35 +557,28 @@ def _get_extra_context( # noqa: PLR0913 # Format extra_context.setdefault("kedro_version", version) - tools = _convert_tool_names_to_numbers(selected_tools) + converted_tools = _convert_tool_names_to_numbers(selected_tools) - if tools is not None: - extra_context["tools"] = tools + if converted_tools is not None: + extra_context["tools"] = converted_tools if project_name is not None: extra_context["project_name"] = project_name # Map the selected tools lists to readable name - tools = extra_context.get("tools") - tools = _parse_tools_input(tools) + tools_context = extra_context.get("tools") + tools = _parse_tools_input(tools_context) # Check if no tools selected if not tools: extra_context["tools"] = str(["None"]) else: - extra_context["tools"] = str( - [ - NUMBER_TO_TOOLS_NAME[tool] - for tool in tools # type: ignore - ] - ) + extra_context["tools"] = str([NUMBER_TO_TOOLS_NAME[tool] for tool in tools]) - extra_context["example_pipeline"] = ( - _parse_yes_no_to_bool( - example_pipeline - if example_pipeline is not None - else extra_context.get("example_pipeline", "no") - ) # type: ignore + extra_context["example_pipeline"] = _parse_yes_no_to_bool( + example_pipeline + if example_pipeline is not None + else extra_context.get("example_pipeline", "no") ) return extra_context @@ -647,11 +641,12 @@ def _fetch_config_from_file(config_path: str) -> dict[str, str]: f"Failed to generate project: could not load config at {config_path}." ) from exc - return config + # The return type defined is more specific than the "Any" type config return from yaml.safe_load + return config # type: ignore[no-any-return] def _fetch_config_from_user_prompts( - prompts: dict[str, Any], cookiecutter_context: OrderedDict + prompts: dict[str, Any], cookiecutter_context: OrderedDict | None ) -> dict[str, str]: """Interactively obtains information from user prompts. @@ -667,6 +662,9 @@ def _fetch_config_from_user_prompts( from cookiecutter.environment import StrictEnvironment from cookiecutter.prompt import read_user_variable, render_variable + if not cookiecutter_context: + raise Exception("No cookiecutter context available.") + config: dict[str, str] = {} for variable_name, prompt_dict in prompts.items(): @@ -687,20 +685,20 @@ def _fetch_config_from_user_prompts( return config -def _make_cookiecutter_context_for_prompts(cookiecutter_dir: Path): +def _make_cookiecutter_context_for_prompts(cookiecutter_dir: Path) -> OrderedDict: # noqa: import-outside-toplevel from cookiecutter.generate import generate_context cookiecutter_context = generate_context(cookiecutter_dir / "cookiecutter.json") - return cookiecutter_context.get("cookiecutter", {}) + return cookiecutter_context.get("cookiecutter", {}) # type: ignore[no-any-return] def _make_cookiecutter_args_and_fetch_template( - config: dict[str, str | list[str]], + config: dict[str, str], checkout: str, directory: str, template_path: str, -): +) -> tuple[dict[str, object], str]: """Creates a dictionary of arguments to pass to cookiecutter and returns project template path. Args: @@ -734,7 +732,7 @@ def _make_cookiecutter_args_and_fetch_template( # If 'tools' or 'example_pipeline' are not specified in prompts.yml, CLI or config.yml, # default options will be used instead # That can be when starter used or while loading from config.yml - tools = config.get("tools", []) + tools: str | list = config.get("tools", []) example_pipeline = config.get("example_pipeline", False) starter_path = "git+https://github.com/kedro-org/kedro-starters.git" @@ -759,7 +757,7 @@ def _make_cookiecutter_args_and_fetch_template( def _validate_config_file_against_prompts( config: dict[str, str], prompts: dict[str, Any] -): +) -> None: """Checks that the configuration file contains all needed variables. Args: @@ -794,11 +792,13 @@ def _validate_config_file_against_prompts( ) -def _validate_config_file_inputs(config: dict[str, str], starter_alias: str | None): +def _validate_config_file_inputs( + config: dict[str, str], starter_alias: str | None +) -> None: """Checks that variables provided through the config file are of the expected format. This - validate the config provided by `kedro new --config` in a similar way to `prompts.yml` - for starters. - Also validates that "tools" or "example_pipeline" options cannot be used in config when any starter option is selected. + validates the config provided by `kedro new --config` in a similar way to `prompts.yml` for starters. + Also validates that "tools" or "example_pipeline" options cannot be used in config when any starter option is + selected. Args: config: The config as a dictionary @@ -822,7 +822,7 @@ def _validate_config_file_inputs(config: dict[str, str], starter_alias: str | No _validate_regex("yes_no", config.get("example_pipeline", "no")) -def _validate_selection(tools: list[str]): +def _validate_selection(tools: list[str]) -> None: # start validating from the end, when user select 1-20, it will generate a message # '20' is not a valid selection instead of '8' for tool in tools[::-1]: @@ -832,7 +832,7 @@ def _validate_selection(tools: list[str]): sys.exit(1) -def _parse_tools_input(tools_str: str | None): +def _parse_tools_input(tools_str: str | None) -> list[str]: """Parse the tools input string. Args: @@ -842,7 +842,7 @@ def _parse_tools_input(tools_str: str | None): list: List of selected tools as strings. """ - def _validate_range(start, end): + def _validate_range(start: Any, end: Any) -> None: if int(start) > int(end): message = f"'{start}-{end}' is an invalid range for project tools.\nPlease ensure range values go from smaller to larger." click.secho(message, fg="red", err=True) @@ -872,7 +872,7 @@ def _validate_range(start, end): return selected -def _create_project(template_path: str, cookiecutter_args: dict[str, Any]): +def _create_project(template_path: str, cookiecutter_args: dict[str, Any]) -> None: """Creates a new kedro project using cookiecutter. Args: @@ -909,7 +909,7 @@ def _create_project(template_path: str, cookiecutter_args: dict[str, Any]): class _Prompt: """Represent a single CLI prompt for `kedro new`""" - def __init__(self, *args, **kwargs) -> None: # noqa: unused-argument + def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: unused-argument try: self.title = kwargs["title"] except KeyError as exc: @@ -943,7 +943,9 @@ def validate(self, user_input: str) -> None: # noqa: unused-argument -def _remove_readonly(func: Callable, path: Path, excinfo: tuple): # pragma: no cover +def _remove_readonly( + func: Callable, path: Path, excinfo: tuple +) -> None: # pragma: no cover """Remove readonly files on Windows See: https://docs.python.org/3/library/shutil.html?highlight=shutil#rmtree-example """ diff --git a/kedro/framework/cli/utils.py b/kedro/framework/cli/utils.py index a23b9bf3cd..40e2a92163 100644 --- a/kedro/framework/cli/utils.py +++ b/kedro/framework/cli/utils.py @@ -10,11 +10,12 @@ import sys import textwrap import traceback +import typing from collections import defaultdict from importlib import import_module from itertools import chain from pathlib import Path -from typing import Iterable, Sequence +from typing import IO, Any, Iterable, Sequence import click import importlib_metadata @@ -39,7 +40,7 @@ logger = logging.getLogger(__name__) -def call(cmd: list[str], **kwargs): # pragma: no cover +def call(cmd: list[str], **kwargs: Any) -> None: # pragma: no cover """Run a subprocess command and raise if it fails. Args: @@ -55,7 +56,9 @@ def call(cmd: list[str], **kwargs): # pragma: no cover raise click.exceptions.Exit(code=code) -def python_call(module: str, arguments: Iterable[str], **kwargs): # pragma: no cover +def python_call( + module: str, arguments: Iterable[str], **kwargs: Any +) -> None: # pragma: no cover """Run a subprocess command that invokes a Python module.""" call([sys.executable, "-m", module] + list(arguments), **kwargs) @@ -70,10 +73,12 @@ def find_stylesheets() -> Iterable[str]: # pragma: no cover ) -def forward_command(group, name=None, forward_help=False): +def forward_command( + group: Any, name: str | None = None, forward_help: bool = False +) -> Any: """A command that receives the rest of the command line as 'args'.""" - def wrapit(func): + def wrapit(func: Any) -> Any: func = click.argument("args", nargs=-1, type=click.UNPROCESSED)(func) func = command_with_verbosity( group, @@ -102,7 +107,7 @@ def _suggest_cli_command( suggestion = "\n\nDid you mean this?" else: suggestion = "\n\nDid you mean one of these?\n" - suggestion += textwrap.indent("\n".join(matches), " " * 4) # type: ignore + suggestion += textwrap.indent("\n".join(matches), " " * 4) return suggestion @@ -124,7 +129,7 @@ def __init__(self, *groups: tuple[str, Sequence[click.MultiCommand]]): ] self._dedupe_commands(sources) super().__init__( - sources=sources, + sources=sources, # type: ignore[arg-type] help="\n\n".join(help_texts), context_settings=CONTEXT_SETTINGS, ) @@ -132,36 +137,38 @@ def __init__(self, *groups: tuple[str, Sequence[click.MultiCommand]]): self.callback = sources[0].callback @staticmethod - def _dedupe_commands(cli_collections: Sequence[click.CommandCollection]): + def _dedupe_commands(cli_collections: Sequence[click.CommandCollection]) -> None: """Deduplicate commands by keeping the ones from the last source in the list. """ seen_names: set[str] = set() for cli_collection in reversed(cli_collections): for cmd_group in reversed(cli_collection.sources): - cmd_group.commands = { # type: ignore + cmd_group.commands = { # type: ignore[attr-defined] cmd_name: cmd - for cmd_name, cmd in cmd_group.commands.items() # type: ignore + for cmd_name, cmd in cmd_group.commands.items() # type: ignore[attr-defined] if cmd_name not in seen_names } - seen_names |= cmd_group.commands.keys() # type: ignore + seen_names |= cmd_group.commands.keys() # type: ignore[attr-defined] # remove empty command groups for cli_collection in cli_collections: cli_collection.sources = [ cmd_group for cmd_group in cli_collection.sources - if cmd_group.commands # type: ignore + if cmd_group.commands # type: ignore[attr-defined] ] @staticmethod - def _merge_same_name_collections(groups: Sequence[click.MultiCommand]): + def _merge_same_name_collections( + groups: Sequence[click.MultiCommand] + ) -> list[click.CommandCollection]: named_groups: defaultdict[str, list[click.MultiCommand]] = defaultdict(list) helps: defaultdict[str, list] = defaultdict(list) for group in groups: - named_groups[group.name].append(group) # type: ignore + named_groups[group.name].append(group) # type: ignore[index] if group.help: - helps[group.name].append(group.help) # type: ignore + helps[group.name].append(group.help) # type: ignore[index] return [ click.CommandCollection( @@ -175,7 +182,9 @@ def _merge_same_name_collections(groups: Sequence[click.MultiCommand]): if cli_list ] - def resolve_command(self, ctx: click.core.Context, args: list): + def resolve_command( + self, ctx: click.core.Context, args: list + ) -> tuple[str | None, click.Command | None, list[str]]: try: return super().resolve_command(ctx, args) except click.exceptions.UsageError as exc: @@ -188,7 +197,7 @@ def resolve_command(self, ctx: click.core.Context, args: list): def format_commands( self, ctx: click.core.Context, formatter: click.formatting.HelpFormatter - ): + ) -> None: for title, cli in self.groups: for group in cli: if group.sources: @@ -226,11 +235,11 @@ def get_pkg_version(reqs_path: (str | Path), package_name: str) -> str: raise KedroCliError(f"Cannot find '{package_name}' package in '{reqs_path}'.") -def _update_verbose_flag(ctx, param, value): # noqa: unused-argument +def _update_verbose_flag(ctx: click.Context, param: Any, value: bool) -> None: # noqa: unused-argument KedroCliError.VERBOSE_ERROR = value -def _click_verbose(func): +def _click_verbose(func: Any) -> Any: """Click option for enabling verbose mode.""" return click.option( "--verbose", @@ -241,10 +250,10 @@ def _click_verbose(func): )(func) -def command_with_verbosity(group: click.core.Group, *args, **kwargs): +def command_with_verbosity(group: click.core.Group, *args: Any, **kwargs: Any) -> Any: """Custom command decorator with verbose flag added.""" - def decorator(func): + def decorator(func: Any) -> Any: func = _click_verbose(func) func = group.command(*args, **kwargs)(func) return func @@ -261,7 +270,7 @@ class KedroCliError(click.exceptions.ClickException): VERBOSE_ERROR = False VERBOSE_EXISTS = True - def show(self, file=None): + def show(self, file: IO | None = None) -> None: if self.VERBOSE_ERROR: click.secho(traceback.format_exc(), nl=False, fg="yellow") elif self.VERBOSE_EXISTS: @@ -280,7 +289,7 @@ def show(self, file=None): ) -def _clean_pycache(path: Path): +def _clean_pycache(path: Path) -> None: """Recursively clean all __pycache__ folders from `path`. Args: @@ -292,13 +301,13 @@ def _clean_pycache(path: Path): shutil.rmtree(each, ignore_errors=True) -def split_string(ctx, param, value): # noqa: unused-argument +def split_string(ctx: click.Context, param: Any, value: str) -> list[str]: # noqa: unused-argument """Split string by comma.""" return [item.strip() for item in value.split(",") if item.strip()] # noqa: unused-argument,missing-param-doc,missing-type-doc -def split_node_names(ctx, param, to_split: str) -> list[str]: +def split_node_names(ctx: click.Context, param: Any, to_split: str) -> list[str]: """Split string by comma, ignoring commas enclosed by square parentheses. This avoids splitting the string of nodes names on commas included in default node names, which have the pattern @@ -333,7 +342,7 @@ def split_node_names(ctx, param, to_split: str) -> list[str]: return result -def env_option(func_=None, **kwargs): +def env_option(func_: Any | None = None, **kwargs: Any) -> Any: """Add `--env` CLI option to a function.""" default_args = {"type": str, "default": None, "help": ENV_HELP} kwargs = {**default_args, **kwargs} @@ -351,14 +360,16 @@ def _check_module_importable(module_name: str) -> None: ) from exc -def _get_entry_points(name: str) -> importlib_metadata.EntryPoints: +def _get_entry_points(name: str) -> Any: """Get all kedro related entry points""" - return importlib_metadata.entry_points().select(group=ENTRY_POINT_GROUPS[name]) + return importlib_metadata.entry_points().select( # type: ignore[no-untyped-call] + group=ENTRY_POINT_GROUPS[name] + ) def _safe_load_entry_point( # noqa: inconsistent-return-statements - entry_point, -): + entry_point: Any, +) -> Any: """Load entrypoint safely, if fails it will just skip the entrypoint.""" try: return entry_point.load() @@ -394,7 +405,8 @@ def load_entry_points(name: str) -> Sequence[click.MultiCommand]: return entry_point_commands -def _config_file_callback(ctx, param, value): # noqa: unused-argument +@typing.no_type_check +def _config_file_callback(ctx: click.Context, param: Any, value: Any) -> Any: # noqa: unused-argument """CLI callback that replaces command line options with values specified in a config file. If command line options are passed, they override config file values. @@ -412,7 +424,7 @@ def _config_file_callback(ctx, param, value): # noqa: unused-argument return value -def _validate_config_file(key): +def _validate_config_file(key: str) -> None: """Validate the keys provided in the config file against the accepted keys.""" from kedro.framework.cli.project import run @@ -420,13 +432,13 @@ def _validate_config_file(key): run_args.remove("config") if key not in run_args: KedroCliError.VERBOSE_EXISTS = False - message = _suggest_cli_command(key, run_args) + message = _suggest_cli_command(key, run_args) # type: ignore[arg-type] raise KedroCliError( f"Key `{key}` in provided configuration is not valid. {message}" ) -def _split_params(ctx, param, value): +def _split_params(ctx: click.Context, param: Any, value: Any) -> Any: if isinstance(value, dict): return value dot_list = [] @@ -454,7 +466,7 @@ def _split_params(ctx, param, value): return OmegaConf.to_container(conf) -def _split_load_versions(ctx, param, value): +def _split_load_versions(ctx: click.Context, param: Any, value: str) -> dict[str, str]: """Split and format the string coming from the --load-versions flag in kedro run, e.g.: "dataset1:time1,dataset2:time2" -> {"dataset1": "time1", "dataset2": "time2"} diff --git a/kedro/framework/context/context.py b/kedro/framework/context/context.py index c0991b2bea..874f21a721 100644 --- a/kedro/framework/context/context.py +++ b/kedro/framework/context/context.py @@ -120,7 +120,7 @@ def _convert_paths_to_absolute_posix( return conf_dictionary -def _validate_transcoded_datasets(catalog: DataCatalog): +def _validate_transcoded_datasets(catalog: DataCatalog) -> None: """Validates transcoded datasets are correctly named Args: @@ -205,7 +205,7 @@ def params(self) -> dict[str, Any]: # Merge nested structures params = OmegaConf.merge(params, self._extra_params) - return OmegaConf.to_container(params) if OmegaConf.is_config(params) else params + return OmegaConf.to_container(params) if OmegaConf.is_config(params) else params # type: ignore[no-any-return] def _get_catalog( self, @@ -229,7 +229,7 @@ def _get_catalog( ) conf_creds = self._get_config_credentials() - catalog = settings.DATA_CATALOG_CLASS.from_config( + catalog: DataCatalog = settings.DATA_CATALOG_CLASS.from_config( catalog=conf_catalog, credentials=conf_creds, load_versions=load_versions, @@ -254,7 +254,7 @@ def _get_feed_dict(self) -> dict[str, Any]: params = self.params feed_dict = {"parameters": params} - def _add_param_to_feed_dict(param_name, param_value): + def _add_param_to_feed_dict(param_name: str, param_value: Any) -> None: """This recursively adds parameter paths to the `feed_dict`, whenever `param_value` is a dictionary itself, so that users can specify specific nested parameters in their node inputs. @@ -281,7 +281,7 @@ def _add_param_to_feed_dict(param_name, param_value): def _get_config_credentials(self) -> dict[str, Any]: """Getter for credentials specified in credentials directory.""" try: - conf_creds = self.config_loader["credentials"] + conf_creds: dict[str, Any] = self.config_loader["credentials"] except MissingConfigException as exc: logging.getLogger(__name__).debug( "Credentials not found in your Kedro project config.\n %s", str(exc) diff --git a/kedro/framework/hooks/manager.py b/kedro/framework/hooks/manager.py index 21bbbd0f3f..64a63257be 100644 --- a/kedro/framework/hooks/manager.py +++ b/kedro/framework/hooks/manager.py @@ -101,11 +101,11 @@ class _NullPluginManager: """This class creates an empty ``hook_manager`` that will ignore all calls to hooks, allowing the runner to function if no ``hook_manager`` has been instantiated.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: pass - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return self - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> None: pass diff --git a/kedro/framework/hooks/specs.py b/kedro/framework/hooks/specs.py index 173edceb0e..14431f0362 100644 --- a/kedro/framework/hooks/specs.py +++ b/kedro/framework/hooks/specs.py @@ -112,7 +112,7 @@ def on_node_error( # noqa: PLR0913 inputs: dict[str, Any], is_async: bool, session_id: str, - ): + ) -> None: """Hook to be invoked if a node run throws an uncaught error. The signature of this error hook should match the signature of ``before_node_run`` along with the error that was raised. @@ -211,7 +211,7 @@ def on_pipeline_error( run_params: dict[str, Any], pipeline: Pipeline, catalog: DataCatalog, - ): + ) -> None: """Hook to be invoked if a pipeline run throws an uncaught Exception. The signature of this error hook should match the signature of ``before_pipeline_run`` along with the error that was raised. diff --git a/kedro/framework/project/__init__.py b/kedro/framework/project/__init__.py index 8f507df5e4..b7e986010d 100644 --- a/kedro/framework/project/__init__.py +++ b/kedro/framework/project/__init__.py @@ -15,6 +15,7 @@ from pathlib import Path from typing import Any +import dynaconf import importlib_resources import yaml from dynaconf import LazySettings @@ -28,10 +29,10 @@ ) -def _get_default_class(class_import_path): +def _get_default_class(class_import_path: str) -> Any: module, _, class_name = class_import_path.rpartition(".") - def validator_func(settings, validators): + def validator_func(settings: dynaconf.base.Settings, validators: Any) -> Any: return getattr(importlib.import_module(module), class_name) return validator_func @@ -40,7 +41,9 @@ def validator_func(settings, validators): class _IsSubclassValidator(Validator): """A validator to check if the supplied setting value is a subclass of the default class""" - def validate(self, settings, *args, **kwargs): + def validate( + self, settings: dynaconf.base.Settings, *args: Any, **kwargs: Any + ) -> None: super().validate(settings, *args, **kwargs) default_class = self.default(settings, self) @@ -58,7 +61,9 @@ class _HasSharedParentClassValidator(Validator): """A validator to check that the parent of the default class is an ancestor of the settings value.""" - def validate(self, settings, *args, **kwargs): + def validate( + self, settings: dynaconf.base.Settings, *args: Any, **kwargs: Any + ) -> None: super().validate(settings, *args, **kwargs) default_class = self.default(settings, self) @@ -112,7 +117,7 @@ class _ProjectSettings(LazySettings): "DATA_CATALOG_CLASS", default=_get_default_class("kedro.io.DataCatalog") ) - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): kwargs.update( validators=[ self._CONF_SOURCE, @@ -129,13 +134,13 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) -def _load_data_wrapper(func): +def _load_data_wrapper(func: Any) -> Any: """Wrap a method in _ProjectPipelines so that data is loaded on first access. Taking inspiration from dynaconf.utils.functional.new_method_proxy """ # noqa: protected-access - def inner(self, *args, **kwargs): + def inner(self: Any, *args: Any, **kwargs: Any) -> Any: self._load_data() return func(self._content, *args, **kwargs) @@ -165,12 +170,12 @@ def __init__(self) -> None: self._content: dict[str, Pipeline] = {} @staticmethod - def _get_pipelines_registry_callable(pipelines_module: str): + def _get_pipelines_registry_callable(pipelines_module: str) -> Any: module_obj = importlib.import_module(pipelines_module) register_pipelines = getattr(module_obj, "register_pipelines") return register_pipelines - def _load_data(self): + def _load_data(self) -> None: """Lazily read pipelines defined in the pipelines registry module.""" # If the pipelines dictionary has not been configured with a pipelines module @@ -212,7 +217,7 @@ def configure(self, pipelines_module: str | None = None) -> None: class _ProjectLogging(UserDict): # noqa: super-init-not-called - def __init__(self): + def __init__(self) -> None: """Initialise project logging. The path to logging configuration is given in environment variable KEDRO_LOGGING_CONFIG (defaults to default_logging.yml).""" path = os.environ.get( @@ -229,7 +234,7 @@ def configure(self, logging_config: dict[str, Any]) -> None: logging.config.dictConfig(logging_config) self.data = logging_config - def set_project_logging(self, package_name: str): + def set_project_logging(self, package_name: str) -> None: """Add the project level logging to the loggers upon provision of a package name. Checks if project logger already exists to prevent overwriting, if none exists it defaults to setting project logs at INFO level.""" @@ -246,7 +251,7 @@ def set_project_logging(self, package_name: str): pipelines = _ProjectPipelines() -def configure_project(package_name: str): +def configure_project(package_name: str) -> None: """Configure a Kedro project by populating its settings with values defined in user's settings.py and pipeline_registry.py. """ @@ -272,7 +277,7 @@ def configure_logging(logging_config: dict[str, Any]) -> None: LOGGING.configure(logging_config) -def validate_settings(): +def validate_settings() -> None: """Eagerly validate that the settings module is importable if it exists. This is desirable to surface any syntax or import errors early. In particular, without eagerly importing the settings module, dynaconf would silence any import error (e.g. missing @@ -287,7 +292,7 @@ def validate_settings(): "Kedro command line interface." ) # Check if file exists, if it does, validate it. - if importlib.util.find_spec(f"{PACKAGE_NAME}.settings") is not None: # type: ignore + if importlib.util.find_spec(f"{PACKAGE_NAME}.settings") is not None: importlib.import_module(f"{PACKAGE_NAME}.settings") else: logger = logging.getLogger(__name__) diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index 24def9e27f..166cc521d7 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -193,7 +193,7 @@ def _init_store(self) -> BaseSessionStore: store_args["session_id"] = self.session_id try: - return store_class(**store_args) + return store_class(**store_args) # type: ignore[no-any-return] except TypeError as err: raise ValueError( f"\n{err}.\nStore config must only contain arguments valid " @@ -204,7 +204,7 @@ def _init_store(self) -> BaseSessionStore: f"\n{err}.\nFailed to instantiate session store of type '{classpath}'." ) from err - def _log_exception(self, exc_type, exc_value, exc_tb): + def _log_exception(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None: type_ = [] if exc_type.__module__ == "builtins" else [exc_type.__module__] type_.append(exc_type.__qualname__) @@ -240,7 +240,7 @@ def load_context(self) -> KedroContext: ) self._hook_manager.hook.after_context_created(context=context) - return context + return context # type: ignore[no-any-return] def _get_config_loader(self) -> AbstractConfigLoader: """An instance of the config loader.""" @@ -248,24 +248,24 @@ def _get_config_loader(self) -> AbstractConfigLoader: extra_params = self.store.get("extra_params") config_loader_class = settings.CONFIG_LOADER_CLASS - return config_loader_class( + return config_loader_class( # type: ignore[no-any-return] conf_source=self._conf_source, env=env, runtime_params=extra_params, **settings.CONFIG_LOADER_ARGS, ) - def close(self): + def close(self) -> None: """Close the current session and save its store to disk if `save_on_close` attribute is True. """ if self.save_on_close: self._store.save() - def __enter__(self): + def __enter__(self) -> KedroSession: return self - def __exit__(self, exc_type, exc_value, tb_): + def __exit__(self, exc_type: Any, exc_value: Any, tb_: Any) -> None: if exc_type: self._log_exception(exc_type, exc_value, tb_) self.close() diff --git a/kedro/framework/session/store.py b/kedro/framework/session/store.py index 6aee727528..b0b82fc663 100644 --- a/kedro/framework/session/store.py +++ b/kedro/framework/session/store.py @@ -33,7 +33,7 @@ def read(self) -> dict[str, Any]: ) return {} - def save(self): + def save(self) -> None: """Persist the session store""" self._logger.debug( "'save()' not implemented for '%s'. Skipping the step.", diff --git a/kedro/framework/startup.py b/kedro/framework/startup.py index adbb7b19f2..166fa4505c 100644 --- a/kedro/framework/startup.py +++ b/kedro/framework/startup.py @@ -25,7 +25,7 @@ class ProjectMetadata(NamedTuple): example_pipeline: str -def _version_mismatch_error(kedro_init_version) -> str: +def _version_mismatch_error(kedro_init_version: str) -> str: return ( f"Your Kedro project version {kedro_init_version} does not match Kedro package " f"version {kedro_version} you are running. Make sure to update your project " diff --git a/kedro/io/cached_dataset.py b/kedro/io/cached_dataset.py index 24dcc7d96e..410073a65d 100644 --- a/kedro/io/cached_dataset.py +++ b/kedro/io/cached_dataset.py @@ -77,7 +77,7 @@ def _release(self) -> None: self._dataset.release() @staticmethod - def _from_config(config, version): + def _from_config(config: dict, version: Version | None) -> AbstractDataset: if VERSIONED_FLAG_KEY in config: raise ValueError( "Cached datasets should specify that they are versioned in the " @@ -96,7 +96,7 @@ def _describe(self) -> dict[str, Any]: "cache": self._cache._describe(), # noqa: protected-access } - def _load(self): + def _load(self) -> Any: data = self._cache.load() if self._cache.exists() else self._dataset.load() if not self._cache.exists(): @@ -111,7 +111,7 @@ def _save(self, data: Any) -> None: def _exists(self) -> bool: return self._cache.exists() or self._dataset.exists() - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: # clearing the cache can be prevented by modifying # how parallel runner handles datasets (not trivial!) logging.getLogger(__name__).warning("%s: clearing cache to pickle.", str(self)) diff --git a/kedro/io/core.py b/kedro/io/core.py index 7335c0b28a..10ee60aeb4 100644 --- a/kedro/io/core.py +++ b/kedro/io/core.py @@ -152,7 +152,7 @@ def from_config( ) from exc try: - dataset = class_obj(**config) # type: ignore + dataset = class_obj(**config) except TypeError as err: raise DatasetError( f"\n{err}.\nDataset '{name}' must only contain arguments valid for the " @@ -220,8 +220,8 @@ def save(self, data: _DI) -> None: message = f"Failed while saving data to data set {str(self)}.\n{str(exc)}" raise DatasetError(message) from exc - def __str__(self): - def _to_str(obj, is_root=False): + def __str__(self) -> str: + def _to_str(obj: Any, is_root: bool = False) -> str: """Returns a string representation where 1. The root level (i.e. the Dataset.__init__ arguments) are formatted like Dataset(key=value). @@ -312,7 +312,7 @@ def release(self) -> None: def _release(self) -> None: pass - def _copy(self, **overwrite_params) -> AbstractDataset: + def _copy(self, **overwrite_params: Any) -> AbstractDataset: dataset_copy = copy.deepcopy(self) for name, value in overwrite_params.items(): setattr(dataset_copy, name, value) @@ -566,7 +566,7 @@ def resolve_load_version(self) -> str | None: if not self._version: return None if self._version.load: - return self._version.load + return self._version.load # type: ignore[no-any-return] return self._fetch_latest_load_version() def _get_load_path(self) -> PurePosixPath: @@ -575,14 +575,14 @@ def _get_load_path(self) -> PurePosixPath: return self._filepath load_version = self.resolve_load_version() - return self._get_versioned_path(load_version) # type: ignore + return self._get_versioned_path(load_version) # type: ignore[arg-type] def resolve_save_version(self) -> str | None: """Compute the version the dataset should be saved with.""" if not self._version: return None if self._version.save: - return self._version.save + return self._version.save # type: ignore[no-any-return] return self._fetch_latest_save_version() def _get_save_path(self) -> PurePosixPath: @@ -591,7 +591,7 @@ def _get_save_path(self) -> PurePosixPath: return self._filepath save_version = self.resolve_save_version() - versioned_path = self._get_versioned_path(save_version) # type: ignore + versioned_path = self._get_versioned_path(save_version) # type: ignore[arg-type] if self._exists_function(str(versioned_path)): raise DatasetError( @@ -750,7 +750,7 @@ def get_filepath_str(raw_path: PurePath, protocol: str) -> str: return path -def validate_on_forbidden_chars(**kwargs): +def validate_on_forbidden_chars(**kwargs: Any) -> None: """Validate that string values do not include white-spaces or ;""" for key, value in kwargs.items(): if " " in value or ";" in value: diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index c36938a9ed..1b05c8634d 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -32,9 +32,7 @@ WORDS_REGEX_PATTERN = re.compile(r"\W+") -def _get_credentials( - credentials_name: str, credentials: dict[str, Any] -) -> dict[str, Any]: +def _get_credentials(credentials_name: str, credentials: dict[str, Any]) -> Any: """Return a set of credentials from the provided credentials dict. Args: @@ -121,7 +119,7 @@ def __init__( ) # Don't allow users to add/change attributes on the fly - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: msg = "Operation not allowed! " if key in self.__dict__: msg += "Please change datasets through configuration." @@ -197,7 +195,7 @@ def __init__( # noqa: PLR0913 self.add_feed_dict(feed_dict) @property - def _logger(self): + def _logger(self) -> logging.Logger: return logging.getLogger(__name__) @classmethod @@ -316,7 +314,7 @@ class to be loaded is specified with the key ``type`` and their ) @staticmethod - def _is_pattern(pattern: str): + def _is_pattern(pattern: str) -> bool: """Check if a given string is a pattern. Assume that any name with '{' is a pattern.""" return "{" in pattern @@ -416,7 +414,7 @@ def _get_dataset( return dataset - def __contains__(self, dataset_name): + def __contains__(self, dataset_name: str) -> bool: """Check if an item is in the catalog as a materialised dataset or pattern""" matched_pattern = self._match_pattern(self._dataset_patterns, dataset_name) if dataset_name in self._datasets or matched_pattern: @@ -551,7 +549,7 @@ def exists(self, name: str) -> bool: return False return dataset.exists() - def release(self, name: str): + def release(self, name: str) -> None: """Release any cached data associated with a data set Args: @@ -738,7 +736,7 @@ def shallow_copy( save_version=self._save_version, ) - def __eq__(self, other): + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] return (self._datasets, self._dataset_patterns) == ( other._datasets, other._dataset_patterns, @@ -757,6 +755,6 @@ def confirm(self, name: str) -> None: dataset = self._get_dataset(name) if hasattr(dataset, "confirm"): - dataset.confirm() # type: ignore + dataset.confirm() else: raise DatasetError(f"Dataset '{name}' does not have 'confirm' method") diff --git a/kedro/io/lambda_dataset.py b/kedro/io/lambda_dataset.py index 54b4f531c0..6b901d60b8 100644 --- a/kedro/io/lambda_dataset.py +++ b/kedro/io/lambda_dataset.py @@ -32,7 +32,7 @@ class LambdaDataset(AbstractDataset): """ def _describe(self) -> dict[str, Any]: - def _to_str(func): + def _to_str(func: Any) -> str | None: if not func: return None try: diff --git a/kedro/io/memory_dataset.py b/kedro/io/memory_dataset.py index 7c696bdfc9..3645de0a29 100644 --- a/kedro/io/memory_dataset.py +++ b/kedro/io/memory_dataset.py @@ -65,7 +65,7 @@ def _load(self) -> Any: data = _copy_with_mode(self._data, copy_mode=copy_mode) return data - def _save(self, data: Any): + def _save(self, data: Any) -> None: copy_mode = self._copy_mode or _infer_copy_mode(data) self._data = _copy_with_mode(data, copy_mode=copy_mode) @@ -96,11 +96,11 @@ def _infer_copy_mode(data: Any) -> str: try: import pandas as pd except ImportError: # pragma: no cover - pd = None # type: ignore # pragma: no cover + pd = None # type: ignore[assignment] # pragma: no cover try: import numpy as np except ImportError: # pragma: no cover - np = None # type: ignore # pragma: no cover + np = None # type: ignore[assignment] # pragma: no cover if pd and isinstance(data, pd.DataFrame) or np and isinstance(data, np.ndarray): copy_mode = "copy" diff --git a/kedro/io/shared_memory_dataset.py b/kedro/io/shared_memory_dataset.py index 2fa952ff65..8aff822db9 100644 --- a/kedro/io/shared_memory_dataset.py +++ b/kedro/io/shared_memory_dataset.py @@ -19,14 +19,14 @@ def __init__(self, manager: SyncManager | None = None): """ if manager: - self.shared_memory_dataset = manager.MemoryDataset() # type: ignore + self.shared_memory_dataset = manager.MemoryDataset() # type: ignore[attr-defined] else: - self.shared_memory_dataset = None # type: ignore + self.shared_memory_dataset = None - def set_manager(self, manager: SyncManager): - self.shared_memory_dataset = manager.MemoryDataset() # type: ignore + def set_manager(self, manager: SyncManager) -> None: + self.shared_memory_dataset = manager.MemoryDataset() # type: ignore[attr-defined] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # This if condition prevents recursive call when deserialising if name == "__setstate__": raise AttributeError() @@ -35,7 +35,7 @@ def __getattr__(self, name): def _load(self) -> Any: return self.shared_memory_dataset.load() - def _save(self, data: Any): + def _save(self, data: Any) -> None: """Calls save method of a shared MemoryDataset in SyncManager.""" try: self.shared_memory_dataset.save(data) diff --git a/kedro/ipython/__init__.py b/kedro/ipython/__init__.py index f814817991..d193789b59 100644 --- a/kedro/ipython/__init__.py +++ b/kedro/ipython/__init__.py @@ -6,10 +6,11 @@ import logging import sys +import typing from pathlib import Path from typing import Any -from IPython import get_ipython +import IPython from IPython.core.magic import needs_local_scope, register_line_magic from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring @@ -27,7 +28,7 @@ logger = logging.getLogger(__name__) -def load_ipython_extension(ipython): +def load_ipython_extension(ipython: Any) -> None: """ Main entry point when %load_ext kedro.ipython is executed, either manually or automatically through `kedro ipython` or `kedro jupyter lab/notebook`. @@ -46,6 +47,7 @@ def load_ipython_extension(ipython): reload_kedro() +@typing.no_type_check @needs_local_scope @magic_arguments() @argument( @@ -70,7 +72,7 @@ def magic_reload_kedro( line: str, local_ns: dict[str, Any] | None = None, conf_source: str | None = None, -): +) -> None: """ The `%reload_kedro` IPython line magic. See https://kedro.readthedocs.io/en/stable/notebooks_and_ipython/kedro_and_notebooks.html#reload-kedro-line-magic # noqa: line-too-long @@ -105,7 +107,7 @@ def reload_kedro( context = session.load_context() catalog = context.catalog - get_ipython().push( + IPython.get_ipython().push( # type: ignore[attr-defined, no-untyped-call] variables={ "context": context, "catalog": catalog, @@ -120,8 +122,8 @@ def reload_kedro( ) for line_magic in load_entry_points("line_magic"): - register_line_magic(needs_local_scope(line_magic)) - logger.info("Registered line magic '%s'", line_magic.__name__) # type: ignore + register_line_magic(needs_local_scope(line_magic)) # type: ignore[no-untyped-call] + logger.info("Registered line magic '%s'", line_magic.__name__) # type: ignore[attr-defined] def _resolve_project_path( @@ -163,7 +165,7 @@ def _resolve_project_path( return project_path -def _remove_cached_modules(package_name): # pragma: no cover +def _remove_cached_modules(package_name: str) -> None: # pragma: no cover to_remove = [mod for mod in sys.modules if mod.startswith(package_name)] # `del` is used instead of `reload()` because: If the new version of a module does not # define a name that was defined by the old version, the old definition remains. @@ -171,7 +173,7 @@ def _remove_cached_modules(package_name): # pragma: no cover del sys.modules[module] -def _find_kedro_project(current_dir: Path): # pragma: no cover +def _find_kedro_project(current_dir: Path) -> Any: # pragma: no cover while current_dir != current_dir.parent: if _is_project(current_dir): return current_dir diff --git a/kedro/logging.py b/kedro/logging.py index 534776c566..bb8a4fd544 100644 --- a/kedro/logging.py +++ b/kedro/logging.py @@ -6,6 +6,7 @@ import os import sys from pathlib import Path +from typing import Any import click import rich.logging @@ -27,7 +28,7 @@ class RichHandler(rich.logging.RichHandler): https://rich.readthedocs.io/en/stable/reference/traceback.html#rich.traceback.install """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) logging.captureWarnings(True) rich.pretty.install() @@ -55,4 +56,4 @@ def __init__(self, *args, **kwargs): # Rich traceback handling does not work on databricks. Hopefully this will be # fixed on their side at some point, but until then we disable it. # See https://github.com/Textualize/rich/issues/2455 - rich.traceback.install(**traceback_install_kwargs) + rich.traceback.install(**traceback_install_kwargs) # type: ignore[arg-type] diff --git a/kedro/pipeline/modular_pipeline.py b/kedro/pipeline/modular_pipeline.py index 1d583daca0..68cb4e2b12 100644 --- a/kedro/pipeline/modular_pipeline.py +++ b/kedro/pipeline/modular_pipeline.py @@ -233,11 +233,11 @@ def _is_transcode_base_in_mapping(name: str) -> bool: base_name, _ = _transcode_split(name) return base_name in mapping - def _map_transcode_base(name: str): + def _map_transcode_base(name: str) -> str: base_name, transcode_suffix = _transcode_split(name) return TRANSCODING_SEPARATOR.join((mapping[base_name], transcode_suffix)) - def _rename(name: str): + def _rename(name: str) -> str: rules = [ # if name mapped to new name, update with new name (lambda n: n in mapping, lambda n: mapping[n]), @@ -252,8 +252,9 @@ def _rename(name: str): ] for predicate, processor in rules: - if predicate(name): - return processor(name) + if predicate(name): # type: ignore[no-untyped-call] + processor_name: str = processor(name) # type: ignore[no-untyped-call] + return processor_name # leave name as is return name @@ -281,11 +282,12 @@ def _copy_node(node: Node) -> Node: f"{namespace}.{node.namespace}" if node.namespace else namespace ) - return node._copy( + node_copy: Node = node._copy( inputs=_process_dataset_names(node._inputs), outputs=_process_dataset_names(node._outputs), namespace=new_namespace, ) + return node_copy new_nodes = [_copy_node(n) for n in pipe.nodes] diff --git a/kedro/pipeline/node.py b/kedro/pipeline/node.py index 231a05399a..5f470e14e1 100644 --- a/kedro/pipeline/node.py +++ b/kedro/pipeline/node.py @@ -100,7 +100,10 @@ def __init__( # noqa: PLR0913 self._func = func self._inputs = inputs - self._outputs = outputs + # The type of _outputs is picked up as possibly being None, however the checks above prevent that + # ever being the case. Mypy doesn't get that though, so it complains about the assignment of outputs to + # _outputs with different types. + self._outputs: str | list[str] | dict[str, str] = outputs # type: ignore[assignment] if name and not re.match(r"[\w\.-]+$", name): raise ValueError( f"'{name}' is not a valid node name. It must contain only " @@ -120,7 +123,7 @@ def __init__( # noqa: PLR0913 self._validate_inputs_dif_than_outputs() self._confirms = confirms - def _copy(self, **overwrite_params): + def _copy(self, **overwrite_params: Any) -> Node: """ Helper function to copy the node, replacing some values. """ @@ -134,15 +137,15 @@ def _copy(self, **overwrite_params): "confirms": self._confirms, } params.update(overwrite_params) - return Node(**params) + return Node(**params) # type: ignore[arg-type] @property - def _logger(self): + def _logger(self) -> logging.Logger: return logging.getLogger(__name__) @property - def _unique_key(self): - def hashable(value): + def _unique_key(self) -> tuple[Any, Any] | Any | tuple: + def hashable(value: Any) -> tuple[Any, Any] | Any | tuple: if isinstance(value, dict): # we sort it because a node with inputs/outputs # {"arg1": "a", "arg2": "b"} is equivalent to @@ -152,23 +155,23 @@ def hashable(value): return tuple(value) return value - return (self.name, hashable(self._inputs), hashable(self._outputs)) + return self.name, hashable(self._inputs), hashable(self._outputs) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, Node): return NotImplemented return self._unique_key == other._unique_key - def __lt__(self, other): + def __lt__(self, other: Any) -> bool: if not isinstance(other, Node): return NotImplemented return self._unique_key < other._unique_key - def __hash__(self): + def __hash__(self) -> int: return hash(self._unique_key) - def __str__(self): - def _set_to_str(xset): + def __str__(self) -> str: + def _set_to_str(xset: set | list[str]) -> str: return f"[{';'.join(xset)}]" out_str = _set_to_str(self.outputs) if self._outputs else "None" @@ -177,13 +180,13 @@ def _set_to_str(xset): prefix = self._name + ": " if self._name else "" return prefix + f"{self._func_name}({in_str}) -> {out_str}" - def __repr__(self): # pragma: no cover + def __repr__(self) -> str: # pragma: no cover return ( f"Node({self._func_name}, {repr(self._inputs)}, {repr(self._outputs)}, " f"{repr(self._name)})" ) - def __call__(self, **kwargs) -> dict[str, Any]: + def __call__(self, **kwargs: Any) -> dict[str, Any]: return self.run(inputs=kwargs) @property @@ -207,7 +210,7 @@ def func(self) -> Callable: return self._func @func.setter - def func(self, func: Callable): + def func(self, func: Callable) -> None: """Sets the underlying function of the node. Useful if user wants to decorate the function in a node's Hook implementation. @@ -367,7 +370,7 @@ def run(self, inputs: dict[str, Any] | None = None) -> dict[str, Any]: ) raise exc - def _run_with_no_inputs(self, inputs: dict[str, Any]): + def _run_with_no_inputs(self, inputs: dict[str, Any]) -> Any: if inputs: raise ValueError( f"Node {str(self)} expected no inputs, " @@ -377,7 +380,7 @@ def _run_with_no_inputs(self, inputs: dict[str, Any]): return self._func() - def _run_with_one_input(self, inputs: dict[str, Any], node_input: str): + def _run_with_one_input(self, inputs: dict[str, Any], node_input: str) -> Any: if len(inputs) != 1 or node_input not in inputs: raise ValueError( f"Node {str(self)} expected one input named '{node_input}', " @@ -387,7 +390,7 @@ def _run_with_one_input(self, inputs: dict[str, Any], node_input: str): return self._func(inputs[node_input]) - def _run_with_list(self, inputs: dict[str, Any], node_inputs: list[str]): + def _run_with_list(self, inputs: dict[str, Any], node_inputs: list[str]) -> Any: # Node inputs and provided run inputs should completely overlap if set(node_inputs) != set(inputs.keys()): raise ValueError( @@ -398,7 +401,9 @@ def _run_with_list(self, inputs: dict[str, Any], node_inputs: list[str]): # Ensure the function gets the inputs in the correct order return self._func(*(inputs[item] for item in node_inputs)) - def _run_with_dict(self, inputs: dict[str, Any], node_inputs: dict[str, str]): + def _run_with_dict( + self, inputs: dict[str, Any], node_inputs: dict[str, str] + ) -> Any: # Node inputs and provided run inputs should completely overlap if set(node_inputs.values()) != set(inputs.keys()): raise ValueError( @@ -410,15 +415,17 @@ def _run_with_dict(self, inputs: dict[str, Any], node_inputs: dict[str, str]): kwargs = {arg: inputs[alias] for arg, alias in node_inputs.items()} return self._func(**kwargs) - def _outputs_to_dictionary(self, outputs): - def _from_dict(): + def _outputs_to_dictionary(self, outputs: Any) -> dict[str, Any]: + def _from_dict() -> dict[str, Any]: result, iterator = outputs, None # generator functions are lazy and we need a peek into their first output if inspect.isgenerator(outputs): (result,), iterator = spy(outputs) - keys = list(self._outputs.keys()) - names = list(self._outputs.values()) + # The type of _outputs is picked up as possibly not being a dict, but _from_dict is only called when + # it is a dictionary and so the calls to .keys and .values will work even though Mypy doesn't pick that up. + keys = list(self._outputs.keys()) # type: ignore[union-attr] + names = list(self._outputs.values()) # type: ignore[union-attr] if not isinstance(result, dict): raise ValueError( f"Failed to save outputs of node {self}.\n" @@ -439,7 +446,7 @@ def _from_dict(): result = tuple(result[k] for k in keys) return dict(zip(names, result)) - def _from_list(): + def _from_list() -> dict: result, iterator = outputs, None # generator functions are lazy and we need a peek into their first output if inspect.isgenerator(outputs): @@ -472,7 +479,9 @@ def _from_list(): return _from_dict() return _from_list() - def _validate_inputs(self, func, inputs): + def _validate_inputs( + self, func: Callable, inputs: None | str | list[str] | dict[str, str] + ) -> None: # inspect does not support built-in Python functions written in C. # Thus we only validate func if it is not built-in. if not inspect.isbuiltin(func): @@ -490,7 +499,7 @@ def _validate_inputs(self, func, inputs): f"but got {inputs}" ) from exc - def _validate_unique_outputs(self): + def _validate_unique_outputs(self) -> None: cnt = Counter(self.outputs) diff = {k for k in cnt if cnt[k] > 1} if diff: @@ -499,7 +508,7 @@ def _validate_unique_outputs(self): f"output(s) {diff}.\nNode outputs must be unique." ) - def _validate_inputs_dif_than_outputs(self): + def _validate_inputs_dif_than_outputs(self) -> None: common_in_out = set(self.inputs).intersection(set(self.outputs)) if common_in_out: raise ValueError( @@ -509,7 +518,9 @@ def _validate_inputs_dif_than_outputs(self): ) @staticmethod - def _process_inputs_for_bind(inputs: str | list[str] | dict[str, str] | None): + def _process_inputs_for_bind( + inputs: str | list[str] | dict[str, str] | None + ) -> tuple[list[str], dict[str, str]]: # Safeguard that we do not mutate list inputs inputs = copy.copy(inputs) args: list[str] = [] @@ -523,7 +534,7 @@ def _process_inputs_for_bind(inputs: str | list[str] | dict[str, str] | None): return args, kwargs -def _node_error_message(msg) -> str: +def _node_error_message(msg: str) -> str: return ( f"Invalid Node definition: {msg}\n" f"Format should be: node(function, inputs, outputs)" @@ -606,7 +617,9 @@ def node( # noqa: PLR0913 ) -def _dict_inputs_to_list(func: Callable[[Any], Any], inputs: dict[str, str]): +def _dict_inputs_to_list( + func: Callable[[Any], Any], inputs: dict[str, str] +) -> list[str]: """Convert a dict representation of the node inputs to a list, ensuring the appropriate order for binding them to the node's function. """ diff --git a/kedro/pipeline/pipeline.py b/kedro/pipeline/pipeline.py index c9794fb3b4..802d25e3c2 100644 --- a/kedro/pipeline/pipeline.py +++ b/kedro/pipeline/pipeline.py @@ -166,7 +166,7 @@ def __init__( self._nodes = tagged_nodes self._topo_sorted_nodes = _topologically_sorted(self.node_dependencies) - def __repr__(self): # pragma: no cover + def __repr__(self) -> str: # pragma: no cover """Pipeline ([node1, ..., node10 ...], name='pipeline_name')""" max_nodes_to_display = 10 @@ -178,27 +178,27 @@ def __repr__(self): # pragma: no cover constructor_repr = f"({nodes_reprs_str})" return f"{self.__class__.__name__}{constructor_repr}" - def __add__(self, other): + def __add__(self, other: Any) -> Pipeline: if not isinstance(other, Pipeline): return NotImplemented return Pipeline(set(self.nodes + other.nodes)) - def __radd__(self, other): + def __radd__(self, other: Any) -> Pipeline: if isinstance(other, int) and other == 0: return self return self.__add__(other) - def __sub__(self, other): + def __sub__(self, other: Any) -> Pipeline: if not isinstance(other, Pipeline): return NotImplemented return Pipeline(set(self.nodes) - set(other.nodes)) - def __and__(self, other): + def __and__(self, other: Any) -> Pipeline: if not isinstance(other, Pipeline): return NotImplemented return Pipeline(set(self.nodes) & set(other.nodes)) - def __or__(self, other): + def __or__(self, other: Any) -> Pipeline: if not isinstance(other, Pipeline): return NotImplemented return Pipeline(set(self.nodes + other.nodes)) @@ -260,7 +260,7 @@ def datasets(self) -> set[str]: """ return self.all_outputs() | self.all_inputs() - def _transcode_compatible_names(self): + def _transcode_compatible_names(self) -> set[str]: return {_strip_transcoding(ds) for ds in self.datasets()} def describe(self, names_only: bool = True) -> str: @@ -300,7 +300,7 @@ def describe(self, names_only: bool = True) -> str: """ - def set_to_string(set_of_strings): + def set_to_string(set_of_strings: set[str]) -> str: """Convert set to a string but return 'None' in case of an empty set. """ @@ -782,7 +782,7 @@ def tag(self, tags: str | Iterable[str]) -> Pipeline: nodes = [n.tag(tags) for n in self.nodes] return Pipeline(nodes) - def to_json(self): + def to_json(self) -> str: """Return a json representation of the pipeline.""" transformed = [ { @@ -801,11 +801,11 @@ def to_json(self): return json.dumps(pipeline_versioned) -def _validate_duplicate_nodes(nodes_or_pipes: Iterable[Node | Pipeline]): +def _validate_duplicate_nodes(nodes_or_pipes: Iterable[Node | Pipeline]) -> None: seen_nodes: set[str] = set() duplicates: dict[Pipeline | None, set[str]] = defaultdict(set) - def _check_node(node_: Node, pipeline_: Pipeline | None = None): + def _check_node(node_: Node, pipeline_: Pipeline | None = None) -> None: name = node_.name if name in seen_nodes: duplicates[pipeline_].add(name) @@ -884,7 +884,7 @@ def _validate_transcoded_inputs_outputs(nodes: list[Node]) -> None: ) -def _topologically_sorted(node_dependencies) -> list[list[Node]]: +def _topologically_sorted(node_dependencies: dict[Node, set[Node]]) -> list[list[Node]]: """Topologically group and sort (order) nodes such that no node depends on a node that appears in the same or a later group. diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 2a2a5c3b03..ae4db946f3 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -9,7 +9,7 @@ from collections import Counter from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait from itertools import chain -from multiprocessing.managers import BaseProxy, SyncManager # type: ignore +from multiprocessing.managers import BaseProxy, SyncManager from multiprocessing.reduction import ForkingPickler from pickle import PicklingError from typing import Any, Iterable @@ -45,12 +45,15 @@ class ParallelRunnerManager(SyncManager): ParallelRunnerManager.register("MemoryDataset", MemoryDataset) # noqa: no-member -def _bootstrap_subprocess(package_name: str, logging_config: dict[str, Any]): +def _bootstrap_subprocess( + package_name: str, logging_config: dict[str, Any] | None = None +) -> None: # noqa: import-outside-toplevel,cyclic-import from kedro.framework.project import configure_logging, configure_project configure_project(package_name) - configure_logging(logging_config) + if logging_config: + configure_logging(logging_config) def _run_node_synchronization( # noqa: PLR0913 @@ -80,7 +83,7 @@ def _run_node_synchronization( # noqa: PLR0913 """ if multiprocessing.get_start_method() == "spawn" and package_name: - _bootstrap_subprocess(package_name, logging_config) # type: ignore + _bootstrap_subprocess(package_name, logging_config) hook_manager = _create_hook_manager() _register_hooks(hook_manager, settings.HOOKS) @@ -139,11 +142,11 @@ def __init__( self._max_workers = max_workers - def __del__(self): + def __del__(self) -> None: self._manager.shutdown() @classmethod - def _validate_nodes(cls, nodes: Iterable[Node]): + def _validate_nodes(cls, nodes: Iterable[Node]) -> None: """Ensure all tasks are serialisable.""" unserialisable = [] for node in nodes: @@ -163,7 +166,7 @@ def _validate_nodes(cls, nodes: Iterable[Node]): ) @classmethod - def _validate_catalog(cls, catalog: DataCatalog, pipeline: Pipeline): + def _validate_catalog(cls, catalog: DataCatalog, pipeline: Pipeline) -> None: """Ensure that all data sets are serialisable and that we do not have any non proxied memory data sets being used as outputs as their content will not be synchronized across threads. @@ -208,7 +211,7 @@ def _validate_catalog(cls, catalog: DataCatalog, pipeline: Pipeline): f"MemoryDatasets" ) - def _set_manager_datasets(self, catalog, pipeline): + def _set_manager_datasets(self, catalog: DataCatalog, pipeline: Pipeline) -> None: for dataset in pipeline.datasets(): try: catalog.exists(dataset) @@ -218,7 +221,7 @@ def _set_manager_datasets(self, catalog, pipeline): if isinstance(ds, SharedMemoryDataset): ds.set_manager(self._manager) - def _get_required_workers_count(self, pipeline: Pipeline): + def _get_required_workers_count(self, pipeline: Pipeline) -> int: """ Calculate the max number of processes required for the pipeline, limit to the number of CPU cores. @@ -289,7 +292,7 @@ def _run( # noqa: too-many-locals,useless-suppression self._is_async, session_id, package_name=PACKAGE_NAME, - logging_config=LOGGING, # type: ignore + logging_config=LOGGING, # type: ignore[arg-type] ) ) if not futures: diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 1afa8dde92..20b594e8e0 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -49,7 +49,7 @@ def __init__( self._extra_dataset_patterns = extra_dataset_patterns @property - def _logger(self): + def _logger(self) -> logging.Logger: return logging.getLogger(self.__module__) def run( @@ -114,7 +114,7 @@ def run( self._logger.info( "Asynchronous mode is enabled for loading and saving data" ) - self._run(pipeline, catalog, hook_or_null_manager, session_id) # type: ignore + self._run(pipeline, catalog, hook_or_null_manager, session_id) # type: ignore[arg-type] self._logger.info("Pipeline execution completed successfully.") @@ -453,7 +453,7 @@ def _run_node_async( hook_manager: PluginManager, session_id: str | None = None, ) -> Node: - def _synchronous_dataset_load(dataset_name: str): + def _synchronous_dataset_load(dataset_name: str) -> Any: """Minimal wrapper to ensure Hooks are run synchronously within an asynchronous dataset load.""" hook_manager.hook.before_dataset_loaded(dataset_name=dataset_name, node=node) diff --git a/kedro/runner/thread_runner.py b/kedro/runner/thread_runner.py index 577863f8b3..0d28d3070f 100644 --- a/kedro/runner/thread_runner.py +++ b/kedro/runner/thread_runner.py @@ -64,7 +64,7 @@ def __init__( self._max_workers = max_workers - def _get_required_workers_count(self, pipeline: Pipeline): + def _get_required_workers_count(self, pipeline: Pipeline) -> int: """ Calculate the max number of processes required for the pipeline """ diff --git a/pyproject.toml b/pyproject.toml index 25f56de882..dd8e4cb8ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,4 +242,5 @@ known-first-party = ["kedro"] [tool.mypy] ignore_missing_imports = true +disable_error_code = ['misc'] exclude = ['^kedro/templates/', '^docs/', '^features/steps/test_starter/'] diff --git a/tests/framework/cli/test_starters.py b/tests/framework/cli/test_starters.py index d35beb7dd1..64fdd5006a 100644 --- a/tests/framework/cli/test_starters.py +++ b/tests/framework/cli/test_starters.py @@ -3,6 +3,7 @@ from __future__ import annotations import shutil +from collections import OrderedDict from pathlib import Path import pytest @@ -17,6 +18,7 @@ TEMPLATE_PATH, KedroStarterSpec, _convert_tool_names_to_numbers, + _fetch_config_from_user_prompts, _parse_tools_input, _parse_yes_no_to_bool, _validate_selection, @@ -439,6 +441,46 @@ def test_custom_prompt_for_essential_variable(self, fake_kedro_cli): ) _clean_up_project(Path("./my_custom_repo")) + def test_fetch_config_from_user_prompts_with_context(self, mocker): + required_prompts = { + "project_name": { + "title": "Project Name", + "text": "Please enter a name for your new project.", + }, + "tools": { + "title": "Project Tools", + "text": "These optional tools can help you apply software engineering best practices.", + }, + "example_pipeline": { + "title": "Example Pipeline", + "text": "Select whether you would like an example spaceflights pipeline included in your project.", + }, + } + cookiecutter_context = OrderedDict( + [ + ("project_name", "New Kedro Project"), + ("tools", "none"), + ("example_pipeline", "no"), + ] + ) + mocker.patch("cookiecutter.prompt.read_user_variable", return_value="none") + config = _fetch_config_from_user_prompts( + prompts=required_prompts, cookiecutter_context=cookiecutter_context + ) + assert config == { + "example_pipeline": "none", + "project_name": "none", + "tools": "none", + } + + def test_fetch_config_from_user_prompts_without_context(self): + required_prompts = {} + message = "No cookiecutter context available." + with pytest.raises(Exception, match=message): + _fetch_config_from_user_prompts( + prompts=required_prompts, cookiecutter_context=None + ) + @pytest.mark.usefixtures("chdir_to_tmp") class TestNewFromUserPromptsInvalid: