Skip to content

Commit

Permalink
Merge branch '4138-catalog-protocol' into 3995-data-catalog-2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ElenaKhaustova committed Sep 13, 2024
2 parents caa7316 + 0833a84 commit 9540a32
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 69 deletions.
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"kedro.io.catalog_config_resolver.CatalogConfigResolver",
"kedro.io.core.AbstractDataset",
"kedro.io.core.AbstractVersionedDataset",
"kedro.io.core.CatalogProtocol",
"kedro.io.core.DatasetError",
"kedro.io.core.Version",
"kedro.io.data_catalog.DataCatalog",
Expand Down Expand Up @@ -171,6 +172,7 @@
"None. Update D from mapping/iterable E and F.",
"Patterns",
"CatalogConfigResolver",
"CatalogProtocol",
),
"py:data": (
"typing.Any",
Expand Down
20 changes: 10 additions & 10 deletions kedro/framework/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from kedro.config import AbstractConfigLoader, MissingConfigException
from kedro.framework.project import settings
from kedro.io import BaseDataCatalog, DataCatalog # noqa: TCH001
from kedro.io import CatalogProtocol, DataCatalog # noqa: TCH001
from kedro.pipeline.transcoding import _transcode_split

if TYPE_CHECKING:
Expand Down Expand Up @@ -123,7 +123,7 @@ def _convert_paths_to_absolute_posix(
return conf_dictionary


def _validate_transcoded_datasets(catalog: BaseDataCatalog) -> None:
def _validate_transcoded_datasets(catalog: CatalogProtocol) -> None:
"""Validates transcoded datasets are correctly named
Args:
Expand Down Expand Up @@ -178,13 +178,13 @@ class KedroContext:
)

@property
def catalog(self) -> BaseDataCatalog:
"""Read-only property referring to Kedro's ``BaseDataCatalog`` for this context.
def catalog(self) -> CatalogProtocol:
"""Read-only property referring to Kedro's catalog` for this context.
Returns:
DataCatalog defined in `catalog.yml`.
catalog defined in `catalog.yml`.
Raises:
KedroContextError: Incorrect ``BaseDataCatalog`` registered for the project.
KedroContextError: Incorrect catalog registered for the project.
"""
return self._get_catalog()
Expand Down Expand Up @@ -213,13 +213,13 @@ def _get_catalog(
self,
save_version: str | None = None,
load_versions: dict[str, str] | None = None,
) -> BaseDataCatalog:
"""A hook for changing the creation of a BaseDataCatalog instance.
) -> CatalogProtocol:
"""A hook for changing the creation of a catalog instance.
Returns:
DataCatalog defined in `catalog.yml`.
catalog defined in `catalog.yml`.
Raises:
KedroContextError: Incorrect ``BaseDataCatalog`` registered for the project.
KedroContextError: Incorrect catalog registered for the project.
"""
# '**/catalog*' reads modular pipeline configs
Expand Down
19 changes: 8 additions & 11 deletions kedro/framework/hooks/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

if TYPE_CHECKING:
from kedro.framework.context import KedroContext
from kedro.io import BaseDataCatalog
from kedro.io import CatalogProtocol
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node

Expand All @@ -22,7 +22,7 @@ class DataCatalogSpecs:
@hook_spec
def after_catalog_created( # noqa: PLR0913
self,
catalog: BaseDataCatalog,
catalog: CatalogProtocol,
conf_catalog: dict[str, Any],
conf_creds: dict[str, Any],
feed_dict: dict[str, Any],
Expand Down Expand Up @@ -53,7 +53,7 @@ class NodeSpecs:
def before_node_run(
self,
node: Node,
catalog: BaseDataCatalog,
catalog: CatalogProtocol,
inputs: dict[str, Any],
is_async: bool,
session_id: str,
Expand Down Expand Up @@ -81,7 +81,7 @@ def before_node_run(
def after_node_run( # noqa: PLR0913
self,
node: Node,
catalog: BaseDataCatalog,
catalog: CatalogProtocol,
inputs: dict[str, Any],
outputs: dict[str, Any],
is_async: bool,
Expand Down Expand Up @@ -110,7 +110,7 @@ def on_node_error( # noqa: PLR0913
self,
error: Exception,
node: Node,
catalog: BaseDataCatalog,
catalog: CatalogProtocol,
inputs: dict[str, Any],
is_async: bool,
session_id: str,
Expand All @@ -137,10 +137,7 @@ class PipelineSpecs:

@hook_spec
def before_pipeline_run(
self,
run_params: dict[str, Any],
pipeline: Pipeline,
catalog: BaseDataCatalog,
self, run_params: dict[str, Any], pipeline: Pipeline, catalog: CatalogProtocol
) -> None:
"""Hook to be invoked before a pipeline runs.
Expand Down Expand Up @@ -177,7 +174,7 @@ def after_pipeline_run(
run_params: dict[str, Any],
run_result: dict[str, Any],
pipeline: Pipeline,
catalog: BaseDataCatalog,
catalog: CatalogProtocol,
) -> None:
"""Hook to be invoked after a pipeline runs.
Expand Down Expand Up @@ -215,7 +212,7 @@ def on_pipeline_error(
error: Exception,
run_params: dict[str, Any],
pipeline: Pipeline,
catalog: BaseDataCatalog,
catalog: CatalogProtocol,
) -> None:
"""Hook to be invoked if a pipeline run throws an uncaught Exception.
The signature of this error hook should match the signature of ``before_pipeline_run``
Expand Down
25 changes: 23 additions & 2 deletions kedro/framework/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dynaconf import LazySettings
from dynaconf.validator import ValidationError, Validator

from kedro.io import CatalogProtocol
from kedro.pipeline import Pipeline, pipeline

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,6 +69,25 @@ def validate(
)


class _ImplementsCatalogProtocolValidator(Validator):
"""A validator to check if the supplied setting value is a subclass of the default class"""

def validate(
self, settings: dynaconf.base.Settings, *args: Any, **kwargs: Any
) -> None:
super().validate(settings, *args, **kwargs)

protocol = CatalogProtocol
for name in self.names:
setting_value = getattr(settings, name)
if not isinstance(setting_value(), protocol):
raise ValidationError(
f"Invalid value '{setting_value.__module__}.{setting_value.__qualname__}' "
f"received for setting '{name}'. It must implement "
f"'{protocol.__module__}.{protocol.__qualname__}'."
)


class _HasSharedParentClassValidator(Validator):
"""A validator to check that the parent of the default class is an ancestor of
the settings value."""
Expand Down Expand Up @@ -124,8 +144,9 @@ class _ProjectSettings(LazySettings):
_CONFIG_LOADER_ARGS = Validator(
"CONFIG_LOADER_ARGS", default={"base_env": "base", "default_run_env": "local"}
)
_DATA_CATALOG_CLASS = _IsSubclassValidator(
"DATA_CATALOG_CLASS", default=_get_default_class("kedro.io.DataCatalog")
_DATA_CATALOG_CLASS = _ImplementsCatalogProtocolValidator(
"DATA_CATALOG_CLASS",
default=_get_default_class("kedro.io.DataCatalog"),
)

def __init__(self, *args: Any, **kwargs: Any):
Expand Down
2 changes: 2 additions & 0 deletions kedro/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .core import (
AbstractDataset,
AbstractVersionedDataset,
CatalogProtocol,
DatasetAlreadyExistsError,
DatasetError,
DatasetNotFoundError,
Expand All @@ -25,6 +26,7 @@
"BaseDataCatalog",
"AbstractVersionedDataset",
"CachedDataset",
"CatalogProtocol",
"DataCatalog",
"CatalogConfigResolver",
"DatasetAlreadyExistsError",
Expand Down
83 changes: 82 additions & 1 deletion kedro/io/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
from glob import iglob
from operator import attrgetter
from pathlib import Path, PurePath, PurePosixPath
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Protocol,
TypeVar,
runtime_checkable,
)
from urllib.parse import urlsplit

from cachetools import Cache, cachedmethod
Expand All @@ -29,6 +37,8 @@
if TYPE_CHECKING:
import os

from kedro.io.catalog_config_resolver import CatalogConfigResolver, Patterns

VERSION_FORMAT = "%Y-%m-%dT%H.%M.%S.%fZ"
VERSIONED_FLAG_KEY = "versioned"
VERSION_KEY = "version"
Expand Down Expand Up @@ -871,3 +881,74 @@ def validate_on_forbidden_chars(**kwargs: Any) -> None:
raise DatasetError(
f"Neither white-space nor semicolon are allowed in '{key}'."
)


_C = TypeVar("_C")


@runtime_checkable
class CatalogProtocol(Protocol[_C]):
_datasets: dict[str, AbstractDataset]

def __contains__(self, ds_name: str) -> bool:
"""Check if a dataset is in the catalog."""
...

@property
def config_resolver(self) -> CatalogConfigResolver:
"""Return a copy of the datasets dictionary."""
...

@classmethod
def from_config(cls, catalog: dict[str, dict[str, Any]] | None) -> _C:
"""Create a ``KedroDataCatalog`` instance from configuration."""
...

def _get_dataset(
self,
dataset_name: str,
version: Any = None,
suggest: bool = True,
) -> AbstractDataset:
"""Retrieve a dataset by its name."""
...

def list(self, regex_search: str | None = None) -> list[str]:
"""List all dataset names registered in the catalog."""
...

def save(self, name: str, data: Any) -> None:
"""Save data to a registered dataset."""
...

def load(self, name: str, version: str | None = None) -> Any:
"""Load data from a registered dataset."""
...

def add(self, ds_name: str, dataset: Any, replace: bool = False) -> None:
"""Add a new dataset to the catalog."""
...

def add_all(self, datasets: dict[str, Any], replace: bool = False) -> None:
"""Add a new dataset to the catalog."""
...

def add_feed_dict(self, datasets: dict[str, Any], replace: bool = False) -> None:
"""Add datasets to the catalog using the data provided through the `feed_dict`."""
...

def exists(self, name: str) -> bool:
"""Checks whether registered data set exists by calling its `exists()` method."""
...

def release(self, name: str) -> None:
"""Release any cached data associated with a dataset."""
...

def confirm(self, name: str) -> None:
"""Confirm a dataset by its name."""
...

def shallow_copy(self, extra_dataset_patterns: Patterns | None = None) -> _C:
"""Returns a shallow copy of the current object."""
...
16 changes: 8 additions & 8 deletions kedro/runner/parallel_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from kedro.framework.project import settings
from kedro.io import (
BaseDataCatalog,
CatalogProtocol,
DatasetNotFoundError,
MemoryDataset,
SharedMemoryDataset,
Expand Down Expand Up @@ -60,7 +60,7 @@ def _bootstrap_subprocess(

def _run_node_synchronization( # noqa: PLR0913
node: Node,
catalog: BaseDataCatalog,
catalog: CatalogProtocol,
is_async: bool = False,
session_id: str | None = None,
package_name: str | None = None,
Expand All @@ -73,7 +73,7 @@ def _run_node_synchronization( # noqa: PLR0913
Args:
node: The ``Node`` to run.
catalog: A ``BaseDataCatalog`` containing the node's inputs and outputs.
catalog: A catalog containing the node's inputs and outputs.
is_async: If True, the node inputs and outputs are loaded and saved
asynchronously with threads. Defaults to False.
session_id: The session id of the pipeline run.
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
cannot be larger than 61 and will be set to min(61, max_workers).
is_async: If True, the node inputs and outputs are loaded and saved
asynchronously with threads. Defaults to False.
extra_dataset_patterns: Extra dataset factory patterns to be added to the BaseDataCatalog
extra_dataset_patterns: Extra dataset factory patterns to be added to the catalog
during the run. This is used to set the default datasets to SharedMemoryDataset
for `ParallelRunner`.
Expand Down Expand Up @@ -168,7 +168,7 @@ def _validate_nodes(cls, nodes: Iterable[Node]) -> None:
)

@classmethod
def _validate_catalog(cls, catalog: BaseDataCatalog, pipeline: Pipeline) -> None:
def _validate_catalog(cls, catalog: CatalogProtocol, pipeline: Pipeline) -> None:
"""Ensure that all data sets are serialisable and that we do not have
any non proxied memory data sets being used as outputs as their content
will not be synchronized across threads.
Expand Down Expand Up @@ -214,7 +214,7 @@ def _validate_catalog(cls, catalog: BaseDataCatalog, pipeline: Pipeline) -> None
)

def _set_manager_datasets(
self, catalog: BaseDataCatalog, pipeline: Pipeline
self, catalog: CatalogProtocol, pipeline: Pipeline
) -> None:
for dataset in pipeline.datasets():
try:
Expand Down Expand Up @@ -242,15 +242,15 @@ def _get_required_workers_count(self, pipeline: Pipeline) -> int:
def _run(
self,
pipeline: Pipeline,
catalog: BaseDataCatalog,
catalog: CatalogProtocol,
hook_manager: PluginManager,
session_id: str | None = None,
) -> None:
"""The abstract interface for running pipelines.
Args:
pipeline: The ``Pipeline`` to run.
catalog: The ``BaseDataCatalog`` from which to fetch data.
catalog: The `catalog from which to fetch data.
hook_manager: The ``PluginManager`` to activate hooks.
session_id: The id of the session.
Expand Down
Loading

0 comments on commit 9540a32

Please sign in to comment.