Skip to content

Commit

Permalink
Fixe using ThreadRunner with dataset factories (#4093)
Browse files Browse the repository at this point in the history
* Added lock for dataset load

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Replaced multiprocessing lock with multithreading

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Added an alternative solution

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Updated solution 2

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Made solution 1 main

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Removed solution 2

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Updated release notes

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Removed lock solution

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Returned solution with patterns resolving

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Added test for ThreadRunner

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Added _match_pattern patching

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Fixed tests to satisfy required coverage

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

* Reverted heading for the release notes

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>

---------

Signed-off-by: Elena Khaustova <ymax70rus@gmail.com>
  • Loading branch information
ElenaKhaustova authored Aug 29, 2024
1 parent f738dc8 commit f6319dd
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 1 deletion.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## Major features and improvements
* Enhanced `OmegaConfigLoader` configuration validation to detect duplicate keys at all parameter levels, ensuring comprehensive nested key checking.
## Bug fixes and other changes
* Fixed bug where using dataset factories breaks with `ThreadRunner`.

## Breaking changes to the API

## Documentation changes
Expand Down
8 changes: 7 additions & 1 deletion kedro/framework/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
validate_settings,
)
from kedro.io.core import generate_timestamp
from kedro.runner import AbstractRunner, SequentialRunner
from kedro.runner import AbstractRunner, SequentialRunner, ThreadRunner
from kedro.utils import _find_kedro_project

if TYPE_CHECKING:
Expand Down Expand Up @@ -395,6 +395,12 @@ def run( # noqa: PLR0913
)

try:
if isinstance(runner, ThreadRunner):
for ds in filtered_pipeline.datasets():
if catalog._match_pattern(
catalog._dataset_patterns, ds
) or catalog._match_pattern(catalog._default_pattern, ds):
_ = catalog._get_dataset(ds)
run_result = runner.run(
filtered_pipeline, catalog, hook_manager, session_id
)
Expand Down
78 changes: 78 additions & 0 deletions tests/framework/session/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ def mock_runner(mocker):
return mock_runner


@pytest.fixture
def mock_thread_runner(mocker):
mock_runner = mocker.patch(
"kedro.runner.thread_runner.ThreadRunner",
autospec=True,
)
mock_runner.__name__ = "MockThreadRunner`"
return mock_runner


@pytest.fixture
def mock_context_class(mocker):
mock_cls = create_attrs_autospec(KedroContext)
Expand Down Expand Up @@ -693,6 +703,74 @@ def test_run(
catalog=mock_catalog,
)

@pytest.mark.usefixtures("mock_settings_context_class")
@pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME])
@pytest.mark.parametrize("match_pattern", [True, False])
def test_run_thread_runner(
self,
fake_project,
fake_session_id,
fake_pipeline_name,
mock_context_class,
mock_thread_runner,
mocker,
match_pattern,
):
"""Test running the project via the session"""

mock_hook = mocker.patch(
"kedro.framework.session.session._create_hook_manager"
).return_value.hook

ds_mock = mocker.Mock(**{"datasets.return_value": ["ds_1", "ds_2"]})
filter_mock = mocker.Mock(**{"filter.return_value": ds_mock})
pipelines_ret = {
_FAKE_PIPELINE_NAME: filter_mock,
"__default__": filter_mock,
}
mocker.patch("kedro.framework.session.session.pipelines", pipelines_ret)
mocker.patch(
"kedro.io.data_catalog.DataCatalog._match_pattern",
return_value=match_pattern,
)

with KedroSession.create(fake_project) as session:
session.run(runner=mock_thread_runner, pipeline_name=fake_pipeline_name)

mock_context = mock_context_class.return_value
record_data = {
"session_id": fake_session_id,
"project_path": fake_project.as_posix(),
"env": mock_context.env,
"kedro_version": kedro_version,
"tags": None,
"from_nodes": None,
"to_nodes": None,
"node_names": None,
"from_inputs": None,
"to_outputs": None,
"load_versions": None,
"extra_params": {},
"pipeline_name": fake_pipeline_name,
"namespace": None,
"runner": mock_thread_runner.__name__,
}
mock_catalog = mock_context._get_catalog.return_value
mock_pipeline = filter_mock.filter()

mock_hook.before_pipeline_run.assert_called_once_with(
run_params=record_data, pipeline=mock_pipeline, catalog=mock_catalog
)
mock_thread_runner.run.assert_called_once_with(
mock_pipeline, mock_catalog, session._hook_manager, fake_session_id
)
mock_hook.after_pipeline_run.assert_called_once_with(
run_params=record_data,
run_result=mock_thread_runner.run.return_value,
pipeline=mock_pipeline,
catalog=mock_catalog,
)

@pytest.mark.usefixtures("mock_settings_context_class")
@pytest.mark.parametrize("fake_pipeline_name", [None, _FAKE_PIPELINE_NAME])
def test_run_multiple_times(
Expand Down

0 comments on commit f6319dd

Please sign in to comment.