Skip to content

Commit

Permalink
Refactor Runners, introduce Task class (#4206)
Browse files Browse the repository at this point in the history
* Introduce task class + small refactoring
* Merge in main and make release_datasets private
* Fix session tests
* Make Task runnable and call inside runners
* Move helper methods from parallel runner to task
* Mark run_node as deprecated
* Refactor helper methods to go inside Task, making hook_manager an optional argument, and adding parallel as boolean flag

---------

Signed-off-by: Merel Theisen <merel.theisen@quantumblack.com>
Signed-off-by: Merel Theisen <49397448+merelcht@users.noreply.github.com>
Signed-off-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>
Co-authored-by: Nok Lam Chan <nok.lam.chan@quantumblack.com>
  • Loading branch information
merelcht and noklam authored Nov 1, 2024
1 parent 974c517 commit 18bde07
Show file tree
Hide file tree
Showing 13 changed files with 495 additions and 362 deletions.
2 changes: 2 additions & 0 deletions kedro/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from .parallel_runner import ParallelRunner
from .runner import AbstractRunner, run_node
from .sequential_runner import SequentialRunner
from .task import Task
from .thread_runner import ThreadRunner

__all__ = [
"AbstractRunner",
"ParallelRunner",
"SequentialRunner",
"Task",
"ThreadRunner",
"run_node",
]
92 changes: 10 additions & 82 deletions kedro/runner/parallel_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from __future__ import annotations

import multiprocessing
import os
import sys
from collections import Counter
Expand All @@ -15,19 +14,14 @@
from pickle import PicklingError
from typing import TYPE_CHECKING, Any

from kedro.framework.hooks.manager import (
_create_hook_manager,
_register_hooks,
_register_hooks_entry_points,
)
from kedro.framework.project import settings
from kedro.io import (
CatalogProtocol,
DatasetNotFoundError,
MemoryDataset,
SharedMemoryDataset,
)
from kedro.runner.runner import AbstractRunner, run_node
from kedro.runner.runner import AbstractRunner
from kedro.runner.task import Task

if TYPE_CHECKING:
from collections.abc import Iterable
Expand All @@ -50,52 +44,6 @@ class ParallelRunnerManager(SyncManager):
ParallelRunnerManager.register("MemoryDataset", MemoryDataset)


def _bootstrap_subprocess(
package_name: str, logging_config: dict[str, Any] | None = None
) -> None:
from kedro.framework.project import configure_logging, configure_project

configure_project(package_name)
if logging_config:
configure_logging(logging_config)


def _run_node_synchronization( # noqa: PLR0913
node: Node,
catalog: CatalogProtocol,
is_async: bool = False,
session_id: str | None = None,
package_name: str | None = None,
logging_config: dict[str, Any] | None = None,
) -> Node:
"""Run a single `Node` with inputs from and outputs to the `catalog`.
A ``PluginManager`` instance is created in each subprocess because the
``PluginManager`` can't be serialised.
Args:
node: The ``Node`` to run.
catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs.
is_async: If True, the node inputs and outputs are loaded and saved
asynchronously with threads. Defaults to False.
session_id: The session id of the pipeline run.
package_name: The name of the project Python package.
logging_config: A dictionary containing logging configuration.
Returns:
The node argument.
"""
if multiprocessing.get_start_method() == "spawn" and package_name:
_bootstrap_subprocess(package_name, logging_config)

hook_manager = _create_hook_manager()
_register_hooks(hook_manager, settings.HOOKS)
_register_hooks_entry_points(hook_manager, settings.DISABLE_HOOKS_FOR_PLUGINS)

return run_node(node, catalog, hook_manager, is_async, session_id)


class ParallelRunner(AbstractRunner):
"""``ParallelRunner`` is an ``AbstractRunner`` implementation. It can
be used to run the ``Pipeline`` in parallel groups formed by toposort.
Expand Down Expand Up @@ -282,24 +230,19 @@ def _run(
done = None
max_workers = self._get_required_workers_count(pipeline)

from kedro.framework.project import LOGGING, PACKAGE_NAME

with ProcessPoolExecutor(max_workers=max_workers) as pool:
while True:
ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes}
todo_nodes -= ready
for node in ready:
futures.add(
pool.submit(
_run_node_synchronization,
node,
catalog,
self._is_async,
session_id,
package_name=PACKAGE_NAME,
logging_config=LOGGING, # type: ignore[arg-type]
)
task = Task(
node=node,
catalog=catalog,
is_async=self._is_async,
session_id=session_id,
parallel=True,
)
futures.add(pool.submit(task))
if not futures:
if todo_nodes:
debug_data = {
Expand All @@ -321,19 +264,4 @@ def _run(
node = future.result()
done_nodes.add(node)

# Decrement load counts, and release any datasets we
# have finished with. This is particularly important
# for the shared, default datasets we created above.
for dataset in node.inputs:
load_counts[dataset] -= 1
if (
load_counts[dataset] < 1
and dataset not in pipeline.inputs()
):
catalog.release(dataset)
for dataset in node.outputs:
if (
load_counts[dataset] < 1
and dataset not in pipeline.outputs()
):
catalog.release(dataset)
self._release_datasets(node, catalog, load_counts, pipeline)
201 changes: 24 additions & 177 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,17 @@
from __future__ import annotations

import inspect
import itertools as it
import logging
import warnings
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterator
from concurrent.futures import (
ALL_COMPLETED,
Future,
ThreadPoolExecutor,
as_completed,
wait,
)
from typing import TYPE_CHECKING, Any

from more_itertools import interleave

from kedro import KedroDeprecationWarning
from kedro.framework.hooks.manager import _NullPluginManager
from kedro.io import CatalogProtocol, MemoryDataset, SharedMemoryDataset
from kedro.pipeline import Pipeline
from kedro.runner.task import Task

if TYPE_CHECKING:
from collections.abc import Collection, Iterable
Expand Down Expand Up @@ -229,6 +221,19 @@ def _suggest_resume_scenario(
f"argument to your previous command:\n{postfix}"
)

@staticmethod
def _release_datasets(
node: Node, catalog: CatalogProtocol, load_counts: dict, pipeline: Pipeline
) -> None:
"""Decrement dataset load counts and release any datasets we've finished with"""
for dataset in node.inputs:
load_counts[dataset] -= 1
if load_counts[dataset] < 1 and dataset not in pipeline.inputs():
catalog.release(dataset)
for dataset in node.outputs:
if load_counts[dataset] < 1 and dataset not in pipeline.outputs():
catalog.release(dataset)


def _find_nodes_to_resume_from(
pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: CatalogProtocol
Expand Down Expand Up @@ -410,6 +415,11 @@ def run_node(
The node argument.
"""
warnings.warn(
"`run_node()` has been deprecated and will be removed in Kedro 0.20.0",
KedroDeprecationWarning,
)

if is_async and inspect.isgeneratorfunction(node.func):
raise ValueError(
f"Async data loading and saving does not work with "
Expand All @@ -418,175 +428,12 @@ def run_node(
f"in node {node!s}."
)

if is_async:
node = _run_node_async(node, catalog, hook_manager, session_id)
else:
node = _run_node_sequential(node, catalog, hook_manager, session_id)

for name in node.confirms:
catalog.confirm(name)
return node


def _collect_inputs_from_hook( # noqa: PLR0913
node: Node,
catalog: CatalogProtocol,
inputs: dict[str, Any],
is_async: bool,
hook_manager: PluginManager,
session_id: str | None = None,
) -> dict[str, Any]:
inputs = inputs.copy() # shallow copy to prevent in-place modification by the hook
hook_response = hook_manager.hook.before_node_run(
node=node,
catalog=catalog,
inputs=inputs,
is_async=is_async,
session_id=session_id,
)

additional_inputs = {}
if (
hook_response is not None
): # all hooks on a _NullPluginManager will return None instead of a list
for response in hook_response:
if response is not None and not isinstance(response, dict):
response_type = type(response).__name__
raise TypeError(
f"'before_node_run' must return either None or a dictionary mapping "
f"dataset names to updated values, got '{response_type}' instead."
)
additional_inputs.update(response or {})

return additional_inputs


def _call_node_run( # noqa: PLR0913
node: Node,
catalog: CatalogProtocol,
inputs: dict[str, Any],
is_async: bool,
hook_manager: PluginManager,
session_id: str | None = None,
) -> dict[str, Any]:
try:
outputs = node.run(inputs)
except Exception as exc:
hook_manager.hook.on_node_error(
error=exc,
node=node,
catalog=catalog,
inputs=inputs,
is_async=is_async,
session_id=session_id,
)
raise exc
hook_manager.hook.after_node_run(
task = Task(
node=node,
catalog=catalog,
inputs=inputs,
outputs=outputs,
hook_manager=hook_manager,
is_async=is_async,
session_id=session_id,
)
return outputs


def _run_node_sequential(
node: Node,
catalog: CatalogProtocol,
hook_manager: PluginManager,
session_id: str | None = None,
) -> Node:
inputs = {}

for name in node.inputs:
hook_manager.hook.before_dataset_loaded(dataset_name=name, node=node)
inputs[name] = catalog.load(name)
hook_manager.hook.after_dataset_loaded(
dataset_name=name, data=inputs[name], node=node
)

is_async = False

additional_inputs = _collect_inputs_from_hook(
node, catalog, inputs, is_async, hook_manager, session_id=session_id
)
inputs.update(additional_inputs)

outputs = _call_node_run(
node, catalog, inputs, is_async, hook_manager, session_id=session_id
)

items: Iterable = outputs.items()
# if all outputs are iterators, then the node is a generator node
if all(isinstance(d, Iterator) for d in outputs.values()):
# Python dictionaries are ordered, so we are sure
# the keys and the chunk streams are in the same order
# [a, b, c]
keys = list(outputs.keys())
# [Iterator[chunk_a], Iterator[chunk_b], Iterator[chunk_c]]
streams = list(outputs.values())
# zip an endless cycle of the keys
# with an interleaved iterator of the streams
# [(a, chunk_a), (b, chunk_b), ...] until all outputs complete
items = zip(it.cycle(keys), interleave(*streams))

for name, data in items:
hook_manager.hook.before_dataset_saved(dataset_name=name, data=data, node=node)
catalog.save(name, data)
hook_manager.hook.after_dataset_saved(dataset_name=name, data=data, node=node)
return node


def _run_node_async(
node: Node,
catalog: CatalogProtocol,
hook_manager: PluginManager,
session_id: str | None = None,
) -> Node:
def _synchronous_dataset_load(dataset_name: str) -> Any:
"""Minimal wrapper to ensure Hooks are run synchronously
within an asynchronous dataset load."""
hook_manager.hook.before_dataset_loaded(dataset_name=dataset_name, node=node)
return_ds = catalog.load(dataset_name)
hook_manager.hook.after_dataset_loaded(
dataset_name=dataset_name, data=return_ds, node=node
)
return return_ds

with ThreadPoolExecutor() as pool:
inputs: dict[str, Future] = {}

for name in node.inputs:
inputs[name] = pool.submit(_synchronous_dataset_load, name)

wait(inputs.values(), return_when=ALL_COMPLETED)
inputs = {key: value.result() for key, value in inputs.items()}
is_async = True
additional_inputs = _collect_inputs_from_hook(
node, catalog, inputs, is_async, hook_manager, session_id=session_id
)
inputs.update(additional_inputs)

outputs = _call_node_run(
node, catalog, inputs, is_async, hook_manager, session_id=session_id
)

future_dataset_mapping = {}
for name, data in outputs.items():
hook_manager.hook.before_dataset_saved(
dataset_name=name, data=data, node=node
)
future = pool.submit(catalog.save, name, data)
future_dataset_mapping[future] = (name, data)

for future in as_completed(future_dataset_mapping):
exception = future.exception()
if exception:
raise exception
name, data = future_dataset_mapping[future]
hook_manager.hook.after_dataset_saved(
dataset_name=name, data=data, node=node
)
node = task.execute()
return node
Loading

0 comments on commit 18bde07

Please sign in to comment.