Skip to content

Commit

Permalink
[KED-2349] Bug in versioned SparkDataSet (#929)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrii-ivaniuk authored Jan 13, 2021
1 parent dfd5440 commit 3ece4be
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ workflows:
- win_pip_compile_38
- win_unit_tests_36
- win_unit_tests_37
- win_unit_tests_38
# Skipped due to Windows fatal exception: stack overflow
# - win_unit_tests_38
- win_e2e_tests_36
- win_e2e_tests_37
- win_e2e_tests_38
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
## Bug fixes and other changes

* The version of a packaged modular pipeline now defaults to the version of the project package.
* Fixed issue with loading a versioned `SparkDataSet` in the interactive workflow.
* Kedro CLI now checks `pyproject.toml` for a `tool.kedro` section before treating the project as a Kedro project

## Breaking changes to the API


## Thanks for supporting contributions

# Release 0.17.0
Expand Down
6 changes: 1 addition & 5 deletions kedro/extras/datasets/spark/spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"""``AbstractDataSet`` implementation to access Spark dataframes using
``pyspark``
"""

from copy import deepcopy
from fnmatch import fnmatch
from functools import partial
Expand Down Expand Up @@ -184,6 +183,7 @@ class SparkDataSet(AbstractVersionedDataSet):
>>> reloaded.take(4)
"""

_SINGLE_PROCESS = True
DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any]

Expand Down Expand Up @@ -289,10 +289,6 @@ def __init__( # pylint: disable=too-many-arguments
self._file_format = file_format
self._fs_prefix = fs_prefix

def __getstate__(self):
# SparkDataSet cannot be used with ParallelRunner
raise AttributeError(f"{self.__class__.__name__} cannot be serialized!")

def _describe(self) -> Dict[str, Any]:
return dict(
filepath=self._fs_prefix + str(self._filepath),
Expand Down
11 changes: 7 additions & 4 deletions kedro/runner/parallel_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,20 @@ def _validate_catalog(cls, catalog: DataCatalog, pipeline: Pipeline):

unserializable = []
for name, data_set in data_sets.items():
if getattr(data_set, "_SINGLE_PROCESS", False): # SKIP_IF_NO_SPARK
unserializable.append(name)
continue
try:
ForkingPickler.dumps(data_set)
except (AttributeError, PicklingError):
unserializable.append(name)

if unserializable:
raise AttributeError(
"The following data_sets cannot be serialized: {}\nIn order "
"to utilize multiprocessing you need to make sure all data "
"sets are serializable, i.e. data sets should not make use of "
"lambda functions, nested functions, closures etc.\nIf you "
"The following data sets cannot be used with multiprocessing: "
"{}\nIn order to utilize multiprocessing you need to make sure "
"all data sets are serializable, i.e. data sets should not make "
"use of lambda functions, nested functions, closures etc.\nIf you "
"are using custom decorators ensure they are correctly using "
"functools.wraps().".format(sorted(unserializable))
)
Expand Down
19 changes: 18 additions & 1 deletion tests/extras/datasets/spark/test_spark_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,31 @@ def test_parallel_runner(self, is_async, spark_in):
"""
catalog = DataCatalog(data_sets={"spark_in": spark_in})
pipeline = Pipeline([node(identity, "spark_in", "spark_out")])
pattern = r"The following data_sets cannot be serialized: \['spark_in'\]"
pattern = (
r"The following data sets cannot be used with "
r"multiprocessing: \['spark_in'\]"
)
with pytest.raises(AttributeError, match=pattern):
ParallelRunner(is_async=is_async).run(pipeline, catalog)

def test_s3_glob_refresh(self):
spark_dataset = SparkDataSet(filepath="s3a://bucket/data")
assert spark_dataset._glob_function.keywords == {"refresh": True}

def test_copy(self):
spark_dataset = SparkDataSet(
filepath="/tmp/data", save_args={"mode": "overwrite"}
)
assert spark_dataset._file_format == "parquet"

spark_dataset_copy = spark_dataset._copy(_file_format="csv")

assert spark_dataset is not spark_dataset_copy
assert spark_dataset._file_format == "parquet"
assert spark_dataset._save_args == {"mode": "overwrite"}
assert spark_dataset_copy._file_format == "csv"
assert spark_dataset_copy._save_args == {"mode": "overwrite"}


class TestSparkDataSetVersionedLocal:
def test_no_version(self, versioned_dataset_local):
Expand Down

0 comments on commit 3ece4be

Please sign in to comment.