Skip to content

Commit

Permalink
[KED-2140] Fix issue with saving versioned HDF5 models. (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
djpetti authored Oct 19, 2020
1 parent cd2687e commit 4aa8337
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
3 changes: 2 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
* Fixed `kedro install` for an Anaconda environment defined in `environment.yml`.
* Fixed backwards compatibility with templates generated with older Kedro versions <0.16.5. No longer need to update `.kedro.yml` to use `kedro lint` and `kedro jupyter notebook convert`.
* Improved documentation.
* Fixed issue with saving a `TensorFlowModelDataset` in the HDF5 format with versioning enabled.

## Breaking changes to the API

## Thanks for supporting contributions
[Deepyaman Datta](https://github.com/deepyaman), [Bhavya Merchant](https://github.com/bnmerchant), [Lovkush Agarwal](https://github.com/Lovkush-A), [Varun Krishna S](https://github.com/vhawk19), [Sebastian Bertoli](https://github.com/sebastianbertoli)
[Deepyaman Datta](https://github.com/deepyaman), [Bhavya Merchant](https://github.com/bnmerchant), [Lovkush Agarwal](https://github.com/Lovkush-A), [Varun Krishna S](https://github.com/vhawk19), [Sebastian Bertoli](https://github.com/sebastianbertoli), [Daniel Petti](https://github.com/djpetti)

# Release 0.16.5

Expand Down
6 changes: 5 additions & 1 deletion kedro/extras/datasets/tensorflow/tensorflow_model_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""
import copy
import tempfile
from pathlib import PurePath, PurePosixPath
from pathlib import Path, PurePath, PurePosixPath
from typing import Any, Dict

import fsspec
Expand Down Expand Up @@ -151,6 +151,10 @@ def _load(self) -> tf.keras.Model:
def _save(self, data: tf.keras.Model) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)

# Make sure all intermediate directories are created.
save_dir = Path(save_path).parent
save_dir.mkdir(parents=True, exist_ok=True)

with tempfile.TemporaryDirectory(prefix=self._tmp_prefix) as path:
if self._is_h5:
path = str(PurePath(path) / TEMPORARY_H5_FILE)
Expand Down
24 changes: 24 additions & 0 deletions tests/extras/datasets/tensorflow/test_tensorflow_model_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,30 @@ def test_save_and_load(
new_predictions = reloaded.predict(dummy_x_test)
np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)

def test_hdf5_save_format(
self,
dummy_tf_base_model,
dummy_x_test,
filepath,
tensorflow_model_dataset,
load_version,
save_version,
):
"""Test versioned TensorflowModelDataset can save TF graph models in
HDF5 format"""
hdf5_dataset = tensorflow_model_dataset(
filepath=filepath,
save_args={"save_format": "h5"},
version=Version(load_version, save_version),
)

predictions = dummy_tf_base_model.predict(dummy_x_test)
hdf5_dataset.save(dummy_tf_base_model)

reloaded = hdf5_dataset.load()
new_predictions = reloaded.predict(dummy_x_test)
np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)

def test_prevent_overwrite(self, dummy_tf_base_model, versioned_tf_model_dataset):
"""Check the error when attempting to override the data set if the
corresponding file for a given save version already exists."""
Expand Down

0 comments on commit 4aa8337

Please sign in to comment.