diff --git a/.circleci/config.yml b/.circleci/config.yml index b2595155ec..90de36ed1f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,12 +7,15 @@ executors: py36: docker: - image: 350138855857.dkr.ecr.eu-west-2.amazonaws.com/kedro-builder:3.6 + resource_class: medium+ py37: docker: - image: 350138855857.dkr.ecr.eu-west-2.amazonaws.com/kedro-builder:3.7 + resource_class: medium+ py38: docker: - image: 350138855857.dkr.ecr.eu-west-2.amazonaws.com/kedro-builder:3.8 + resource_class: medium+ commands: setup_conda: @@ -38,6 +41,14 @@ commands: - run: name: Install requirements and test requirements command: pip install --upgrade -r test_requirements.txt + - run: + # this is needed to fix java cacerts so + # spark can automatically download packages from mvn + # https://stackoverflow.com/a/50103533/1684058 + name: Fix cacerts + command: | + sudo rm /etc/ssl/certs/java/cacerts + sudo update-ca-certificates -f - run: # Since recently Spark installation for some reason does not have enough permissions to execute # /home/circleci/miniconda/envs/kedro_builder/lib/python3.X/site-packages/pyspark/bin/spark-class. diff --git a/RELEASE.md b/RELEASE.md index 895c5f1beb..ef832b034b 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -5,13 +5,14 @@ * Enabled overriding nested parameters with `params` in CLI, i.e. `kedro run --params="model.model_tuning.booster:gbtree"` updates parameters to `{"model": {"model_tuning": {"booster": "gbtree"}}}`. * Added option to `pandas.SQLQueryDataSet` to specify a `filepath` with a SQL query, in addition to the current method of supplying the query itself in the `sql` argument. * Extended `ExcelDataSet` to support saving Excel files with multiple sheets. -* Added the following new dataset (see ([Issue #839](https://github.com/quantumblacklabs/kedro/issues/839)): +* Added the following new datasets (see ([Issue #839](https://github.com/quantumblacklabs/kedro/issues/839)): | Type | Description | Location | | --------------------------- | ---------------------------------------------------- | --------------------------------- | | `plotly.JSONDataSet` | Works with plotly graph object Figures (saves as json file) | `kedro.extras.datasets.plotly` | | `pandas.GenericDataSet` | Provides a 'best effort' facility to read / write any format provided by the `pandas` library | `kedro.extras.datasets.pandas` | | `pandas.GBQQueryDataSet` | Loads data from a Google Bigquery table using provided SQL query | `kedro.extras.datasets.pandas` | +| `spark.DeltaTableDataSet` | Dataset designed to handle Delta Lake Tables and their CRUD-style operations, including `update`, `merge` and `delete` | `kedro.extras.datasets.spark` | ## Bug fixes and other changes * Fixed an issue where `kedro new --config config.yml` was ignoring the config file when `prompts.yml` didn't exist. diff --git a/docs/conf.py b/docs/conf.py index c07aac182b..c54f59b925 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -192,19 +192,19 @@ # too many requests, or forbidden URL linkcheck_ignore = [ "https://datacamp.com/community/tutorials/docstrings-python", # "forbidden" url - "https://setuptools.readthedocs.io/en/latest/setuptools.html#dynamic-discovery-of-services-and-plugins", "https://github.com/argoproj/argo/blob/master/README.md#quickstart", "https://console.aws.amazon.com/batch/home#/jobs", "https://github.com/EbookFoundation/free-programming-books/blob/master/books/free-programming-books-langs.md#python", "https://github.com/jazzband/pip-tools#example-usage-for-pip-compile", "https://www.astronomer.io/docs/cloud/stable/get-started/quickstart#", - "https://github.com/quantumblacklabs/private-kedro/blob/main/kedro/templates/project/*", "https://eternallybored.org/misc/wget/", "https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.from_pandas", "https://www.oracle.com/java/technologies/javase-downloads.html", # "forbidden" url "https://towardsdatascience.com/the-importance-of-layered-thinking-in-data-engineering-a09f685edc71", "https://medium.com/quantumblack/beyond-the-notebook-and-into-the-data-science-framework-revolution-a7fd364ab9c4", "https://www.java.com/en/download/help/download_options.html", # "403 Client Error: Forbidden for url" + # "anchor not found" but it's a valid selector for code examples + "https://docs.delta.io/latest/delta-update.html#language-python", ] # retry before render a link broken (fix for "too many requests") diff --git a/docs/source/07_extend_kedro/04_plugins.md b/docs/source/07_extend_kedro/04_plugins.md index 1ce8fd82e8..10dda1436c 100644 --- a/docs/source/07_extend_kedro/04_plugins.md +++ b/docs/source/07_extend_kedro/04_plugins.md @@ -4,7 +4,7 @@ Kedro plugins allow you to create new features for Kedro and inject additional c ## Overview -Kedro uses [`setuptools`](https://setuptools.readthedocs.io/en/latest/setuptools.html), which is a collection of enhancements to the Python `distutils` to allow developers to build and distribute Python packages. Kedro uses various entry points in [`pkg_resources`](https://setuptools.readthedocs.io/en/latest/setuptools.html#dynamic-discovery-of-services-and-plugins) to provide plugin functionality. +Kedro uses [`setuptools`](https://setuptools.readthedocs.io/en/latest/setuptools.html), which is a collection of enhancements to the Python `distutils` to allow developers to build and distribute Python packages. Kedro uses various entry points in [`pkg_resources`](https://setuptools.readthedocs.io/en/latest/setuptools.html) to provide plugin functionality. ## Example of a simple plugin @@ -148,7 +148,7 @@ When you are ready to submit your code: 2. Choose a command approach: `global` and / or `project` commands: - All `global` commands should be provided as a single `click` group - All `project` commands should be provided as another `click` group - - The `click` groups are declared through the [`pkg_resources` entry_point system](https://setuptools.readthedocs.io/en/latest/setuptools.html#dynamic-discovery-of-services-and-plugins) + - The `click` groups are declared through the [`pkg_resources` entry_point system](https://setuptools.readthedocs.io/en/latest/setuptools.html) 3. Include a `README.md` describing your plugin's functionality and all dependencies that should be included 4. Use GitHub tagging to tag your plugin as a `kedro-plugin` so that we can find it diff --git a/docs/source/11_tools_integration/01_pyspark.md b/docs/source/11_tools_integration/01_pyspark.md index e730da0af2..14a75115d8 100644 --- a/docs/source/11_tools_integration/01_pyspark.md +++ b/docs/source/11_tools_integration/01_pyspark.md @@ -77,6 +77,7 @@ CONTEXT_CLASS = CustomContext We recommend using Kedro's built-in Spark datasets to load raw data into Spark's [DataFrame](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.DataFrame.html), as well as to write them back to storage. Some of our built-in Spark datasets include: +* [spark.DeltaTableDataSet](/kedro.extras.datasets.spark.DeltaTableDataSet) * [spark.SparkDataSet](/kedro.extras.datasets.spark.SparkDataSet) * [spark.SparkJDBCDataSet](/kedro.extras.datasets.spark.SparkJDBCDataSet) * [spark.SparkHiveDataSet](/kedro.extras.datasets.spark.SparkHiveDataSet) @@ -115,6 +116,86 @@ df = catalog.load("weather") assert isinstance(df, pyspark.sql.DataFrame) ``` +## Spark and Delta Lake interaction + +[Delta Lake](https://delta.io/) is an open-source project that enables building a Lakehouse architecture on top of data lakes. It provides ACID transactions and unifies streaming and batch data processing on top of existing data lakes, such as S3, ADLS, GCS, and HDFS. +To setup PySpark with Delta Lake, have a look at [the recommendations in Delta Lake's documentation](https://docs.delta.io/latest/quick-start.html#python). + +We recommend the following workflow, which makes use of the [Transcoding](../05_data/01_data_catalog.md) feature in Kedro: + +* To create a Delta table, use a `SparkDataSet` with `file_format="delta"`. You can also use this type of dataset to read from a Delta table and/or overwrite it. +* To perform [Delta table deletes, updates, and merges](https://docs.delta.io/latest/delta-update.html#language-python), load the data using a `DeltaTableDataSet` and perform the write operations within the node function. + +As a result, we end up with a catalog that looks like this: + +```yaml +temperature: + type: spark.SparkDataSet + filepath: data/01_raw/data.csv + file_format: "csv" + load_args: + header: True + inferSchema: True + save_args: + sep: '|' + header: True + +weather@spark: + type: spark.SparkDataSet + filepath: s3a://my_bucket/03_primary/weather + file_format: "delta" + save_args: + mode: "overwrite" + df_writer: + versionAsOf: 0 + +weather@delta: + type: spark.DeltaTableDataSet + filepath: s3a://my_bucket/03_primary/weather +``` + +The `DeltaTableDataSet` does not support `save()` operation, as the updates happen in place inside the node function, i.e. through `DeltaTable.update()`, `DeltaTable.delete()`, `DeltaTable.merge()`. + + +> Note: If you have defined an implementation for the Kedro `before_dataset_saved`/`after_dataset_saved` hook, the hook will not be triggered. This is because the save operation happens within the `node` itself, via the DeltaTable API. + +```python +Pipeline( + [ + node( + func=process_barometer_data, inputs="temperature", outputs="weather@spark" + ), + node( + func=update_meterological_state, + inputs="weather@delta", + outputs="first_operation_complete", + ), + node( + func=estimate_weather_trend, + inputs=["first_operation_complete", "weather@delta"], + outputs="second_operation_complete", + ), + ] +) +``` + +`first_operation_complete` is a `MemoryDataSet` and it signals that any Delta operations which occur "outside" the Kedro DAG are complete. This can be used as input to a downstream node, to preserve the shape of the DAG. Otherwise, if no downstream nodes need to run after this, the node can simply not return anything: + +```python +Pipeline( + [ + node(func=..., inputs="temperature", outputs="weather@spark"), + node(func=..., inputs="weather@delta", outputs=None), + ] +) +``` + +The following diagram is the visual representation of the workflow explained above: + +![Spark and Delta Lake workflow](../meta/images/spark_delta_workflow.png) + +> Note: This pattern of creating "dummy" datasets to preserve the data flow also applies to other "out of DAG" execution operations such as SQL operations within a node. + ## Use `MemoryDataSet` for intermediary `DataFrame` For nodes operating on `DataFrame` that doesn't need to perform Spark actions such as writing the `DataFrame` to storage, we recommend using the default `MemoryDataSet` to hold the `DataFrame`. In other words, there is no need to specify it in the `DataCatalog` or `catalog.yml`. This allows you to take advantage of Spark's optimiser and lazy evaluation. diff --git a/docs/source/15_api_docs/kedro.extras.datasets.rst b/docs/source/15_api_docs/kedro.extras.datasets.rst index 6150ca44eb..c49217f409 100644 --- a/docs/source/15_api_docs/kedro.extras.datasets.rst +++ b/docs/source/15_api_docs/kedro.extras.datasets.rst @@ -36,6 +36,7 @@ kedro.extras.datasets kedro.extras.datasets.pillow.ImageDataSet kedro.extras.datasets.plotly.JSONDataSet kedro.extras.datasets.plotly.PlotlyDataSet + kedro.extras.datasets.spark.DeltaTableDataSet kedro.extras.datasets.spark.SparkDataSet kedro.extras.datasets.spark.SparkHiveDataSet kedro.extras.datasets.spark.SparkJDBCDataSet diff --git a/docs/source/meta/images/spark_delta_workflow.png b/docs/source/meta/images/spark_delta_workflow.png new file mode 100644 index 0000000000..66d71a7fc6 Binary files /dev/null and b/docs/source/meta/images/spark_delta_workflow.png differ diff --git a/kedro/extras/datasets/spark/__init__.py b/kedro/extras/datasets/spark/__init__.py index acbd3b8ab6..3dede09aa8 100644 --- a/kedro/extras/datasets/spark/__init__.py +++ b/kedro/extras/datasets/spark/__init__.py @@ -1,6 +1,6 @@ """Provides I/O modules for Apache Spark.""" -__all__ = ["SparkDataSet", "SparkHiveDataSet", "SparkJDBCDataSet"] +__all__ = ["SparkDataSet", "SparkHiveDataSet", "SparkJDBCDataSet", "DeltaTableDataSet"] from contextlib import suppress @@ -10,3 +10,5 @@ from .spark_hive_dataset import SparkHiveDataSet with suppress(ImportError): from .spark_jdbc_dataset import SparkJDBCDataSet +with suppress(ImportError): + from .deltatable_dataset import DeltaTableDataSet diff --git a/kedro/extras/datasets/spark/deltatable_dataset.py b/kedro/extras/datasets/spark/deltatable_dataset.py new file mode 100644 index 0000000000..59a6810908 --- /dev/null +++ b/kedro/extras/datasets/spark/deltatable_dataset.py @@ -0,0 +1,110 @@ +"""``AbstractVersionedDataSet`` implementation to access DeltaTables using +``delta-spark`` +""" +from copy import deepcopy +from pathlib import PurePosixPath +from typing import Any, Dict + +from delta.tables import DeltaTable +from pyspark.sql import SparkSession +from pyspark.sql.utils import AnalysisException + +from kedro.extras.datasets.spark.spark_dataset import ( + _split_filepath, + _strip_dbfs_prefix, +) +from kedro.io.core import AbstractDataSet, DataSetError + + +class DeltaTableDataSet(AbstractDataSet): + """``DeltaTableDataSet`` loads data into DeltaTable objects. + + Example adding a catalog entry with + `YAML API `_: + + .. code-block:: yaml + + >>> weather@spark: + >>> type: spark.SparkDataSet + >>> filepath: data/02_intermediate/data.parquet + >>> file_format: "delta" + >>> + >>> weather@delta: + >>> type: spark.DeltaTableDataSet + >>> filepath: data/02_intermediate/data.parquet + + Example using Python API: + :: + + >>> from pyspark.sql import SparkSession + >>> from pyspark.sql.types import (StructField, StringType, + >>> IntegerType, StructType) + >>> + >>> from kedro.extras.datasets.spark import DeltaTableDataSet, SparkDataSet + >>> + >>> schema = StructType([StructField("name", StringType(), True), + >>> StructField("age", IntegerType(), True)]) + >>> + >>> data = [('Alex', 31), ('Bob', 12), ('Clarke', 65), ('Dave', 29)] + >>> + >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) + >>> + >>> data_set = SparkDataSet(filepath="test_data", file_format="delta") + >>> data_set.save(spark_df) + >>> deltatable_dataset = DeltaTableDataSet(filepath="test_data") + >>> delta_table = deltatable_dataset.load() + >>> + >>> delta_table.update() + """ + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # using ``ThreadRunner`` instead + _SINGLE_PROCESS = True + + def __init__(self, filepath: str, credentials: Dict[str, Any] = None) -> None: + """Creates a new instance of ``DeltaTableDataSet``. + + Args: + filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks + and working with data written to mount path points, + specify ``filepath``s for (versioned) ``SparkDataSet``s + starting with ``/dbfs/mnt``. + credentials: Credentials to access the S3 bucket, such as + ``key``, ``secret``, if ``filepath`` prefix is ``s3a://`` or ``s3n://``. + Optional keyword arguments passed to ``hdfs.client.InsecureClient`` + if ``filepath`` prefix is ``hdfs://``. Ignored otherwise. + """ + credentials = deepcopy(credentials) or {} # do we need these anywhere?? + fs_prefix, filepath = _split_filepath(filepath) + + self._fs_prefix = fs_prefix + self._filepath = PurePosixPath(filepath) + + @staticmethod + def _get_spark(): + return SparkSession.builder.getOrCreate() + + def _load(self) -> DeltaTable: + load_path = self._fs_prefix + str(self._filepath) + return DeltaTable.forPath(self._get_spark(), load_path) + + def _save(self, data: Any) -> None: + raise DataSetError(f"{self.__class__.__name__} is a read only dataset type") + + def _exists(self) -> bool: + load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._filepath)) + + try: + self._get_spark().read.load(path=load_path, format="delta") + except AnalysisException as exception: + if "is not a Delta table" in exception.desc: + return False + raise + + return True + + def _describe(self): + return dict(filepath=str(self._filepath), fs_prefix=self._fs_prefix) diff --git a/kedro/extras/datasets/spark/spark_dataset.py b/kedro/extras/datasets/spark/spark_dataset.py index 8d4755cf49..0fec5a7cfa 100644 --- a/kedro/extras/datasets/spark/spark_dataset.py +++ b/kedro/extras/datasets/spark/spark_dataset.py @@ -1,4 +1,4 @@ -"""``AbstractDataSet`` implementation to access Spark dataframes using +"""``AbstractVersionedDataSet`` implementation to access Spark dataframes using ``pyspark`` """ from copy import deepcopy @@ -13,7 +13,7 @@ from pyspark.sql.utils import AnalysisException from s3fs import S3FileSystem -from kedro.io.core import AbstractVersionedDataSet, Version +from kedro.io.core import AbstractVersionedDataSet, DataSetError, Version def _parse_glob_pattern(pattern: str) -> str: @@ -223,7 +223,7 @@ def __init__( # pylint: disable=too-many-arguments starting with ``/dbfs/mnt``. file_format: File format used during load and save operations. These are formats supported by the running - SparkContext include parquet, csv. For a list of supported + SparkContext include parquet, csv, delta. For a list of supported formats please refer to Apache Spark documentation at https://spark.apache.org/docs/latest/sql-programming-guide.html load_args: Load args passed to Spark DataFrameReader load method. @@ -304,9 +304,13 @@ def __init__( # pylint: disable=too-many-arguments if save_args is not None: self._save_args.update(save_args) + ### would they be relevant on load_args / on read as well? + self._dfwriter_options = self._save_args.pop("dfwriter_options", {}) or {} self._file_format = file_format self._fs_prefix = fs_prefix + self._handle_delta_format() + def _describe(self) -> Dict[str, Any]: return dict( filepath=self._fs_prefix + str(self._filepath), @@ -329,7 +333,9 @@ def _load(self) -> DataFrame: def _save(self, data: DataFrame) -> None: save_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_save_path())) - data.write.save(save_path, self._file_format, **self._save_args) + data.write.options(**self._dfwriter_options).save( + save_path, self._file_format, **self._save_args + ) def _exists(self) -> bool: load_path = _strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) @@ -337,7 +343,20 @@ def _exists(self) -> bool: try: self._get_spark().read.load(load_path, self._file_format) except AnalysisException as exception: - if exception.desc.startswith("Path does not exist:"): + if ( + exception.desc.startswith("Path does not exist:") + or "is not a Delta table" in exception.desc + ): return False raise return True + + def _handle_delta_format(self) -> None: + unsupported_modes = {"merge", "delete", "update"} + write_mode = self._save_args.get("mode") or "" + if self._file_format == "delta" and write_mode.lower() in unsupported_modes: + raise DataSetError( + f"It is not possible to perform `save()` for file format `delta` " + f"with mode `{write_mode}` on `SparkDataSet`. " + f"Please use `spark.DeltaTableDataSet` instead." + ) diff --git a/setup.py b/setup.py index 262df43429..7a1e9d5c16 100644 --- a/setup.py +++ b/setup.py @@ -93,6 +93,7 @@ def _collect_requirements(requires): "spark.SparkDataSet": [SPARK, HDFS, S3FS], "spark.SparkHiveDataSet": [SPARK, HDFS, S3FS], "spark.SparkJDBCDataSet": [SPARK, HDFS, S3FS], + "spark.DeltaTableDataSet": [SPARK, HDFS, S3FS, "delta-spark~=1.0"], } tensorflow_required = { "tensorflow.TensorflowModelDataset": [ diff --git a/test_requirements.txt b/test_requirements.txt index 96e215083e..e9d785ad75 100644 --- a/test_requirements.txt +++ b/test_requirements.txt @@ -8,7 +8,9 @@ blacken-docs==1.9.2 compress-pickle[lz4]~=1.2.0 dask>=2021.10.0, <2022.01; python_version > '3.6' # not directly required, pinned by Snyk to avoid a vulnerability dask[complete]~=2.6; python_version == '3.6' +delta-spark~=1.0 dill~=0.3.1 +filelock>=3.4.0, <4.0 gcsfs>=2021.04, <2022.01 # Upper bound set arbitrarily, to be reassessed in early 2022 geopandas>=0.6.0, <1.0 hdfs>=2.5.8, <3.0 diff --git a/tests/extras/datasets/spark/conftest.py b/tests/extras/datasets/spark/conftest.py index bac9d88880..8c30ae50f2 100644 --- a/tests/extras/datasets/spark/conftest.py +++ b/tests/extras/datasets/spark/conftest.py @@ -4,52 +4,38 @@ discover them automatically. More info here: https://docs.pytest.org/en/latest/fixture.html """ -import gc -from subprocess import Popen - import pytest +from delta import configure_spark_with_delta_pip +from filelock import FileLock try: - from pyspark import SparkContext from pyspark.sql import SparkSession except ImportError: # pragma: no cover pass # this is only for test discovery to succeed on Python 3.8 -the_real_getOrCreate = None - - -class UseTheSparkSessionFixtureOrMock: # pylint: disable=too-few-public-methods - pass - -# prevent using spark without going through the spark_session fixture -@pytest.fixture(scope="session", autouse=True) -def replace_spark_default_getorcreate(): - global the_real_getOrCreate # pylint: disable=global-statement - the_real_getOrCreate = SparkSession.builder.getOrCreate - SparkSession.builder.getOrCreate = UseTheSparkSessionFixtureOrMock - return the_real_getOrCreate - - -# clean up pyspark after the test module finishes -@pytest.fixture(scope="module") -def spark_session(): # SKIP_IF_NO_SPARK - SparkSession.builder.getOrCreate = the_real_getOrCreate - spark = SparkSession.builder.getOrCreate() +def _setup_spark_session(): + return configure_spark_with_delta_pip( + SparkSession.builder.appName("MyApp") + .master("local[*]") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + ).getOrCreate() + + +@pytest.fixture(scope="module", autouse=True) +def spark_session(tmp_path_factory): # SKIP_IF_NO_SPARK + # When running these spark tests with pytest-xdist, we need to make sure + # that the spark session setup on each test process don't interfere with each other. + # Therefore, we block the process during the spark session setup. + # Locking procedure comes from pytest-xdist's own recommendation: + # https://github.com/pytest-dev/pytest-xdist#making-session-scoped-fixtures-execute-only-once + root_tmp_dir = tmp_path_factory.getbasetemp().parent + lock = root_tmp_dir / "semaphore.lock" + with FileLock(lock): # pylint: disable=abstract-class-instantiated + spark = _setup_spark_session() yield spark spark.stop() - SparkSession.builder.getOrCreate = UseTheSparkSessionFixtureOrMock - - # remove the cached JVM vars - SparkContext._jvm = None # pylint: disable=protected-access - SparkContext._gateway = None # pylint: disable=protected-access - - # py4j doesn't shutdown properly so kill the actual JVM process - for obj in gc.get_objects(): - try: - if isinstance(obj, Popen) and "pyspark" in obj.args[0]: - obj.terminate() - except ReferenceError: # pragma: no cover - # gc.get_objects may return dead weak proxy objects that will raise - # ReferenceError when you isinstance them - pass diff --git a/tests/extras/datasets/spark/test_deltatable_dataset.py b/tests/extras/datasets/spark/test_deltatable_dataset.py new file mode 100644 index 0000000000..73fcc9c537 --- /dev/null +++ b/tests/extras/datasets/spark/test_deltatable_dataset.py @@ -0,0 +1,89 @@ +import pytest +from delta import DeltaTable +from pyspark.sql import SparkSession +from pyspark.sql.types import IntegerType, StringType, StructField, StructType +from pyspark.sql.utils import AnalysisException + +from kedro.extras.datasets.spark import DeltaTableDataSet, SparkDataSet +from kedro.io import DataCatalog, DataSetError +from kedro.pipeline import Pipeline, node +from kedro.runner import ParallelRunner + + +@pytest.fixture +def sample_spark_df(): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + + return SparkSession.builder.getOrCreate().createDataFrame(data, schema) + + +class TestDeltaTableDataSet: + def test_load(self, tmp_path, sample_spark_df): + filepath = (tmp_path / "test_data").as_posix() + spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta") + spark_delta_ds.save(sample_spark_df) + loaded_with_spark = spark_delta_ds.load() + assert loaded_with_spark.exceptAll(sample_spark_df).count() == 0 + + delta_ds = DeltaTableDataSet(filepath=filepath) + delta_table = delta_ds.load() + + assert isinstance(delta_table, DeltaTable) + loaded_with_deltalake = delta_table.toDF() + assert loaded_with_deltalake.exceptAll(loaded_with_spark).count() == 0 + + def test_save(self, tmp_path, sample_spark_df): + filepath = (tmp_path / "test_data").as_posix() + delta_ds = DeltaTableDataSet(filepath=filepath) + assert not delta_ds.exists() + + pattern = "DeltaTableDataSet is a read only dataset type" + with pytest.raises(DataSetError, match=pattern): + delta_ds.save(sample_spark_df) + + # check that indeed nothing is written + assert not delta_ds.exists() + + def test_exists(self, tmp_path, sample_spark_df): + filepath = (tmp_path / "test_data").as_posix() + delta_ds = DeltaTableDataSet(filepath=filepath) + + assert not delta_ds.exists() + + spark_delta_ds = SparkDataSet(filepath=filepath, file_format="delta") + spark_delta_ds.save(sample_spark_df) + + assert delta_ds.exists() + + def test_exists_raises_error(self, mocker): + delta_ds = DeltaTableDataSet(filepath="") + mocker.patch.object( + delta_ds, "_get_spark", side_effect=AnalysisException("Other Exception", []) + ) + + with pytest.raises(DataSetError, match="Other Exception"): + delta_ds.exists() + + @pytest.mark.parametrize("is_async", [False, True]) + def test_parallel_runner(self, is_async): + """Test ParallelRunner with SparkDataSet fails.""" + + def no_output(x): + _ = x + 1 # pragma: no cover + + delta_ds = DeltaTableDataSet(filepath="") + catalog = DataCatalog(data_sets={"delta_in": delta_ds}) + pipeline = Pipeline([node(no_output, "delta_in", None)]) + pattern = ( + r"The following data sets cannot be used with " + r"multiprocessing: \['delta_in'\]" + ) + with pytest.raises(AttributeError, match=pattern): + ParallelRunner(is_async=is_async).run(pipeline, catalog) diff --git a/tests/extras/datasets/spark/test_spark_dataset.py b/tests/extras/datasets/spark/test_spark_dataset.py index f2e6673547..d86c9e3aea 100644 --- a/tests/extras/datasets/spark/test_spark_dataset.py +++ b/tests/extras/datasets/spark/test_spark_dataset.py @@ -1,10 +1,11 @@ +import re import sys import tempfile from pathlib import Path, PurePosixPath import pandas as pd import pytest -from pyspark.sql import SparkSession +from pyspark.sql import DataFrame, SparkSession from pyspark.sql.functions import col from pyspark.sql.types import IntegerType, StringType, StructField, StructType from pyspark.sql.utils import AnalysisException @@ -50,12 +51,6 @@ ] -@pytest.fixture(autouse=True) -def spark_session_autouse(spark_session): - # all the tests in this file require Spark - return spark_session - - @pytest.fixture def sample_pandas_df() -> pd.DataFrame: return pd.DataFrame( @@ -192,9 +187,7 @@ def test_str_representation(self): with tempfile.NamedTemporaryFile() as temp_data_file: filepath = Path(temp_data_file.name).as_posix() spark_data_set = SparkDataSet( - filepath=filepath, - file_format="csv", - load_args={"header": True}, + filepath=filepath, file_format="csv", load_args={"header": True} ) assert "SparkDataSet" in str(spark_data_set) assert f"filepath={filepath}" in str(spark_data_set) @@ -218,6 +211,49 @@ def test_save_overwrite_mode(self, tmp_path, sample_spark_df): spark_data_set.save(sample_spark_df) spark_data_set.save(sample_spark_df) + @pytest.mark.parametrize( + "save_args,expected_options", + [ + ({"mode": "overwrite"}, {}), + ({"mode": "overwrite", "dfwriter_options": {}}, {}), + ({"mode": "overwrite", "dfwriter_options": None}, {}), + ( + {"mode": "overwrite", "dfwriter_options": {"versionAsOf": 0}}, + {"versionAsOf": 0}, + ), + ], + ) + def test_save_dfwriter_options( + self, tmp_path, save_args, expected_options, sample_spark_df, mocker + ): + filepath = (tmp_path / "test_data").as_posix() + spark_data_set = SparkDataSet( + filepath=filepath, file_format="delta", save_args=save_args + ) + + assert spark_data_set._dfwriter_options == expected_options + + mock_writer = mocker.patch.object(DataFrame, "write") + spark_data_set.save(sample_spark_df) + if expected_options: + mock_writer.options.assert_called_once_with(versionAsOf=0) + else: + mock_writer.options.assert_called_once_with() + + @pytest.mark.parametrize("mode", ["merge", "delete", "update"]) + def test_file_format_delta_and_unsupported_mode(self, tmp_path, mode): + filepath = (tmp_path / "test_data").as_posix() + pattern = ( + f"It is not possible to perform `save()` for file format `delta` " + f"with mode `{mode}` on `SparkDataSet`. " + f"Please use `spark.DeltaTableDataSet` instead." + ) + + with pytest.raises(DataSetError, match=re.escape(pattern)): + _ = SparkDataSet( + filepath=filepath, file_format="delta", save_args={"mode": mode} + ) + def test_save_partition(self, tmp_path, sample_spark_df): # To verify partitioning this test will partition the data by one # of the columns and then check whether partitioned column is added @@ -235,7 +271,7 @@ def test_save_partition(self, tmp_path, sample_spark_df): assert expected_path.exists() - @pytest.mark.parametrize("file_format", ["csv", "parquet"]) + @pytest.mark.parametrize("file_format", ["csv", "parquet", "delta"]) def test_exists(self, file_format, tmp_path, sample_spark_df): filepath = (tmp_path / "test_data").as_posix() spark_data_set = SparkDataSet(filepath=filepath, file_format=file_format) @@ -580,7 +616,7 @@ def test_save(self, versioned_dataset_s3, version, mocker): ) versioned_dataset_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once_with( + mocked_spark_df.write.options().save.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v=version.save), "parquet", ) @@ -600,7 +636,7 @@ def test_save_version_warning(self, mocker): ) with pytest.warns(UserWarning, match=pattern): ds_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once_with( + mocked_spark_df.write.options().save.assert_called_once_with( "s3a://{b}/{f}/{v}/{f}".format( b=BUCKET_NAME, f=FILENAME, v=exact_version.save ), @@ -712,7 +748,7 @@ def test_save(self, mocker, version): "{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME), strict=False, ) - mocked_spark_df.write.save.assert_called_once_with( + mocked_spark_df.write.options().save.assert_called_once_with( "hdfs://{fn}/{f}/{v}/{f}".format( fn=FOLDER_NAME, v=version.save, f=FILENAME ), @@ -734,7 +770,7 @@ def test_save_version_warning(self, mocker): with pytest.warns(UserWarning, match=pattern): versioned_hdfs.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once_with( + mocked_spark_df.write.options().save.assert_called_once_with( "hdfs://{fn}/{f}/{sv}/{f}".format( fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save ), diff --git a/tests/extras/datasets/spark/test_spark_hive_dataset.py b/tests/extras/datasets/spark/test_spark_hive_dataset.py index c6b1eb31f6..c30ca49cd3 100644 --- a/tests/extras/datasets/spark/test_spark_hive_dataset.py +++ b/tests/extras/datasets/spark/test_spark_hive_dataset.py @@ -10,15 +10,12 @@ from kedro.extras.datasets.spark import SparkHiveDataSet from kedro.io import DataSetError -from tests.extras.datasets.spark.conftest import UseTheSparkSessionFixtureOrMock TESTSPARKDIR = "test_spark_dir" -# clean up pyspark after the test module finishes @pytest.fixture(scope="module") -def spark_hive_session(replace_spark_default_getorcreate): - SparkSession.builder.getOrCreate = replace_spark_default_getorcreate +def spark_session(): try: with TemporaryDirectory(TESTSPARKDIR) as tmpdir: spark = ( @@ -48,8 +45,6 @@ def spark_hive_session(replace_spark_default_getorcreate): # files are still used by Java process. pass - SparkSession.builder.getOrCreate = UseTheSparkSessionFixtureOrMock - # remove the cached JVM vars SparkContext._jvm = None # pylint: disable=protected-access SparkContext._gateway = None # pylint: disable=protected-access @@ -66,7 +61,7 @@ def spark_hive_session(replace_spark_default_getorcreate): @pytest.fixture(scope="module", autouse=True) -def spark_test_databases(spark_hive_session): +def spark_test_databases(spark_session): """Setup spark test databases for all tests in this module.""" dataset = _generate_spark_df_one() dataset.createOrReplaceTempView("tmp") @@ -74,15 +69,15 @@ def spark_test_databases(spark_hive_session): # Setup the databases and test table before testing for database in databases: - spark_hive_session.sql(f"create database {database}") - spark_hive_session.sql("use default_1") - spark_hive_session.sql("create table table_1 as select * from tmp") + spark_session.sql(f"create database {database}") + spark_session.sql("use default_1") + spark_session.sql("create table table_1 as select * from tmp") - yield spark_hive_session + yield spark_session # Drop the databases after testing for database in databases: - spark_hive_session.sql(f"drop database {database} cascade") + spark_session.sql(f"drop database {database} cascade") def assert_df_equal(expected, result): @@ -150,8 +145,8 @@ def test_read_existing_table(self): ) assert_df_equal(_generate_spark_df_one(), dataset.load()) - def test_overwrite_empty_table(self, spark_hive_session): - spark_hive_session.sql( + def test_overwrite_empty_table(self, spark_session): + spark_session.sql( "create table default_1.test_overwrite_empty_table (name string, age integer)" ).take(1) dataset = SparkHiveDataSet( @@ -162,8 +157,8 @@ def test_overwrite_empty_table(self, spark_hive_session): dataset.save(_generate_spark_df_one()) assert_df_equal(dataset.load(), _generate_spark_df_one()) - def test_overwrite_not_empty_table(self, spark_hive_session): - spark_hive_session.sql( + def test_overwrite_not_empty_table(self, spark_session): + spark_session.sql( "create table default_1.test_overwrite_full_table (name string, age integer)" ).take(1) dataset = SparkHiveDataSet( @@ -175,8 +170,8 @@ def test_overwrite_not_empty_table(self, spark_hive_session): dataset.save(_generate_spark_df_one()) assert_df_equal(dataset.load(), _generate_spark_df_one()) - def test_insert_not_empty_table(self, spark_hive_session): - spark_hive_session.sql( + def test_insert_not_empty_table(self, spark_session): + spark_session.sql( "create table default_1.test_insert_not_empty_table (name string, age integer)" ).take(1) dataset = SparkHiveDataSet( @@ -197,8 +192,8 @@ def test_upsert_config_err(self): ): SparkHiveDataSet(database="default_1", table="table_1", write_mode="upsert") - def test_upsert_empty_table(self, spark_hive_session): - spark_hive_session.sql( + def test_upsert_empty_table(self, spark_session): + spark_session.sql( "create table default_1.test_upsert_empty_table (name string, age integer)" ).take(1) dataset = SparkHiveDataSet( @@ -212,8 +207,8 @@ def test_upsert_empty_table(self, spark_hive_session): dataset.load().sort("name"), _generate_spark_df_one().sort("name") ) - def test_upsert_not_empty_table(self, spark_hive_session): - spark_hive_session.sql( + def test_upsert_not_empty_table(self, spark_session): + spark_session.sql( "create table default_1.test_upsert_not_empty_table (name string, age integer)" ).take(1) dataset = SparkHiveDataSet( @@ -257,8 +252,8 @@ def test_invalid_write_mode_provided(self): table_pk=["name"], ) - def test_invalid_schema_insert(self, spark_hive_session): - spark_hive_session.sql( + def test_invalid_schema_insert(self, spark_session): + spark_session.sql( "create table default_1.test_invalid_schema_insert " "(name string, additional_column_on_hive integer)" ).take(1)