From 26057fc017f0025044249b054b40baa12796aacc Mon Sep 17 00:00:00 2001 From: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com> Date: Fri, 7 Jul 2023 16:34:45 +0100 Subject: [PATCH] Add type alias for dataset factory patterns (#2779) * Add type alias for patterns and refactoring Signed-off-by: Ankita Katiyar * Update kedro/io/data_catalog.py Co-authored-by: Ivan Danov * Also add type alias to init fn Signed-off-by: Ankita Katiyar --------- Signed-off-by: Ankita Katiyar Co-authored-by: Ivan Danov --- kedro/io/data_catalog.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index 4143024a73..425e491c0b 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -11,7 +11,7 @@ import logging import re from collections import defaultdict -from typing import Any, Iterable +from typing import Any, Dict, Iterable from parse import parse @@ -26,6 +26,8 @@ ) from kedro.io.memory_dataset import MemoryDataset +Patterns = Dict[str, Dict[str, Any]] + CATALOG_KEY = "catalog" CREDENTIALS_KEY = "credentials" WORDS_REGEX_PATTERN = re.compile(r"\W+") @@ -143,7 +145,7 @@ def __init__( # pylint: disable=too-many-arguments data_sets: dict[str, AbstractDataSet] = None, feed_dict: dict[str, Any] = None, layers: dict[str, set[str]] = None, - dataset_patterns: dict[str, dict[str, Any]] = None, + dataset_patterns: Patterns = None, load_versions: dict[str, str] = None, save_version: str = None, ) -> None: @@ -300,7 +302,7 @@ class to be loaded is specified with the key ``type`` and their missing_keys = [ key for key in load_versions.keys() - if not (cls._match_pattern(sorted_patterns, key) or key in catalog) + if not (key in catalog or cls._match_pattern(sorted_patterns, key)) ] if missing_keys: raise DatasetNotFoundError( @@ -322,20 +324,17 @@ def _is_pattern(pattern: str): return "{" in pattern @staticmethod - def _match_pattern( - data_set_patterns: dict[str, dict[str, Any]], data_set_name: str - ) -> str | None: + def _match_pattern(data_set_patterns: Patterns, data_set_name: str) -> str | None: """Match a dataset name against patterns in a dictionary containing patterns""" - for pattern, _ in data_set_patterns.items(): - result = parse(pattern, data_set_name) - if result: - return pattern - return None + matches = ( + pattern + for pattern in data_set_patterns.keys() + if parse(pattern, data_set_name) + ) + return next(matches, None) @classmethod - def _sort_patterns( - cls, data_set_patterns: dict[str, dict[str, Any]] - ) -> dict[str, dict[str, Any]]: + def _sort_patterns(cls, data_set_patterns: Patterns) -> dict[str, dict[str, Any]]: """Sort a dictionary of dataset patterns according to parsing rules - 1. Decreasing specificity (number of characters outside the curly brackets) 2. Decreasing number of placeholders (number of curly bracket pairs) @@ -349,10 +348,7 @@ def _sort_patterns( pattern, ), ) - sorted_patterns = {} - for key in sorted_keys: - sorted_patterns[key] = data_set_patterns[key] - return sorted_patterns + return {key: data_set_patterns[key] for key in sorted_keys} @staticmethod def _specificity(pattern: str) -> int: