From a153afd7924acd77765151aa7d48465852b35a35 Mon Sep 17 00:00:00 2001 From: Merel Theisen <49397448+merelcht@users.noreply.github.com> Date: Tue, 6 Aug 2024 09:04:02 +0100 Subject: [PATCH] Move `_find_run_command` from template to framework (#4012) * Move _find_run_command from template to framework * Solve importing issues by moving _find_run_command to CLI * Make _find_run_command public * Move find run command functions to cli utils * Add tests for find run command * Update release notes --------- Signed-off-by: Merel Theisen Signed-off-by: Merel Theisen <49397448+merelcht@users.noreply.github.com> Co-authored-by: Nok Lam Chan --- RELEASE.md | 1 + .../__main__.py | 33 +-------- kedro/framework/cli/utils.py | 44 ++++++++++- .../__main__.py | 33 +-------- tests/framework/cli/test_cli.py | 73 +++++++++++++++++++ 5 files changed, 121 insertions(+), 63 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index cbb3a9f3a0..0df120a8c9 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -3,6 +3,7 @@ ## Major features and improvements ## Bug fixes and other changes +* Moved `_find_run_command()` and `_find_run_command_in_plugins()` from `__main__.py` in the project template to the framework itself. ## Breaking changes to the API diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py b/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py index 9e6750922a..d951412ad1 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py @@ -1,45 +1,16 @@ """{{ cookiecutter.project_name }} file for ensuring the package is executable as `{{ cookiecutter.repo_name }}` and `python -m {{ cookiecutter.python_package }}` """ -import importlib from pathlib import Path -from kedro.framework.cli.utils import KedroCliError, load_entry_points +from kedro.framework.cli.utils import find_run_command from kedro.framework.project import configure_project -def _find_run_command(package_name): - try: - project_cli = importlib.import_module(f"{package_name}.cli") - # fail gracefully if cli.py does not exist - except ModuleNotFoundError as exc: - if f"{package_name}.cli" not in str(exc): - raise - plugins = load_entry_points("project") - run = _find_run_command_in_plugins(plugins) if plugins else None - if run: - # use run command from installed plugin if it exists - return run - # use run command from `kedro.framework.cli.project` - from kedro.framework.cli.project import run - - return run - # fail badly if cli.py exists, but has no `cli` in it - if not hasattr(project_cli, "cli"): - raise KedroCliError(f"Cannot load commands from {package_name}.cli") - return project_cli.run - - -def _find_run_command_in_plugins(plugins): - for group in plugins: - if "run" in group.commands: - return group.commands["run"] - - def main(*args, **kwargs): package_name = Path(__file__).parent.name configure_project(package_name) - run = _find_run_command(package_name) + run = find_run_command(package_name) run(*args, **kwargs) diff --git a/kedro/framework/cli/utils.py b/kedro/framework/cli/utils.py index 7258fb2680..6611cabac5 100644 --- a/kedro/framework/cli/utils.py +++ b/kedro/framework/cli/utils.py @@ -2,6 +2,7 @@ from __future__ import annotations import difflib +import importlib import logging import re import shlex @@ -16,7 +17,7 @@ from importlib import import_module from itertools import chain from pathlib import Path -from typing import IO, Any, Iterable, Sequence +from typing import IO, Any, Callable, Iterable, Sequence import click import importlib_metadata @@ -388,6 +389,47 @@ def load_entry_points(name: str) -> Sequence[click.MultiCommand]: return entry_point_commands +def find_run_command(package_name: str) -> Callable: + """Find the run command to be executed. + This is either the default run command defined in the Kedro framework or a run command defined by + an installed plugin. + + Args: + package_name: The name of the package being run. + + Raises: + KedroCliError: If the run command is not found. + + Returns: + Run command to be executed. + """ + try: + project_cli = importlib.import_module(f"{package_name}.cli") + # fail gracefully if cli.py does not exist + except ModuleNotFoundError as exc: + if f"{package_name}.cli" not in str(exc): + raise + plugins = load_entry_points("project") + run = _find_run_command_in_plugins(plugins) if plugins else None + if run: + # use run command from installed plugin if it exists + return run # type: ignore[no-any-return] + # use run command from `kedro.framework.cli.project` + from kedro.framework.cli.project import run + + return run # type: ignore[no-any-return] + # fail badly if cli.py exists, but has no `cli` in it + if not hasattr(project_cli, "cli"): + raise KedroCliError(f"Cannot load commands from {package_name}.cli") + return project_cli.run # type: ignore[no-any-return] + + +def _find_run_command_in_plugins(plugins: Any) -> Any: + for group in plugins: + if "run" in group.commands: + return group.commands["run"] + + @typing.no_type_check def _config_file_callback(ctx: click.Context, param: Any, value: Any) -> Any: """CLI callback that replaces command line options diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py b/kedro/templates/project/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py index 9e6750922a..d951412ad1 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/src/{{ cookiecutter.python_package }}/__main__.py @@ -1,45 +1,16 @@ """{{ cookiecutter.project_name }} file for ensuring the package is executable as `{{ cookiecutter.repo_name }}` and `python -m {{ cookiecutter.python_package }}` """ -import importlib from pathlib import Path -from kedro.framework.cli.utils import KedroCliError, load_entry_points +from kedro.framework.cli.utils import find_run_command from kedro.framework.project import configure_project -def _find_run_command(package_name): - try: - project_cli = importlib.import_module(f"{package_name}.cli") - # fail gracefully if cli.py does not exist - except ModuleNotFoundError as exc: - if f"{package_name}.cli" not in str(exc): - raise - plugins = load_entry_points("project") - run = _find_run_command_in_plugins(plugins) if plugins else None - if run: - # use run command from installed plugin if it exists - return run - # use run command from `kedro.framework.cli.project` - from kedro.framework.cli.project import run - - return run - # fail badly if cli.py exists, but has no `cli` in it - if not hasattr(project_cli, "cli"): - raise KedroCliError(f"Cannot load commands from {package_name}.cli") - return project_cli.run - - -def _find_run_command_in_plugins(plugins): - for group in plugins: - if "run" in group.commands: - return group.commands["run"] - - def main(*args, **kwargs): package_name = Path(__file__).parent.name configure_project(package_name) - run = _find_run_command(package_name) + run = find_run_command(package_name) run(*args, **kwargs) diff --git a/tests/framework/cli/test_cli.py b/tests/framework/cli/test_cli.py index b9144ebc1b..cd83c9bf21 100644 --- a/tests/framework/cli/test_cli.py +++ b/tests/framework/cli/test_cli.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import click +import pytest from click.testing import CliRunner from omegaconf import OmegaConf from pytest import fixture, mark, raises, warns @@ -23,6 +24,7 @@ CommandCollection, KedroCliError, _clean_pycache, + find_run_command, forward_command, get_pkg_version, ) @@ -276,6 +278,77 @@ def test_clean_pycache(self, tmp_path, mocker): ] assert mocked_rmtree.mock_calls == expected_calls + def test_find_run_command_non_existing_project(self): + with pytest.raises(ModuleNotFoundError, match="No module named 'fake_project'"): + _ = find_run_command("fake_project") + + def test_find_run_command_with_clipy( + self, fake_metadata, fake_repo_path, fake_project_cli, mocker + ): + mocker.patch("kedro.framework.cli.cli._is_project", return_value=True) + mocker.patch( + "kedro.framework.cli.cli.bootstrap_project", return_value=fake_metadata + ) + + mock_project_cli = MagicMock(spec=[fake_repo_path / "cli.py"]) + mock_project_cli.cli = MagicMock(spec=["cli"]) + mock_project_cli.run = MagicMock(spec=["run"]) + mocker.patch( + "kedro.framework.cli.utils.importlib.import_module", + return_value=mock_project_cli, + ) + + run = find_run_command(fake_metadata.package_name) + assert run is mock_project_cli.run + + def test_find_run_command_no_clipy(self, fake_metadata, fake_repo_path, mocker): + mocker.patch("kedro.framework.cli.cli._is_project", return_value=True) + mocker.patch( + "kedro.framework.cli.cli.bootstrap_project", return_value=fake_metadata + ) + mock_project_cli = MagicMock(spec=[fake_repo_path / "cli.py"]) + mocker.patch( + "kedro.framework.cli.utils.importlib.import_module", + return_value=mock_project_cli, + ) + + with raises(KedroCliError, match="Cannot load commands from"): + _ = find_run_command(fake_metadata.package_name) + + def test_find_run_command_use_plugin_run( + self, fake_metadata, fake_repo_path, mocker + ): + mock_plugin = MagicMock(spec=["plugins"]) + mock_command = MagicMock(name="run_command") + mock_plugin.commands = {"run": mock_command} + mocker.patch( + "kedro.framework.cli.utils.load_entry_points", return_value=[mock_plugin] + ) + + mocker.patch("kedro.framework.cli.cli._is_project", return_value=True) + mocker.patch( + "kedro.framework.cli.cli.bootstrap_project", return_value=fake_metadata + ) + mocker.patch( + "kedro.framework.cli.cli.importlib.import_module", + side_effect=ModuleNotFoundError("dummy_package.cli"), + ) + + run = find_run_command(fake_metadata.package_name) + assert run == mock_command + + def test_find_run_command_use_default_run(self, fake_metadata, mocker): + mocker.patch("kedro.framework.cli.cli._is_project", return_value=True) + mocker.patch( + "kedro.framework.cli.cli.bootstrap_project", return_value=fake_metadata + ) + mocker.patch( + "kedro.framework.cli.cli.importlib.import_module", + side_effect=ModuleNotFoundError("dummy_package.cli"), + ) + run = find_run_command(fake_metadata.package_name) + assert run.help == "Run the pipeline." + class TestEntryPoints: def test_project_groups(self, entry_points, entry_point):