Skip to content

Commit

Permalink
improve %load_node support full node syntax (#3633)
Browse files Browse the repository at this point in the history
* add docs for line magic

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>

* add release note

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>

* rename node->node_name

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>

* Intrdouce NodeBoundArguments for refactoring - introduce a proper object instead of generate string for cells directly

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>

* revert accidental chagnes of notebook docs

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>

* revert notes

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>

* rename variables

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>

* refactor function call to use the new arg list

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>

* bug fix, should iterate through *args because it can be more than one element

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>

* merge with main

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>

* Lint

Signed-off-by: lrcouto <laurarccouto@gmail.com>

* Add some tests

Signed-off-by: lrcouto <laurarccouto@gmail.com>

* Add another test

Signed-off-by: lrcouto <laurarccouto@gmail.com>

* Tests for format_node_inputs_text

Signed-off-by: lrcouto <laurarccouto@gmail.com>

* Lint

Signed-off-by: lrcouto <laurarccouto@gmail.com>

* Lint

Signed-off-by: lrcouto <laurarccouto@gmail.com>

* Attempt to suppress sphinx warning

Signed-off-by: lrcouto <laurarccouto@gmail.com>

* Attempt to suppress sphinx warning

Signed-off-by: lrcouto <laurarccouto@gmail.com>

* Make NodeBoundArguments private

Signed-off-by: lrcouto <laurarccouto@gmail.com>

---------

Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>
Signed-off-by: lrcouto <laurarccouto@gmail.com>
Co-authored-by: L. R. Couto <57910428+lrcouto@users.noreply.github.com>
Co-authored-by: lrcouto <laurarccouto@gmail.com>
  • Loading branch information
3 people authored Feb 26, 2024
1 parent 527108c commit ffc2683
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 31 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Major features and improvements
* Create the debugging line magic `%load_node` for Jupyter Notebook and Jupyter Lab.
* Add better IPython, VSCode Notebook support for `%load_node` and minimal support for Databricks.
* Add full Kedro Node input syntax for `%load_node`.

## Bug fixes and other changes
* Updated CLI Command `kedro catalog resolve` to work with dataset factories that use `PartitionedDataset`.
Expand Down
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@
"kedro_docs_style_guide.md",
]


type_targets = {
"py:class": (
"object",
Expand Down
106 changes: 84 additions & 22 deletions kedro/ipython/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import typing
import warnings
from pathlib import Path
from typing import Any, Callable
from types import MappingProxyType
from typing import Any, Callable, OrderedDict

from IPython.core.getipython import get_ipython
from IPython.core.magic import needs_local_scope, register_line_magic
Expand All @@ -36,6 +37,8 @@

logger = logging.getLogger(__name__)

FunctionParameters = MappingProxyType


def load_ipython_extension(ipython: Any) -> None:
"""
Expand All @@ -45,9 +48,9 @@ def load_ipython_extension(ipython: Any) -> None:
See https://ipython.readthedocs.io/en/stable/config/extensions/index.html
"""
ipython.register_magic_function(magic_reload_kedro, magic_name="reload_kedro")
logger.info("Registered line magic 'reload_kedro'")
logger.info("Registered line magic '%reload_kedro'")
ipython.register_magic_function(magic_load_node, magic_name="load_node")
logger.info("Registered line magic 'load_node'")
logger.info("Registered line magic '%load_node'")

if _find_kedro_project(Path.cwd()) is None:
logger.warning(
Expand Down Expand Up @@ -225,7 +228,9 @@ def magic_load_node(args: str) -> None:
"""

parameters = parse_argstring(magic_load_node, args)
cells = _load_node(parameters.node, pipelines)
node_name = parameters.node

cells = _load_node(node_name, pipelines)

run_environment = _guess_run_environment()
if run_environment == "jupyter":
Expand All @@ -240,6 +245,36 @@ def magic_load_node(args: str) -> None:
_print_cells(cells)


class _NodeBoundArguments(inspect.BoundArguments):
"""Similar to inspect.BoundArguments"""

def __init__(
self, signature: inspect.Signature, arguments: OrderedDict[str, Any]
) -> None:
super().__init__(signature, arguments)

@property
def input_params_dict(self) -> dict[str, str] | None:
"""A mapping of {variable name: dataset_name}"""
var_positional_arg_name = self._find_var_positional_arg()
inputs_params_dict = {}
for param, dataset_name in self.arguments.items():
if param == var_positional_arg_name:
# If the argument is *args, use the dataset name instead
for arg in dataset_name:
inputs_params_dict[arg] = arg
else:
inputs_params_dict[param] = dataset_name
return inputs_params_dict

def _find_var_positional_arg(self) -> str | None:
"""Find the name of the VAR_POSITIONAL argument( *args), if any."""
for k, v in self.signature.parameters.items():
if v.kind == inspect.Parameter.VAR_POSITIONAL:
return k
return None


def _create_cell_with_text(text: str, is_jupyter: bool = True) -> None:
if is_jupyter:
from ipylab import JupyterFrontEnd
Expand Down Expand Up @@ -277,16 +312,20 @@ def _load_node(node_name: str, pipelines: _ProjectPipelines) -> list[str]:
node = _find_node(node_name, pipelines)
node_func = node.func

node_inputs = _prepare_node_inputs(node)
imports = _prepare_imports(node_func)
function_definition = _prepare_function_body(node_func)
function_call = _prepare_function_call(node_func)
imports_cell = _prepare_imports(node_func)
function_definition_cell = _prepare_function_body(node_func)

node_bound_arguments = _get_node_bound_arguments(node)
inputs_params_mapping = _prepare_node_inputs(node_bound_arguments)
node_inputs_cell = _format_node_inputs_text(inputs_params_mapping)
function_call_cell = _prepare_function_call(node_func, node_bound_arguments)

cells: list[str] = []
cells.append(node_inputs)
cells.append(imports)
cells.append(function_definition)
cells.append(function_call)
if node_inputs_cell:
cells.append(node_inputs_cell)
cells.append(imports_cell)
cells.append(function_definition_cell)
cells.append(function_call_cell)
return cells


Expand Down Expand Up @@ -323,20 +362,37 @@ def _prepare_imports(node_func: Callable) -> str:
raise FileNotFoundError(f"Could not find {node_func.__name__}")


def _prepare_node_inputs(node: Node) -> str:
def _get_node_bound_arguments(node: Node) -> _NodeBoundArguments:
node_func = node.func
node_inputs = node.inputs

args, kwargs = Node._process_inputs_for_bind(node_inputs)
signature = inspect.signature(node_func)
bound_arguments = signature.bind(*args, **kwargs)
return _NodeBoundArguments(bound_arguments.signature, bound_arguments.arguments)

node_inputs = node.inputs
func_params = list(signature.parameters)

def _prepare_node_inputs(
node_bound_arguments: _NodeBoundArguments,
) -> dict[str, str] | None:
# Remove the *args. For example {'first_arg':'a', 'args': ('b','c')}
# will be loaded as follow:
# first_arg = catalog.load("a")
# b = catalog.load("b") # It doesn't have an arg name, so use the dataset name instead.
# c = catalog.load("c")
return node_bound_arguments.input_params_dict


def _format_node_inputs_text(input_params_dict: dict[str, str] | None) -> str | None:
statements = [
"# Prepare necessary inputs for debugging",
"# All debugging inputs must be defined in your project catalog",
]
if not input_params_dict:
return None

for node_input, func_param in zip(node_inputs, func_params):
statements.append(f'{func_param} = catalog.load("{node_input}")')
for func_param, dataset_name in input_params_dict.items():
statements.append(f'{func_param} = catalog.load("{dataset_name}")')

input_statements = "\n".join(statements)
return input_statements
Expand All @@ -348,13 +404,19 @@ def _prepare_function_body(func: Callable) -> str:
return body


def _prepare_function_call(node_func: Callable) -> str:
def _prepare_function_call(
node_func: Callable, node_bound_arguments: _NodeBoundArguments
) -> str:
"""Prepare the text for the function call."""
func_name = node_func.__name__
signature = inspect.signature(node_func)
func_params = list(signature.parameters)
args = node_bound_arguments.input_params_dict
kwargs = node_bound_arguments.kwargs

# Construct the statement of func_name(a=1,b=2,c=3)
func_args = ", ".join(func_params)
body = f"""{func_name}({func_args})"""
args_str_literal = [f"{node_input}" for node_input in args] if args else []
kwargs_str_literal = [
f"{node_input}={dataset_name}" for node_input, dataset_name in kwargs.items()
]
func_params = ", ".join(args_str_literal + kwargs_str_literal)
body = f"""{func_name}({func_params})"""
return body
31 changes: 31 additions & 0 deletions tests/ipython/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .dummy_function_fixtures import (
dummy_function,
dummy_function_with_loop,
dummy_function_with_variable_length,
dummy_nested_function,
)

Expand Down Expand Up @@ -105,6 +106,36 @@ def dummy_node():
)


@pytest.fixture
def dummy_node_empty_input():
return node(
func=dummy_function,
inputs=["", ""],
outputs=[None],
name="dummy_node_empty_input",
)


@pytest.fixture
def dummy_node_dict_input():
return node(
func=dummy_function,
inputs=dict(dummy_input="dummy_input", my_input="extra_input"),
outputs=["dummy_output"],
name="dummy_node_empty_input",
)


@pytest.fixture
def dummy_node_with_variable_length():
return node(
func=dummy_function_with_variable_length,
inputs=["dummy_input", "extra_input", "first", "second"],
outputs=["dummy_output"],
name="dummy_node_with_variable_length",
)


@pytest.fixture
def lambda_node():
return node(
Expand Down
4 changes: 4 additions & 0 deletions tests/ipython/dummy_function_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ def dummy_function_with_loop(dummy_list):
for x in dummy_list:
continue
return len(dummy_list)


def dummy_function_with_variable_length(dummy_input, my_input, *args, **kwargs):
pass
91 changes: 83 additions & 8 deletions tests/ipython/test_ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from kedro.framework.project import pipelines
from kedro.ipython import (
_find_node,
_format_node_inputs_text,
_get_node_bound_arguments,
_load_node,
_prepare_function_body,
_prepare_imports,
Expand Down Expand Up @@ -199,13 +201,13 @@ def test_load_extension_register_line_magic(self, mocker, ipython):
"--conf-source=new_conf",
],
)
def test_reload_kedro_magic_with_valid_arguments(self, mocker, args, ipython):
def test_line_magic_with_valid_arguments(self, mocker, args, ipython):
mocker.patch("kedro.ipython._find_kedro_project")
mocker.patch("kedro.ipython.reload_kedro")

ipython.magic(f"reload_kedro {args}")

def test_reload_kedro_with_invalid_arguments(self, mocker, ipython):
def test_line_magic_with_invalid_arguments(self, mocker, ipython):
mocker.patch("kedro.ipython._find_kedro_project")
mocker.patch("kedro.ipython.reload_kedro")
load_ipython_extension(ipython)
Expand Down Expand Up @@ -357,13 +359,48 @@ def test_prepare_node_inputs(
self,
dummy_node,
):
func_inputs = """# Prepare necessary inputs for debugging
# All debugging inputs must be defined in your project catalog
dummy_input = catalog.load("dummy_input")
my_input = catalog.load("extra_input")"""
expected = {"dummy_input": "dummy_input", "my_input": "extra_input"}

node_bound_arguments = _get_node_bound_arguments(dummy_node)
result = _prepare_node_inputs(node_bound_arguments)
assert result == expected

def test_prepare_node_inputs_when_input_is_empty(
self,
dummy_node_empty_input,
):
expected = {"dummy_input": "", "my_input": ""}

node_bound_arguments = _get_node_bound_arguments(dummy_node_empty_input)
result = _prepare_node_inputs(node_bound_arguments)
assert result == expected

def test_prepare_node_inputs_with_dict_input(
self,
dummy_node_dict_input,
):
expected = {"dummy_input": "dummy_input", "my_input": "extra_input"}

node_bound_arguments = _get_node_bound_arguments(dummy_node_dict_input)
result = _prepare_node_inputs(node_bound_arguments)
assert result == expected

result = _prepare_node_inputs(dummy_node)
assert result == func_inputs
def test_prepare_node_inputs_with_variable_length_args(
self,
dummy_node_with_variable_length,
):
expected = {
"dummy_input": "dummy_input",
"my_input": "extra_input",
"first": "first",
"second": "second",
}

node_bound_arguments = _get_node_bound_arguments(
dummy_node_with_variable_length
)
result = _prepare_node_inputs(node_bound_arguments)
assert result == expected

def test_prepare_function_body(self, dummy_function_defintion):
result = _prepare_function_body(dummy_function)
Expand Down Expand Up @@ -430,3 +467,41 @@ def test_load_node_with_other(self, mocker, ipython, run_env):
load_ipython_extension(ipython)
ipython.magic("load_node dummy_node")
spy.assert_called_once()


class TestFormatNodeInputsText:
def test_format_node_inputs_text_empty_input(self):
# Test with empty input_params_dict
input_params_dict = {}
expected_output = None
assert _format_node_inputs_text(input_params_dict) == expected_output

def test_format_node_inputs_text_single_input(self):
# Test with a single input
input_params_dict = {"input1": "dataset1"}
expected_output = (
"# Prepare necessary inputs for debugging\n"
"# All debugging inputs must be defined in your project catalog\n"
'input1 = catalog.load("dataset1")'
)
assert _format_node_inputs_text(input_params_dict) == expected_output

def test_format_node_inputs_text_multiple_inputs(self):
# Test with multiple inputs
input_params_dict = {
"input1": "dataset1",
"input2": "dataset2",
"input3": "dataset3",
}
expected_output = (
"# Prepare necessary inputs for debugging\n"
"# All debugging inputs must be defined in your project catalog\n"
'input1 = catalog.load("dataset1")\n'
'input2 = catalog.load("dataset2")\n'
'input3 = catalog.load("dataset3")'
)
assert _format_node_inputs_text(input_params_dict) == expected_output

def test_format_node_inputs_text_no_catalog_load(self):
# Test with no catalog.load() statements if input_params_dict is None
assert _format_node_inputs_text(None) is None

0 comments on commit ffc2683

Please sign in to comment.