Skip to content

Commit

Permalink
[KED-1668] Fix caching on PartitionedDataSet and `IncrementalDataSe…
Browse files Browse the repository at this point in the history
…t` (#593)
  • Loading branch information
DmitriiDeriabinQB authored May 14, 2020
1 parent 671c57c commit 7b00f80
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 30 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
* Bug in `SparkDataSet` not allowing for loading data from DBFS in a Windows machine using Databricks-connect.
* Added option to lint the project without applying the formatting changes (`kedro lint --check-only`).
* Improved the error message for `DataSetNotFoundError` to suggest possible dataset names user meant to type.
* Replaced `functools.lru_cache` with `cachetools.cachedmethod` in `PartitionedDataSet` and `IncrementalDataSet` for per-instance cache invalidation.

## Breaking changes to the API
* Made `invalidate_cache` method on datasets private.
Expand All @@ -60,6 +61,7 @@
* Made constant `PARAMETER_KEYWORDS` private, and moved it from `kedro.pipeline.pipeline` to `kedro.pipeline.modular_pipeline`.
* Removed `CSVBlobDataSet` and `JSONBlobDataSet` as redundant.
* Layers are no longer part of the dataset object, as they've moved to the `DataCatalog`.
* `PartitionedDataSet` and `IncrementalDataSet` method `invalidate_cache` was made private: `_invalidate_caches`.

### Migration guide from Kedro 0.15.* to Upcoming Release
#### Migration for datasets
Expand Down
2 changes: 1 addition & 1 deletion docs/source/04_user_guide/08_advanced_io.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ As you can see from the example above, on load `PartitionedDataSet` _does not_ a
>
> Example 2: if `path="s3://my-bucket-name/folder"` and `filename_suffix=".csv"` and partition is stored in `s3://my-bucket-name/folder/2019-12-04/data.csv` then its Partition ID is `2019-12-04/data`.
> *Note:* `PartitionedDataSet` implements caching on load operation, which means that if multiple nodes consume the same `PartitionedDataSet`, they will all receive the same partition dictionary even if some new partitions were added to the folder after the first load has been completed. This is done deliberately to guarantee the consistency of load operations between the nodes and avoid race conditions. You can reset cache by calling `.invalidate_cache()` method of the partitioned dataset object.
> *Note:* `PartitionedDataSet` implements caching on load operation, which means that if multiple nodes consume the same `PartitionedDataSet`, they will all receive the same partition dictionary even if some new partitions were added to the folder after the first load has been completed. This is done deliberately to guarantee the consistency of load operations between the nodes and avoid race conditions. You can reset the cache by calling `release()` method of the partitioned dataset object.
### Partitioned dataset save

Expand Down
20 changes: 11 additions & 9 deletions kedro/io/partitioned_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@
"""
import operator
from copy import deepcopy
from functools import lru_cache
from typing import Any, Callable, Dict, List, Tuple, Type, Union
from urllib.parse import urlparse
from warnings import warn

from cachetools import Cache, cachedmethod

from kedro.io.core import (
VERSION_KEY,
VERSIONED_FLAG_KEY,
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__( # pylint: disable=too-many-arguments
self._path = path
self._filename_suffix = filename_suffix
self._protocol = infer_storage_options(self._path)["protocol"]
self._partition_cache = Cache(maxsize=1)

dataset = dataset if isinstance(dataset, dict) else {"type": dataset}
self._dataset_type, self._dataset_config = parse_dataset_definition(dataset)
Expand Down Expand Up @@ -179,7 +181,7 @@ def __init__( # pylint: disable=too-many-arguments
self._load_args = deepcopy(load_args) or {}
self._sep = self._filesystem.sep
# since some filesystem implementations may implement a global cache
self.invalidate_cache()
self._invalidate_caches()

@property
def _filesystem(self):
Expand All @@ -195,7 +197,7 @@ def _normalized_path(self) -> str:
return urlparse(self._path)._replace(scheme="s3").geturl()
return self._path

@lru_cache(maxsize=None)
@cachedmethod(cache=operator.attrgetter("_partition_cache"))
def _list_partitions(self) -> List[str]:
return [
path
Expand Down Expand Up @@ -247,7 +249,7 @@ def _save(self, data: Dict[str, Any]) -> None:
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
dataset.save(partition_data)
self.invalidate_cache()
self._invalidate_caches()

def _describe(self) -> Dict[str, Any]:
clean_dataset_config = (
Expand All @@ -261,16 +263,16 @@ def _describe(self) -> Dict[str, Any]:
dataset_config=clean_dataset_config,
)

def invalidate_cache(self):
"""Invalidate `_list_partitions` method and underlying filesystem caches."""
self._list_partitions.cache_clear()
def _invalidate_caches(self):
self._partition_cache.clear()
self._filesystem.invalidate_cache(self._normalized_path)

def _exists(self) -> bool:
return bool(self._list_partitions())

def _release(self) -> None:
self.invalidate_cache()
super()._release()
self._invalidate_caches()


def _split_credentials(
Expand Down Expand Up @@ -434,7 +436,7 @@ def _parse_checkpoint_config(

return {**default_config, **checkpoint_config}

@lru_cache(maxsize=None)
@cachedmethod(cache=operator.attrgetter("_partition_cache"))
def _list_partitions(self) -> List[str]:
checkpoint = self._read_checkpoint()
checkpoint_path = self._filesystem._strip_protocol( # pylint: disable=protected-access
Expand Down
29 changes: 16 additions & 13 deletions tests/extras/datasets/pandas/test_csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,10 @@ def test_catalog_release(self, mocker):
fs_mock = mocker.patch("fsspec.filesystem").return_value
filepath = "test.csv"
data_set = CSVDataSet(filepath=filepath)
assert data_set._version_cache.currsize == 0 # no cache if unversioned
data_set.release()
fs_mock.invalidate_cache.assert_called_once_with(filepath)
assert data_set._version_cache.currsize == 0


class TestCSVDataSetVersioned:
Expand Down Expand Up @@ -222,26 +224,27 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv):
ds_new = CSVDataSet(filepath=filepath_csv, version=Version(None, None))
assert ds_new.resolve_load_version() == second_load_version

def test_invalidate_version_cache(self, dummy_dataframe, filepath_csv):
"""Test that version cache invalidation in one instance doesn't affect others"""
def test_release_instance_cache(self, dummy_dataframe, filepath_csv):
"""Test that cache invalidation does not affect other instances"""
ds_a = CSVDataSet(filepath=filepath_csv, version=Version(None, None))
assert ds_a._version_cache.currsize == 0
ds_a.save(dummy_dataframe) # create a version
ds_a_save_version = ds_a.resolve_save_version()
ds_a_load_version = ds_a.resolve_load_version()
assert ds_a._version_cache.currsize == 2

ds_b = CSVDataSet(filepath=filepath_csv, version=Version(None, None))
ds_b_save_version = ds_b.resolve_save_version()
ds_b_load_version = ds_b.resolve_load_version()
assert ds_b._version_cache.currsize == 0
ds_b.resolve_save_version()
assert ds_b._version_cache.currsize == 1
ds_b.resolve_load_version()
assert ds_b._version_cache.currsize == 2

ds_a.save(dummy_dataframe) # create a new version
ds_a.release()

# dataset A has been updated
assert ds_a.resolve_save_version() > ds_a_save_version
assert ds_a.resolve_load_version() > ds_a_load_version
# dataset A cache is cleared
assert ds_a._version_cache.currsize == 0

# dataset B versions are unaffected
assert ds_b.resolve_save_version() == ds_b_save_version
assert ds_b.resolve_load_version() == ds_b_load_version
# dataset B cache is unaffected
assert ds_b._version_cache.currsize == 2

def test_no_versions(self, versioned_csv_data_set):
"""Check the error if no versions are available for load."""
Expand Down
42 changes: 35 additions & 7 deletions tests/io/test_partitioned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,44 @@ def test_save(self, dataset, local_csvs, suffix):
reloaded_data = loaded_partitions[part_id]()
assert_frame_equal(reloaded_data, original_data)

@pytest.mark.parametrize("dataset", LOCAL_DATASET_DEFINITION)
def test_save_invalidates_cache(self, dataset, local_csvs):
pds = PartitionedDataSet(str(local_csvs), dataset)
def test_save_invalidates_cache(self, local_csvs, mocker):
"""Test that save calls invalidate partition cache"""
pds = PartitionedDataSet(str(local_csvs), "pandas.CSVDataSet")
mocked_fs_invalidate = mocker.patch.object(pds._filesystem, "invalidate_cache")
first_load = pds.load()
assert pds._partition_cache.currsize == 1
mocked_fs_invalidate.assert_not_called()

# save clears cache
data = pd.DataFrame({"foo": 42, "bar": ["a", "b", None]})
part_id = "new/data.csv"
pds.save({part_id: data})
assert part_id not in first_load
assert part_id in pds.load()
new_partition = "new/data.csv"
pds.save({new_partition: data})
assert pds._partition_cache.currsize == 0
# it seems that `_filesystem.invalidate_cache` calls itself inside,
# resulting in not one, but 2 mock calls
# hence using `assert_any_call` instead of `assert_called_once_with`
mocked_fs_invalidate.assert_any_call(pds._normalized_path)

# new load returns new partition too
second_load = pds.load()
assert new_partition not in first_load
assert new_partition in second_load

def test_release_instance_cache(self, local_csvs):
"""Test that cache invalidation does not affect other instances"""
ds_a = PartitionedDataSet(str(local_csvs), "pandas.CSVDataSet")
ds_a.load()
ds_b = PartitionedDataSet(str(local_csvs), "pandas.CSVDataSet")
ds_b.load()

assert ds_a._partition_cache.currsize == 1
assert ds_b._partition_cache.currsize == 1

# invalidate cache of the dataset A
ds_a.release()
assert ds_a._partition_cache.currsize == 0
# cache of the dataset B is unaffected
assert ds_b._partition_cache.currsize == 1

@pytest.mark.parametrize("dataset", ["pandas.CSVDataSet", "pandas.ParquetDataSet"])
def test_exists(self, local_csvs, dataset):
Expand Down

0 comments on commit 7b00f80

Please sign in to comment.