Skip to content

Commit

Permalink
Add method to remove runtime patterns after run (#4236)
Browse files Browse the repository at this point in the history
* Added method to remove runtime patterns

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Added test for remove_runtime_pattern

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Fixed types match

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Implemented alternative solution

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Moved catalog validation before it extended with runtime patter

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Removed debug output

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Added test to call run twice

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

---------

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>
  • Loading branch information
ElenaKhaustova authored Oct 21, 2024
1 parent 3fe61a0 commit 3818a2a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
39 changes: 22 additions & 17 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from more_itertools import interleave

from kedro.framework.hooks.manager import _NullPluginManager
from kedro.io import CatalogProtocol, MemoryDataset
from kedro.io import CatalogProtocol, MemoryDataset, SharedMemoryDataset
from kedro.pipeline import Pipeline

if TYPE_CHECKING:
Expand Down Expand Up @@ -84,11 +84,8 @@ def run(
by the node outputs.
"""

hook_or_null_manager = hook_manager or _NullPluginManager()

# Check which datasets used in the pipeline are in the catalog or match
# a pattern in the catalog
# a pattern in the catalog, not including extra dataset patterns
registered_ds = [ds for ds in pipeline.datasets() if ds in catalog]

# Check if there are any input datasets that aren't in the catalog and
Expand All @@ -100,22 +97,17 @@ def run(
f"Pipeline input(s) {unsatisfied} not found in the {catalog.__class__.__name__}"
)

# Identify MemoryDataset in the catalog
memory_datasets = {
ds_name
for ds_name, ds in catalog._datasets.items()
if isinstance(ds, MemoryDataset)
}

# Check if there's any output datasets that aren't in the catalog and don't match a pattern
# in the catalog and include MemoryDataset.
free_outputs = pipeline.outputs() - (set(registered_ds) - memory_datasets)

# Register the default dataset pattern with the catalog
catalog = catalog.shallow_copy(
extra_dataset_patterns=self._extra_dataset_patterns
)

hook_or_null_manager = hook_manager or _NullPluginManager()

# Check which datasets used in the pipeline are in the catalog or match
# a pattern in the catalog, including added extra_dataset_patterns
registered_ds = [ds for ds in pipeline.datasets() if ds in catalog]

if self._is_async:
self._logger.info(
"Asynchronous mode is enabled for loading and saving data"
Expand All @@ -124,7 +116,20 @@ def run(

self._logger.info("Pipeline execution completed successfully.")

return {ds_name: catalog.load(ds_name) for ds_name in free_outputs}
# Identify MemoryDataset in the catalog
memory_datasets = {
ds_name
for ds_name, ds in catalog._datasets.items()
if isinstance(ds, MemoryDataset) or isinstance(ds, SharedMemoryDataset)
}

# Check if there's any output datasets that aren't in the catalog and don't match a pattern
# in the catalog and include MemoryDataset.
free_outputs = pipeline.outputs() - (set(registered_ds) - memory_datasets)

run_output = {ds_name: catalog.load(ds_name) for ds_name in free_outputs}

return run_output

def run_only_missing(
self, pipeline: Pipeline, catalog: CatalogProtocol, hook_manager: PluginManager
Expand Down
11 changes: 11 additions & 0 deletions tests/runner/test_sequential_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def test_log_not_using_async(self, fan_out_fan_in, catalog, caplog):
SequentialRunner().run(fan_out_fan_in, catalog)
assert "Using synchronous mode for loading and saving data." in caplog.text

def test_run_twice_giving_same_result(self, fan_out_fan_in, catalog):
catalog.add_feed_dict({"A": 42})
result_first_run = SequentialRunner().run(
fan_out_fan_in, catalog, hook_manager=_create_hook_manager()
)
result_second_run = SequentialRunner().run(
fan_out_fan_in, catalog, hook_manager=_create_hook_manager()
)

assert result_first_run == result_second_run


@pytest.mark.parametrize("is_async", [False, True])
class TestSeqentialRunnerBranchlessPipeline:
Expand Down

0 comments on commit 3818a2a

Please sign in to comment.