Skip to content

Commit

Permalink
Add type alias for dataset factory patterns (#2779)
Browse files Browse the repository at this point in the history
* Add type alias for patterns and refactoring

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

* Update kedro/io/data_catalog.py

Co-authored-by: Ivan Danov <idanov@users.noreply.github.com>

* Also add type alias to init fn

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>

---------

Signed-off-by: Ankita Katiyar <ankitakatiyar2401@gmail.com>
Co-authored-by: Ivan Danov <idanov@users.noreply.github.com>
  • Loading branch information
ankatiyar and idanov committed Jul 7, 2023
1 parent e6671e8 commit 26057fc
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions kedro/io/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logging
import re
from collections import defaultdict
from typing import Any, Iterable
from typing import Any, Dict, Iterable

from parse import parse

Expand All @@ -26,6 +26,8 @@
)
from kedro.io.memory_dataset import MemoryDataset

Patterns = Dict[str, Dict[str, Any]]

CATALOG_KEY = "catalog"
CREDENTIALS_KEY = "credentials"
WORDS_REGEX_PATTERN = re.compile(r"\W+")
Expand Down Expand Up @@ -143,7 +145,7 @@ def __init__( # pylint: disable=too-many-arguments
data_sets: dict[str, AbstractDataSet] = None,
feed_dict: dict[str, Any] = None,
layers: dict[str, set[str]] = None,
dataset_patterns: dict[str, dict[str, Any]] = None,
dataset_patterns: Patterns = None,
load_versions: dict[str, str] = None,
save_version: str = None,
) -> None:
Expand Down Expand Up @@ -300,7 +302,7 @@ class to be loaded is specified with the key ``type`` and their
missing_keys = [
key
for key in load_versions.keys()
if not (cls._match_pattern(sorted_patterns, key) or key in catalog)
if not (key in catalog or cls._match_pattern(sorted_patterns, key))
]
if missing_keys:
raise DatasetNotFoundError(
Expand All @@ -322,20 +324,17 @@ def _is_pattern(pattern: str):
return "{" in pattern

@staticmethod
def _match_pattern(
data_set_patterns: dict[str, dict[str, Any]], data_set_name: str
) -> str | None:
def _match_pattern(data_set_patterns: Patterns, data_set_name: str) -> str | None:
"""Match a dataset name against patterns in a dictionary containing patterns"""
for pattern, _ in data_set_patterns.items():
result = parse(pattern, data_set_name)
if result:
return pattern
return None
matches = (
pattern
for pattern in data_set_patterns.keys()
if parse(pattern, data_set_name)
)
return next(matches, None)

@classmethod
def _sort_patterns(
cls, data_set_patterns: dict[str, dict[str, Any]]
) -> dict[str, dict[str, Any]]:
def _sort_patterns(cls, data_set_patterns: Patterns) -> dict[str, dict[str, Any]]:
"""Sort a dictionary of dataset patterns according to parsing rules -
1. Decreasing specificity (number of characters outside the curly brackets)
2. Decreasing number of placeholders (number of curly bracket pairs)
Expand All @@ -349,10 +348,7 @@ def _sort_patterns(
pattern,
),
)
sorted_patterns = {}
for key in sorted_keys:
sorted_patterns[key] = data_set_patterns[key]
return sorted_patterns
return {key: data_set_patterns[key] for key in sorted_keys}

@staticmethod
def _specificity(pattern: str) -> int:
Expand Down

0 comments on commit 26057fc

Please sign in to comment.