From e949c6c7d8814ba5fc97c7204c69103a3e9ad21a Mon Sep 17 00:00:00 2001 From: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com> Date: Thu, 7 Sep 2023 17:52:53 +0100 Subject: [PATCH] Make dataset factory resolve nested dict properly (#2993) --- RELEASE.md | 2 ++ kedro/framework/cli/catalog.py | 18 +++++++++++++-- kedro/io/data_catalog.py | 41 +++++++++++++++++++++------------- tests/io/test_data_catalog.py | 38 ++++++++++++++++++++++++++++++- 4 files changed, 80 insertions(+), 19 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 8d5cb62415..6f18db1fe1 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,6 +11,8 @@ ## Major features and improvements ## Bug fixes and other changes +* Updated dataset factories to resolve nested catalog config properly. + ## Documentation changes ## Breaking changes to the API ## Upcoming deprecations for Kedro 0.19.0 diff --git a/kedro/framework/cli/catalog.py b/kedro/framework/cli/catalog.py index b8849b5843..24816a9492 100644 --- a/kedro/framework/cli/catalog.py +++ b/kedro/framework/cli/catalog.py @@ -1,4 +1,5 @@ """A collection of CLI commands for working with Kedro catalog.""" +import copy from collections import defaultdict from itertools import chain @@ -84,7 +85,13 @@ def list_datasets(metadata: ProjectMetadata, pipeline, env): data_catalog._dataset_patterns, ds_name ) if matched_pattern: - ds_config = data_catalog._resolve_config(ds_name, matched_pattern) + ds_config_copy = copy.deepcopy( + data_catalog._dataset_patterns[matched_pattern] + ) + + ds_config = data_catalog._resolve_config( + ds_name, matched_pattern, ds_config_copy + ) factory_ds_by_type[ds_config["type"]].append(ds_name) default_ds = default_ds - set(chain.from_iterable(factory_ds_by_type.values())) @@ -244,7 +251,14 @@ def resolve_patterns(metadata: ProjectMetadata, env): data_catalog._dataset_patterns, ds_name ) if matched_pattern: - ds_config = data_catalog._resolve_config(ds_name, matched_pattern) + ds_config_copy = copy.deepcopy( + data_catalog._dataset_patterns[matched_pattern] + ) + + ds_config = data_catalog._resolve_config( + ds_name, matched_pattern, ds_config_copy + ) + ds_config["filepath"] = _trim_filepath( str(context.project_path) + "/", ds_config["filepath"] ) diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index 4cbe6c0142..031abb5b51 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -11,7 +11,7 @@ import logging import re from collections import defaultdict -from typing import Any, Dict, Iterable +from typing import Any, Dict from parse import parse @@ -388,7 +388,10 @@ def _get_dataset( if data_set_name not in self._data_sets and matched_pattern: # If the dataset is a patterned dataset, materialise it and add it to # the catalog - data_set_config = self._resolve_config(data_set_name, matched_pattern) + config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern]) + data_set_config = self._resolve_config( + data_set_name, matched_pattern, config_copy + ) ds_layer = data_set_config.pop("layer", None) if ds_layer: self.layers = self.layers or {} @@ -436,27 +439,33 @@ def __contains__(self, data_set_name): return True return False + @classmethod def _resolve_config( - self, + cls, data_set_name: str, matched_pattern: str, + config: dict, ) -> dict[str, Any]: """Get resolved AbstractDataset from a factory config""" result = parse(matched_pattern, data_set_name) - config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern]) # Resolve the factory config for the dataset - for key, value in config_copy.items(): - if isinstance(value, Iterable) and "}" in value: - # result.named: gives access to all dict items in the match result. - # format_map fills in dict values into a string with {...} placeholders - # of the same key name. - try: - config_copy[key] = str(value).format_map(result.named) - except KeyError as exc: - raise DatasetError( - f"Unable to resolve '{key}' for the pattern '{matched_pattern}'" - ) from exc - return config_copy + if isinstance(config, dict): + for key, value in config.items(): + config[key] = cls._resolve_config(data_set_name, matched_pattern, value) + elif isinstance(config, (list, tuple)): + config = [ + cls._resolve_config(data_set_name, matched_pattern, value) + for value in config + ] + elif isinstance(config, str) and "}" in config: + try: + config = str(config).format_map(result.named) + except KeyError as exc: + raise DatasetError( + f"Unable to resolve '{config}' from the pattern '{matched_pattern}'. Keys used in the configuration " + f"should be present in the dataset factory pattern." + ) from exc + return config def load(self, name: str, version: str = None) -> Any: """Loads a registered data set. diff --git a/tests/io/test_data_catalog.py b/tests/io/test_data_catalog.py index 9273fa5200..f4ac13974f 100644 --- a/tests/io/test_data_catalog.py +++ b/tests/io/test_data_catalog.py @@ -107,6 +107,29 @@ def config_with_dataset_factories(): } +@pytest.fixture +def config_with_dataset_factories_nested(): + return { + "catalog": { + "{brand}_cars": { + "type": "PartitionedDataset", + "path": "data/01_raw", + "dataset": "pandas.CSVDataSet", + "metadata": { + "my-plugin": { + "brand": "{brand}", + "list_config": [ + "NA", + "{brand}", + ], + "nested_list_dict": [{}, {"brand": "{brand}"}], + } + }, + }, + }, + } + + @pytest.fixture def config_with_dataset_factories_with_default(config_with_dataset_factories): config_with_dataset_factories["catalog"]["{default_dataset}"] = { @@ -840,7 +863,10 @@ def test_unmatched_key_error_when_parsing_config( ): """Check error raised when key mentioned in the config is not in pattern name""" catalog = DataCatalog.from_config(**config_with_dataset_factories_bad_pattern) - pattern = "Unable to resolve 'filepath' for the pattern '{type}@planes'" + pattern = ( + "Unable to resolve 'data/01_raw/{brand}_plane.pq' from the pattern '{type}@planes'. " + "Keys used in the configuration should be present in the dataset factory pattern." + ) with pytest.raises(DatasetError, match=re.escape(pattern)): catalog._get_dataset("jet@planes") @@ -896,3 +922,13 @@ def test_factory_config_versioned( microsecond=current_ts.microsecond // 1000 * 1000, tzinfo=None ) assert actual_timestamp == expected_timestamp + + def test_factory_nested_config(self, config_with_dataset_factories_nested): + catalog = DataCatalog.from_config(**config_with_dataset_factories_nested) + dataset = catalog._get_dataset("tesla_cars") + assert dataset.metadata["my-plugin"]["brand"] == "tesla" + assert dataset.metadata["my-plugin"]["list_config"] == ["NA", "tesla"] + assert dataset.metadata["my-plugin"]["nested_list_dict"] == [ + {}, + {"brand": "tesla"}, + ]