diff --git a/.github/styles/Kedro/ignore.txt b/.github/styles/Kedro/ignore.txt index 9634d1b14b..3d568cddc9 100644 --- a/.github/styles/Kedro/ignore.txt +++ b/.github/styles/Kedro/ignore.txt @@ -44,3 +44,5 @@ transcoding transcode Claypot ethanknights +Aneira +Printify diff --git a/.github/workflows/benchmark-performance.yml b/.github/workflows/benchmark-performance.yml new file mode 100644 index 0000000000..30922193c3 --- /dev/null +++ b/.github/workflows/benchmark-performance.yml @@ -0,0 +1,59 @@ +name: ASV Benchmark + +on: + push: + branches: + - main # Run benchmarks on every commit to the main branch + workflow_dispatch: + + +jobs: + + benchmark: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + path: "kedro" + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install asv # Install ASV + + - name: Run ASV benchmarks + run: | + cd kedro + asv machine --machine=github-actions + asv run -v --machine=github-actions + + - name: Set git email and name + run: | + git config --global user.email "kedro@kedro.com" + git config --global user.name "Kedro" + + - name: Checkout target repository + uses: actions/checkout@v4 + with: + repository: kedro-org/kedro-benchmark-results + token: ${{ secrets.GH_TAGGING_TOKEN }} + ref: 'main' + path: "kedro-benchmark-results" + + - name: Copy files to target repository + run: | + cp -r /home/runner/work/kedro/kedro/kedro/.asv /home/runner/work/kedro/kedro/kedro-benchmark-results/ + + - name: Commit and Push changes to kedro-org/kedro-benchmark-results + run: | + cd kedro-benchmark-results + git add . + git commit -m "Add results" + git push diff --git a/RELEASE.md b/RELEASE.md index d1a74d407d..61560acf87 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,11 +1,23 @@ # Upcoming Release ## Major features and improvements +* Implemented `KedroDataCatalog` repeating `DataCatalog` functionality with a few API enhancements: + * Removed `_FrozenDatasets` and access datasets as properties; + * Added get dataset by name feature; + * `add_feed_dict()` was simplified and renamed to `add_data()`; + * Datasets' initialisation was moved out from `from_config()` method to the constructor. +* Moved development requirements from `requirements.txt` to the dedicated section in `pyproject.toml` for project template. +* Implemented `Protocol` abstraction for the current `DataCatalog` and adding new catalog implementations. +* Refactored `kedro run` and `kedro catalog` commands. +* Moved pattern resolution logic from `DataCatalog` to a separate component - `CatalogConfigResolver`. Updated `DataCatalog` to use `CatalogConfigResolver` internally. * Made packaged Kedro projects return `session.run()` output to be used when running it in the interactive environment. * Enhanced `OmegaConfigLoader` configuration validation to detect duplicate keys at all parameter levels, ensuring comprehensive nested key checking. ## Bug fixes and other changes * Fixed bug where using dataset factories breaks with `ThreadRunner`. * Fixed a bug where `SharedMemoryDataset.exists` would not call the underlying `MemoryDataset`. +* Fixed template projects example tests. +* Made credentials loading consistent between `KedroContext._get_catalog()` and `resolve_patterns` so that both us +e `_get_config_credentials()` ## Breaking changes to the API * Removed `ShelveStore` to address a security vulnerability. @@ -17,6 +29,9 @@ ## Community contributions * [Puneet](https://github.com/puneeter) * [ethanknights](https://github.com/ethanknights) +* [Manezki](https://github.com/Manezki) +* [MigQ2](https://github.com/MigQ2) +* [Felix Scherz](https://github.com/felixscherz) # Release 0.19.8 diff --git a/asv.conf.json b/asv.conf.json new file mode 100644 index 0000000000..2cfcd3a057 --- /dev/null +++ b/asv.conf.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "project": "Kedro", + "project_url": "https://kedro.org/", + "repo": ".", + "install_command": ["pip install -e ."], + "branches": ["main"], + "environment_type": "virtualenv", + "show_commit_url": "http://github.com/kedro-org/kedro/commit/", + "results_dir": ".asv/results", + "html_dir": ".asv/html" +} diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/benchmarks/benchmark_dummy.py b/benchmarks/benchmark_dummy.py new file mode 100644 index 0000000000..fc047eb712 --- /dev/null +++ b/benchmarks/benchmark_dummy.py @@ -0,0 +1,16 @@ +# Write the benchmarking functions here. +# See "Writing benchmarks" in the asv docs for more information. + + +class TimeSuite: + """ + A dummy benchmark suite to test with asv framework. + """ + def setup(self): + self.d = {} + for x in range(500): + self.d[x] = None + + def time_keys(self): + for key in self.d.keys(): + pass diff --git a/docs/source/conf.py b/docs/source/conf.py index 562f5a4b0e..50b719f117 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -127,11 +127,14 @@ "typing.Type", "typing.Set", "kedro.config.config.ConfigLoader", + "kedro.io.catalog_config_resolver.CatalogConfigResolver", "kedro.io.core.AbstractDataset", "kedro.io.core.AbstractVersionedDataset", + "kedro.io.core.CatalogProtocol", "kedro.io.core.DatasetError", "kedro.io.core.Version", "kedro.io.data_catalog.DataCatalog", + "kedro.io.kedro_data_catalog.KedroDataCatalog", "kedro.io.memory_dataset.MemoryDataset", "kedro.io.partitioned_dataset.PartitionedDataset", "kedro.pipeline.pipeline.Pipeline", @@ -168,6 +171,9 @@ "D[k] if k in D, else d. d defaults to None.", "None. Update D from mapping/iterable E and F.", "Patterns", + "CatalogConfigResolver", + "CatalogProtocol", + "KedroDataCatalog", ), "py:data": ( "typing.Any", diff --git a/docs/source/contribution/technical_steering_committee.md b/docs/source/contribution/technical_steering_committee.md index a17590bdad..b324c15910 100644 --- a/docs/source/contribution/technical_steering_committee.md +++ b/docs/source/contribution/technical_steering_committee.md @@ -61,10 +61,10 @@ We look for commitment markers who can do the following: | [Huong Nguyen](https://github.com/Huongg) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | | [Ivan Danov](https://github.com/idanov) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | | [Jitendra Gundaniya](https://github.com/jitu5) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | -| [Joel Schwarzmann](https://github.com/datajoely) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | +| [Joel Schwarzmann](https://github.com/datajoely) | [Aneira Health](https://www.aneira.health) | | [Juan Luis Cano](https://github.com/astrojuanlu) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | | [Laura Couto](https://github.com/lrcouto) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | -| [Marcin Zabłocki](https://github.com/marrrcin) | [Printify, Inc.](https://printify.com/) | +| [Marcin Zabłocki](https://github.com/marrrcin) | [Printify, Inc.](https://printify.com/) | | [Merel Theisen](https://github.com/merelcht) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | | [Nok Lam Chan](https://github.com/noklam) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | | [Rashida Kanchwala](https://github.com/rashidakanchwala) | [QuantumBlack, AI by McKinsey](https://www.mckinsey.com/capabilities/quantumblack) | diff --git a/docs/source/data/how_to_create_a_custom_dataset.md b/docs/source/data/how_to_create_a_custom_dataset.md index 01ad199f55..7f39987dd7 100644 --- a/docs/source/data/how_to_create_a_custom_dataset.md +++ b/docs/source/data/how_to_create_a_custom_dataset.md @@ -4,7 +4,7 @@ ## AbstractDataset -If you are a contributor and would like to submit a new dataset, you must extend the {py:class}`~kedro.io.AbstractDataset` interface or {py:class}`~kedro.io.AbstractVersionedDataset` interface if you plan to support versioning. It requires subclasses to override the `_load` and `_save` and provides `load` and `save` methods that enrich the corresponding private methods with uniform error handling. It also requires subclasses to override `_describe`, which is used in logging the internal information about the instances of your custom `AbstractDataset` implementation. +If you are a contributor and would like to submit a new dataset, you must extend the {py:class}`~kedro.io.AbstractDataset` interface or {py:class}`~kedro.io.AbstractVersionedDataset` interface if you plan to support versioning. It requires subclasses to implement the `load` and `save` methods while providing wrappers that enrich the corresponding methods with uniform error handling. It also requires subclasses to override `_describe`, which is used in logging the internal information about the instances of your custom `AbstractDataset` implementation. ## Scenario @@ -31,8 +31,8 @@ Consult the [Pillow documentation](https://pillow.readthedocs.io/en/stable/insta At the minimum, a valid Kedro dataset needs to subclass the base {py:class}`~kedro.io.AbstractDataset` and provide an implementation for the following abstract methods: -* `_load` -* `_save` +* `load` +* `save` * `_describe` `AbstractDataset` is generically typed with an input data type for saving data, and an output data type for loading data. @@ -70,7 +70,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): """ self._filepath = filepath - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -78,7 +78,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): """ ... - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath""" ... @@ -96,11 +96,11 @@ src/kedro_pokemon/datasets └── image_dataset.py ``` -## Implement the `_load` method with `fsspec` +## Implement the `load` method with `fsspec` Many of the built-in Kedro datasets rely on [fsspec](https://filesystem-spec.readthedocs.io/en/latest/) as a consistent interface to different data sources, as described earlier in the section about the [Data Catalog](../data/data_catalog.md#dataset-filepath). In this example, it's particularly convenient to use `fsspec` in conjunction with `Pillow` to read image data, since it allows the dataset to work flexibly with different image locations and formats. -Here is the implementation of the `_load` method using `fsspec` and `Pillow` to read the data of a single image into a `numpy` array: +Here is the implementation of the `load` method using `fsspec` and `Pillow` to read the data of a single image into a `numpy` array:
Click to expand @@ -130,7 +130,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): self._filepath = PurePosixPath(path) self._fs = fsspec.filesystem(self._protocol) - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -168,14 +168,14 @@ In [2]: from PIL import Image In [3]: Image.fromarray(image).show() ``` -## Implement the `_save` method with `fsspec` +## Implement the `save` method with `fsspec` Similarly, we can implement the `_save` method as follows: ```python class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" # using get_filepath_str ensures that the protocol and path are appended correctly for different filesystems save_path = get_filepath_str(self._filepath, self._protocol) @@ -243,7 +243,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): self._filepath = PurePosixPath(path) self._fs = fsspec.filesystem(self._protocol) - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -254,7 +254,7 @@ class ImageDataset(AbstractDataset[np.ndarray, np.ndarray]): image = Image.open(f).convert("RGBA") return np.asarray(image) - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" save_path = get_filepath_str(self._filepath, self._protocol) with self._fs.open(save_path, mode="wb") as f: @@ -312,7 +312,7 @@ To add versioning support to the new dataset we need to extend the {py:class}`~kedro.io.AbstractVersionedDataset` to: * Accept a `version` keyword argument as part of the constructor -* Adapt the `_load` and `_save` method to use the versioned data path obtained from `_get_load_path` and `_get_save_path` respectively +* Adapt the `load` and `save` method to use the versioned data path obtained from `_get_load_path` and `_get_save_path` respectively The following amends the full implementation of our basic `ImageDataset`. It now loads and saves data to and from a versioned subfolder (`data/01_raw/pokemon-images-and-types/images/images/pikachu.png//pikachu.png` with `version` being a datetime-formatted string `YYYY-MM-DDThh.mm.ss.sssZ` by default): @@ -359,7 +359,7 @@ class ImageDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]): glob_function=self._fs.glob, ) - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -370,7 +370,7 @@ class ImageDataset(AbstractVersionedDataset[np.ndarray, np.ndarray]): image = Image.open(f).convert("RGBA") return np.asarray(image) - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" save_path = get_filepath_str(self._get_save_path(), self._protocol) with self._fs.open(save_path, mode="wb") as f: @@ -435,7 +435,7 @@ The difference between the original `ImageDataset` and the versioned `ImageDatas + glob_function=self._fs.glob, + ) + - def _load(self) -> np.ndarray: + def load(self) -> np.ndarray: """Loads data from the image file. Returns: @@ -447,7 +447,7 @@ The difference between the original `ImageDataset` and the versioned `ImageDatas image = Image.open(f).convert("RGBA") return np.asarray(image) - def _save(self, data: np.ndarray) -> None: + def save(self, data: np.ndarray) -> None: """Saves image data to the specified filepath.""" - save_path = get_filepath_str(self._filepath, self._protocol) + save_path = get_filepath_str(self._get_save_path(), self._protocol) diff --git a/docs/source/development/automated_testing.md b/docs/source/development/automated_testing.md index ed3efe3287..c4ca9a6538 100644 --- a/docs/source/development/automated_testing.md +++ b/docs/source/development/automated_testing.md @@ -19,21 +19,36 @@ There are many testing frameworks available for Python. One of the most popular Let's look at how you can start working with `pytest` in your Kedro project. -### Prerequisite: Install your Kedro project +### Install test requirements +Before getting started with test requirements, it is important to ensure you have installed your project locally. This allows you to test different parts of your project by importing them into your test files. + + +To install your project including all the project-specific dependencies and test requirements: +1. Add the following section to the `pyproject.toml` file located in the project root: +```toml +[project.optional-dependencies] +dev = [ + "pytest-cov", + "pytest-mock", + "pytest", +] +``` + +2. Navigate to the root directory of the project and run: +```bash +pip install ."[dev]" +``` -Before getting started with `pytest`, it is important to ensure you have installed your project locally. This allows you to test different parts of your project by importing them into your test files. +Alternatively, you can individually install test requirements as you would install other packages with `pip`, making sure you have installed your project locally and your [project's virtual environment is active](../get_started/install.md#create-a-virtual-environment-for-your-kedro-project). -To install your project, navigate to your project root and run the following command: +1. To install your project, navigate to your project root and run the following command: ```bash pip install -e . ``` - >**NOTE**: The option `-e` installs an editable version of your project, allowing you to make changes to the project files without needing to re-install them each time. -### Install `pytest` - -Install `pytest` as you would install other packages with `pip`, making sure your [project's virtual environment is active](../get_started/install.md#create-a-virtual-environment-for-your-kedro-project). +2. Install test requirements one by one: ```bash pip install pytest ``` diff --git a/docs/source/development/linting.md b/docs/source/development/linting.md index 61989cdf85..fbc0b0147c 100644 --- a/docs/source/development/linting.md +++ b/docs/source/development/linting.md @@ -18,17 +18,17 @@ There are a variety of Python tools available to use with your Kedro projects. T type. ### Install the tools -Install `ruff` by adding the following lines to your project's `requirements.txt` -file: -```text -ruff # Used for linting, formatting and sorting module imports +To install `ruff` add the following section to the `pyproject.toml` file located in the project root: +```toml +[project.optional-dependencies] +dev = ["ruff"] ``` -To install all the project-specific dependencies, including the linting tools, navigate to the root directory of the +Then to install your project including all the project-specific dependencies and the linting tools, navigate to the root directory of the project and run: ```bash -pip install -r requirements.txt +pip install ."[dev]" ``` Alternatively, you can individually install the linting tools using the following shell commands: diff --git a/docs/source/meta/images/slice_pipeline_kedro_viz.gif b/docs/source/meta/images/slice_pipeline_kedro_viz.gif new file mode 100644 index 0000000000..2d49c9e766 Binary files /dev/null and b/docs/source/meta/images/slice_pipeline_kedro_viz.gif differ diff --git a/docs/source/nodes_and_pipelines/slice_a_pipeline.md b/docs/source/nodes_and_pipelines/slice_a_pipeline.md index 2324a12fb0..2b2871dffe 100644 --- a/docs/source/nodes_and_pipelines/slice_a_pipeline.md +++ b/docs/source/nodes_and_pipelines/slice_a_pipeline.md @@ -1,6 +1,13 @@ # Slice a pipeline -Sometimes it is desirable to run a subset, or a 'slice' of a pipeline's nodes. In this page, we illustrate the programmatic options that Kedro provides. You can also use the [Kedro CLI to pass parameters to `kedro run`](../development/commands_reference.md#run-the-project) command and slice a pipeline. +Sometimes it is desirable to run a subset, or a 'slice' of a pipeline's nodes. There are two primary ways to achieve this: + + +1. **Visually through Kedro-Viz:** This approach allows you to visually choose and slice pipeline nodes, which then generates a run command for executing the slice within your Kedro project. Detailed steps on how to achieve this are available in the Kedro-Viz documentation: [Slice a Pipeline](https://docs.kedro.org/projects/kedro-viz/en/stable/slice_a_pipeline.html). + +![](../meta/images/slice_pipeline_kedro_viz.gif) + +2. **Programmatically with the Kedro CLI.** You can also use the [Kedro CLI to pass parameters to `kedro run`](../development/commands_reference.md#run-the-project) command and slice a pipeline. In this page, we illustrate the programmatic options that Kedro provides. Let's look again at the example pipeline from the [pipeline introduction documentation](./pipeline_introduction.md#how-to-build-a-pipeline), which computes the variance of a set of numbers: diff --git a/features/load_node.feature b/features/load_node.feature index fbc5a65a07..e745378e22 100644 --- a/features/load_node.feature +++ b/features/load_node.feature @@ -5,5 +5,6 @@ Feature: load_node in new project And I have run a non-interactive kedro new with starter "default" Scenario: Execute ipython load_node magic - When I execute the load_node magic command + When I install project and its dev dependencies + And I execute the load_node magic command Then the logs should show that load_node executed successfully diff --git a/features/steps/cli_steps.py b/features/steps/cli_steps.py index 7ee2c153d8..62cda23001 100644 --- a/features/steps/cli_steps.py +++ b/features/steps/cli_steps.py @@ -755,3 +755,13 @@ def exec_magic_command(context): def change_dir(context, dir): """Execute Kedro target.""" util.chdir(dir) + + +@when("I install project and its dev dependencies") +def pip_install_project_and_dev_dependencies(context): + """Install project and its development dependencies using pip.""" + _ = run( + [context.pip, "install", ".[dev]"], + env=context.env, + cwd=str(context.root_project_dir), + ) diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml b/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml index 462dd26eee..eb7cb5f113 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/pyproject.toml @@ -12,15 +12,21 @@ dynamic = ["dependencies", "version"] [project.optional-dependencies] docs = [ - "docutils<0.18.0", - "sphinx~=3.4.3", - "sphinx_rtd_theme==0.5.1", + "docutils<0.21", + "sphinx>=5.3,<7.3", + "sphinx_rtd_theme==2.0.0", "nbsphinx==0.8.1", - "sphinx-autodoc-typehints==1.11.1", - "sphinx_copybutton==0.3.1", + "sphinx-autodoc-typehints==1.20.2", + "sphinx_copybutton==0.5.2", "ipykernel>=5.3, <7.0", - "Jinja2<3.1.0", - "myst-parser~=0.17.2", + "Jinja2<3.2.0", + "myst-parser>=1.0,<2.1" +] +dev = [ + "pytest-cov~=3.0", + "pytest-mock>=1.7.1, <2.0", + "pytest~=7.2", + "ruff~=0.1.8" ] [tool.setuptools.dynamic] diff --git a/features/steps/test_starter/{{ cookiecutter.repo_name }}/requirements.txt b/features/steps/test_starter/{{ cookiecutter.repo_name }}/requirements.txt index 8da5d60851..b07568d9da 100644 --- a/features/steps/test_starter/{{ cookiecutter.repo_name }}/requirements.txt +++ b/features/steps/test_starter/{{ cookiecutter.repo_name }}/requirements.txt @@ -1,10 +1,6 @@ -ruff==0.1.8 ipython>=8.10 jupyterlab>=3.0 notebook kedro~={{ cookiecutter.kedro_version}} kedro-datasets[pandas-csvdataset]; python_version >= "3.9" kedro-datasets[pandas.CSVDataset]<2.0.0; python_version < '3.9' -pytest-cov~=3.0 -pytest-mock>=1.7.1, <2.0 -pytest~=7.2 diff --git a/kedro/framework/cli/catalog.py b/kedro/framework/cli/catalog.py index 223980dade..25fad6083d 100644 --- a/kedro/framework/cli/catalog.py +++ b/kedro/framework/cli/catalog.py @@ -2,9 +2,8 @@ from __future__ import annotations -import copy from collections import defaultdict -from itertools import chain +from itertools import chain, filterfalse from typing import TYPE_CHECKING, Any import click @@ -28,6 +27,11 @@ def _create_session(package_name: str, **kwargs: Any) -> KedroSession: return KedroSession.create(**kwargs) +def is_parameter(dataset_name: str) -> bool: + """Check if dataset is a parameter.""" + return dataset_name.startswith("params:") or dataset_name == "parameters" + + @click.group(name="Kedro") def catalog_cli() -> None: # pragma: no cover pass @@ -88,21 +92,13 @@ def list_datasets(metadata: ProjectMetadata, pipeline: str, env: str) -> None: # resolve any factory datasets in the pipeline factory_ds_by_type = defaultdict(list) - for ds_name in default_ds: - matched_pattern = data_catalog._match_pattern( - data_catalog._dataset_patterns, ds_name - ) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name) - if matched_pattern: - ds_config_copy = copy.deepcopy( - data_catalog._dataset_patterns.get(matched_pattern) - or data_catalog._default_pattern.get(matched_pattern) - or {} - ) - ds_config = data_catalog._resolve_config( - ds_name, matched_pattern, ds_config_copy + for ds_name in default_ds: + if data_catalog.config_resolver.match_pattern(ds_name): + ds_config = data_catalog.config_resolver.resolve_pattern(ds_name) + factory_ds_by_type[ds_config.get("type", "DefaultDataset")].append( + ds_name ) - factory_ds_by_type[ds_config["type"]].append(ds_name) default_ds = default_ds - set(chain.from_iterable(factory_ds_by_type.values())) @@ -128,12 +124,10 @@ def _map_type_to_datasets( datasets of the specific type as a value. """ mapping = defaultdict(list) # type: ignore[var-annotated] - for dataset in datasets: - is_param = dataset.startswith("params:") or dataset == "parameters" - if not is_param: - ds_type = datasets_meta[dataset].__class__.__name__ - if dataset not in mapping[ds_type]: - mapping[ds_type].append(dataset) + for dataset_name in filterfalse(is_parameter, datasets): + ds_type = datasets_meta[dataset_name].__class__.__name__ + if dataset_name not in mapping[ds_type]: + mapping[ds_type].append(dataset_name) return mapping @@ -170,20 +164,12 @@ def create_catalog(metadata: ProjectMetadata, pipeline_name: str, env: str) -> N f"'{pipeline_name}' pipeline not found! Existing pipelines: {existing_pipelines}" ) - pipe_datasets = { - ds_name - for ds_name in pipeline.datasets() - if not ds_name.startswith("params:") and ds_name != "parameters" - } + pipeline_datasets = set(filterfalse(is_parameter, pipeline.datasets())) - catalog_datasets = { - ds_name - for ds_name in context.catalog._datasets.keys() - if not ds_name.startswith("params:") and ds_name != "parameters" - } + catalog_datasets = set(filterfalse(is_parameter, context.catalog.list())) # Datasets that are missing in Data Catalog - missing_ds = sorted(pipe_datasets - catalog_datasets) + missing_ds = sorted(pipeline_datasets - catalog_datasets) if missing_ds: catalog_path = ( context.project_path @@ -221,12 +207,9 @@ def rank_catalog_factories(metadata: ProjectMetadata, env: str) -> None: session = _create_session(metadata.package_name, env=env) context = session.load_context() - catalog_factories = { - **context.catalog._dataset_patterns, - **context.catalog._default_pattern, - } + catalog_factories = context.catalog.config_resolver.list_patterns() if catalog_factories: - click.echo(yaml.dump(list(catalog_factories.keys()))) + click.echo(yaml.dump(catalog_factories)) else: click.echo("There are no dataset factories in the catalog.") @@ -242,7 +225,7 @@ def resolve_patterns(metadata: ProjectMetadata, env: str) -> None: context = session.load_context() catalog_config = context.config_loader["catalog"] - credentials_config = context.config_loader.get("credentials", None) + credentials_config = context._get_config_credentials() data_catalog = DataCatalog.from_config( catalog=catalog_config, credentials=credentials_config ) @@ -250,35 +233,25 @@ def resolve_patterns(metadata: ProjectMetadata, env: str) -> None: explicit_datasets = { ds_name: ds_config for ds_name, ds_config in catalog_config.items() - if not data_catalog._is_pattern(ds_name) + if not data_catalog.config_resolver.is_pattern(ds_name) } target_pipelines = pipelines.keys() - datasets = set() + pipeline_datasets = set() for pipe in target_pipelines: pl_obj = pipelines.get(pipe) if pl_obj: - datasets.update(pl_obj.datasets()) + pipeline_datasets.update(pl_obj.datasets()) - for ds_name in datasets: - is_param = ds_name.startswith("params:") or ds_name == "parameters" - if ds_name in explicit_datasets or is_param: + for ds_name in pipeline_datasets: + if ds_name in explicit_datasets or is_parameter(ds_name): continue - matched_pattern = data_catalog._match_pattern( - data_catalog._dataset_patterns, ds_name - ) or data_catalog._match_pattern(data_catalog._default_pattern, ds_name) - if matched_pattern: - ds_config_copy = copy.deepcopy( - data_catalog._dataset_patterns.get(matched_pattern) - or data_catalog._default_pattern.get(matched_pattern) - or {} - ) + ds_config = data_catalog.config_resolver.resolve_pattern(ds_name) - ds_config = data_catalog._resolve_config( - ds_name, matched_pattern, ds_config_copy - ) + # Exclude MemoryDatasets not set in the catalog explicitly + if ds_config: explicit_datasets[ds_name] = ds_config secho(yaml.dump(explicit_datasets)) diff --git a/kedro/framework/context/context.py b/kedro/framework/context/context.py index 3b61b747f6..5c14cbae38 100644 --- a/kedro/framework/context/context.py +++ b/kedro/framework/context/context.py @@ -14,7 +14,7 @@ from kedro.config import AbstractConfigLoader, MissingConfigException from kedro.framework.project import settings -from kedro.io import DataCatalog # noqa: TCH001 +from kedro.io import CatalogProtocol, DataCatalog # noqa: TCH001 from kedro.pipeline.transcoding import _transcode_split if TYPE_CHECKING: @@ -123,7 +123,7 @@ def _convert_paths_to_absolute_posix( return conf_dictionary -def _validate_transcoded_datasets(catalog: DataCatalog) -> None: +def _validate_transcoded_datasets(catalog: CatalogProtocol) -> None: """Validates transcoded datasets are correctly named Args: @@ -178,13 +178,13 @@ class KedroContext: ) @property - def catalog(self) -> DataCatalog: - """Read-only property referring to Kedro's ``DataCatalog`` for this context. + def catalog(self) -> CatalogProtocol: + """Read-only property referring to Kedro's catalog` for this context. Returns: - DataCatalog defined in `catalog.yml`. + catalog defined in `catalog.yml`. Raises: - KedroContextError: Incorrect ``DataCatalog`` registered for the project. + KedroContextError: Incorrect catalog registered for the project. """ return self._get_catalog() @@ -213,13 +213,13 @@ def _get_catalog( self, save_version: str | None = None, load_versions: dict[str, str] | None = None, - ) -> DataCatalog: - """A hook for changing the creation of a DataCatalog instance. + ) -> CatalogProtocol: + """A hook for changing the creation of a catalog instance. Returns: - DataCatalog defined in `catalog.yml`. + catalog defined in `catalog.yml`. Raises: - KedroContextError: Incorrect ``DataCatalog`` registered for the project. + KedroContextError: Incorrect catalog registered for the project. """ # '**/catalog*' reads modular pipeline configs diff --git a/kedro/framework/hooks/specs.py b/kedro/framework/hooks/specs.py index b0037a0878..3b32eb294c 100644 --- a/kedro/framework/hooks/specs.py +++ b/kedro/framework/hooks/specs.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: from kedro.framework.context import KedroContext - from kedro.io import DataCatalog + from kedro.io import CatalogProtocol from kedro.pipeline import Pipeline from kedro.pipeline.node import Node @@ -22,7 +22,7 @@ class DataCatalogSpecs: @hook_spec def after_catalog_created( # noqa: PLR0913 self, - catalog: DataCatalog, + catalog: CatalogProtocol, conf_catalog: dict[str, Any], conf_creds: dict[str, Any], feed_dict: dict[str, Any], @@ -53,7 +53,7 @@ class NodeSpecs: def before_node_run( self, node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], is_async: bool, session_id: str, @@ -63,7 +63,7 @@ def before_node_run( Args: node: The ``Node`` to run. - catalog: A ``DataCatalog`` containing the node's inputs and outputs. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. inputs: The dictionary of inputs dataset. The keys are dataset names and the values are the actual loaded input data, not the dataset instance. @@ -81,7 +81,7 @@ def before_node_run( def after_node_run( # noqa: PLR0913 self, node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], outputs: dict[str, Any], is_async: bool, @@ -93,7 +93,7 @@ def after_node_run( # noqa: PLR0913 Args: node: The ``Node`` that ran. - catalog: A ``DataCatalog`` containing the node's inputs and outputs. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. inputs: The dictionary of inputs dataset. The keys are dataset names and the values are the actual loaded input data, not the dataset instance. @@ -110,7 +110,7 @@ def on_node_error( # noqa: PLR0913 self, error: Exception, node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], is_async: bool, session_id: str, @@ -122,7 +122,7 @@ def on_node_error( # noqa: PLR0913 Args: error: The uncaught exception thrown during the node run. node: The ``Node`` to run. - catalog: A ``DataCatalog`` containing the node's inputs and outputs. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. inputs: The dictionary of inputs dataset. The keys are dataset names and the values are the actual loaded input data, not the dataset instance. @@ -137,7 +137,7 @@ class PipelineSpecs: @hook_spec def before_pipeline_run( - self, run_params: dict[str, Any], pipeline: Pipeline, catalog: DataCatalog + self, run_params: dict[str, Any], pipeline: Pipeline, catalog: CatalogProtocol ) -> None: """Hook to be invoked before a pipeline runs. @@ -164,7 +164,7 @@ def before_pipeline_run( } pipeline: The ``Pipeline`` that will be run. - catalog: The ``DataCatalog`` to be used during the run. + catalog: An implemented instance of ``CatalogProtocol`` to be used during the run. """ pass @@ -174,7 +174,7 @@ def after_pipeline_run( run_params: dict[str, Any], run_result: dict[str, Any], pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, ) -> None: """Hook to be invoked after a pipeline runs. @@ -202,7 +202,7 @@ def after_pipeline_run( run_result: The output of ``Pipeline`` run. pipeline: The ``Pipeline`` that was run. - catalog: The ``DataCatalog`` used during the run. + catalog: An implemented instance of ``CatalogProtocol`` used during the run. """ pass @@ -212,7 +212,7 @@ def on_pipeline_error( error: Exception, run_params: dict[str, Any], pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, ) -> None: """Hook to be invoked if a pipeline run throws an uncaught Exception. The signature of this error hook should match the signature of ``before_pipeline_run`` @@ -242,7 +242,7 @@ def on_pipeline_error( } pipeline: The ``Pipeline`` that will was run. - catalog: The ``DataCatalog`` used during the run. + catalog: An implemented instance of ``CatalogProtocol`` used during the run. """ pass diff --git a/kedro/framework/project/__init__.py b/kedro/framework/project/__init__.py index a3248b9daf..195fa077f6 100644 --- a/kedro/framework/project/__init__.py +++ b/kedro/framework/project/__init__.py @@ -20,6 +20,7 @@ from dynaconf import LazySettings from dynaconf.validator import ValidationError, Validator +from kedro.io import CatalogProtocol from kedro.pipeline import Pipeline, pipeline if TYPE_CHECKING: @@ -59,6 +60,25 @@ def validate( ) +class _ImplementsCatalogProtocolValidator(Validator): + """A validator to check if the supplied setting value is a subclass of the default class""" + + def validate( + self, settings: dynaconf.base.Settings, *args: Any, **kwargs: Any + ) -> None: + super().validate(settings, *args, **kwargs) + + protocol = CatalogProtocol + for name in self.names: + setting_value = getattr(settings, name) + if not isinstance(setting_value(), protocol): + raise ValidationError( + f"Invalid value '{setting_value.__module__}.{setting_value.__qualname__}' " + f"received for setting '{name}'. It must implement " + f"'{protocol.__module__}.{protocol.__qualname__}'." + ) + + class _HasSharedParentClassValidator(Validator): """A validator to check that the parent of the default class is an ancestor of the settings value.""" @@ -115,8 +135,9 @@ class _ProjectSettings(LazySettings): _CONFIG_LOADER_ARGS = Validator( "CONFIG_LOADER_ARGS", default={"base_env": "base", "default_run_env": "local"} ) - _DATA_CATALOG_CLASS = _IsSubclassValidator( - "DATA_CATALOG_CLASS", default=_get_default_class("kedro.io.DataCatalog") + _DATA_CATALOG_CLASS = _ImplementsCatalogProtocolValidator( + "DATA_CATALOG_CLASS", + default=_get_default_class("kedro.io.DataCatalog"), ) def __init__(self, *args: Any, **kwargs: Any): diff --git a/kedro/framework/session/session.py b/kedro/framework/session/session.py index 91928f7c4b..caa3553954 100644 --- a/kedro/framework/session/session.py +++ b/kedro/framework/session/session.py @@ -394,13 +394,11 @@ def run( # noqa: PLR0913 run_params=record_data, pipeline=filtered_pipeline, catalog=catalog ) + if isinstance(runner, ThreadRunner): + for ds in filtered_pipeline.datasets(): + if catalog.config_resolver.match_pattern(ds): + _ = catalog._get_dataset(ds) try: - if isinstance(runner, ThreadRunner): - for ds in filtered_pipeline.datasets(): - if catalog._match_pattern( - catalog._dataset_patterns, ds - ) or catalog._match_pattern(catalog._default_pattern, ds): - _ = catalog._get_dataset(ds) run_result = runner.run( filtered_pipeline, catalog, hook_manager, session_id ) diff --git a/kedro/io/__init__.py b/kedro/io/__init__.py index aba59827e9..9697e1bd35 100644 --- a/kedro/io/__init__.py +++ b/kedro/io/__init__.py @@ -5,15 +5,18 @@ from __future__ import annotations from .cached_dataset import CachedDataset +from .catalog_config_resolver import CatalogConfigResolver from .core import ( AbstractDataset, AbstractVersionedDataset, + CatalogProtocol, DatasetAlreadyExistsError, DatasetError, DatasetNotFoundError, Version, ) from .data_catalog import DataCatalog +from .kedro_data_catalog import KedroDataCatalog from .lambda_dataset import LambdaDataset from .memory_dataset import MemoryDataset from .shared_memory_dataset import SharedMemoryDataset @@ -22,10 +25,13 @@ "AbstractDataset", "AbstractVersionedDataset", "CachedDataset", + "CatalogProtocol", "DataCatalog", + "CatalogConfigResolver", "DatasetAlreadyExistsError", "DatasetError", "DatasetNotFoundError", + "KedroDataCatalog", "LambdaDataset", "MemoryDataset", "SharedMemoryDataset", diff --git a/kedro/io/cached_dataset.py b/kedro/io/cached_dataset.py index 5f8d96dc36..85d9341db5 100644 --- a/kedro/io/cached_dataset.py +++ b/kedro/io/cached_dataset.py @@ -103,7 +103,7 @@ def __repr__(self) -> str: } return self._pretty_repr(object_description) - def _load(self) -> Any: + def load(self) -> Any: data = self._cache.load() if self._cache.exists() else self._dataset.load() if not self._cache.exists(): @@ -111,7 +111,7 @@ def _load(self) -> Any: return data - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: self._dataset.save(data) self._cache.save(data) diff --git a/kedro/io/catalog_config_resolver.py b/kedro/io/catalog_config_resolver.py new file mode 100644 index 0000000000..8ec624d9e9 --- /dev/null +++ b/kedro/io/catalog_config_resolver.py @@ -0,0 +1,259 @@ +"""``CatalogConfigResolver`` resolves dataset configurations and datasets' +patterns based on catalog configuration and credentials provided. +""" + +from __future__ import annotations + +import copy +import logging +import re +from typing import Any, Dict + +from parse import parse + +from kedro.io.core import DatasetError + +Patterns = Dict[str, Dict[str, Any]] + +CREDENTIALS_KEY = "credentials" + + +class CatalogConfigResolver: + """Resolves dataset configurations based on patterns and credentials.""" + + def __init__( + self, + config: dict[str, dict[str, Any]] | None = None, + credentials: dict[str, dict[str, Any]] | None = None, + ): + self._runtime_patterns: Patterns = {} + self._dataset_patterns, self._default_pattern = self._extract_patterns( + config, credentials + ) + self._resolved_configs = self._resolve_config_credentials(config, credentials) + + @property + def config(self) -> dict[str, dict[str, Any]]: + return self._resolved_configs + + @property + def _logger(self) -> logging.Logger: + return logging.getLogger(__name__) + + @staticmethod + def is_pattern(pattern: str) -> bool: + """Check if a given string is a pattern. Assume that any name with '{' is a pattern.""" + return "{" in pattern + + @staticmethod + def _pattern_specificity(pattern: str) -> int: + """Calculate the specificity of a pattern based on characters outside curly brackets.""" + # Remove all the placeholders from the pattern and count the number of remaining chars + result = re.sub(r"\{.*?\}", "", pattern) + return len(result) + + @classmethod + def _sort_patterns(cls, dataset_patterns: Patterns) -> Patterns: + """Sort a dictionary of dataset patterns according to parsing rules. + + In order: + 1. Decreasing specificity (number of characters outside the curly brackets) + 2. Decreasing number of placeholders (number of curly bracket pairs) + 3. Alphabetically + """ + sorted_keys = sorted( + dataset_patterns, + key=lambda pattern: ( + -(cls._pattern_specificity(pattern)), + -pattern.count("{"), + pattern, + ), + ) + catch_all = [ + pattern for pattern in sorted_keys if cls._pattern_specificity(pattern) == 0 + ] + if len(catch_all) > 1: + raise DatasetError( + f"Multiple catch-all patterns found in the catalog: {', '.join(catch_all)}. Only one catch-all pattern is allowed, remove the extras." + ) + return {key: dataset_patterns[key] for key in sorted_keys} + + @staticmethod + def _fetch_credentials(credentials_name: str, credentials: dict[str, Any]) -> Any: + """Fetch the specified credentials from the provided credentials dictionary. + + Args: + credentials_name: Credentials name. + credentials: A dictionary with all credentials. + + Returns: + The set of requested credentials. + + Raises: + KeyError: When a data set with the given name has not yet been + registered. + + """ + try: + return credentials[credentials_name] + except KeyError as exc: + raise KeyError( + f"Unable to find credentials '{credentials_name}': check your data " + "catalog and credentials configuration. See " + "https://kedro.readthedocs.io/en/stable/kedro.io.DataCatalog.html " + "for an example." + ) from exc + + @classmethod + def _resolve_credentials( + cls, config: dict[str, Any], credentials: dict[str, Any] + ) -> dict[str, Any]: + """Return the dataset configuration where credentials are resolved using + credentials dictionary provided. + + Args: + config: Original dataset config, which may contain unresolved credentials. + credentials: A dictionary with all credentials. + + Returns: + The dataset config, where all the credentials are successfully resolved. + """ + config = copy.deepcopy(config) + + def _resolve_value(key: str, value: Any) -> Any: + if key == CREDENTIALS_KEY and isinstance(value, str): + return cls._fetch_credentials(value, credentials) + if isinstance(value, dict): + return {k: _resolve_value(k, v) for k, v in value.items()} + return value + + return {k: _resolve_value(k, v) for k, v in config.items()} + + @classmethod + def _resolve_dataset_config( + cls, + ds_name: str, + pattern: str, + config: Any, + ) -> Any: + """Resolve dataset configuration based on the provided pattern.""" + resolved_vars = parse(pattern, ds_name) + # Resolve the factory config for the dataset + if isinstance(config, dict): + for key, value in config.items(): + config[key] = cls._resolve_dataset_config(ds_name, pattern, value) + elif isinstance(config, (list, tuple)): + config = [ + cls._resolve_dataset_config(ds_name, pattern, value) for value in config + ] + elif isinstance(config, str) and "}" in config: + try: + config = config.format_map(resolved_vars.named) + except KeyError as exc: + raise DatasetError( + f"Unable to resolve '{config}' from the pattern '{pattern}'. Keys used in the configuration " + f"should be present in the dataset factory pattern." + ) from exc + return config + + def list_patterns(self) -> list[str]: + """List al patterns available in the catalog.""" + return ( + list(self._dataset_patterns.keys()) + + list(self._default_pattern.keys()) + + list(self._runtime_patterns.keys()) + ) + + def match_pattern(self, ds_name: str) -> str | None: + """Match a dataset name against patterns in a dictionary.""" + all_patterns = self.list_patterns() + matches = (pattern for pattern in all_patterns if parse(pattern, ds_name)) + return next(matches, None) + + def _get_pattern_config(self, pattern: str) -> dict[str, Any]: + return ( + self._dataset_patterns.get(pattern) + or self._default_pattern.get(pattern) + or self._runtime_patterns.get(pattern) + or {} + ) + + @classmethod + def _extract_patterns( + cls, + config: dict[str, dict[str, Any]] | None, + credentials: dict[str, dict[str, Any]] | None, + ) -> tuple[Patterns, Patterns]: + """Extract and sort patterns from the configuration.""" + config = config or {} + credentials = credentials or {} + dataset_patterns = {} + user_default = {} + + for ds_name, ds_config in config.items(): + if cls.is_pattern(ds_name): + dataset_patterns[ds_name] = cls._resolve_credentials( + ds_config, credentials + ) + + sorted_patterns = cls._sort_patterns(dataset_patterns) + if sorted_patterns: + # If the last pattern is a catch-all pattern, pop it and set it as the default + if cls._pattern_specificity(list(sorted_patterns.keys())[-1]) == 0: + last_pattern = sorted_patterns.popitem() + user_default = {last_pattern[0]: last_pattern[1]} + + return sorted_patterns, user_default + + def _resolve_config_credentials( + self, + config: dict[str, dict[str, Any]] | None, + credentials: dict[str, dict[str, Any]] | None, + ) -> dict[str, dict[str, Any]]: + """Initialize the dataset configuration with resolved credentials.""" + config = config or {} + credentials = credentials or {} + resolved_configs = {} + + for ds_name, ds_config in config.items(): + if not isinstance(ds_config, dict): + raise DatasetError( + f"Catalog entry '{ds_name}' is not a valid dataset configuration. " + "\nHint: If this catalog entry is intended for variable interpolation, " + "make sure that the key is preceded by an underscore." + ) + if not self.is_pattern(ds_name): + resolved_configs[ds_name] = self._resolve_credentials( + ds_config, credentials + ) + + return resolved_configs + + def resolve_pattern(self, ds_name: str) -> dict[str, Any]: + """Resolve dataset patterns and return resolved configurations based on the existing patterns.""" + matched_pattern = self.match_pattern(ds_name) + + if matched_pattern and ds_name not in self._resolved_configs: + pattern_config = self._get_pattern_config(matched_pattern) + ds_config = self._resolve_dataset_config( + ds_name, matched_pattern, copy.deepcopy(pattern_config) + ) + + if ( + self._pattern_specificity(matched_pattern) == 0 + and matched_pattern in self._default_pattern + ): + self._logger.warning( + "Config from the dataset factory pattern '%s' in the catalog will be used to " + "override the default dataset creation for '%s'", + matched_pattern, + ds_name, + ) + return ds_config # type: ignore[no-any-return] + + return self._resolved_configs.get(ds_name, {}) + + def add_runtime_patterns(self, dataset_patterns: Patterns) -> None: + """Add new runtime patterns and re-sort them.""" + self._runtime_patterns = {**self._runtime_patterns, **dataset_patterns} + self._runtime_patterns = self._sort_patterns(self._runtime_patterns) diff --git a/kedro/io/core.py b/kedro/io/core.py index f3975c9c3c..036babc829 100644 --- a/kedro/io/core.py +++ b/kedro/io/core.py @@ -17,7 +17,15 @@ from glob import iglob from operator import attrgetter from pathlib import Path, PurePath, PurePosixPath -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Protocol, + TypeVar, + runtime_checkable, +) from urllib.parse import urlsplit from cachetools import Cache, cachedmethod @@ -29,12 +37,25 @@ if TYPE_CHECKING: import os + from kedro.io.catalog_config_resolver import CatalogConfigResolver, Patterns + VERSION_FORMAT = "%Y-%m-%dT%H.%M.%S.%fZ" VERSIONED_FLAG_KEY = "versioned" VERSION_KEY = "version" HTTP_PROTOCOLS = ("http", "https") PROTOCOL_DELIMITER = "://" -CLOUD_PROTOCOLS = ("s3", "s3n", "s3a", "gcs", "gs", "adl", "abfs", "abfss", "gdrive") +CLOUD_PROTOCOLS = ( + "abfs", + "abfss", + "adl", + "gcs", + "gdrive", + "gs", + "oss", + "s3", + "s3a", + "s3n", +) class DatasetError(Exception): @@ -871,3 +892,70 @@ def validate_on_forbidden_chars(**kwargs: Any) -> None: raise DatasetError( f"Neither white-space nor semicolon are allowed in '{key}'." ) + + +_C = TypeVar("_C") + + +@runtime_checkable +class CatalogProtocol(Protocol[_C]): + _datasets: dict[str, AbstractDataset] + + def __contains__(self, ds_name: str) -> bool: + """Check if a dataset is in the catalog.""" + ... + + @property + def config_resolver(self) -> CatalogConfigResolver: + """Return a copy of the datasets dictionary.""" + ... + + @classmethod + def from_config(cls, catalog: dict[str, dict[str, Any]] | None) -> _C: + """Create a catalog instance from configuration.""" + ... + + def _get_dataset( + self, + dataset_name: str, + version: Any = None, + suggest: bool = True, + ) -> AbstractDataset: + """Retrieve a dataset by its name.""" + ... + + def list(self, regex_search: str | None = None) -> list[str]: + """List all dataset names registered in the catalog.""" + ... + + def save(self, name: str, data: Any) -> None: + """Save data to a registered dataset.""" + ... + + def load(self, name: str, version: str | None = None) -> Any: + """Load data from a registered dataset.""" + ... + + def add(self, ds_name: str, dataset: Any, replace: bool = False) -> None: + """Add a new dataset to the catalog.""" + ... + + def add_feed_dict(self, datasets: dict[str, Any], replace: bool = False) -> None: + """Add datasets to the catalog using the data provided through the `feed_dict`.""" + ... + + def exists(self, name: str) -> bool: + """Checks whether registered data set exists by calling its `exists()` method.""" + ... + + def release(self, name: str) -> None: + """Release any cached data associated with a dataset.""" + ... + + def confirm(self, name: str) -> None: + """Confirm a dataset by its name.""" + ... + + def shallow_copy(self, extra_dataset_patterns: Patterns | None = None) -> _C: + """Returns a shallow copy of the current object.""" + ... diff --git a/kedro/io/data_catalog.py b/kedro/io/data_catalog.py index d3fd163230..a010f3e852 100644 --- a/kedro/io/data_catalog.py +++ b/kedro/io/data_catalog.py @@ -7,15 +7,17 @@ from __future__ import annotations -import copy import difflib import logging import pprint import re -from typing import Any, Dict - -from parse import parse +from typing import Any +from kedro.io.catalog_config_resolver import ( + CREDENTIALS_KEY, # noqa: F401 + CatalogConfigResolver, + Patterns, +) from kedro.io.core import ( AbstractDataset, AbstractVersionedDataset, @@ -28,64 +30,10 @@ from kedro.io.memory_dataset import MemoryDataset from kedro.utils import _format_rich, _has_rich_handler -Patterns = Dict[str, Dict[str, Any]] - -CATALOG_KEY = "catalog" -CREDENTIALS_KEY = "credentials" +CATALOG_KEY = "catalog" # Kept to avoid the breaking change WORDS_REGEX_PATTERN = re.compile(r"\W+") -def _get_credentials(credentials_name: str, credentials: dict[str, Any]) -> Any: - """Return a set of credentials from the provided credentials dict. - - Args: - credentials_name: Credentials name. - credentials: A dictionary with all credentials. - - Returns: - The set of requested credentials. - - Raises: - KeyError: When a data set with the given name has not yet been - registered. - - """ - try: - return credentials[credentials_name] - except KeyError as exc: - raise KeyError( - f"Unable to find credentials '{credentials_name}': check your data " - "catalog and credentials configuration. See " - "https://docs.kedro.org/en/stable/api/kedro.io.DataCatalog.html " - "for an example." - ) from exc - - -def _resolve_credentials( - config: dict[str, Any], credentials: dict[str, Any] -) -> dict[str, Any]: - """Return the dataset configuration where credentials are resolved using - credentials dictionary provided. - - Args: - config: Original dataset config, which may contain unresolved credentials. - credentials: A dictionary with all credentials. - - Returns: - The dataset config, where all the credentials are successfully resolved. - """ - config = copy.deepcopy(config) - - def _map_value(key: str, value: Any) -> Any: - if key == CREDENTIALS_KEY and isinstance(value, str): - return _get_credentials(value, credentials) - if isinstance(value, dict): - return {k: _map_value(k, v) for k, v in value.items()} - return value - - return {k: _map_value(k, v) for k, v in config.items()} - - def _sub_nonword_chars(dataset_name: str) -> str: """Replace non-word characters in data set names since Kedro 0.16.2. @@ -103,13 +51,15 @@ class _FrozenDatasets: def __init__( self, - *datasets_collections: _FrozenDatasets | dict[str, AbstractDataset], + *datasets_collections: _FrozenDatasets | dict[str, AbstractDataset] | None, ): """Return a _FrozenDatasets instance from some datasets collections. Each collection could either be another _FrozenDatasets or a dictionary. """ self._original_names: dict[str, str] = {} for collection in datasets_collections: + if collection is None: + continue if isinstance(collection, _FrozenDatasets): self.__dict__.update(collection.__dict__) self._original_names.update(collection._original_names) @@ -125,7 +75,7 @@ def __setattr__(self, key: str, value: Any) -> None: if key == "_original_names": super().__setattr__(key, value) return - msg = "Operation not allowed! " + msg = "Operation not allowed. " if key in self.__dict__: msg += "Please change datasets through configuration." else: @@ -161,10 +111,11 @@ def __init__( # noqa: PLR0913 self, datasets: dict[str, AbstractDataset] | None = None, feed_dict: dict[str, Any] | None = None, - dataset_patterns: Patterns | None = None, + dataset_patterns: Patterns | None = None, # Kept for interface compatibility load_versions: dict[str, str] | None = None, save_version: str | None = None, - default_pattern: Patterns | None = None, + default_pattern: Patterns | None = None, # Kept for interface compatibility + config_resolver: CatalogConfigResolver | None = None, ) -> None: """``DataCatalog`` stores instances of ``AbstractDataset`` implementations to provide ``load`` and ``save`` capabilities from @@ -195,6 +146,8 @@ def __init__( # noqa: PLR0913 sorted in lexicographical order. default_pattern: A dictionary of the default catch-all pattern that overrides the default pattern provided through the runners. + config_resolver: An instance of CatalogConfigResolver to resolve dataset patterns and configurations. + Example: :: @@ -206,14 +159,21 @@ def __init__( # noqa: PLR0913 >>> save_args={"index": False}) >>> catalog = DataCatalog(datasets={'cars': cars}) """ - self._datasets = dict(datasets or {}) - self.datasets = _FrozenDatasets(self._datasets) - # Keep a record of all patterns in the catalog. - # {dataset pattern name : dataset pattern body} - self._dataset_patterns = dataset_patterns or {} + self._config_resolver = config_resolver or CatalogConfigResolver() + + # Kept to avoid breaking changes + if not config_resolver: + self._config_resolver._dataset_patterns = dataset_patterns or {} + self._config_resolver._default_pattern = default_pattern or {} + + self._datasets: dict[str, AbstractDataset] = {} + self.datasets: _FrozenDatasets | None = None + + self.add_all(datasets or {}) + self._load_versions = load_versions or {} self._save_version = save_version - self._default_pattern = default_pattern or {} + self._use_rich_markup = _has_rich_handler() if feed_dict: @@ -222,6 +182,23 @@ def __init__( # noqa: PLR0913 def __repr__(self) -> str: return self.datasets.__repr__() + def __contains__(self, dataset_name: str) -> bool: + """Check if an item is in the catalog as a materialised dataset or pattern""" + return ( + dataset_name in self._datasets + or self._config_resolver.match_pattern(dataset_name) is not None + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + return (self._datasets, self._config_resolver.list_patterns()) == ( + other._datasets, + other.config_resolver.list_patterns(), + ) + + @property + def config_resolver(self) -> CatalogConfigResolver: + return self._config_resolver + @property def _logger(self) -> logging.Logger: return logging.getLogger(__name__) @@ -303,44 +280,28 @@ class to be loaded is specified with the key ``type`` and their >>> df = catalog.load("cars") >>> catalog.save("boats", df) """ + catalog = catalog or {} datasets = {} - dataset_patterns = {} - catalog = copy.deepcopy(catalog) or {} - credentials = copy.deepcopy(credentials) or {} + config_resolver = CatalogConfigResolver(catalog, credentials) save_version = save_version or generate_timestamp() - load_versions = copy.deepcopy(load_versions) or {} - user_default = {} - - for ds_name, ds_config in catalog.items(): - if not isinstance(ds_config, dict): - raise DatasetError( - f"Catalog entry '{ds_name}' is not a valid dataset configuration. " - "\nHint: If this catalog entry is intended for variable interpolation, " - "make sure that the key is preceded by an underscore." - ) + load_versions = load_versions or {} - ds_config = _resolve_credentials( # noqa: PLW2901 - ds_config, credentials - ) - if cls._is_pattern(ds_name): - # Add each factory to the dataset_patterns dict. - dataset_patterns[ds_name] = ds_config - - else: + for ds_name in catalog: + if not config_resolver.is_pattern(ds_name): datasets[ds_name] = AbstractDataset.from_config( - ds_name, ds_config, load_versions.get(ds_name), save_version + ds_name, + config_resolver.config.get(ds_name, {}), + load_versions.get(ds_name), + save_version, ) - sorted_patterns = cls._sort_patterns(dataset_patterns) - if sorted_patterns: - # If the last pattern is a catch-all pattern, pop it and set it as the default - if cls._specificity(list(sorted_patterns.keys())[-1]) == 0: - last_pattern = sorted_patterns.popitem() - user_default = {last_pattern[0]: last_pattern[1]} missing_keys = [ - key - for key in load_versions.keys() - if not (key in catalog or cls._match_pattern(sorted_patterns, key)) + ds_name + for ds_name in load_versions + if not ( + ds_name in config_resolver.config + or config_resolver.match_pattern(ds_name) + ) ] if missing_keys: raise DatasetNotFoundError( @@ -350,107 +311,29 @@ class to be loaded is specified with the key ``type`` and their return cls( datasets=datasets, - dataset_patterns=sorted_patterns, + dataset_patterns=config_resolver._dataset_patterns, load_versions=load_versions, save_version=save_version, - default_pattern=user_default, + default_pattern=config_resolver._default_pattern, + config_resolver=config_resolver, ) - @staticmethod - def _is_pattern(pattern: str) -> bool: - """Check if a given string is a pattern. Assume that any name with '{' is a pattern.""" - return "{" in pattern - - @staticmethod - def _match_pattern(dataset_patterns: Patterns, dataset_name: str) -> str | None: - """Match a dataset name against patterns in a dictionary.""" - matches = ( - pattern - for pattern in dataset_patterns.keys() - if parse(pattern, dataset_name) - ) - return next(matches, None) - - @classmethod - def _sort_patterns(cls, dataset_patterns: Patterns) -> dict[str, dict[str, Any]]: - """Sort a dictionary of dataset patterns according to parsing rules. - - In order: - - 1. Decreasing specificity (number of characters outside the curly brackets) - 2. Decreasing number of placeholders (number of curly bracket pairs) - 3. Alphabetically - """ - sorted_keys = sorted( - dataset_patterns, - key=lambda pattern: ( - -(cls._specificity(pattern)), - -pattern.count("{"), - pattern, - ), - ) - catch_all = [ - pattern for pattern in sorted_keys if cls._specificity(pattern) == 0 - ] - if len(catch_all) > 1: - raise DatasetError( - f"Multiple catch-all patterns found in the catalog: {', '.join(catch_all)}. Only one catch-all pattern is allowed, remove the extras." - ) - return {key: dataset_patterns[key] for key in sorted_keys} - - @staticmethod - def _specificity(pattern: str) -> int: - """Helper function to check the length of exactly matched characters not inside brackets. - - Example: - :: - - >>> specificity("{namespace}.companies") = 10 - >>> specificity("{namespace}.{dataset}") = 1 - >>> specificity("france.companies") = 16 - """ - # Remove all the placeholders from the pattern and count the number of remaining chars - result = re.sub(r"\{.*?\}", "", pattern) - return len(result) - def _get_dataset( self, dataset_name: str, version: Version | None = None, suggest: bool = True, ) -> AbstractDataset: - matched_pattern = self._match_pattern( - self._dataset_patterns, dataset_name - ) or self._match_pattern(self._default_pattern, dataset_name) - if dataset_name not in self._datasets and matched_pattern: - # If the dataset is a patterned dataset, materialise it and add it to - # the catalog - config_copy = copy.deepcopy( - self._dataset_patterns.get(matched_pattern) - or self._default_pattern.get(matched_pattern) - or {} - ) - dataset_config = self._resolve_config( - dataset_name, matched_pattern, config_copy - ) - dataset = AbstractDataset.from_config( + ds_config = self._config_resolver.resolve_pattern(dataset_name) + + if dataset_name not in self._datasets and ds_config: + ds = AbstractDataset.from_config( dataset_name, - dataset_config, + ds_config, self._load_versions.get(dataset_name), self._save_version, ) - if ( - self._specificity(matched_pattern) == 0 - and matched_pattern in self._default_pattern - ): - self._logger.warning( - "Config from the dataset factory pattern '%s' in the catalog will be used to " - "override the default dataset creation for '%s'", - matched_pattern, - dataset_name, - ) - - self.add(dataset_name, dataset) + self.add(dataset_name, ds) if dataset_name not in self._datasets: error_msg = f"Dataset '{dataset_name}' not found in the catalog" @@ -462,7 +345,9 @@ def _get_dataset( suggestions = ", ".join(matches) error_msg += f" - did you mean one of these instead: {suggestions}" raise DatasetNotFoundError(error_msg) + dataset = self._datasets[dataset_name] + if version and isinstance(dataset, AbstractVersionedDataset): # we only want to return a similar-looking dataset, # not modify the one stored in the current catalog @@ -470,41 +355,6 @@ def _get_dataset( return dataset - def __contains__(self, dataset_name: str) -> bool: - """Check if an item is in the catalog as a materialised dataset or pattern""" - matched_pattern = self._match_pattern(self._dataset_patterns, dataset_name) - if dataset_name in self._datasets or matched_pattern: - return True - return False - - @classmethod - def _resolve_config( - cls, - dataset_name: str, - matched_pattern: str, - config: dict, - ) -> dict[str, Any]: - """Get resolved AbstractDataset from a factory config""" - result = parse(matched_pattern, dataset_name) - # Resolve the factory config for the dataset - if isinstance(config, dict): - for key, value in config.items(): - config[key] = cls._resolve_config(dataset_name, matched_pattern, value) - elif isinstance(config, (list, tuple)): - config = [ - cls._resolve_config(dataset_name, matched_pattern, value) - for value in config - ] - elif isinstance(config, str) and "}" in config: - try: - config = str(config).format_map(result.named) - except KeyError as exc: - raise DatasetError( - f"Unable to resolve '{config}' from the pattern '{matched_pattern}'. Keys used in the configuration " - f"should be present in the dataset factory pattern." - ) from exc - return config - def load(self, name: str, version: str | None = None) -> Any: """Loads a registered data set. @@ -619,7 +469,10 @@ def release(self, name: str) -> None: dataset.release() def add( - self, dataset_name: str, dataset: AbstractDataset, replace: bool = False + self, + dataset_name: str, + dataset: AbstractDataset, + replace: bool = False, ) -> None: """Adds a new ``AbstractDataset`` object to the ``DataCatalog``. @@ -657,7 +510,9 @@ def add( self.datasets = _FrozenDatasets(self.datasets, {dataset_name: dataset}) def add_all( - self, datasets: dict[str, AbstractDataset], replace: bool = False + self, + datasets: dict[str, AbstractDataset], + replace: bool = False, ) -> None: """Adds a group of new data sets to the ``DataCatalog``. @@ -688,8 +543,8 @@ def add_all( >>> >>> assert catalog.list() == ["cars", "planes", "boats"] """ - for name, dataset in datasets.items(): - self.add(name, dataset, replace) + for ds_name, ds in datasets.items(): + self.add(ds_name, ds, replace) def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> None: """Add datasets to the ``DataCatalog`` using the data provided through the `feed_dict`. @@ -726,13 +581,13 @@ def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> Non >>> >>> assert catalog.load("data_csv_dataset").equals(df) """ - for dataset_name in feed_dict: - if isinstance(feed_dict[dataset_name], AbstractDataset): - dataset = feed_dict[dataset_name] - else: - dataset = MemoryDataset(data=feed_dict[dataset_name]) # type: ignore[abstract] - - self.add(dataset_name, dataset, replace) + for ds_name, ds_data in feed_dict.items(): + dataset = ( + ds_data + if isinstance(ds_data, AbstractDataset) + else MemoryDataset(data=ds_data) # type: ignore[abstract] + ) + self.add(ds_name, dataset, replace) def list(self, regex_search: str | None = None) -> list[str]: """ @@ -777,7 +632,7 @@ def list(self, regex_search: str | None = None) -> list[str]: raise SyntaxError( f"Invalid regular expression provided: '{regex_search}'" ) from exc - return [dset_name for dset_name in self._datasets if pattern.search(dset_name)] + return [ds_name for ds_name in self._datasets if pattern.search(ds_name)] def shallow_copy( self, extra_dataset_patterns: Patterns | None = None @@ -787,26 +642,15 @@ def shallow_copy( Returns: Copy of the current object. """ - if not self._default_pattern and extra_dataset_patterns: - unsorted_dataset_patterns = { - **self._dataset_patterns, - **extra_dataset_patterns, - } - dataset_patterns = self._sort_patterns(unsorted_dataset_patterns) - else: - dataset_patterns = self._dataset_patterns + if extra_dataset_patterns: + self._config_resolver.add_runtime_patterns(extra_dataset_patterns) return self.__class__( datasets=self._datasets, - dataset_patterns=dataset_patterns, + dataset_patterns=self._config_resolver._dataset_patterns, + default_pattern=self._config_resolver._default_pattern, load_versions=self._load_versions, save_version=self._save_version, - default_pattern=self._default_pattern, - ) - - def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] - return (self._datasets, self._dataset_patterns) == ( - other._datasets, - other._dataset_patterns, + config_resolver=self._config_resolver, ) def confirm(self, name: str) -> None: diff --git a/kedro/io/kedro_data_catalog.py b/kedro/io/kedro_data_catalog.py new file mode 100644 index 0000000000..ce06e34aac --- /dev/null +++ b/kedro/io/kedro_data_catalog.py @@ -0,0 +1,346 @@ +"""``KedroDataCatalog`` stores instances of ``AbstractDataset`` implementations to +provide ``load`` and ``save`` capabilities from anywhere in the program. To +use a ``KedroDataCatalog``, you need to instantiate it with a dictionary of datasets. +Then it will act as a single point of reference for your calls, relaying load and +save functions to the underlying datasets. +""" + +from __future__ import annotations + +import copy +import difflib +import logging +import re +from typing import Any + +from kedro.io.catalog_config_resolver import CatalogConfigResolver, Patterns +from kedro.io.core import ( + AbstractDataset, + AbstractVersionedDataset, + CatalogProtocol, + DatasetAlreadyExistsError, + DatasetError, + DatasetNotFoundError, + Version, + generate_timestamp, +) +from kedro.io.memory_dataset import MemoryDataset +from kedro.utils import _format_rich, _has_rich_handler + + +class KedroDataCatalog(CatalogProtocol): + def __init__( + self, + datasets: dict[str, AbstractDataset] | None = None, + raw_data: dict[str, Any] | None = None, + config_resolver: CatalogConfigResolver | None = None, + load_versions: dict[str, str] | None = None, + save_version: str | None = None, + ) -> None: + """``KedroDataCatalog`` stores instances of ``AbstractDataset`` + implementations to provide ``load`` and ``save`` capabilities from + anywhere in the program. To use a ``KedroDataCatalog``, you need to + instantiate it with a dictionary of datasets. Then it will act as a + single point of reference for your calls, relaying load and save + functions to the underlying datasets. + + Args: + datasets: A dictionary of dataset names and dataset instances. + raw_data: A dictionary with data to be added in memory as `MemoryDataset`` instances. + Keys represent dataset names and the values are raw data. + config_resolver: An instance of CatalogConfigResolver to resolve dataset patterns and configurations. + load_versions: A mapping between dataset names and versions + to load. Has no effect on datasets without enabled versioning. + save_version: Version string to be used for ``save`` operations + by all datasets with enabled versioning. It must: a) be a + case-insensitive string that conforms with operating system + filename limitations, b) always return the latest version when + sorted in lexicographical order. + """ + self._config_resolver = config_resolver or CatalogConfigResolver() + self._datasets = datasets or {} + self._load_versions = load_versions or {} + self._save_version = save_version + + self._use_rich_markup = _has_rich_handler() + + for ds_name, ds_config in self._config_resolver.config.items(): + self._add_from_config(ds_name, ds_config) + + if raw_data: + self.add_data(raw_data) + + @property + def datasets(self) -> dict[str, Any]: + return copy.copy(self._datasets) + + @datasets.setter + def datasets(self, value: Any) -> None: + raise AttributeError( + "Operation not allowed. Please use KedroDataCatalog.add() instead." + ) + + @property + def config_resolver(self) -> CatalogConfigResolver: + return self._config_resolver + + def __repr__(self) -> str: + return repr(self._datasets) + + def __contains__(self, dataset_name: str) -> bool: + """Check if an item is in the catalog as a materialised dataset or pattern""" + return ( + dataset_name in self._datasets + or self._config_resolver.match_pattern(dataset_name) is not None + ) + + def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] + return (self._datasets, self._config_resolver.list_patterns()) == ( + other._datasets, + other.config_resolver.list_patterns(), + ) + + @property + def _logger(self) -> logging.Logger: + return logging.getLogger(__name__) + + @classmethod + def from_config( + cls, + catalog: dict[str, dict[str, Any]] | None, + credentials: dict[str, dict[str, Any]] | None = None, + load_versions: dict[str, str] | None = None, + save_version: str | None = None, + ) -> KedroDataCatalog: + """Create a ``KedroDataCatalog`` instance from configuration. This is a + factory method used to provide developers with a way to instantiate + ``KedroDataCatalog`` with configuration parsed from configuration files. + """ + catalog = catalog or {} + config_resolver = CatalogConfigResolver(catalog, credentials) + save_version = save_version or generate_timestamp() + load_versions = load_versions or {} + + missing_keys = [ + ds_name + for ds_name in load_versions + if not ( + ds_name in config_resolver.config + or config_resolver.match_pattern(ds_name) + ) + ] + if missing_keys: + raise DatasetNotFoundError( + f"'load_versions' keys [{', '.join(sorted(missing_keys))}] " + f"are not found in the catalog." + ) + + return cls( + load_versions=load_versions, + save_version=save_version, + config_resolver=config_resolver, + ) + + @staticmethod + def _validate_dataset_config(ds_name: str, ds_config: Any) -> None: + if not isinstance(ds_config, dict): + raise DatasetError( + f"Catalog entry '{ds_name}' is not a valid dataset configuration. " + "\nHint: If this catalog entry is intended for variable interpolation, " + "make sure that the key is preceded by an underscore." + ) + + def _add_from_config(self, ds_name: str, ds_config: dict[str, Any]) -> None: + # TODO: Add lazy loading feature to store the configuration but not to init actual dataset + # TODO: Initialise actual dataset when load or save + self._validate_dataset_config(ds_name, ds_config) + ds = AbstractDataset.from_config( + ds_name, + ds_config, + self._load_versions.get(ds_name), + self._save_version, + ) + + self.add(ds_name, ds) + + def get_dataset( + self, ds_name: str, version: Version | None = None, suggest: bool = True + ) -> AbstractDataset: + """Get a dataset by name from an internal collection of datasets. + + If a dataset is not in the collection but matches any pattern + it is instantiated and added to the collection first, then returned. + + Args: + ds_name: A dataset name. + version: Optional argument for concrete dataset version to be loaded. + Works only with versioned datasets. + suggest: Optional argument whether to suggest fuzzy-matching datasets' names + in the DatasetNotFoundError message. + + Returns: + An instance of AbstractDataset. + + Raises: + DatasetNotFoundError: When a dataset with the given name + is not in the collection and do not match patterns. + """ + if ds_name not in self._datasets: + ds_config = self._config_resolver.resolve_pattern(ds_name) + if ds_config: + self._add_from_config(ds_name, ds_config) + + dataset = self._datasets.get(ds_name, None) + + if dataset is None: + error_msg = f"Dataset '{ds_name}' not found in the catalog" + # Flag to turn on/off fuzzy-matching which can be time consuming and + # slow down plugins like `kedro-viz` + if suggest: + matches = difflib.get_close_matches(ds_name, self._datasets.keys()) + if matches: + suggestions = ", ".join(matches) + error_msg += f" - did you mean one of these instead: {suggestions}" + raise DatasetNotFoundError(error_msg) + + if version and isinstance(dataset, AbstractVersionedDataset): + # we only want to return a similar-looking dataset, + # not modify the one stored in the current catalog + dataset = dataset._copy(_version=version) + + return dataset + + def _get_dataset( + self, dataset_name: str, version: Version | None = None, suggest: bool = True + ) -> AbstractDataset: + # TODO: remove when removing old catalog + return self.get_dataset(dataset_name, version, suggest) + + def add( + self, ds_name: str, dataset: AbstractDataset, replace: bool = False + ) -> None: + """Adds a new ``AbstractDataset`` object to the ``KedroDataCatalog``.""" + if ds_name in self._datasets: + if replace: + self._logger.warning("Replacing dataset '%s'", ds_name) + else: + raise DatasetAlreadyExistsError( + f"Dataset '{ds_name}' has already been registered" + ) + self._datasets[ds_name] = dataset + + def list(self, regex_search: str | None = None) -> list[str]: + """ + List of all dataset names registered in the catalog. + This can be filtered by providing an optional regular expression + which will only return matching keys. + """ + + if regex_search is None: + return list(self._datasets.keys()) + + if not regex_search.strip(): + self._logger.warning("The empty string will not match any datasets") + return [] + + try: + pattern = re.compile(regex_search, flags=re.IGNORECASE) + except re.error as exc: + raise SyntaxError( + f"Invalid regular expression provided: '{regex_search}'" + ) from exc + return [ds_name for ds_name in self._datasets if pattern.search(ds_name)] + + def save(self, name: str, data: Any) -> None: + """Save data to a registered dataset.""" + dataset = self.get_dataset(name) + + self._logger.info( + "Saving data to %s (%s)...", + _format_rich(name, "dark_orange") if self._use_rich_markup else name, + type(dataset).__name__, + extra={"markup": True}, + ) + + dataset.save(data) + + def load(self, name: str, version: str | None = None) -> Any: + """Loads a registered dataset.""" + load_version = Version(version, None) if version else None + dataset = self.get_dataset(name, version=load_version) + + self._logger.info( + "Loading data from %s (%s)...", + _format_rich(name, "dark_orange") if self._use_rich_markup else name, + type(dataset).__name__, + extra={"markup": True}, + ) + + return dataset.load() + + def release(self, name: str) -> None: + """Release any cached data associated with a dataset + Args: + name: A dataset to be checked. + Raises: + DatasetNotFoundError: When a dataset with the given name + has not yet been registered. + """ + dataset = self.get_dataset(name) + dataset.release() + + def confirm(self, name: str) -> None: + """Confirm a dataset by its name. + Args: + name: Name of the dataset. + Raises: + DatasetError: When the dataset does not have `confirm` method. + """ + self._logger.info("Confirming dataset '%s'", name) + dataset = self.get_dataset(name) + + if hasattr(dataset, "confirm"): + dataset.confirm() + else: + raise DatasetError(f"Dataset '{name}' does not have 'confirm' method") + + def add_data(self, data: dict[str, Any], replace: bool = False) -> None: + # This method was simplified to add memory datasets only, since + # adding AbstractDataset can be done via add() method + for ds_name, ds_data in data.items(): + self.add(ds_name, MemoryDataset(data=ds_data), replace) # type: ignore[abstract] + + def add_feed_dict(self, feed_dict: dict[str, Any], replace: bool = False) -> None: + # TODO: remove when removing old catalog + return self.add_data(feed_dict, replace) + + def shallow_copy( + self, extra_dataset_patterns: Patterns | None = None + ) -> KedroDataCatalog: + # TODO: remove when removing old catalog + """Returns a shallow copy of the current object. + + Returns: + Copy of the current object. + """ + if extra_dataset_patterns: + self._config_resolver.add_runtime_patterns(extra_dataset_patterns) + return self + + def exists(self, name: str) -> bool: + """Checks whether registered dataset exists by calling its `exists()` + method. Raises a warning and returns False if `exists()` is not + implemented. + + Args: + name: A dataset to be checked. + + Returns: + Whether the dataset output exists. + + """ + try: + dataset = self._get_dataset(name) + except DatasetNotFoundError: + return False + return dataset.exists() diff --git a/kedro/io/memory_dataset.py b/kedro/io/memory_dataset.py index 56ad92b7f2..1b4bb8a371 100644 --- a/kedro/io/memory_dataset.py +++ b/kedro/io/memory_dataset.py @@ -59,7 +59,7 @@ def __init__( if data is not _EMPTY: self.save.__wrapped__(self, data) # type: ignore[attr-defined] - def _load(self) -> Any: + def load(self) -> Any: if self._data is _EMPTY: raise DatasetError("Data for MemoryDataset has not been saved yet.") @@ -67,7 +67,7 @@ def _load(self) -> Any: data = _copy_with_mode(self._data, copy_mode=copy_mode) return data - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: copy_mode = self._copy_mode or _infer_copy_mode(data) self._data = _copy_with_mode(data, copy_mode=copy_mode) diff --git a/kedro/io/shared_memory_dataset.py b/kedro/io/shared_memory_dataset.py index abe89ff2b9..139180b578 100644 --- a/kedro/io/shared_memory_dataset.py +++ b/kedro/io/shared_memory_dataset.py @@ -36,10 +36,10 @@ def __getattr__(self, name: str) -> Any: raise AttributeError() return getattr(self.shared_memory_dataset, name) # pragma: no cover - def _load(self) -> Any: + def load(self) -> Any: return self.shared_memory_dataset.load() - def _save(self, data: Any) -> None: + def save(self, data: Any) -> None: """Calls save method of a shared MemoryDataset in SyncManager.""" try: self.shared_memory_dataset.save(data) diff --git a/kedro/runner/parallel_runner.py b/kedro/runner/parallel_runner.py index 62d7e1216b..d09601ff7e 100644 --- a/kedro/runner/parallel_runner.py +++ b/kedro/runner/parallel_runner.py @@ -22,7 +22,7 @@ ) from kedro.framework.project import settings from kedro.io import ( - DataCatalog, + CatalogProtocol, DatasetNotFoundError, MemoryDataset, SharedMemoryDataset, @@ -60,7 +60,7 @@ def _bootstrap_subprocess( def _run_node_synchronization( # noqa: PLR0913 node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, is_async: bool = False, session_id: str | None = None, package_name: str | None = None, @@ -73,7 +73,7 @@ def _run_node_synchronization( # noqa: PLR0913 Args: node: The ``Node`` to run. - catalog: A ``DataCatalog`` containing the node's inputs and outputs. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. session_id: The session id of the pipeline run. @@ -118,7 +118,7 @@ def __init__( cannot be larger than 61 and will be set to min(61, max_workers). is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. - extra_dataset_patterns: Extra dataset factory patterns to be added to the DataCatalog + extra_dataset_patterns: Extra dataset factory patterns to be added to the catalog during the run. This is used to set the default datasets to SharedMemoryDataset for `ParallelRunner`. @@ -168,7 +168,7 @@ def _validate_nodes(cls, nodes: Iterable[Node]) -> None: ) @classmethod - def _validate_catalog(cls, catalog: DataCatalog, pipeline: Pipeline) -> None: + def _validate_catalog(cls, catalog: CatalogProtocol, pipeline: Pipeline) -> None: """Ensure that all data sets are serialisable and that we do not have any non proxied memory data sets being used as outputs as their content will not be synchronized across threads. @@ -213,7 +213,9 @@ def _validate_catalog(cls, catalog: DataCatalog, pipeline: Pipeline) -> None: f"MemoryDatasets" ) - def _set_manager_datasets(self, catalog: DataCatalog, pipeline: Pipeline) -> None: + def _set_manager_datasets( + self, catalog: CatalogProtocol, pipeline: Pipeline + ) -> None: for dataset in pipeline.datasets(): try: catalog.exists(dataset) @@ -240,7 +242,7 @@ def _get_required_workers_count(self, pipeline: Pipeline) -> int: def _run( self, pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> None: @@ -248,7 +250,7 @@ def _run( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. diff --git a/kedro/runner/runner.py b/kedro/runner/runner.py index 2ffd0389e4..f3a0889909 100644 --- a/kedro/runner/runner.py +++ b/kedro/runner/runner.py @@ -21,7 +21,7 @@ from more_itertools import interleave from kedro.framework.hooks.manager import _NullPluginManager -from kedro.io import DataCatalog, MemoryDataset +from kedro.io import CatalogProtocol, MemoryDataset from kedro.pipeline import Pipeline if TYPE_CHECKING: @@ -45,7 +45,7 @@ def __init__( Args: is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. - extra_dataset_patterns: Extra dataset factory patterns to be added to the DataCatalog + extra_dataset_patterns: Extra dataset factory patterns to be added to the catalog during the run. This is used to set the default datasets on the Runner instances. """ @@ -59,7 +59,7 @@ def _logger(self) -> logging.Logger: def run( self, pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager | None = None, session_id: str | None = None, ) -> dict[str, Any]: @@ -68,7 +68,7 @@ def run( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. @@ -76,14 +76,13 @@ def run( ValueError: Raised when ``Pipeline`` inputs cannot be satisfied. Returns: - Any node outputs that cannot be processed by the ``DataCatalog``. + Any node outputs that cannot be processed by the catalog. These are returned in a dictionary, where the keys are defined by the node outputs. """ hook_or_null_manager = hook_manager or _NullPluginManager() - catalog = catalog.shallow_copy() # Check which datasets used in the pipeline are in the catalog or match # a pattern in the catalog @@ -95,7 +94,7 @@ def run( if unsatisfied: raise ValueError( - f"Pipeline input(s) {unsatisfied} not found in the DataCatalog" + f"Pipeline input(s) {unsatisfied} not found in the {catalog.__class__.__name__}" ) # Identify MemoryDataset in the catalog @@ -125,7 +124,7 @@ def run( return {ds_name: catalog.load(ds_name) for ds_name in free_outputs} def run_only_missing( - self, pipeline: Pipeline, catalog: DataCatalog, hook_manager: PluginManager + self, pipeline: Pipeline, catalog: CatalogProtocol, hook_manager: PluginManager ) -> dict[str, Any]: """Run only the missing outputs from the ``Pipeline`` using the datasets provided by ``catalog``, and save results back to the @@ -133,7 +132,7 @@ def run_only_missing( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. Raises: ValueError: Raised when ``Pipeline`` inputs cannot be @@ -141,7 +140,7 @@ def run_only_missing( Returns: Any node outputs that cannot be processed by the - ``DataCatalog``. These are returned in a dictionary, where + catalog. These are returned in a dictionary, where the keys are defined by the node outputs. """ @@ -165,7 +164,7 @@ def run_only_missing( def _run( self, pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> None: @@ -174,7 +173,7 @@ def _run( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. @@ -185,7 +184,7 @@ def _suggest_resume_scenario( self, pipeline: Pipeline, done_nodes: Iterable[Node], - catalog: DataCatalog, + catalog: CatalogProtocol, ) -> None: """ Suggest a command to the user to resume a run after it fails. @@ -195,7 +194,7 @@ def _suggest_resume_scenario( Args: pipeline: the ``Pipeline`` of the run. done_nodes: the ``Node``s that executed successfully. - catalog: the ``DataCatalog`` of the run. + catalog: an implemented instance of ``CatalogProtocol`` of the run. """ remaining_nodes = set(pipeline.nodes) - set(done_nodes) @@ -224,7 +223,7 @@ def _suggest_resume_scenario( def _find_nodes_to_resume_from( - pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: DataCatalog + pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: CatalogProtocol ) -> set[str]: """Given a collection of unfinished nodes in a pipeline using a certain catalog, find the node names to pass to pipeline.from_nodes() @@ -234,7 +233,7 @@ def _find_nodes_to_resume_from( Args: pipeline: the ``Pipeline`` to find starting nodes for. unfinished_nodes: collection of ``Node``s that have not finished yet - catalog: the ``DataCatalog`` of the run. + catalog: an implemented instance of ``CatalogProtocol`` of the run. Returns: Set of node names to pass to pipeline.from_nodes() to continue @@ -252,7 +251,7 @@ def _find_nodes_to_resume_from( def _find_all_nodes_for_resumed_pipeline( - pipeline: Pipeline, unfinished_nodes: Iterable[Node], catalog: DataCatalog + pipeline: Pipeline, unfinished_nodes: Iterable[Node], catalog: CatalogProtocol ) -> set[Node]: """Breadth-first search approach to finding the complete set of ``Node``s which need to run to cover all unfinished nodes, @@ -262,7 +261,7 @@ def _find_all_nodes_for_resumed_pipeline( Args: pipeline: the ``Pipeline`` to analyze. unfinished_nodes: the iterable of ``Node``s which have not finished yet. - catalog: the ``DataCatalog`` of the run. + catalog: an implemented instance of ``CatalogProtocol`` of the run. Returns: A set containing all input unfinished ``Node``s and all remaining @@ -310,12 +309,12 @@ def _nodes_with_external_inputs(nodes_of_interest: Iterable[Node]) -> set[Node]: return set(p_nodes_with_external_inputs.nodes) -def _enumerate_non_persistent_inputs(node: Node, catalog: DataCatalog) -> set[str]: +def _enumerate_non_persistent_inputs(node: Node, catalog: CatalogProtocol) -> set[str]: """Enumerate non-persistent input datasets of a ``Node``. Args: node: the ``Node`` to check the inputs of. - catalog: the ``DataCatalog`` of the run. + catalog: an implemented instance of ``CatalogProtocol`` of the run. Returns: Set of names of non-persistent inputs of given ``Node``. @@ -380,7 +379,7 @@ def _find_initial_node_group(pipeline: Pipeline, nodes: Iterable[Node]) -> list[ def run_node( node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, is_async: bool = False, session_id: str | None = None, @@ -389,7 +388,7 @@ def run_node( Args: node: The ``Node`` to run. - catalog: A ``DataCatalog`` containing the node's inputs and outputs. + catalog: An implemented instance of ``CatalogProtocol`` containing the node's inputs and outputs. hook_manager: The ``PluginManager`` to activate hooks. is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. @@ -423,7 +422,7 @@ def run_node( def _collect_inputs_from_hook( # noqa: PLR0913 node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], is_async: bool, hook_manager: PluginManager, @@ -456,7 +455,7 @@ def _collect_inputs_from_hook( # noqa: PLR0913 def _call_node_run( # noqa: PLR0913 node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, inputs: dict[str, Any], is_async: bool, hook_manager: PluginManager, @@ -487,7 +486,7 @@ def _call_node_run( # noqa: PLR0913 def _run_node_sequential( node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> Node: @@ -534,7 +533,7 @@ def _run_node_sequential( def _run_node_async( node: Node, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> Node: diff --git a/kedro/runner/sequential_runner.py b/kedro/runner/sequential_runner.py index 48dac3cd54..c888e737cf 100644 --- a/kedro/runner/sequential_runner.py +++ b/kedro/runner/sequential_runner.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from pluggy import PluginManager - from kedro.io import DataCatalog + from kedro.io import CatalogProtocol from kedro.pipeline import Pipeline @@ -34,7 +34,7 @@ def __init__( Args: is_async: If True, the node inputs and outputs are loaded and saved asynchronously with threads. Defaults to False. - extra_dataset_patterns: Extra dataset factory patterns to be added to the DataCatalog + extra_dataset_patterns: Extra dataset factory patterns to be added to the catalog during the run. This is used to set the default datasets to MemoryDataset for `SequentialRunner`. @@ -48,7 +48,7 @@ def __init__( def _run( self, pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> None: @@ -56,7 +56,7 @@ def _run( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. diff --git a/kedro/runner/thread_runner.py b/kedro/runner/thread_runner.py index b4751a602a..5ad13b9153 100644 --- a/kedro/runner/thread_runner.py +++ b/kedro/runner/thread_runner.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from pluggy import PluginManager - from kedro.io import DataCatalog + from kedro.io import CatalogProtocol from kedro.pipeline import Pipeline from kedro.pipeline.node import Node @@ -43,7 +43,7 @@ def __init__( is_async: If True, set to False, because `ThreadRunner` doesn't support loading and saving the node inputs and outputs asynchronously with threads. Defaults to False. - extra_dataset_patterns: Extra dataset factory patterns to be added to the DataCatalog + extra_dataset_patterns: Extra dataset factory patterns to be added to the catalog during the run. This is used to set the default datasets to MemoryDataset for `ThreadRunner`. @@ -87,7 +87,7 @@ def _get_required_workers_count(self, pipeline: Pipeline) -> int: def _run( self, pipeline: Pipeline, - catalog: DataCatalog, + catalog: CatalogProtocol, hook_manager: PluginManager, session_id: str | None = None, ) -> None: @@ -95,7 +95,7 @@ def _run( Args: pipeline: The ``Pipeline`` to run. - catalog: The ``DataCatalog`` from which to fetch data. + catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data. hook_manager: The ``PluginManager`` to activate hooks. session_id: The id of the session. diff --git a/kedro/templates/project/hooks/utils.py b/kedro/templates/project/hooks/utils.py index 947c234f19..c112d72bbb 100644 --- a/kedro/templates/project/hooks/utils.py +++ b/kedro/templates/project/hooks/utils.py @@ -20,7 +20,9 @@ ] # Configuration key for documentation dependencies -docs_pyproject_requirements = ["project.optional-dependencies"] # For pyproject.toml +docs_pyproject_requirements = ["project.optional-dependencies.docs"] # For pyproject.toml +# Configuration key for linting and testing dependencies +dev_pyproject_requirements = ["project.optional-dependencies.dev"] # For pyproject.toml # Requirements for example pipelines example_pipeline_requirements = "seaborn~=0.12.1\nscikit-learn~=1.0\n" @@ -191,12 +193,14 @@ def setup_template_tools( python_package_name (str): The name of the python package. example_pipeline (str): 'True' if example pipeline was selected """ + + if "Linting" not in selected_tools_list and "Testing" not in selected_tools_list: + _remove_from_toml(pyproject_file_path, dev_pyproject_requirements) + if "Linting" not in selected_tools_list: - _remove_from_file(requirements_file_path, lint_requirements) _remove_from_toml(pyproject_file_path, lint_pyproject_requirements) if "Testing" not in selected_tools_list: - _remove_from_file(requirements_file_path, test_requirements) _remove_from_toml(pyproject_file_path, test_pyproject_requirements) _remove_dir(current_dir / "tests") diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml b/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml index f22a91242f..b2ab54c3bb 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/pyproject.toml @@ -24,6 +24,12 @@ docs = [ "Jinja2<3.2.0", "myst-parser>=1.0,<2.1" ] +dev = [ + "pytest-cov~=3.0", + "pytest-mock>=1.7.1, <2.0", + "pytest~=7.2", + "ruff~=0.1.8" +] [tool.setuptools.dynamic] diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/requirements.txt b/kedro/templates/project/{{ cookiecutter.repo_name }}/requirements.txt index 9301f4e3f3..1be43016fb 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/requirements.txt +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/requirements.txt @@ -2,7 +2,3 @@ ipython>=8.10 jupyterlab>=3.0 notebook kedro~={{ cookiecutter.kedro_version }} -pytest-cov~=3.0 -pytest-mock>=1.7.1, <2.0 -pytest~=7.2 -ruff~=0.1.8 diff --git a/kedro/templates/project/{{ cookiecutter.repo_name }}/tests/test_run.py b/kedro/templates/project/{{ cookiecutter.repo_name }}/tests/test_run.py index eb57d1908e..c7b3cf08a8 100644 --- a/kedro/templates/project/{{ cookiecutter.repo_name }}/tests/test_run.py +++ b/kedro/templates/project/{{ cookiecutter.repo_name }}/tests/test_run.py @@ -12,7 +12,7 @@ import pytest -from kedro.config import ConfigLoader +from kedro.config import OmegaConfigLoader from kedro.framework.context import KedroContext from kedro.framework.hooks import _create_hook_manager from kedro.framework.project import settings @@ -20,7 +20,7 @@ @pytest.fixture def config_loader(): - return ConfigLoader(conf_source=str(Path.cwd() / settings.CONF_SOURCE)) + return OmegaConfigLoader(conf_source=str(Path.cwd() / settings.CONF_SOURCE)) @pytest.fixture @@ -28,6 +28,7 @@ def project_context(config_loader): return KedroContext( package_name="{{ cookiecutter.python_package }}", project_path=Path.cwd(), + env="local", config_loader=config_loader, hook_manager=_create_hook_manager(), ) diff --git a/pyproject.toml b/pyproject.toml index 8b7b4cb09b..d9ebbfd70b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,7 +134,7 @@ omit = [ "kedro/runner/parallel_runner.py", "*/site-packages/*", ] -exclude_also = ["raise NotImplementedError", "if TYPE_CHECKING:"] +exclude_also = ["raise NotImplementedError", "if TYPE_CHECKING:", "class CatalogProtocol"] [tool.pytest.ini_options] addopts=""" diff --git a/tests/framework/cli/test_catalog.py b/tests/framework/cli/test_catalog.py index f34034296e..8905da9c94 100644 --- a/tests/framework/cli/test_catalog.py +++ b/tests/framework/cli/test_catalog.py @@ -490,7 +490,6 @@ def test_rank_catalog_factories( mocked_context.catalog = DataCatalog.from_config( fake_catalog_with_overlapping_factories ) - print("!!!!", mocked_context.catalog._dataset_patterns) result = CliRunner().invoke( fake_project_cli, ["catalog", "rank"], obj=fake_metadata ) @@ -544,10 +543,11 @@ def test_catalog_resolve( "catalog": fake_catalog_config, "credentials": fake_credentials_config, } + mocked_context._get_config_credentials.return_value = fake_credentials_config mocked_context.catalog = DataCatalog.from_config( catalog=fake_catalog_config, credentials=fake_credentials_config ) - placeholder_ds = mocked_context.catalog._dataset_patterns.keys() + placeholder_ds = mocked_context.catalog.config_resolver.list_patterns() pipeline_datasets = {"csv_example", "parquet_example", "explicit_dataset"} mocker.patch.object( diff --git a/tests/framework/cli/test_starters.py b/tests/framework/cli/test_starters.py index 32f618d68f..7f2641da10 100644 --- a/tests/framework/cli/test_starters.py +++ b/tests/framework/cli/test_starters.py @@ -147,17 +147,11 @@ def _assert_requirements_ok( assert "Congratulations!" in result.output assert f"has been created in the directory \n{root_path}" in result.output - requirements_file_path = root_path / "requirements.txt" pyproject_file_path = root_path / "pyproject.toml" tools_list = _parse_tools_input(tools) if "1" in tools_list: - with open(requirements_file_path) as requirements_file: - requirements = requirements_file.read() - - assert "ruff" in requirements - pyproject_config = toml.load(pyproject_file_path) expected = { "tool": { @@ -171,15 +165,11 @@ def _assert_requirements_ok( } } assert expected["tool"]["ruff"] == pyproject_config["tool"]["ruff"] + assert ( + "ruff~=0.1.8" in pyproject_config["project"]["optional-dependencies"]["dev"] + ) if "2" in tools_list: - with open(requirements_file_path) as requirements_file: - requirements = requirements_file.read() - - assert "pytest-cov~=3.0" in requirements - assert "pytest-mock>=1.7.1, <2.0" in requirements - assert "pytest~=7.2" in requirements - pyproject_config = toml.load(pyproject_file_path) expected = { "pytest": { @@ -198,6 +188,18 @@ def _assert_requirements_ok( assert expected["pytest"] == pyproject_config["tool"]["pytest"] assert expected["coverage"] == pyproject_config["tool"]["coverage"] + assert ( + "pytest-cov~=3.0" + in pyproject_config["project"]["optional-dependencies"]["dev"] + ) + assert ( + "pytest-mock>=1.7.1, <2.0" + in pyproject_config["project"]["optional-dependencies"]["dev"] + ) + assert ( + "pytest~=7.2" in pyproject_config["project"]["optional-dependencies"]["dev"] + ) + if "4" in tools_list: pyproject_config = toml.load(pyproject_file_path) expected = { diff --git a/tests/framework/context/test_context.py b/tests/framework/context/test_context.py index 61e4bbaa6f..ea62cb04c9 100644 --- a/tests/framework/context/test_context.py +++ b/tests/framework/context/test_context.py @@ -261,7 +261,7 @@ def test_wrong_catalog_type(self, mock_settings_file_bad_data_catalog_class): pattern = ( "Invalid value 'tests.framework.context.test_context.BadCatalog' received " "for setting 'DATA_CATALOG_CLASS'. " - "It must be a subclass of 'kedro.io.data_catalog.DataCatalog'." + "It must implement 'kedro.io.core.CatalogProtocol'." ) mock_settings = _ProjectSettings( settings_file=str(mock_settings_file_bad_data_catalog_class) diff --git a/tests/framework/session/test_session.py b/tests/framework/session/test_session.py index 83550f3a56..086d581045 100644 --- a/tests/framework/session/test_session.py +++ b/tests/framework/session/test_session.py @@ -693,7 +693,7 @@ def test_run_thread_runner( } mocker.patch("kedro.framework.session.session.pipelines", pipelines_ret) mocker.patch( - "kedro.io.data_catalog.DataCatalog._match_pattern", + "kedro.io.data_catalog.CatalogConfigResolver.match_pattern", return_value=match_pattern, ) diff --git a/tests/io/conftest.py b/tests/io/conftest.py index 2cc38aa1ea..9abce4c83e 100644 --- a/tests/io/conftest.py +++ b/tests/io/conftest.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import pytest +from kedro_datasets.pandas import CSVDataset @pytest.fixture @@ -21,3 +22,68 @@ def input_data(request): @pytest.fixture def new_data(): return pd.DataFrame({"col1": ["a", "b"], "col2": ["c", "d"], "col3": ["e", "f"]}) + + +@pytest.fixture +def filepath(tmp_path): + return (tmp_path / "some" / "dir" / "test.csv").as_posix() + + +@pytest.fixture +def dataset(filepath): + return CSVDataset(filepath=filepath, save_args={"index": False}) + + +@pytest.fixture +def correct_config(filepath): + return { + "catalog": { + "boats": {"type": "pandas.CSVDataset", "filepath": filepath}, + "cars": { + "type": "pandas.CSVDataset", + "filepath": "s3://test_bucket/test_file.csv", + "credentials": "s3_credentials", + }, + }, + "credentials": { + "s3_credentials": {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} + }, + } + + +@pytest.fixture +def correct_config_with_nested_creds(correct_config): + correct_config["catalog"]["cars"]["credentials"] = { + "client_kwargs": {"credentials": "other_credentials"}, + "key": "secret", + } + correct_config["credentials"]["other_credentials"] = { + "client_kwargs": { + "aws_access_key_id": "OTHER_FAKE_ACCESS_KEY", + "aws_secret_access_key": "OTHER_FAKE_SECRET_KEY", + } + } + return correct_config + + +@pytest.fixture +def bad_config(filepath): + return { + "bad": {"type": "tests.io.test_data_catalog.BadDataset", "filepath": filepath} + } + + +@pytest.fixture +def correct_config_with_tracking_ds(tmp_path): + boat_path = (tmp_path / "some" / "dir" / "test.csv").as_posix() + plane_path = (tmp_path / "some" / "dir" / "metrics.json").as_posix() + return { + "catalog": { + "boats": { + "type": "pandas.CSVDataset", + "filepath": boat_path, + "versioned": True, + }, + "planes": {"type": "tracking.MetricsDataset", "filepath": plane_path}, + }, + } diff --git a/tests/io/test_data_catalog.py b/tests/io/test_data_catalog.py index dbec57e64d..a552d8959c 100644 --- a/tests/io/test_data_catalog.py +++ b/tests/io/test_data_catalog.py @@ -29,64 +29,6 @@ ) -@pytest.fixture -def filepath(tmp_path): - return (tmp_path / "some" / "dir" / "test.csv").as_posix() - - -@pytest.fixture -def dummy_dataframe(): - return pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}) - - -@pytest.fixture -def sane_config(filepath): - return { - "catalog": { - "boats": {"type": "pandas.CSVDataset", "filepath": filepath}, - "cars": { - "type": "pandas.CSVDataset", - "filepath": "s3://test_bucket/test_file.csv", - "credentials": "s3_credentials", - }, - }, - "credentials": { - "s3_credentials": {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} - }, - } - - -@pytest.fixture -def sane_config_with_nested_creds(sane_config): - sane_config["catalog"]["cars"]["credentials"] = { - "client_kwargs": {"credentials": "other_credentials"}, - "key": "secret", - } - sane_config["credentials"]["other_credentials"] = { - "client_kwargs": { - "aws_access_key_id": "OTHER_FAKE_ACCESS_KEY", - "aws_secret_access_key": "OTHER_FAKE_SECRET_KEY", - } - } - return sane_config - - -@pytest.fixture -def sane_config_with_tracking_ds(tmp_path): - boat_path = (tmp_path / "some" / "dir" / "test.csv").as_posix() - plane_path = (tmp_path / "some" / "dir" / "metrics.json").as_posix() - return { - "catalog": { - "boats": { - "type": "pandas.CSVDataset", - "filepath": boat_path, - "versioned": True, - }, - "planes": {"type": "tracking.MetricsDataset", "filepath": plane_path}, - }, - } - - @pytest.fixture def config_with_dataset_factories(): return { @@ -180,11 +122,6 @@ def config_with_dataset_factories_only_patterns_no_default( return config_with_dataset_factories_only_patterns -@pytest.fixture -def dataset(filepath): - return CSVDataset(filepath=filepath, save_args={"index": False}) - - @pytest.fixture def multi_catalog(): csv = CSVDataset(filepath="abc.csv") @@ -220,21 +157,14 @@ def _describe(self): return {} -@pytest.fixture -def bad_config(filepath): - return { - "bad": {"type": "tests.io.test_data_catalog.BadDataset", "filepath": filepath} - } - - @pytest.fixture def data_catalog(dataset): return DataCatalog(datasets={"test": dataset}) @pytest.fixture -def data_catalog_from_config(sane_config): - return DataCatalog.from_config(**sane_config) +def data_catalog_from_config(correct_config): + return DataCatalog.from_config(**correct_config) class TestDataCatalog: @@ -468,78 +398,78 @@ def test_key_completions(self, data_catalog_from_config): class TestDataCatalogFromConfig: - def test_from_sane_config(self, data_catalog_from_config, dummy_dataframe): + def test_from_correct_config(self, data_catalog_from_config, dummy_dataframe): """Test populating the data catalog from config""" data_catalog_from_config.save("boats", dummy_dataframe) reloaded_df = data_catalog_from_config.load("boats") assert_frame_equal(reloaded_df, dummy_dataframe) - def test_config_missing_type(self, sane_config): + def test_config_missing_type(self, correct_config): """Check the error if type attribute is missing for some data set(s) in the config""" - del sane_config["catalog"]["boats"]["type"] + del correct_config["catalog"]["boats"]["type"] pattern = ( "An exception occurred when parsing config for dataset 'boats':\n" "'type' is missing from dataset catalog configuration" ) with pytest.raises(DatasetError, match=re.escape(pattern)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_invalid_module(self, sane_config): + def test_config_invalid_module(self, correct_config): """Check the error if the type points to nonexistent module""" - sane_config["catalog"]["boats"]["type"] = ( + correct_config["catalog"]["boats"]["type"] = ( "kedro.invalid_module_name.io.CSVDataset" ) error_msg = "Class 'kedro.invalid_module_name.io.CSVDataset' not found" with pytest.raises(DatasetError, match=re.escape(error_msg)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_relative_import(self, sane_config): + def test_config_relative_import(self, correct_config): """Check the error if the type points to a relative import""" - sane_config["catalog"]["boats"]["type"] = ".CSVDatasetInvalid" + correct_config["catalog"]["boats"]["type"] = ".CSVDatasetInvalid" pattern = "'type' class path does not support relative paths" with pytest.raises(DatasetError, match=re.escape(pattern)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_import_kedro_datasets(self, sane_config, mocker): + def test_config_import_kedro_datasets(self, correct_config, mocker): """Test kedro_datasets default path to the dataset class""" # Spy _load_obj because kedro_datasets is not installed and we can't import it. import kedro.io.core spy = mocker.spy(kedro.io.core, "_load_obj") - parse_dataset_definition(sane_config["catalog"]["boats"]) + parse_dataset_definition(correct_config["catalog"]["boats"]) for prefix, call_args in zip(_DEFAULT_PACKAGES, spy.call_args_list): # In Python 3.7 call_args.args is not available thus we access the call # arguments with less meaningful index. # The 1st index returns a tuple, the 2nd index return the name of module. assert call_args[0][0] == f"{prefix}pandas.CSVDataset" - def test_config_import_extras(self, sane_config): + def test_config_import_extras(self, correct_config): """Test kedro_datasets default path to the dataset class""" - sane_config["catalog"]["boats"]["type"] = "pandas.CSVDataset" - assert DataCatalog.from_config(**sane_config) + correct_config["catalog"]["boats"]["type"] = "pandas.CSVDataset" + assert DataCatalog.from_config(**correct_config) - def test_config_missing_class(self, sane_config): + def test_config_missing_class(self, correct_config): """Check the error if the type points to nonexistent class""" - sane_config["catalog"]["boats"]["type"] = "kedro.io.CSVDatasetInvalid" + correct_config["catalog"]["boats"]["type"] = "kedro.io.CSVDatasetInvalid" pattern = ( "An exception occurred when parsing config for dataset 'boats':\n" "Class 'kedro.io.CSVDatasetInvalid' not found, is this a typo?" ) with pytest.raises(DatasetError, match=re.escape(pattern)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) @pytest.mark.skipif( sys.version_info < (3, 9), reason="for python 3.8 kedro-datasets version 1.8 is used which has the old spelling", ) - def test_config_incorrect_spelling(self, sane_config): + def test_config_incorrect_spelling(self, correct_config): """Check hint if the type uses the old DataSet spelling""" - sane_config["catalog"]["boats"]["type"] = "pandas.CSVDataSet" + correct_config["catalog"]["boats"]["type"] = "pandas.CSVDataSet" pattern = ( "An exception occurred when parsing config for dataset 'boats':\n" @@ -548,63 +478,63 @@ def test_config_incorrect_spelling(self, sane_config): " make sure that the dataset name uses the `Dataset` spelling instead of `DataSet`." ) with pytest.raises(DatasetError, match=re.escape(pattern)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_invalid_dataset(self, sane_config): + def test_config_invalid_dataset(self, correct_config): """Check the error if the type points to invalid class""" - sane_config["catalog"]["boats"]["type"] = "DataCatalog" + correct_config["catalog"]["boats"]["type"] = "DataCatalog" pattern = ( "An exception occurred when parsing config for dataset 'boats':\n" "Dataset type 'kedro.io.data_catalog.DataCatalog' is invalid: " "all data set types must extend 'AbstractDataset'" ) with pytest.raises(DatasetError, match=re.escape(pattern)): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_invalid_arguments(self, sane_config): + def test_config_invalid_arguments(self, correct_config): """Check the error if the data set config contains invalid arguments""" - sane_config["catalog"]["boats"]["save_and_load_args"] = False + correct_config["catalog"]["boats"]["save_and_load_args"] = False pattern = ( r"Dataset 'boats' must only contain arguments valid for " r"the constructor of '.*CSVDataset'" ) with pytest.raises(DatasetError, match=pattern): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_config_invalid_dataset_config(self, sane_config): - sane_config["catalog"]["invalid_entry"] = "some string" + def test_config_invalid_dataset_config(self, correct_config): + correct_config["catalog"]["invalid_entry"] = "some string" pattern = ( "Catalog entry 'invalid_entry' is not a valid dataset configuration. " "\nHint: If this catalog entry is intended for variable interpolation, " "make sure that the key is preceded by an underscore." ) with pytest.raises(DatasetError, match=pattern): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) def test_empty_config(self): """Test empty config""" assert DataCatalog.from_config(None) - def test_missing_credentials(self, sane_config): + def test_missing_credentials(self, correct_config): """Check the error if credentials can't be located""" - sane_config["catalog"]["cars"]["credentials"] = "missing" + correct_config["catalog"]["cars"]["credentials"] = "missing" with pytest.raises(KeyError, match=r"Unable to find credentials \'missing\'"): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_link_credentials(self, sane_config, mocker): + def test_link_credentials(self, correct_config, mocker): """Test credentials being linked to the relevant data set""" mock_client = mocker.patch("kedro_datasets.pandas.csv_dataset.fsspec") - config = deepcopy(sane_config) + config = deepcopy(correct_config) del config["catalog"]["boats"] DataCatalog.from_config(**config) - expected_client_kwargs = sane_config["credentials"]["s3_credentials"] + expected_client_kwargs = correct_config["credentials"]["s3_credentials"] mock_client.filesystem.assert_called_with("s3", **expected_client_kwargs) - def test_nested_credentials(self, sane_config_with_nested_creds, mocker): + def test_nested_credentials(self, correct_config_with_nested_creds, mocker): mock_client = mocker.patch("kedro_datasets.pandas.csv_dataset.fsspec") - config = deepcopy(sane_config_with_nested_creds) + config = deepcopy(correct_config_with_nested_creds) del config["catalog"]["boats"] DataCatalog.from_config(**config) @@ -621,13 +551,13 @@ def test_nested_credentials(self, sane_config_with_nested_creds, mocker): } mock_client.filesystem.assert_called_once_with("s3", **expected_client_kwargs) - def test_missing_nested_credentials(self, sane_config_with_nested_creds): - del sane_config_with_nested_creds["credentials"]["other_credentials"] + def test_missing_nested_credentials(self, correct_config_with_nested_creds): + del correct_config_with_nested_creds["credentials"]["other_credentials"] pattern = "Unable to find credentials 'other_credentials'" with pytest.raises(KeyError, match=pattern): - DataCatalog.from_config(**sane_config_with_nested_creds) + DataCatalog.from_config(**correct_config_with_nested_creds) - def test_missing_dependency(self, sane_config, mocker): + def test_missing_dependency(self, correct_config, mocker): """Test that dependency is missing.""" pattern = "dependency issue" @@ -639,12 +569,12 @@ def dummy_load(obj_path, *args, **kwargs): mocker.patch("kedro.io.core.load_obj", side_effect=dummy_load) with pytest.raises(DatasetError, match=pattern): - DataCatalog.from_config(**sane_config) + DataCatalog.from_config(**correct_config) - def test_idempotent_catalog(self, sane_config): + def test_idempotent_catalog(self, correct_config): """Test that data catalog instantiations are idempotent""" - _ = DataCatalog.from_config(**sane_config) - catalog = DataCatalog.from_config(**sane_config) + _ = DataCatalog.from_config(**correct_config) + catalog = DataCatalog.from_config(**correct_config) assert catalog def test_error_dataset_init(self, bad_config): @@ -684,18 +614,18 @@ def test_confirm(self, tmp_path, caplog, mocker): ("boats", "Dataset 'boats' does not have 'confirm' method"), ], ) - def test_bad_confirm(self, sane_config, dataset_name, pattern): + def test_bad_confirm(self, correct_config, dataset_name, pattern): """Test confirming non existent dataset or the one that does not have `confirm` method""" - data_catalog = DataCatalog.from_config(**sane_config) + data_catalog = DataCatalog.from_config(**correct_config) with pytest.raises(DatasetError, match=re.escape(pattern)): data_catalog.confirm(dataset_name) class TestDataCatalogVersioned: - def test_from_sane_config_versioned(self, sane_config, dummy_dataframe): + def test_from_correct_config_versioned(self, correct_config, dummy_dataframe): """Test load and save of versioned data sets from config""" - sane_config["catalog"]["boats"]["versioned"] = True + correct_config["catalog"]["boats"]["versioned"] = True # Decompose `generate_timestamp` to keep `current_ts` reference. current_ts = datetime.now(tz=timezone.utc) @@ -706,13 +636,13 @@ def test_from_sane_config_versioned(self, sane_config, dummy_dataframe): version = fmt.format(d=current_ts, ms=current_ts.microsecond // 1000) catalog = DataCatalog.from_config( - **sane_config, + **correct_config, load_versions={"boats": version}, save_version=version, ) catalog.save("boats", dummy_dataframe) - path = Path(sane_config["catalog"]["boats"]["filepath"]) + path = Path(correct_config["catalog"]["boats"]["filepath"]) path = path / version / path.name assert path.is_file() @@ -733,12 +663,14 @@ def test_from_sane_config_versioned(self, sane_config, dummy_dataframe): assert actual_timestamp == expected_timestamp @pytest.mark.parametrize("versioned", [True, False]) - def test_from_sane_config_versioned_warn(self, caplog, sane_config, versioned): + def test_from_correct_config_versioned_warn( + self, caplog, correct_config, versioned + ): """Check the warning if `version` attribute was added to the data set config""" - sane_config["catalog"]["boats"]["versioned"] = versioned - sane_config["catalog"]["boats"]["version"] = True - DataCatalog.from_config(**sane_config) + correct_config["catalog"]["boats"]["versioned"] = versioned + correct_config["catalog"]["boats"]["version"] = True + DataCatalog.from_config(**correct_config) log_record = caplog.records[0] expected_log_message = ( "'version' attribute removed from data set configuration since it " @@ -747,21 +679,21 @@ def test_from_sane_config_versioned_warn(self, caplog, sane_config, versioned): assert log_record.levelname == "WARNING" assert expected_log_message in log_record.message - def test_from_sane_config_load_versions_warn(self, sane_config): - sane_config["catalog"]["boats"]["versioned"] = True + def test_from_correct_config_load_versions_warn(self, correct_config): + correct_config["catalog"]["boats"]["versioned"] = True version = generate_timestamp() - load_version = {"non-boart": version} - pattern = r"\'load_versions\' keys \[non-boart\] are not found in the catalog\." + load_version = {"non-boat": version} + pattern = r"\'load_versions\' keys \[non-boat\] are not found in the catalog\." with pytest.raises(DatasetNotFoundError, match=pattern): - DataCatalog.from_config(**sane_config, load_versions=load_version) + DataCatalog.from_config(**correct_config, load_versions=load_version) def test_compare_tracking_and_other_dataset_versioned( - self, sane_config_with_tracking_ds, dummy_dataframe + self, correct_config_with_tracking_ds, dummy_dataframe ): """Test saving of tracking data sets from config results in the same save version as other versioned datasets.""" - catalog = DataCatalog.from_config(**sane_config_with_tracking_ds) + catalog = DataCatalog.from_config(**correct_config_with_tracking_ds) catalog.save("boats", dummy_dataframe) dummy_data = {"col1": 1, "col2": 2, "col3": 3} @@ -779,20 +711,20 @@ def test_compare_tracking_and_other_dataset_versioned( assert tracking_timestamp == csv_timestamp - def test_load_version(self, sane_config, dummy_dataframe, mocker): + def test_load_version(self, correct_config, dummy_dataframe, mocker): """Test load versioned data sets from config""" new_dataframe = pd.DataFrame({"col1": [0, 0], "col2": [0, 0], "col3": [0, 0]}) - sane_config["catalog"]["boats"]["versioned"] = True + correct_config["catalog"]["boats"]["versioned"] = True mocker.patch( "kedro.io.data_catalog.generate_timestamp", side_effect=["first", "second"] ) # save first version of the dataset - catalog = DataCatalog.from_config(**sane_config) + catalog = DataCatalog.from_config(**correct_config) catalog.save("boats", dummy_dataframe) # save second version of the dataset - catalog = DataCatalog.from_config(**sane_config) + catalog = DataCatalog.from_config(**correct_config) catalog.save("boats", new_dataframe) assert_frame_equal(catalog.load("boats", version="first"), dummy_dataframe) @@ -800,11 +732,11 @@ def test_load_version(self, sane_config, dummy_dataframe, mocker): assert_frame_equal(catalog.load("boats"), new_dataframe) def test_load_version_on_unversioned_dataset( - self, sane_config, dummy_dataframe, mocker + self, correct_config, dummy_dataframe, mocker ): mocker.patch("kedro.io.data_catalog.generate_timestamp", return_value="first") - catalog = DataCatalog.from_config(**sane_config) + catalog = DataCatalog.from_config(**correct_config) catalog.save("boats", dummy_dataframe) with pytest.raises(DatasetError): @@ -846,7 +778,7 @@ def test_match_added_to_datasets_on_get(self, config_with_dataset_factories): catalog = DataCatalog.from_config(**config_with_dataset_factories) assert "{brand}_cars" not in catalog._datasets assert "tesla_cars" not in catalog._datasets - assert "{brand}_cars" in catalog._dataset_patterns + assert "{brand}_cars" in catalog.config_resolver._dataset_patterns tesla_cars = catalog._get_dataset("tesla_cars") assert isinstance(tesla_cars, CSVDataset) @@ -875,8 +807,8 @@ def test_patterns_not_in_catalog_datasets(self, config_with_dataset_factories): catalog = DataCatalog.from_config(**config_with_dataset_factories) assert "audi_cars" in catalog._datasets assert "{brand}_cars" not in catalog._datasets - assert "audi_cars" not in catalog._dataset_patterns - assert "{brand}_cars" in catalog._dataset_patterns + assert "audi_cars" not in catalog.config_resolver._dataset_patterns + assert "{brand}_cars" in catalog.config_resolver._dataset_patterns def test_explicit_entry_not_overwritten(self, config_with_dataset_factories): """Check that the existing catalog entry is not overwritten by config in pattern""" @@ -909,11 +841,7 @@ def test_sorting_order_patterns(self, config_with_dataset_factories_only_pattern "{dataset}s", "{user_default}", ] - assert ( - list(catalog._dataset_patterns.keys()) - + list(catalog._default_pattern.keys()) - == sorted_keys_expected - ) + assert catalog.config_resolver.list_patterns() == sorted_keys_expected def test_multiple_catch_all_patterns_not_allowed( self, config_with_dataset_factories @@ -953,13 +881,13 @@ def test_sorting_order_with_other_dataset_through_extra_pattern( ) sorted_keys_expected = [ "{country}_companies", - "{another}#csv", "{namespace}_{dataset}", "{dataset}s", + "{another}#csv", "{default}", ] assert ( - list(catalog_with_default._dataset_patterns.keys()) == sorted_keys_expected + catalog_with_default.config_resolver.list_patterns() == sorted_keys_expected ) def test_user_default_overwrites_runner_default(self): @@ -988,11 +916,15 @@ def test_user_default_overwrites_runner_default(self): sorted_keys_expected = [ "{dataset}s", "{a_default}", + "{another}#csv", + "{default}", ] - assert "{a_default}" in catalog_with_runner_default._default_pattern assert ( - list(catalog_with_runner_default._dataset_patterns.keys()) - + list(catalog_with_runner_default._default_pattern.keys()) + "{a_default}" + in catalog_with_runner_default.config_resolver._default_pattern + ) + assert ( + catalog_with_runner_default.config_resolver.list_patterns() == sorted_keys_expected ) diff --git a/tests/io/test_kedro_data_catalog.py b/tests/io/test_kedro_data_catalog.py new file mode 100644 index 0000000000..b98e8fae83 --- /dev/null +++ b/tests/io/test_kedro_data_catalog.py @@ -0,0 +1,650 @@ +import logging +import re +import sys +from copy import deepcopy +from datetime import datetime, timezone +from pathlib import Path + +import pandas as pd +import pytest +from kedro_datasets.pandas import CSVDataset, ParquetDataset +from pandas.testing import assert_frame_equal + +from kedro.io import ( + DatasetAlreadyExistsError, + DatasetError, + DatasetNotFoundError, + KedroDataCatalog, + LambdaDataset, + MemoryDataset, +) +from kedro.io.core import ( + _DEFAULT_PACKAGES, + VERSION_FORMAT, + generate_timestamp, + parse_dataset_definition, +) + + +@pytest.fixture +def data_catalog(dataset): + return KedroDataCatalog(datasets={"test": dataset}) + + +@pytest.fixture +def memory_catalog(): + ds1 = MemoryDataset({"data": 42}) + ds2 = MemoryDataset([1, 2, 3, 4, 5]) + return KedroDataCatalog({"ds1": ds1, "ds2": ds2}) + + +@pytest.fixture +def conflicting_feed_dict(): + return {"ds1": 0, "ds3": 1} + + +@pytest.fixture +def multi_catalog(): + csv = CSVDataset(filepath="abc.csv") + parq = ParquetDataset(filepath="xyz.parq") + return KedroDataCatalog({"abc": csv, "xyz": parq}) + + +@pytest.fixture +def data_catalog_from_config(correct_config): + return KedroDataCatalog.from_config(**correct_config) + + +class TestKedroDataCatalog: + def test_save_and_load(self, data_catalog, dummy_dataframe): + """Test saving and reloading the dataset""" + data_catalog.save("test", dummy_dataframe) + reloaded_df = data_catalog.load("test") + + assert_frame_equal(reloaded_df, dummy_dataframe) + + def test_add_save_and_load(self, dataset, dummy_dataframe): + """Test adding and then saving and reloading the dataset""" + catalog = KedroDataCatalog(datasets={}) + catalog.add("test", dataset) + catalog.save("test", dummy_dataframe) + reloaded_df = catalog.load("test") + + assert_frame_equal(reloaded_df, dummy_dataframe) + + def test_load_error(self, data_catalog): + """Check the error when attempting to load a dataset + from nonexistent source""" + pattern = r"Failed while loading data from data set CSVDataset" + with pytest.raises(DatasetError, match=pattern): + data_catalog.load("test") + + def test_add_dataset_twice(self, data_catalog, dataset): + """Check the error when attempting to add the dataset twice""" + pattern = r"Dataset 'test' has already been registered" + with pytest.raises(DatasetAlreadyExistsError, match=pattern): + data_catalog.add("test", dataset) + + def test_load_from_unregistered(self): + """Check the error when attempting to load unregistered dataset""" + catalog = KedroDataCatalog(datasets={}) + pattern = r"Dataset 'test' not found in the catalog" + with pytest.raises(DatasetNotFoundError, match=pattern): + catalog.load("test") + + def test_save_to_unregistered(self, dummy_dataframe): + """Check the error when attempting to save to unregistered dataset""" + catalog = KedroDataCatalog(datasets={}) + pattern = r"Dataset 'test' not found in the catalog" + with pytest.raises(DatasetNotFoundError, match=pattern): + catalog.save("test", dummy_dataframe) + + def test_feed_dict(self, memory_catalog, conflicting_feed_dict): + """Test feed dict overriding some of the datasets""" + assert "data" in memory_catalog.load("ds1") + memory_catalog.add_feed_dict(conflicting_feed_dict, replace=True) + assert memory_catalog.load("ds1") == 0 + assert isinstance(memory_catalog.load("ds2"), list) + assert memory_catalog.load("ds3") == 1 + + def test_exists(self, data_catalog, dummy_dataframe): + """Test `exists` method invocation""" + assert not data_catalog.exists("test") + data_catalog.save("test", dummy_dataframe) + assert data_catalog.exists("test") + + def test_exists_not_implemented(self, caplog): + """Test calling `exists` on the dataset, which didn't implement it""" + catalog = KedroDataCatalog(datasets={"test": LambdaDataset(None, None)}) + result = catalog.exists("test") + + log_record = caplog.records[0] + assert log_record.levelname == "WARNING" + assert ( + "'exists()' not implemented for 'LambdaDataset'. " + "Assuming output does not exist." in log_record.message + ) + assert result is False + + def test_exists_invalid(self, data_catalog): + """Check the error when calling `exists` on invalid dataset""" + assert not data_catalog.exists("wrong_key") + + def test_release_unregistered(self, data_catalog): + """Check the error when calling `release` on unregistered dataset""" + pattern = r"Dataset \'wrong_key\' not found in the catalog" + with pytest.raises(DatasetNotFoundError, match=pattern) as e: + data_catalog.release("wrong_key") + assert "did you mean" not in str(e.value) + + def test_release_unregistered_typo(self, data_catalog): + """Check the error when calling `release` on mistyped dataset""" + pattern = ( + "Dataset 'text' not found in the catalog" + " - did you mean one of these instead: test" + ) + with pytest.raises(DatasetNotFoundError, match=re.escape(pattern)): + data_catalog.release("text") + + def test_multi_catalog_list(self, multi_catalog): + """Test data catalog which contains multiple datasets""" + entries = multi_catalog.list() + assert "abc" in entries + assert "xyz" in entries + + @pytest.mark.parametrize( + "pattern,expected", + [ + ("^a", ["abc"]), + ("a|x", ["abc", "xyz"]), + ("^(?!(a|x))", []), + ("def", []), + ("", []), + ], + ) + def test_multi_catalog_list_regex(self, multi_catalog, pattern, expected): + """Test that regex patterns filter datasets accordingly""" + assert multi_catalog.list(regex_search=pattern) == expected + + def test_multi_catalog_list_bad_regex(self, multi_catalog): + """Test that bad regex is caught accordingly""" + escaped_regex = r"\(\(" + pattern = f"Invalid regular expression provided: '{escaped_regex}'" + with pytest.raises(SyntaxError, match=pattern): + multi_catalog.list("((") + + def test_eq(self, multi_catalog, data_catalog): + assert multi_catalog == multi_catalog.shallow_copy() + assert multi_catalog != data_catalog + + def test_datasets_on_init(self, data_catalog_from_config): + """Check datasets are loaded correctly on construction""" + assert isinstance(data_catalog_from_config.datasets["boats"], CSVDataset) + assert isinstance(data_catalog_from_config.datasets["cars"], CSVDataset) + + def test_datasets_on_add(self, data_catalog_from_config): + """Check datasets are updated correctly after adding""" + data_catalog_from_config.add("new_dataset", CSVDataset(filepath="some_path")) + assert isinstance(data_catalog_from_config.datasets["new_dataset"], CSVDataset) + assert isinstance(data_catalog_from_config.datasets["boats"], CSVDataset) + + def test_adding_datasets_not_allowed(self, data_catalog_from_config): + """Check error if user tries to update the datasets attribute""" + pattern = r"Operation not allowed. Please use KedroDataCatalog.add\(\) instead." + with pytest.raises(AttributeError, match=pattern): + data_catalog_from_config.datasets = None + + def test_confirm(self, mocker, caplog): + """Confirm the dataset""" + with caplog.at_level(logging.INFO): + mock_ds = mocker.Mock() + data_catalog = KedroDataCatalog(datasets={"mocked": mock_ds}) + data_catalog.confirm("mocked") + mock_ds.confirm.assert_called_once_with() + assert caplog.record_tuples == [ + ( + "kedro.io.kedro_data_catalog", + logging.INFO, + "Confirming dataset 'mocked'", + ) + ] + + @pytest.mark.parametrize( + "dataset_name,error_pattern", + [ + ("missing", "Dataset 'missing' not found in the catalog"), + ("test", "Dataset 'test' does not have 'confirm' method"), + ], + ) + def test_bad_confirm(self, data_catalog, dataset_name, error_pattern): + """Test confirming a non-existent dataset or one that + does not have `confirm` method""" + with pytest.raises(DatasetError, match=re.escape(error_pattern)): + data_catalog.confirm(dataset_name) + + def test_shallow_copy_returns_correct_class_type( + self, + ): + class MyDataCatalog(KedroDataCatalog): + pass + + data_catalog = MyDataCatalog() + copy = data_catalog.shallow_copy() + assert isinstance(copy, MyDataCatalog) + + @pytest.mark.parametrize( + "runtime_patterns,sorted_keys_expected", + [ + ( + { + "{default}": {"type": "MemoryDataset"}, + "{another}#csv": { + "type": "pandas.CSVDataset", + "filepath": "data/{another}.csv", + }, + }, + ["{another}#csv", "{default}"], + ) + ], + ) + def test_shallow_copy_adds_patterns( + self, data_catalog, runtime_patterns, sorted_keys_expected + ): + assert not data_catalog.config_resolver.list_patterns() + data_catalog = data_catalog.shallow_copy(runtime_patterns) + assert data_catalog.config_resolver.list_patterns() == sorted_keys_expected + + def test_init_with_raw_data(self, dummy_dataframe, dataset): + """Test catalog initialisation with raw data""" + catalog = KedroDataCatalog( + datasets={"ds": dataset}, raw_data={"df": dummy_dataframe} + ) + assert "ds" in catalog + assert "df" in catalog + assert isinstance(catalog.datasets["ds"], CSVDataset) + assert isinstance(catalog.datasets["df"], MemoryDataset) + + def test_repr(self, data_catalog): + assert data_catalog.__repr__() == str(data_catalog) + + def test_missing_keys_from_load_versions(self, correct_config): + """Test load versions include keys missing in the catalog""" + pattern = "'load_versions' keys [version] are not found in the catalog." + with pytest.raises(DatasetNotFoundError, match=re.escape(pattern)): + KedroDataCatalog.from_config( + **correct_config, load_versions={"version": "test_version"} + ) + + def test_get_dataset_matching_pattern(self, data_catalog): + """Test get_dataset() when dataset is not in the catalog but pattern matches""" + match_pattern_ds = "match_pattern_ds" + assert match_pattern_ds not in data_catalog + data_catalog.config_resolver.add_runtime_patterns( + {"{default}": {"type": "MemoryDataset"}} + ) + ds = data_catalog.get_dataset(match_pattern_ds) + assert isinstance(ds, MemoryDataset) + + def test_release(self, data_catalog): + """Test release is called without errors""" + data_catalog.release("test") + + class TestKedroDataCatalogFromConfig: + def test_from_correct_config(self, data_catalog_from_config, dummy_dataframe): + """Test populating the data catalog from config""" + data_catalog_from_config.save("boats", dummy_dataframe) + reloaded_df = data_catalog_from_config.load("boats") + assert_frame_equal(reloaded_df, dummy_dataframe) + + def test_config_missing_type(self, correct_config): + """Check the error if type attribute is missing for some dataset(s) + in the config""" + del correct_config["catalog"]["boats"]["type"] + pattern = ( + "An exception occurred when parsing config for dataset 'boats':\n" + "'type' is missing from dataset catalog configuration" + ) + with pytest.raises(DatasetError, match=re.escape(pattern)): + KedroDataCatalog.from_config(**correct_config) + + def test_config_invalid_module(self, correct_config): + """Check the error if the type points to nonexistent module""" + correct_config["catalog"]["boats"]["type"] = ( + "kedro.invalid_module_name.io.CSVDataset" + ) + + error_msg = "Class 'kedro.invalid_module_name.io.CSVDataset' not found" + with pytest.raises(DatasetError, match=re.escape(error_msg)): + KedroDataCatalog.from_config(**correct_config) + + def test_config_relative_import(self, correct_config): + """Check the error if the type points to a relative import""" + correct_config["catalog"]["boats"]["type"] = ".CSVDatasetInvalid" + + pattern = "'type' class path does not support relative paths" + with pytest.raises(DatasetError, match=re.escape(pattern)): + KedroDataCatalog.from_config(**correct_config) + + def test_config_import_kedro_datasets(self, correct_config, mocker): + """Test kedro_datasets default path to the dataset class""" + # Spy _load_obj because kedro_datasets is not installed and we can't import it. + + import kedro.io.core + + spy = mocker.spy(kedro.io.core, "_load_obj") + parse_dataset_definition(correct_config["catalog"]["boats"]) + for prefix, call_args in zip(_DEFAULT_PACKAGES, spy.call_args_list): + assert call_args.args[0] == f"{prefix}pandas.CSVDataset" + + def test_config_missing_class(self, correct_config): + """Check the error if the type points to nonexistent class""" + correct_config["catalog"]["boats"]["type"] = "kedro.io.CSVDatasetInvalid" + + pattern = ( + "An exception occurred when parsing config for dataset 'boats':\n" + "Class 'kedro.io.CSVDatasetInvalid' not found, is this a typo?" + ) + with pytest.raises(DatasetError, match=re.escape(pattern)): + KedroDataCatalog.from_config(**correct_config) + + @pytest.mark.skipif( + sys.version_info < (3, 9), + reason="for python 3.8 kedro-datasets version 1.8 is used which has the old spelling", + ) + def test_config_incorrect_spelling(self, correct_config): + """Check hint if the type uses the old DataSet spelling""" + correct_config["catalog"]["boats"]["type"] = "pandas.CSVDataSet" + + pattern = ( + "An exception occurred when parsing config for dataset 'boats':\n" + "Class 'pandas.CSVDataSet' not found, is this a typo?" + "\nHint: If you are trying to use a dataset from `kedro-datasets`>=2.0.0," + " make sure that the dataset name uses the `Dataset` spelling instead of `DataSet`." + ) + with pytest.raises(DatasetError, match=re.escape(pattern)): + KedroDataCatalog.from_config(**correct_config) + + def test_config_invalid_dataset(self, correct_config): + """Check the error if the type points to invalid class""" + correct_config["catalog"]["boats"]["type"] = "KedroDataCatalog" + pattern = ( + "An exception occurred when parsing config for dataset 'boats':\n" + "Dataset type 'kedro.io.kedro_data_catalog.KedroDataCatalog' is invalid: " + "all data set types must extend 'AbstractDataset'" + ) + with pytest.raises(DatasetError, match=re.escape(pattern)): + KedroDataCatalog.from_config(**correct_config) + + def test_config_invalid_arguments(self, correct_config): + """Check the error if the dataset config contains invalid arguments""" + correct_config["catalog"]["boats"]["save_and_load_args"] = False + pattern = ( + r"Dataset 'boats' must only contain arguments valid for " + r"the constructor of '.*CSVDataset'" + ) + with pytest.raises(DatasetError, match=pattern): + KedroDataCatalog.from_config(**correct_config) + + def test_config_invalid_dataset_config(self, correct_config): + correct_config["catalog"]["invalid_entry"] = "some string" + pattern = ( + "Catalog entry 'invalid_entry' is not a valid dataset configuration. " + "\nHint: If this catalog entry is intended for variable interpolation, " + "make sure that the key is preceded by an underscore." + ) + with pytest.raises(DatasetError, match=pattern): + KedroDataCatalog.from_config(**correct_config) + + def test_empty_config(self): + """Test empty config""" + assert KedroDataCatalog.from_config(None) + + def test_missing_credentials(self, correct_config): + """Check the error if credentials can't be located""" + correct_config["catalog"]["cars"]["credentials"] = "missing" + with pytest.raises( + KeyError, match=r"Unable to find credentials \'missing\'" + ): + KedroDataCatalog.from_config(**correct_config) + + def test_link_credentials(self, correct_config, mocker): + """Test credentials being linked to the relevant dataset""" + mock_client = mocker.patch("kedro_datasets.pandas.csv_dataset.fsspec") + config = deepcopy(correct_config) + del config["catalog"]["boats"] + + KedroDataCatalog.from_config(**config) + + expected_client_kwargs = correct_config["credentials"]["s3_credentials"] + mock_client.filesystem.assert_called_with("s3", **expected_client_kwargs) + + def test_nested_credentials(self, correct_config_with_nested_creds, mocker): + mock_client = mocker.patch("kedro_datasets.pandas.csv_dataset.fsspec") + config = deepcopy(correct_config_with_nested_creds) + del config["catalog"]["boats"] + KedroDataCatalog.from_config(**config) + + expected_client_kwargs = { + "client_kwargs": { + "credentials": { + "client_kwargs": { + "aws_access_key_id": "OTHER_FAKE_ACCESS_KEY", + "aws_secret_access_key": "OTHER_FAKE_SECRET_KEY", + } + } + }, + "key": "secret", + } + mock_client.filesystem.assert_called_once_with( + "s3", **expected_client_kwargs + ) + + def test_missing_nested_credentials(self, correct_config_with_nested_creds): + del correct_config_with_nested_creds["credentials"]["other_credentials"] + pattern = "Unable to find credentials 'other_credentials'" + with pytest.raises(KeyError, match=pattern): + KedroDataCatalog.from_config(**correct_config_with_nested_creds) + + def test_missing_dependency(self, correct_config, mocker): + """Test that dependency is missing.""" + pattern = "dependency issue" + + def dummy_load(obj_path, *args, **kwargs): + if obj_path == "kedro_datasets.pandas.CSVDataset": + raise AttributeError(pattern) + if obj_path == "kedro_datasets.pandas.__all__": + return ["CSVDataset"] + + mocker.patch("kedro.io.core.load_obj", side_effect=dummy_load) + with pytest.raises(DatasetError, match=pattern): + KedroDataCatalog.from_config(**correct_config) + + def test_idempotent_catalog(self, correct_config): + """Test that data catalog instantiations are idempotent""" + _ = KedroDataCatalog.from_config(**correct_config) + catalog = KedroDataCatalog.from_config(**correct_config) + assert catalog + + def test_error_dataset_init(self, bad_config): + """Check the error when trying to instantiate erroneous dataset""" + pattern = r"Failed to instantiate dataset \'bad\' of type '.*BadDataset'" + with pytest.raises(DatasetError, match=pattern): + KedroDataCatalog.from_config(bad_config, None) + + def test_validate_dataset_config(self): + """Test _validate_dataset_config raises error when wrong dataset config type is passed""" + pattern = ( + "Catalog entry 'bad' is not a valid dataset configuration. \n" + "Hint: If this catalog entry is intended for variable interpolation, make sure that the key is preceded by an underscore." + ) + with pytest.raises(DatasetError, match=pattern): + KedroDataCatalog._validate_dataset_config( + ds_name="bad", ds_config="not_dict" + ) + + def test_confirm(self, tmp_path, caplog, mocker): + """Confirm the dataset""" + with caplog.at_level(logging.INFO): + mock_confirm = mocker.patch( + "kedro_datasets.partitions.incremental_dataset.IncrementalDataset.confirm" + ) + catalog = { + "ds_to_confirm": { + "type": "kedro_datasets.partitions.incremental_dataset.IncrementalDataset", + "dataset": "pandas.CSVDataset", + "path": str(tmp_path), + } + } + data_catalog = KedroDataCatalog.from_config(catalog=catalog) + data_catalog.confirm("ds_to_confirm") + assert caplog.record_tuples == [ + ( + "kedro.io.kedro_data_catalog", + logging.INFO, + "Confirming dataset 'ds_to_confirm'", + ) + ] + mock_confirm.assert_called_once_with() + + @pytest.mark.parametrize( + "dataset_name,pattern", + [ + ("missing", "Dataset 'missing' not found in the catalog"), + ("boats", "Dataset 'boats' does not have 'confirm' method"), + ], + ) + def test_bad_confirm(self, correct_config, dataset_name, pattern): + """Test confirming non existent dataset or the one that + does not have `confirm` method""" + data_catalog = KedroDataCatalog.from_config(**correct_config) + with pytest.raises(DatasetError, match=re.escape(pattern)): + data_catalog.confirm(dataset_name) + + class TestDataCatalogVersioned: + def test_from_correct_config_versioned(self, correct_config, dummy_dataframe): + """Test load and save of versioned datasets from config""" + correct_config["catalog"]["boats"]["versioned"] = True + + # Decompose `generate_timestamp` to keep `current_ts` reference. + current_ts = datetime.now(tz=timezone.utc) + fmt = ( + "{d.year:04d}-{d.month:02d}-{d.day:02d}T{d.hour:02d}" + ".{d.minute:02d}.{d.second:02d}.{ms:03d}Z" + ) + version = fmt.format(d=current_ts, ms=current_ts.microsecond // 1000) + + catalog = KedroDataCatalog.from_config( + **correct_config, + load_versions={"boats": version}, + save_version=version, + ) + + catalog.save("boats", dummy_dataframe) + path = Path(correct_config["catalog"]["boats"]["filepath"]) + path = path / version / path.name + assert path.is_file() + + reloaded_df = catalog.load("boats") + assert_frame_equal(reloaded_df, dummy_dataframe) + + reloaded_df_version = catalog.load("boats", version=version) + assert_frame_equal(reloaded_df_version, dummy_dataframe) + + # Verify that `VERSION_FORMAT` can help regenerate `current_ts`. + actual_timestamp = datetime.strptime( + catalog.datasets["boats"].resolve_load_version(), + VERSION_FORMAT, + ) + expected_timestamp = current_ts.replace( + microsecond=current_ts.microsecond // 1000 * 1000, tzinfo=None + ) + assert actual_timestamp == expected_timestamp + + @pytest.mark.parametrize("versioned", [True, False]) + def test_from_correct_config_versioned_warn( + self, caplog, correct_config, versioned + ): + """Check the warning if `version` attribute was added + to the dataset config""" + correct_config["catalog"]["boats"]["versioned"] = versioned + correct_config["catalog"]["boats"]["version"] = True + KedroDataCatalog.from_config(**correct_config) + log_record = caplog.records[0] + expected_log_message = ( + "'version' attribute removed from data set configuration since it " + "is a reserved word and cannot be directly specified" + ) + assert log_record.levelname == "WARNING" + assert expected_log_message in log_record.message + + def test_from_correct_config_load_versions_warn(self, correct_config): + correct_config["catalog"]["boats"]["versioned"] = True + version = generate_timestamp() + load_version = {"non-boat": version} + pattern = ( + r"\'load_versions\' keys \[non-boat\] are not found in the catalog\." + ) + with pytest.raises(DatasetNotFoundError, match=pattern): + KedroDataCatalog.from_config( + **correct_config, load_versions=load_version + ) + + def test_compare_tracking_and_other_dataset_versioned( + self, correct_config_with_tracking_ds, dummy_dataframe + ): + """Test saving of tracking datasets from config results in the same + save version as other versioned datasets.""" + + catalog = KedroDataCatalog.from_config(**correct_config_with_tracking_ds) + + catalog.save("boats", dummy_dataframe) + dummy_data = {"col1": 1, "col2": 2, "col3": 3} + catalog.save("planes", dummy_data) + + # Verify that saved version on tracking dataset is the same as on the CSV dataset + csv_timestamp = datetime.strptime( + catalog.datasets["boats"].resolve_save_version(), + VERSION_FORMAT, + ) + tracking_timestamp = datetime.strptime( + catalog.datasets["planes"].resolve_save_version(), + VERSION_FORMAT, + ) + + assert tracking_timestamp == csv_timestamp + + def test_load_version(self, correct_config, dummy_dataframe, mocker): + """Test load versioned datasets from config""" + new_dataframe = pd.DataFrame( + {"col1": [0, 0], "col2": [0, 0], "col3": [0, 0]} + ) + correct_config["catalog"]["boats"]["versioned"] = True + mocker.patch( + "kedro.io.kedro_data_catalog.generate_timestamp", + side_effect=["first", "second"], + ) + + # save first version of the dataset + catalog = KedroDataCatalog.from_config(**correct_config) + catalog.save("boats", dummy_dataframe) + + # save second version of the dataset + catalog = KedroDataCatalog.from_config(**correct_config) + catalog.save("boats", new_dataframe) + + assert_frame_equal(catalog.load("boats", version="first"), dummy_dataframe) + assert_frame_equal(catalog.load("boats", version="second"), new_dataframe) + assert_frame_equal(catalog.load("boats"), new_dataframe) + + def test_load_version_on_unversioned_dataset( + self, correct_config, dummy_dataframe, mocker + ): + mocker.patch( + "kedro.io.kedro_data_catalog.generate_timestamp", return_value="first" + ) + + catalog = KedroDataCatalog.from_config(**correct_config) + catalog.save("boats", dummy_dataframe) + + with pytest.raises(DatasetError): + catalog.load("boats", version="first") diff --git a/tests/runner/test_sequential_runner.py b/tests/runner/test_sequential_runner.py index dbc73a30f0..4f22bab296 100644 --- a/tests/runner/test_sequential_runner.py +++ b/tests/runner/test_sequential_runner.py @@ -130,7 +130,9 @@ def test_conflict_feed_catalog( def test_unsatisfied_inputs(self, is_async, unfinished_outputs_pipeline, catalog): """ds1, ds2 and ds3 were not specified.""" - with pytest.raises(ValueError, match=r"not found in the DataCatalog"): + with pytest.raises( + ValueError, match=rf"not found in the {catalog.__class__.__name__}" + ): SequentialRunner(is_async=is_async).run( unfinished_outputs_pipeline, catalog )