-
Notifications
You must be signed in to change notification settings - Fork 906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve resume suggestions for SequentialRunner #3026
Changes from 8 commits
994d6b5
688f1eb
bf24347
7ce2ea6
36294d5
0214c4e
10db9dc
700645c
9e25f6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
as_completed, | ||
wait, | ||
) | ||
from typing import Any, Iterable, Iterator | ||
from typing import Any, Collection, Iterable, Iterator | ||
|
||
from more_itertools import interleave | ||
from pluggy import PluginManager | ||
|
@@ -198,16 +198,15 @@ def _suggest_resume_scenario( | |
|
||
postfix = "" | ||
if done_nodes: | ||
node_names = (n.name for n in remaining_nodes) | ||
resume_p = pipeline.only_nodes(*node_names) | ||
start_p = resume_p.only_nodes_with_inputs(*resume_p.inputs()) | ||
# Find which of the remaining nodes would need to run first (in topo sort) | ||
remaining_initial_nodes = _find_initial_nodes(pipeline, remaining_nodes) | ||
|
||
# find the nearest persistent ancestors of the nodes in start_p | ||
start_p_persistent_ancestors = _find_persistent_ancestors( | ||
pipeline, start_p.nodes, catalog | ||
# Find the nearest persistent ancestors of these nodes | ||
persistent_ancestors = _find_persistent_ancestors( | ||
pipeline, remaining_initial_nodes, catalog | ||
) | ||
|
||
start_node_names = (n.name for n in start_p_persistent_ancestors) | ||
start_node_names = sorted(n.name for n in persistent_ancestors) | ||
postfix += f" --from-nodes \"{','.join(start_node_names)}\"" | ||
|
||
if not postfix: | ||
|
@@ -230,7 +229,7 @@ def _find_persistent_ancestors( | |
) -> set[Node]: | ||
"""Breadth-first search approach to finding the complete set of | ||
persistent ancestors of an iterable of ``Node``s. Persistent | ||
ancestors exclusively have persisted ``Dataset``s as inputs. | ||
ancestors exclusively have persisted ``Dataset``s or parameters as inputs. | ||
|
||
Args: | ||
pipeline: the ``Pipeline`` to find ancestors in. | ||
|
@@ -242,54 +241,86 @@ def _find_persistent_ancestors( | |
``Node``s. | ||
|
||
""" | ||
ancestor_nodes_to_run = set() | ||
initial_nodes_to_run: set[Node] = set() | ||
|
||
queue, visited = deque(children), set(children) | ||
while queue: | ||
current_node = queue.popleft() | ||
if _has_persistent_inputs(current_node, catalog): | ||
ancestor_nodes_to_run.add(current_node) | ||
impersistent_inputs = _enumerate_impersistent_inputs(current_node, catalog) | ||
|
||
# If all inputs are persistent, we can run this node as is | ||
if not impersistent_inputs: | ||
initial_nodes_to_run.add(current_node) | ||
continue | ||
for parent in _enumerate_parents(pipeline, current_node): | ||
if parent in visited: | ||
|
||
# Otherwise, look for the nodes that produce impersistent inputs | ||
for node in _enumerate_nodes_with_outputs(pipeline, impersistent_inputs): | ||
if node in visited: | ||
continue | ||
visited.add(parent) | ||
queue.append(parent) | ||
return ancestor_nodes_to_run | ||
visited.add(node) | ||
queue.append(node) | ||
|
||
return initial_nodes_to_run | ||
|
||
|
||
def _enumerate_impersistent_inputs(node: Node, catalog: DataCatalog) -> set[str]: | ||
"""Enumerate impersistent input Datasets of a ``Node``. | ||
|
||
Args: | ||
node: the ``Node`` to check the inputs of. | ||
catalog: the ``DataCatalog`` of the run. | ||
|
||
Returns: | ||
Set of names of impersistent inputs of given ``Node``. | ||
|
||
""" | ||
# We use _data_sets because they pertain parameter name format | ||
catalog_datasets = catalog._datasets | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to clarify, does this mean that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@merelcht Mostly terminology, it's just hard to remember as previously it's just MemoryDataset, now is basically memoryDataset + params. New thought: #3520 - any chance we can utilise the new flag that we added? Can we load There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
missing_inputs: set[str] = set() | ||
for node_input in node.inputs: | ||
# Important difference vs. Kedro approach | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't forget to remove this comment, or otherwise clarify that parameters are treated as persistent. |
||
if node_input.startswith("params:"): | ||
continue | ||
if isinstance(catalog_datasets[node_input], MemoryDataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it necessary to load the dataset though? Isn't that a bit of overkill? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this approach I need to check the Dataset instance, not exactly its data contents |
||
missing_inputs.add(node_input) | ||
|
||
def _enumerate_parents(pipeline: Pipeline, child: Node) -> list[Node]: | ||
"""For a given ``Node``, returns a list containing the direct parents | ||
of that ``Node`` in the given ``Pipeline``. | ||
return missing_inputs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe call this |
||
|
||
|
||
def _enumerate_nodes_with_outputs( | ||
pipeline: Pipeline, outputs: Collection[str] | ||
) -> list[Node]: | ||
"""For given outputs, returns a list containing nodes that | ||
generate them in the given ``Pipeline``. | ||
|
||
Args: | ||
pipeline: the ``Pipeline`` to search for direct parents in. | ||
child: the ``Node`` to find parents of. | ||
pipeline: the ``Pipeline`` to search for nodes in. | ||
outputs: the dataset names to find source nodes for. | ||
|
||
Returns: | ||
A list of all ``Node``s that are direct parents of ``child``. | ||
A list of all ``Node``s that are producing ``outputs``. | ||
|
||
""" | ||
parent_pipeline = pipeline.only_nodes_with_outputs(*child.inputs) | ||
parent_pipeline = pipeline.only_nodes_with_outputs(*outputs) | ||
return parent_pipeline.nodes | ||
|
||
|
||
def _has_persistent_inputs(node: Node, catalog: DataCatalog) -> bool: | ||
"""Check if a ``Node`` exclusively has persisted Datasets as inputs. | ||
If at least one input is a ``MemoryDataset``, return False. | ||
def _find_initial_nodes(pipeline: Pipeline, nodes: Iterable[Node]) -> list[Node]: | ||
"""Given a collection of ``Node``s in a ``Pipeline``, | ||
find the initial group of ``Node``s to be run (in topological order). | ||
|
||
Args: | ||
node: the ``Node`` to check the inputs of. | ||
catalog: the ``DataCatalog`` of the run. | ||
pipeline: the ``Pipeline`` to search for initial ``Node``s in. | ||
nodes: the ``Node``s to find initial group for. | ||
|
||
Returns: | ||
True if the ``Node`` being checked exclusively has inputs that | ||
are not ``MemoryDataset``, else False. | ||
A list of initial ``Node``s to run given inputs (in topological order). | ||
|
||
""" | ||
for node_input in node.inputs: | ||
if isinstance(catalog._datasets[node_input], MemoryDataset): | ||
return False | ||
return True | ||
node_names = set(n.name for n in nodes) | ||
sub_pipeline = pipeline.only_nodes(*node_names) | ||
initial_nodes = sub_pipeline.grouped_nodes[0] | ||
merelcht marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return initial_nodes | ||
|
||
|
||
def run_node( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.