Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve resume suggestions for SequentialRunner #3026

Closed
43 changes: 32 additions & 11 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,15 @@ def _suggest_resume_scenario(

postfix = ""
if done_nodes:
node_names = (n.name for n in remaining_nodes)
resume_p = pipeline.only_nodes(*node_names)
start_p = resume_p.only_nodes_with_inputs(*resume_p.inputs())
# Find which of the remaining nodes would need to run first (in topo sort)
remaining_initial_nodes = _find_initial_nodes(pipeline, remaining_nodes)

# find the nearest persistent ancestors of the nodes in start_p
start_p_persistent_ancestors = _find_persistent_ancestors(
pipeline, start_p.nodes, catalog
# Find the nearest persistent ancestors of these nodes
persistent_ancestors = _find_persistent_ancestors(
pipeline, remaining_initial_nodes, catalog
)

start_node_names = (n.name for n in start_p_persistent_ancestors)
start_node_names = sorted(n.name for n in persistent_ancestors)
postfix += f" --from-nodes \"{','.join(start_node_names)}\""

if not postfix:
Expand All @@ -229,7 +228,7 @@ def _find_persistent_ancestors(
) -> set[Node]:
"""Breadth-first search approach to finding the complete set of
persistent ancestors of an iterable of ``Node``s. Persistent
ancestors exclusively have persisted ``Dataset``s as inputs.
ancestors exclusively have persisted ``Dataset``s or parameters as inputs.

Args:
pipeline: the ``Pipeline`` to find ancestors in.
Expand Down Expand Up @@ -273,25 +272,47 @@ def _enumerate_parents(pipeline: Pipeline, child: Node) -> list[Node]:


def _has_persistent_inputs(node: Node, catalog: DataCatalog) -> bool:
"""Check if a ``Node`` exclusively has persisted Datasets as inputs.
If at least one input is a ``MemoryDataset``, return False.
"""Check if a ``Node`` exclusively has either persisted Datasets
or parameters as inputs. If at least one non-parametric input
is a ``MemoryDataset``, return False.

Args:
node: the ``Node`` to check the inputs of.
catalog: the ``DataCatalog`` of the run.

Returns:
True if the ``Node`` being checked exclusively has inputs that
are not ``MemoryDataset``, else False.
are persisted Datasets or parameters, else False.

"""
for node_input in node.inputs:
# Parameters are represented as MemoryDatasets but are considered persistent
if node_input.startswith("params:"):
continue
# noqa: protected-access
if isinstance(catalog._data_sets[node_input], MemoryDataset):
return False
return True


def _find_initial_nodes(pipeline: Pipeline, nodes: Iterable[Node]) -> list[Node]:
"""Given a collection of ``Node``s in a ``Pipeline``,
find the initial group of ``Node``s to be run (in topological order).

Args:
pipeline: the ``Pipeline`` to search for initial ``Node``s in.
nodes: the ``Node``s to find initial group for.

Returns:
A list of initial ``Node``s to run given inputs (in topological order).

"""
node_names = set(n.name for n in nodes)
sub_pipeline = pipeline.only_nodes(*node_names)
initial_nodes = sub_pipeline.grouped_nodes[0]
merelcht marked this conversation as resolved.
Show resolved Hide resolved
return initial_nodes


def run_node(
node: Node,
catalog: DataCatalog,
Expand Down
40 changes: 39 additions & 1 deletion tests/runner/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def identity(arg):
return arg


def first_arg(*args):
return args[0]


def sink(arg): # pylint: disable=unused-argument
pass

Expand All @@ -36,7 +40,7 @@ def return_not_serialisable(arg): # pylint: disable=unused-argument
return lambda x: x


def multi_input_list_output(arg1, arg2):
def multi_input_list_output(arg1, arg2, arg3=None): # pylint: disable=unused-argument
return [arg1, arg2]


Expand Down Expand Up @@ -81,6 +85,8 @@ def _save(arg):
"ds0_B": persistent_dataset,
"ds2_A": persistent_dataset,
"ds2_B": persistent_dataset,
"dsX": persistent_dataset,
"params:p": MemoryDataSet(1),
}
)

Expand Down Expand Up @@ -166,3 +172,35 @@ def two_branches_crossed_pipeline():
node(identity, "ds3_B", "ds4_B", name="node4_B"),
]
)


@pytest.fixture(
params=[(), ("dsX",), ("params:p",)],
ids=[
"no_extras",
"extra_persistent_ds",
"extra_param",
],
)
def two_branches_crossed_pipeline_variable_inputs(request):
"""A ``Pipeline`` with an X-shape (two branches with one common node).
Non-persistent datasets (other than parameters) are prefixed with an underscore.
"""
extra_inputs = list(request.param)

return pipeline(
[
node(first_arg, ["ds0_A"] + extra_inputs, "_ds1_A", name="node1_A"),
node(first_arg, ["ds0_B"] + extra_inputs, "_ds1_B", name="node1_B"),
node(
multi_input_list_output,
["_ds1_A", "_ds1_B"] + extra_inputs,
["ds2_A", "ds2_B"],
name="node2",
),
node(first_arg, ["ds2_A"] + extra_inputs, "_ds3_A", name="node3_A"),
node(first_arg, ["ds2_B"] + extra_inputs, "_ds3_B", name="node3_B"),
node(first_arg, ["_ds3_A"] + extra_inputs, "_ds4_A", name="node4_A"),
node(first_arg, ["_ds3_B"] + extra_inputs, "_ds4_B", name="node4_B"),
]
)
60 changes: 49 additions & 11 deletions tests/runner/test_sequential_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,18 @@ def test_confirms(self, mocker, test_pipeline, is_async):
fake_dataset_instance.confirm.assert_called_once_with()


@pytest.mark.parametrize(
"failing_node_names,expected_pattern",
[
(["node1_A"], r"No nodes ran."),
(["node2"], r"(node1_A,node1_B|node1_B,node1_A)"),
(["node3_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node2", "node4_A"], r"(node1_A,node1_B|node1_B,node1_A)"),
],
)
class TestSuggestResumeScenario:
@pytest.mark.parametrize(
"failing_node_names,expected_pattern",
[
(["node1_A"], r"No nodes ran."),
(["node2"], r"(node1_A,node1_B|node1_B,node1_A)"),
(["node3_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node2", "node4_A"], r"(node1_A,node1_B|node1_B,node1_A)"),
],
)
def test_suggest_resume_scenario(
self,
caplog,
Expand All @@ -274,3 +274,41 @@ def test_suggest_resume_scenario(
hook_manager=_create_hook_manager(),
)
assert re.search(expected_pattern, caplog.text)

@pytest.mark.parametrize(
"failing_node_names,expected_pattern",
[
(["node1_A"], r"No nodes ran."),
(["node2"], r'"node1_A,node1_B"'),
(["node3_A"], r'"node3_A,node3_B"'),
(["node4_A"], r'"node3_A,node3_B"'),
(["node3_A", "node4_A"], r'"node3_A,node3_B"'),
(["node2", "node4_A"], r'"node1_A,node1_B"'),
],
)
def test_stricter_suggest_resume_scenario(
self,
caplog,
two_branches_crossed_pipeline_variable_inputs,
persistent_dataset_catalog,
failing_node_names,
expected_pattern,
):
"""
Stricter version of previous test.
Covers pipelines where inputs are shared across nodes.
"""
test_pipeline = two_branches_crossed_pipeline_variable_inputs

nodes = {n.name: n for n in test_pipeline.nodes}
for name in failing_node_names:
test_pipeline -= modular_pipeline([nodes[name]])
test_pipeline += modular_pipeline([nodes[name]._copy(func=exception_fn)])

with pytest.raises(Exception, match="test exception"):
SequentialRunner().run(
test_pipeline,
persistent_dataset_catalog,
hook_manager=_create_hook_manager(),
)
assert re.search(expected_pattern, caplog.text)