Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve resume suggestions #3719

Merged
merged 13 commits into from
Apr 2, 2024
172 changes: 127 additions & 45 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
as_completed,
wait,
)
from typing import Any, Iterable, Iterator
from typing import Any, Collection, Iterable, Iterator

from more_itertools import interleave
from pluggy import PluginManager
Expand Down Expand Up @@ -198,39 +198,62 @@ def _suggest_resume_scenario(

postfix = ""
if done_nodes:
node_names = (n.name for n in remaining_nodes)
resume_p = pipeline.only_nodes(*node_names)
start_p = resume_p.only_nodes_with_inputs(*resume_p.inputs())

# find the nearest persistent ancestors of the nodes in start_p
start_p_persistent_ancestors = _find_persistent_ancestors(
pipeline, start_p.nodes, catalog
start_node_names = find_nodes_to_resume_from(
pipeline=pipeline,
unfinished_nodes=remaining_nodes,
catalog=catalog,
)

start_node_names = (n.name for n in start_p_persistent_ancestors)
postfix += f" --from-nodes \"{','.join(start_node_names)}\""
start_nodes_str = ",".join(sorted(start_node_names))
postfix += f' --from-nodes "{start_nodes_str}"'

if not postfix:
self._logger.warning(
"No nodes ran. Repeat the previous command to attempt a new run."
)
else:
self._logger.warning(
"There are %d nodes that have not run.\n"
f"There are {len(remaining_nodes)} nodes that have not run.\n"
"You can resume the pipeline run from the nearest nodes with "
"persisted inputs by adding the following "
"argument to your previous command:\n%s",
len(remaining_nodes),
postfix,
f"argument to your previous command:\n{postfix}"
)


def _find_persistent_ancestors(
pipeline: Pipeline, children: Iterable[Node], catalog: DataCatalog
def find_nodes_to_resume_from(
merelcht marked this conversation as resolved.
Show resolved Hide resolved
pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: DataCatalog
) -> set[str]:
"""Given a collection of unfinished nodes in a pipeline using
a certain catalog, find the node names to pass to pipeline.from_nodes()
to cover all unfinished nodes, including any additional nodes
that should be re-run if their outputs are not persisted.

Args:
pipeline: the ``Pipeline`` to find starting nodes for.
unfinished_nodes: collection of ``Node``s that have not finished yet
catalog: the ``DataCatalog`` of the run.

Returns:
Set of node names to pass to pipeline.from_nodes() to continue
the run.

"""
all_nodes_that_need_to_run = find_all_required_nodes(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all_nodes_that_need_to_run = find_all_required_nodes(
nodes_to_be_run = find_all_required_nodes(

?

pipeline, unfinished_nodes, catalog
)

# Find which of the remaining nodes would need to run first (in topo sort)
persistent_ancestors = find_initial_node_group(pipeline, all_nodes_that_need_to_run)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed for "Number of suggested nodes to re-run from is minimised (-> shorter message for the same pipeline)" this?


return {n.name for n in persistent_ancestors}


def find_all_required_nodes(
pipeline: Pipeline, unfinished_nodes: Iterable[Node], catalog: DataCatalog
) -> set[Node]:
"""Breadth-first search approach to finding the complete set of
persistent ancestors of an iterable of ``Node``s. Persistent
ancestors exclusively have persisted ``Dataset``s as inputs.
``Node``s which need to run to cover all unfinished nodes,
including any additional nodes that should be re-run if their outputs
are not persisted.

Args:
pipeline: the ``Pipeline`` to find ancestors in.
merelcht marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -242,54 +265,113 @@ def _find_persistent_ancestors(
``Node``s.

"""
ancestor_nodes_to_run = set()
queue, visited = deque(children), set(children)
nodes_to_run = set(unfinished_nodes)
initial_nodes = _nodes_with_external_inputs(pipeline, unfinished_nodes)

queue, visited = deque(initial_nodes), set(initial_nodes)
DimedS marked this conversation as resolved.
Show resolved Hide resolved
while queue:
current_node = queue.popleft()
if _has_persistent_inputs(current_node, catalog):
ancestor_nodes_to_run.add(current_node)
continue
for parent in _enumerate_parents(pipeline, current_node):
if parent in visited:
nodes_to_run.add(current_node)
non_persistent_inputs = _enumerate_non_persistent_inputs(current_node, catalog)
# Look for the nodes that produce non-persistent inputs (if those exist)
for node in _enumerate_nodes_with_outputs(pipeline, non_persistent_inputs):
if node in visited:
merelcht marked this conversation as resolved.
Show resolved Hide resolved
continue
visited.add(parent)
queue.append(parent)
return ancestor_nodes_to_run
visited.add(node)
queue.append(node)

# Make sure no downstream tasks are skipped
nodes_to_run = pipeline.from_nodes(*(n.name for n in nodes_to_run)).nodes

def _enumerate_parents(pipeline: Pipeline, child: Node) -> list[Node]:
"""For a given ``Node``, returns a list containing the direct parents
of that ``Node`` in the given ``Pipeline``.
return set(nodes_to_run)


def _nodes_with_external_inputs(
pipeline: Pipeline, nodes_of_interest: Iterable[Node]
) -> set[Node]:
"""For given ``Node``s in a ``Pipeline``, find their
subset which depends on external inputs of the ``Pipeline``.

Args:
pipeline: the ``Pipeline`` to search for direct parents in.
child: the ``Node`` to find parents of.
pipeline: the ``Pipeline`` to search for nodes in.
nodes_of_interest: the ``Node``s to analyze.

Returns:
A list of all ``Node``s that are direct parents of ``child``.
A set of ``Node``s that depend on external inputs
of nodes of interest.

"""
parent_pipeline = pipeline.only_nodes_with_outputs(*child.inputs)
return parent_pipeline.nodes
p_nodes_of_interest = pipeline.only_nodes(*(n.name for n in nodes_of_interest))
p_nodes_with_external_inputs = p_nodes_of_interest.only_nodes_with_inputs(
*p_nodes_of_interest.inputs()
)
return set(p_nodes_with_external_inputs.nodes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name is _node_with_external_inputs but its arguments is a list of node. If you are trying to search for pipeline node that has external_inputs, can you use pipeline.inputs and then pipeline.only_nodes_with_inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, simplified that



def _has_persistent_inputs(node: Node, catalog: DataCatalog) -> bool:
"""Check if a ``Node`` exclusively has persisted Datasets as inputs.
If at least one input is a ``MemoryDataset``, return False.
def _enumerate_non_persistent_inputs(node: Node, catalog: DataCatalog) -> set[str]:
"""Enumerate non-persistent input datasets of a ``Node``.

Args:
node: the ``Node`` to check the inputs of.
catalog: the ``DataCatalog`` of the run.

Returns:
True if the ``Node`` being checked exclusively has inputs that
are not ``MemoryDataset``, else False.
Set of names of non-persistent inputs of given ``Node``.

"""
# We use _datasets because they pertain parameter name format
catalog_datasets = catalog._datasets
non_persistent_inputs: set[str] = set()
for node_input in node.inputs:
if isinstance(catalog._datasets[node_input], MemoryDataset):
return False
return True
if node_input.startswith("params:"):
continue
if node_input not in catalog_datasets or isinstance(
merelcht marked this conversation as resolved.
Show resolved Hide resolved
catalog_datasets[node_input], MemoryDataset
):
non_persistent_inputs.add(node_input)

return non_persistent_inputs


def _enumerate_nodes_with_outputs(
pipeline: Pipeline, outputs: Collection[str]
) -> list[Node]:
"""For given outputs, returns a list containing nodes that
generate them in the given ``Pipeline``.

Args:
pipeline: the ``Pipeline`` to search for nodes in.
outputs: the dataset names to find source nodes for.

Returns:
A list of all ``Node``s that are producing ``outputs``.

"""
parent_pipeline = pipeline.only_nodes_with_outputs(*outputs)
return parent_pipeline.nodes


def find_initial_node_group(pipeline: Pipeline, nodes: Iterable[Node]) -> list[Node]:
"""Given a collection of ``Node``s in a ``Pipeline``,
find the initial group of ``Node``s to be run (in topological order).

This can be used to define a sub-pipeline with the smallest possible
set of nodes to pass to --from-nodes.

Args:
pipeline: the ``Pipeline`` to search for initial ``Node``s in.
nodes: the ``Node``s to find initial group for.

Returns:
A list of initial ``Node``s to run given inputs (in topological order).

"""
node_names = set(n.name for n in nodes)
if len(node_names) == 0:
return []
sub_pipeline = pipeline.only_nodes(*node_names)
initial_nodes = sub_pipeline.grouped_nodes[0]
return initial_nodes


def run_node(
Expand Down
111 changes: 103 additions & 8 deletions tests/runner/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def identity(arg):
return arg


def first_arg(*args):
return args[0]


def sink(arg):
pass

Expand All @@ -36,7 +40,7 @@ def return_not_serialisable(arg):
return lambda x: x


def multi_input_list_output(arg1, arg2):
def multi_input_list_output(arg1, arg2, arg3=None):
return [arg1, arg2]


Expand Down Expand Up @@ -80,6 +84,9 @@ def _save(arg):
"ds0_B": persistent_dataset,
"ds2_A": persistent_dataset,
"ds2_B": persistent_dataset,
"dsX": persistent_dataset,
"dsY": persistent_dataset,
"params:p": MemoryDataset(1),
}
)

Expand Down Expand Up @@ -148,21 +155,31 @@ def unfinished_outputs_pipeline():

@pytest.fixture
def two_branches_crossed_pipeline():
"""A ``Pipeline`` with an X-shape (two branches with one common node)"""
r"""A ``Pipeline`` with an X-shape (two branches with one common node):

(node1_A) (node1_B)
\ /
(node2)
/ \
(node3_A) (node3_B)
/ \
(node4_A) (node4_B)

"""
return pipeline(
[
node(identity, "ds0_A", "ds1_A", name="node1_A"),
node(identity, "ds0_B", "ds1_B", name="node1_B"),
node(first_arg, "ds0_A", "ds1_A", name="node1_A"),
node(first_arg, "ds0_B", "ds1_B", name="node1_B"),
node(
multi_input_list_output,
["ds1_A", "ds1_B"],
["ds2_A", "ds2_B"],
name="node2",
),
node(identity, "ds2_A", "ds3_A", name="node3_A"),
node(identity, "ds2_B", "ds3_B", name="node3_B"),
node(identity, "ds3_A", "ds4_A", name="node4_A"),
node(identity, "ds3_B", "ds4_B", name="node4_B"),
node(first_arg, "ds2_A", "ds3_A", name="node3_A"),
node(first_arg, "ds2_B", "ds3_B", name="node3_B"),
node(first_arg, "ds3_A", "ds4_A", name="node4_A"),
node(first_arg, "ds3_B", "ds4_B", name="node4_B"),
]
)

Expand All @@ -175,3 +192,81 @@ def pipeline_with_memory_datasets():
node(func=identity, inputs="Input2", outputs="MemOutput2", name="node2"),
]
)


@pytest.fixture
def pipeline_asymmetric():
r"""

(node1)
\
(node3) (node2)
\ /
(node4)

"""
return pipeline(
[
node(first_arg, ["ds0_A"], ["_ds1"], name="node1"),
node(first_arg, ["ds0_B"], ["_ds2"], name="node2"),
node(first_arg, ["_ds1"], ["_ds3"], name="node3"),
node(first_arg, ["_ds2", "_ds3"], ["_ds4"], name="node4"),
]
)


@pytest.fixture
def pipeline_triangular():
r"""

(node1)
| \
| (node2)
| /
(node3)

"""
return pipeline(
[
node(first_arg, ["ds0_A"], ["_ds1_A"], name="node1"),
node(first_arg, ["_ds1_A"], ["ds2_A"], name="node2"),
node(first_arg, ["ds2_A", "_ds1_A"], ["_ds3_A"], name="node3"),
]
)


@pytest.fixture
def empty_pipeline():
return pipeline([])


@pytest.fixture(
params=[(), ("dsX",), ("params:p",)],
ids=[
"no_extras",
"extra_persistent_ds",
"extra_param",
],
)
def two_branches_crossed_pipeline_variable_inputs(request):
"""A ``Pipeline`` with an X-shape (two branches with one common node).
Non-persistent datasets (other than parameters) are prefixed with an underscore.
"""
extra_inputs = list(request.param)

return pipeline(
[
node(first_arg, ["ds0_A"] + extra_inputs, "_ds1_A", name="node1_A"),
node(first_arg, ["ds0_B"] + extra_inputs, "_ds1_B", name="node1_B"),
node(
multi_input_list_output,
["_ds1_A", "_ds1_B"] + extra_inputs,
["ds2_A", "ds2_B"],
name="node2",
),
node(first_arg, ["ds2_A"] + extra_inputs, "_ds3_A", name="node3_A"),
node(first_arg, ["ds2_B"] + extra_inputs, "_ds3_B", name="node3_B"),
node(first_arg, ["_ds3_A"] + extra_inputs, "_ds4_A", name="node4_A"),
node(first_arg, ["_ds3_B"] + extra_inputs, "_ds4_B", name="node4_B"),
]
)
Loading