diff --git a/kedro/config/omegaconf_config.py b/kedro/config/omegaconf_config.py index 691ae30385..c4850159a1 100644 --- a/kedro/config/omegaconf_config.py +++ b/kedro/config/omegaconf_config.py @@ -9,6 +9,7 @@ import mimetypes import typing from collections.abc import KeysView +from enum import Enum, auto from pathlib import Path from typing import Any, Callable, Iterable @@ -26,6 +27,17 @@ _NO_VALUE = object() +class MergeStrategies(Enum): + SOFT = auto() + DESTRUCTIVE = auto() + + +MERGING_IMPLEMENTATIONS = { + MergeStrategies.SOFT: "_soft_merge", + MergeStrategies.DESTRUCTIVE: "_destructive_merge", +} + + class OmegaConfigLoader(AbstractConfigLoader): """Recursively scan directories (config paths) contained in ``conf_source`` for configuration files with a ``yaml``, ``yml`` or ``json`` extension, load and merge @@ -131,18 +143,9 @@ def __init__( # noqa: PLR0913 self._register_new_resolvers(custom_resolvers) # Register globals resolver self._register_globals_resolver() - file_mimetype, _ = mimetypes.guess_type(conf_source) - if file_mimetype == "application/x-tar": - self._protocol = "tar" - elif file_mimetype in ( - "application/zip", - "application/x-zip-compressed", - "application/zip-compressed", - ): - self._protocol = "zip" - else: - self._protocol = "file" - self._fs = fsspec.filesystem(protocol=self._protocol, fo=conf_source) + + # Setup file system and protocol + self._fs, self._protocol = self._initialise_filesystem_and_protocol(conf_source) super().__init__( conf_source=conf_source, @@ -220,6 +223,11 @@ def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912 # Load chosen env config run_env = self.env or self.default_run_env + + # Return if chosen env config is the same as base config to avoid loading the same config twice + if run_env == self.base_env: + return config # type: ignore[no-any-return] + if self._protocol == "file": env_path = str(Path(self.conf_source) / run_env) else: @@ -236,16 +244,7 @@ def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912 else: raise exc - merging_strategy = self.merge_strategy.get(key) - if merging_strategy == "soft": - resulting_config = self._soft_merge(config, env_config) - elif merging_strategy == "destructive" or not merging_strategy: - resulting_config = self._destructive_merge(config, env_config, env_path) - else: - raise ValueError( - f"Merging strategy {merging_strategy} not supported. The accepted merging " - f"strategies are `soft` and `destructive`." - ) + resulting_config = self._merge_configs(config, env_config, key, env_path) if not processed_files and key != "globals": raise MissingConfigException( @@ -355,6 +354,47 @@ def load_and_merge_dir_config( if not k.startswith("_") } + @staticmethod + def _initialise_filesystem_and_protocol( + conf_source: str, + ) -> tuple[fsspec.AbstractFileSystem, str]: + """Set up the file system based on the file type detected in conf_source.""" + file_mimetype, _ = mimetypes.guess_type(conf_source) + if file_mimetype == "application/x-tar": + protocol = "tar" + elif file_mimetype in ( + "application/zip", + "application/x-zip-compressed", + "application/zip-compressed", + ): + protocol = "zip" + else: + protocol = "file" + fs = fsspec.filesystem(protocol=protocol, fo=conf_source) + return fs, protocol + + def _merge_configs( + self, + config: dict[str, Any], + env_config: dict[str, Any], + key: str, + env_path: str, + ) -> Any: + merging_strategy = self.merge_strategy.get(key, "destructive") + try: + strategy = MergeStrategies[merging_strategy.upper()] + + # Get the corresponding merge function and call it + merge_function_name = MERGING_IMPLEMENTATIONS[strategy] + merge_function = getattr(self, merge_function_name) + return merge_function(config, env_config, env_path) + except KeyError: + allowed_strategies = [strategy.name.lower() for strategy in MergeStrategies] + raise ValueError( + f"Merging strategy {merging_strategy} not supported. The accepted merging " + f"strategies are {allowed_strategies}." + ) + def _get_all_keys(self, cfg: Any, parent_key: str = "") -> set[str]: keys: set[str] = set() @@ -499,7 +539,9 @@ def _destructive_merge( return config @staticmethod - def _soft_merge(config: dict[str, Any], env_config: dict[str, Any]) -> Any: + def _soft_merge( + config: dict[str, Any], env_config: dict[str, Any], env_path: str | None = None + ) -> Any: # Soft merge the two env dirs. The chosen env will override base if keys clash. return OmegaConf.to_container(OmegaConf.merge(config, env_config))