Skip to content

Commit

Permalink
Make dataset factory resolve nested dict properly (#2993)
Browse files Browse the repository at this point in the history
  • Loading branch information
ankatiyar authored Sep 7, 2023
1 parent ad002bd commit e949c6c
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 19 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

## Major features and improvements
## Bug fixes and other changes
* Updated dataset factories to resolve nested catalog config properly.

## Documentation changes
## Breaking changes to the API
## Upcoming deprecations for Kedro 0.19.0
Expand Down
18 changes: 16 additions & 2 deletions kedro/framework/cli/catalog.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A collection of CLI commands for working with Kedro catalog."""
import copy
from collections import defaultdict
from itertools import chain

Expand Down Expand Up @@ -84,7 +85,13 @@ def list_datasets(metadata: ProjectMetadata, pipeline, env):
data_catalog._dataset_patterns, ds_name
)
if matched_pattern:
ds_config = data_catalog._resolve_config(ds_name, matched_pattern)
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns[matched_pattern]
)

ds_config = data_catalog._resolve_config(
ds_name, matched_pattern, ds_config_copy
)
factory_ds_by_type[ds_config["type"]].append(ds_name)

default_ds = default_ds - set(chain.from_iterable(factory_ds_by_type.values()))
Expand Down Expand Up @@ -244,7 +251,14 @@ def resolve_patterns(metadata: ProjectMetadata, env):
data_catalog._dataset_patterns, ds_name
)
if matched_pattern:
ds_config = data_catalog._resolve_config(ds_name, matched_pattern)
ds_config_copy = copy.deepcopy(
data_catalog._dataset_patterns[matched_pattern]
)

ds_config = data_catalog._resolve_config(
ds_name, matched_pattern, ds_config_copy
)

ds_config["filepath"] = _trim_filepath(
str(context.project_path) + "/", ds_config["filepath"]
)
Expand Down
41 changes: 25 additions & 16 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
import re
from collections import defaultdict
from typing import Any, Dict, Iterable
from typing import Any, Dict

from parse import parse

Expand Down Expand Up @@ -388,7 +388,10 @@ def _get_dataset(
if data_set_name not in self._data_sets and matched_pattern:
# If the dataset is a patterned dataset, materialise it and add it to
# the catalog
data_set_config = self._resolve_config(data_set_name, matched_pattern)
config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern])
data_set_config = self._resolve_config(
data_set_name, matched_pattern, config_copy
)
ds_layer = data_set_config.pop("layer", None)
if ds_layer:
self.layers = self.layers or {}
Expand Down Expand Up @@ -436,27 +439,33 @@ def __contains__(self, data_set_name):
return True
return False

@classmethod
def _resolve_config(
self,
cls,
data_set_name: str,
matched_pattern: str,
config: dict,
) -> dict[str, Any]:
"""Get resolved AbstractDataset from a factory config"""
result = parse(matched_pattern, data_set_name)
config_copy = copy.deepcopy(self._dataset_patterns[matched_pattern])
# Resolve the factory config for the dataset
for key, value in config_copy.items():
if isinstance(value, Iterable) and "}" in value:
# result.named: gives access to all dict items in the match result.
# format_map fills in dict values into a string with {...} placeholders
# of the same key name.
try:
config_copy[key] = str(value).format_map(result.named)
except KeyError as exc:
raise DatasetError(
f"Unable to resolve '{key}' for the pattern '{matched_pattern}'"
) from exc
return config_copy
if isinstance(config, dict):
for key, value in config.items():
config[key] = cls._resolve_config(data_set_name, matched_pattern, value)
elif isinstance(config, (list, tuple)):
config = [
cls._resolve_config(data_set_name, matched_pattern, value)
for value in config
]
elif isinstance(config, str) and "}" in config:
try:
config = str(config).format_map(result.named)
except KeyError as exc:
raise DatasetError(
f"Unable to resolve '{config}' from the pattern '{matched_pattern}'. Keys used in the configuration "
f"should be present in the dataset factory pattern."
) from exc
return config

def load(self, name: str, version: str = None) -> Any:
"""Loads a registered data set.
Expand Down
38 changes: 37 additions & 1 deletion tests/io/test_data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ def config_with_dataset_factories():
}


@pytest.fixture
def config_with_dataset_factories_nested():
return {
"catalog": {
"{brand}_cars": {
"type": "PartitionedDataset",
"path": "data/01_raw",
"dataset": "pandas.CSVDataSet",
"metadata": {
"my-plugin": {
"brand": "{brand}",
"list_config": [
"NA",
"{brand}",
],
"nested_list_dict": [{}, {"brand": "{brand}"}],
}
},
},
},
}


@pytest.fixture
def config_with_dataset_factories_with_default(config_with_dataset_factories):
config_with_dataset_factories["catalog"]["{default_dataset}"] = {
Expand Down Expand Up @@ -840,7 +863,10 @@ def test_unmatched_key_error_when_parsing_config(
):
"""Check error raised when key mentioned in the config is not in pattern name"""
catalog = DataCatalog.from_config(**config_with_dataset_factories_bad_pattern)
pattern = "Unable to resolve 'filepath' for the pattern '{type}@planes'"
pattern = (
"Unable to resolve 'data/01_raw/{brand}_plane.pq' from the pattern '{type}@planes'. "
"Keys used in the configuration should be present in the dataset factory pattern."
)
with pytest.raises(DatasetError, match=re.escape(pattern)):
catalog._get_dataset("jet@planes")

Expand Down Expand Up @@ -896,3 +922,13 @@ def test_factory_config_versioned(
microsecond=current_ts.microsecond // 1000 * 1000, tzinfo=None
)
assert actual_timestamp == expected_timestamp

def test_factory_nested_config(self, config_with_dataset_factories_nested):
catalog = DataCatalog.from_config(**config_with_dataset_factories_nested)
dataset = catalog._get_dataset("tesla_cars")
assert dataset.metadata["my-plugin"]["brand"] == "tesla"
assert dataset.metadata["my-plugin"]["list_config"] == ["NA", "tesla"]
assert dataset.metadata["my-plugin"]["nested_list_dict"] == [
{},
{"brand": "tesla"},
]

0 comments on commit e949c6c

Please sign in to comment.