Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve resume suggestions for SequentialRunner #3026

Closed
101 changes: 66 additions & 35 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
as_completed,
wait,
)
from typing import Any, Iterable, Iterator
from typing import Any, Collection, Iterable, Iterator

from more_itertools import interleave
from pluggy import PluginManager
Expand Down Expand Up @@ -198,16 +198,15 @@ def _suggest_resume_scenario(

postfix = ""
if done_nodes:
node_names = (n.name for n in remaining_nodes)
resume_p = pipeline.only_nodes(*node_names)
start_p = resume_p.only_nodes_with_inputs(*resume_p.inputs())
# Find which of the remaining nodes would need to run first (in topo sort)
remaining_initial_nodes = _find_initial_nodes(pipeline, remaining_nodes)

# find the nearest persistent ancestors of the nodes in start_p
start_p_persistent_ancestors = _find_persistent_ancestors(
pipeline, start_p.nodes, catalog
# Find the nearest persistent ancestors of these nodes
persistent_ancestors = _find_persistent_ancestors(
pipeline, remaining_initial_nodes, catalog
)

start_node_names = (n.name for n in start_p_persistent_ancestors)
start_node_names = sorted(n.name for n in persistent_ancestors)
postfix += f" --from-nodes \"{','.join(start_node_names)}\""

if not postfix:
Expand All @@ -230,7 +229,7 @@ def _find_persistent_ancestors(
) -> set[Node]:
"""Breadth-first search approach to finding the complete set of
persistent ancestors of an iterable of ``Node``s. Persistent
ancestors exclusively have persisted ``Dataset``s as inputs.
ancestors exclusively have persisted ``Dataset``s or parameters as inputs.

Args:
pipeline: the ``Pipeline`` to find ancestors in.
Expand All @@ -242,54 +241,86 @@ def _find_persistent_ancestors(
``Node``s.

"""
ancestor_nodes_to_run = set()
initial_nodes_to_run: set[Node] = set()

queue, visited = deque(children), set(children)
while queue:
current_node = queue.popleft()
if _has_persistent_inputs(current_node, catalog):
ancestor_nodes_to_run.add(current_node)
impersistent_inputs = _enumerate_impersistent_inputs(current_node, catalog)

# If all inputs are persistent, we can run this node as is
if not impersistent_inputs:
initial_nodes_to_run.add(current_node)
continue
for parent in _enumerate_parents(pipeline, current_node):
if parent in visited:

# Otherwise, look for the nodes that produce impersistent inputs
for node in _enumerate_nodes_with_outputs(pipeline, impersistent_inputs):
if node in visited:
continue
visited.add(parent)
queue.append(parent)
return ancestor_nodes_to_run
visited.add(node)
queue.append(node)

return initial_nodes_to_run


def _enumerate_impersistent_inputs(node: Node, catalog: DataCatalog) -> set[str]:
"""Enumerate impersistent input Datasets of a ``Node``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Enumerate impersistent input Datasets of a ``Node``.
"""Enumerate impersistent input datasets of a ``Node``.


Args:
node: the ``Node`` to check the inputs of.
catalog: the ``DataCatalog`` of the run.

Returns:
Set of names of impersistent inputs of given ``Node``.

"""
# We use _data_sets because they pertain parameter name format
catalog_datasets = catalog._datasets
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, does this mean that catalog.dataset (the public attribute) does not contain the parameter name format?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@noklam Do you mean here you don't like the terminology of the word impersistent or do you think the implementation itself is leaky and should be changed?

If it's about terminology, I do agree that impersistent is a tricky word. It's of course the correct opposite of persistent, but for a non-native speaker (me) it makes it harder to read the code for some reason. Maybe just call it non_persistent? 😅 In any case I think it would be helpful to clarify what we mean with impersistent/non_persistent in the context of Kedro. From reading the code (specifically _enumerate_impersistent_inputs), I interpret the definition being anything that is a MemoryDataset. Even if it's defined in the catalog.

@merelcht Mostly terminology, it's just hard to remember as previously it's just MemoryDataset, now is basically memoryDataset + params.

New thought: #3520 - any chance we can utilise the new flag that we added? Can we load parameters as a "special" kind of MemoryDataset that will have self._EPHEMERAL = False? We can take this PR first and refactor this afterward if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

catalog.datasets members need to be accessible as attributes, so colons and periods are replaced with __. I agree this is very far from an elegant solution, so I'm definitely open to suggestions

missing_inputs: set[str] = set()
for node_input in node.inputs:
# Important difference vs. Kedro approach
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget to remove this comment, or otherwise clarify that parameters are treated as persistent.

if node_input.startswith("params:"):
continue
if isinstance(catalog_datasets[node_input], MemoryDataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use catalog.load instead of catalog._datasets? If i understand you want to perserve params: and catalog.load should work

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to load the dataset though? Isn't that a bit of overkill?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this approach I need to check the Dataset instance, not exactly its data contents

missing_inputs.add(node_input)

def _enumerate_parents(pipeline: Pipeline, child: Node) -> list[Node]:
"""For a given ``Node``, returns a list containing the direct parents
of that ``Node`` in the given ``Pipeline``.
return missing_inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call this impersistent_inputs instead? It's not immediately clear that impersistent and missing are meant to be the same thing.



def _enumerate_nodes_with_outputs(
pipeline: Pipeline, outputs: Collection[str]
) -> list[Node]:
"""For given outputs, returns a list containing nodes that
generate them in the given ``Pipeline``.

Args:
pipeline: the ``Pipeline`` to search for direct parents in.
child: the ``Node`` to find parents of.
pipeline: the ``Pipeline`` to search for nodes in.
outputs: the dataset names to find source nodes for.

Returns:
A list of all ``Node``s that are direct parents of ``child``.
A list of all ``Node``s that are producing ``outputs``.

"""
parent_pipeline = pipeline.only_nodes_with_outputs(*child.inputs)
parent_pipeline = pipeline.only_nodes_with_outputs(*outputs)
return parent_pipeline.nodes


def _has_persistent_inputs(node: Node, catalog: DataCatalog) -> bool:
"""Check if a ``Node`` exclusively has persisted Datasets as inputs.
If at least one input is a ``MemoryDataset``, return False.
def _find_initial_nodes(pipeline: Pipeline, nodes: Iterable[Node]) -> list[Node]:
"""Given a collection of ``Node``s in a ``Pipeline``,
find the initial group of ``Node``s to be run (in topological order).

Args:
node: the ``Node`` to check the inputs of.
catalog: the ``DataCatalog`` of the run.
pipeline: the ``Pipeline`` to search for initial ``Node``s in.
nodes: the ``Node``s to find initial group for.

Returns:
True if the ``Node`` being checked exclusively has inputs that
are not ``MemoryDataset``, else False.
A list of initial ``Node``s to run given inputs (in topological order).

"""
for node_input in node.inputs:
if isinstance(catalog._datasets[node_input], MemoryDataset):
return False
return True
node_names = set(n.name for n in nodes)
sub_pipeline = pipeline.only_nodes(*node_names)
initial_nodes = sub_pipeline.grouped_nodes[0]
merelcht marked this conversation as resolved.
Show resolved Hide resolved
return initial_nodes


def run_node(
Expand Down
41 changes: 39 additions & 2 deletions tests/runner/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def identity(arg):
return arg


def first_arg(*args):
return args[0]


def sink(arg):
pass

Expand All @@ -36,7 +40,7 @@ def return_not_serialisable(arg):
return lambda x: x


def multi_input_list_output(arg1, arg2):
def multi_input_list_output(arg1, arg2, arg3=None): # pylint: disable=unused-argument
return [arg1, arg2]


Expand Down Expand Up @@ -80,6 +84,8 @@ def _save(arg):
"ds0_B": persistent_dataset,
"ds2_A": persistent_dataset,
"ds2_B": persistent_dataset,
"dsX": persistent_dataset,
"params:p": MemoryDataset(1),
}
)

Expand Down Expand Up @@ -167,7 +173,38 @@ def two_branches_crossed_pipeline():
)


@pytest.fixture
@pytest.fixture(
params=[(), ("dsX",), ("params:p",)],
ids=[
"no_extras",
"extra_persistent_ds",
"extra_param",
],
)
def two_branches_crossed_pipeline_variable_inputs(request):
"""A ``Pipeline`` with an X-shape (two branches with one common node).
Non-persistent datasets (other than parameters) are prefixed with an underscore.
"""
extra_inputs = list(request.param)

return pipeline(
[
node(first_arg, ["ds0_A"] + extra_inputs, "_ds1_A", name="node1_A"),
node(first_arg, ["ds0_B"] + extra_inputs, "_ds1_B", name="node1_B"),
node(
multi_input_list_output,
["_ds1_A", "_ds1_B"] + extra_inputs,
["ds2_A", "ds2_B"],
name="node2",
),
node(first_arg, ["ds2_A"] + extra_inputs, "_ds3_A", name="node3_A"),
node(first_arg, ["ds2_B"] + extra_inputs, "_ds3_B", name="node3_B"),
node(first_arg, ["_ds3_A"] + extra_inputs, "_ds4_A", name="node4_A"),
node(first_arg, ["_ds3_B"] + extra_inputs, "_ds4_B", name="node4_B"),
]
)


def pipeline_with_memory_datasets():
return pipeline(
[
Expand Down
66 changes: 54 additions & 12 deletions tests/runner/test_sequential_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,18 +252,18 @@ def test_confirms(self, mocker, test_pipeline, is_async):
fake_dataset_instance.confirm.assert_called_once_with()


@pytest.mark.parametrize(
"failing_node_names,expected_pattern",
[
(["node1_A"], r"No nodes ran."),
(["node2"], r"(node1_A,node1_B|node1_B,node1_A)"),
(["node3_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node2", "node4_A"], r"(node1_A,node1_B|node1_B,node1_A)"),
],
)
class TestSuggestResumeScenario:
@pytest.mark.parametrize(
"failing_node_names,expected_pattern",
[
(["node1_A"], r"No nodes ran."),
(["node2"], r"(node1_A,node1_B|node1_B,node1_A)"),
(["node3_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node3_A", "node4_A"], r"(node3_A,node3_B|node3_B,node3_A)"),
(["node2", "node4_A"], r"(node1_A,node1_B|node1_B,node1_A)"),
],
)
def test_suggest_resume_scenario(
self,
caplog,
Expand All @@ -284,7 +284,49 @@ def test_suggest_resume_scenario(
persistent_dataset_catalog,
hook_manager=_create_hook_manager(),
)
assert re.search(expected_pattern, caplog.text)
assert re.search(
expected_pattern, caplog.text
), f"{expected_pattern=}, {caplog.text=}"

@pytest.mark.parametrize(
"failing_node_names,expected_pattern",
[
(["node1_A"], r"No nodes ran."),
(["node2"], r'"node1_A,node1_B"'),
(["node3_A"], r'"node3_A,node3_B"'),
(["node4_A"], r'"node3_A,node3_B"'),
(["node3_A", "node4_A"], r'"node3_A,node3_B"'),
(["node2", "node4_A"], r'"node1_A,node1_B"'),
],
)
def test_stricter_suggest_resume_scenario(
self,
caplog,
two_branches_crossed_pipeline_variable_inputs,
persistent_dataset_catalog,
failing_node_names,
expected_pattern,
):
"""
Stricter version of previous test.
Covers pipelines where inputs are shared across nodes.
"""
test_pipeline = two_branches_crossed_pipeline_variable_inputs

nodes = {n.name: n for n in test_pipeline.nodes}
for name in failing_node_names:
test_pipeline -= modular_pipeline([nodes[name]])
test_pipeline += modular_pipeline([nodes[name]._copy(func=exception_fn)])

with pytest.raises(Exception, match="test exception"):
SequentialRunner().run(
test_pipeline,
persistent_dataset_catalog,
hook_manager=_create_hook_manager(),
)
assert re.search(
expected_pattern, caplog.text
), f"{expected_pattern=}, {caplog.text=}"


class TestMemoryDatasetBehaviour:
Expand Down