-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Increase test coverage + Fix
save_model_to_hdf5
+ Improve `is_remot…
…e_path` + Fix `is_remote_path` (#900) * Increase test coverage in `saving` * Add FAILED tests TODO * Add tests for `LambdaCallback` * Add tests for `LambdaCallback` * Add test for saving_api.py#L96 * Increase test coverage in `saving` * Increase test coverage * refines the logic `os.makedirs` +Increase tests * Increase test coverage * Increase test coverage * More tests file_utils_test.py+fix bug `rmtree` * More tests `file_utils_test` + fix bug `rmtree` * More tests file_utils_test + fix bug rmtree * Increase test coverage * add tests to `lambda_callback_test` * Add tests in file_utils_test.py * Add tests in file_utils_test.py * Add more tests `file_utils_test` * add class TestValidateFile * Add tests for `TestIsRemotePath` * Add tests in file_utils_test.py * Add tests in file_utils_test.py * Add tests in file_utils_test.py * Add tests in `file_utils_test.py` * fix `is_remote_path` * improve `is_remote_path` * Add test for `raise_if_no_gfile_raises` * Add tests for file_utils.py * Add tests in `saving_api_test.py` * Add tests `saving_api_test.py` * Add tests saving_api_test.py * Add tests in `saving_api_test.py` * Add test `test_directory_creation_on_save` * Add test `legacy_h5_format_test.py` * Flake8 for `LambdaCallbackTest` * use `get_model` and `self.get_temp_dir` * Fix format * Improve `is_remote_path` + Add tests * Fix `is_remote_path`
- Loading branch information
1 parent
5af4344
commit a354797
Showing
6 changed files
with
947 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
import os | ||
import unittest.mock as mock | ||
|
||
import numpy as np | ||
from absl import logging | ||
|
||
from keras_core import layers | ||
from keras_core.models import Sequential | ||
from keras_core.saving import saving_api | ||
from keras_core.testing import test_case | ||
|
||
|
||
class SaveModelTests(test_case.TestCase): | ||
def get_model(self): | ||
return Sequential( | ||
[ | ||
layers.Dense(5, input_shape=(3,)), | ||
layers.Softmax(), | ||
] | ||
) | ||
|
||
def test_basic_saving(self): | ||
"""Test basic model saving and loading.""" | ||
model = self.get_model() | ||
filepath = os.path.join(self.get_temp_dir(), "test_model.keras") | ||
saving_api.save_model(model, filepath) | ||
|
||
loaded_model = saving_api.load_model(filepath) | ||
x = np.random.uniform(size=(10, 3)) | ||
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x))) | ||
|
||
def test_invalid_save_format(self): | ||
"""Test deprecated save_format argument.""" | ||
model = self.get_model() | ||
with self.assertRaisesRegex( | ||
ValueError, "The `save_format` argument is deprecated" | ||
): | ||
saving_api.save_model(model, "model.txt", save_format=True) | ||
|
||
def test_unsupported_arguments(self): | ||
"""Test unsupported argument during model save.""" | ||
model = self.get_model() | ||
filepath = os.path.join(self.get_temp_dir(), "test_model.keras") | ||
with self.assertRaisesRegex( | ||
ValueError, r"The following argument\(s\) are not supported" | ||
): | ||
saving_api.save_model(model, filepath, random_arg=True) | ||
|
||
def test_save_h5_format(self): | ||
"""Test saving model in h5 format.""" | ||
model = self.get_model() | ||
filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5") | ||
saving_api.save_model(model, filepath_h5) | ||
self.assertTrue(os.path.exists(filepath_h5)) | ||
os.remove(filepath_h5) | ||
|
||
def test_save_unsupported_extension(self): | ||
"""Test saving model with unsupported extension.""" | ||
model = self.get_model() | ||
with self.assertRaisesRegex( | ||
ValueError, "Invalid filepath extension for saving" | ||
): | ||
saving_api.save_model(model, "model.png") | ||
|
||
|
||
class LoadModelTests(test_case.TestCase): | ||
def get_model(self): | ||
return Sequential( | ||
[ | ||
layers.Dense(5, input_shape=(3,)), | ||
layers.Softmax(), | ||
] | ||
) | ||
|
||
def test_basic_load(self): | ||
"""Test basic model loading.""" | ||
model = self.get_model() | ||
filepath = os.path.join(self.get_temp_dir(), "test_model.keras") | ||
saving_api.save_model(model, filepath) | ||
|
||
loaded_model = saving_api.load_model(filepath) | ||
x = np.random.uniform(size=(10, 3)) | ||
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x))) | ||
|
||
def test_load_unsupported_format(self): | ||
"""Test loading model with unsupported format.""" | ||
with self.assertRaisesRegex(ValueError, "File format not supported"): | ||
saving_api.load_model("model.pkl") | ||
|
||
def test_load_keras_not_zip(self): | ||
"""Test loading keras file that's not a zip.""" | ||
with self.assertRaisesRegex(ValueError, "File not found"): | ||
saving_api.load_model("not_a_zip.keras") | ||
|
||
def test_load_h5_format(self): | ||
"""Test loading model in h5 format.""" | ||
model = self.get_model() | ||
filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5") | ||
saving_api.save_model(model, filepath_h5) | ||
loaded_model = saving_api.load_model(filepath_h5) | ||
x = np.random.uniform(size=(10, 3)) | ||
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x))) | ||
os.remove(filepath_h5) | ||
|
||
def test_load_model_with_custom_objects(self): | ||
"""Test loading model with custom objects.""" | ||
|
||
class CustomLayer(layers.Layer): | ||
def call(self, inputs): | ||
return inputs | ||
|
||
model = Sequential([CustomLayer(input_shape=(3,))]) | ||
filepath = os.path.join(self.get_temp_dir(), "custom_model.keras") | ||
model.save(filepath) | ||
loaded_model = saving_api.load_model( | ||
filepath, custom_objects={"CustomLayer": CustomLayer} | ||
) | ||
self.assertIsInstance(loaded_model.layers[0], CustomLayer) | ||
os.remove(filepath) | ||
|
||
|
||
class LoadWeightsTests(test_case.TestCase): | ||
def get_model(self): | ||
return Sequential( | ||
[ | ||
layers.Dense(5, input_shape=(3,)), | ||
layers.Softmax(), | ||
] | ||
) | ||
|
||
def test_load_keras_weights(self): | ||
"""Test loading keras weights.""" | ||
model = self.get_model() | ||
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5") | ||
model.save_weights(filepath) | ||
original_weights = model.get_weights() | ||
model.load_weights(filepath) | ||
loaded_weights = model.get_weights() | ||
for orig, loaded in zip(original_weights, loaded_weights): | ||
self.assertTrue(np.array_equal(orig, loaded)) | ||
|
||
def test_load_h5_weights_by_name(self): | ||
"""Test loading h5 weights by name.""" | ||
model = self.get_model() | ||
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5") | ||
model.save_weights(filepath) | ||
with self.assertRaisesRegex(ValueError, "Invalid keyword arguments"): | ||
model.load_weights(filepath, by_name=True) | ||
|
||
def test_load_weights_invalid_extension(self): | ||
"""Test loading weights with unsupported extension.""" | ||
model = self.get_model() | ||
with self.assertRaisesRegex(ValueError, "File format not supported"): | ||
model.load_weights("invalid_extension.pkl") | ||
|
||
|
||
class SaveModelTestsWarning(test_case.TestCase): | ||
def get_model(self): | ||
return Sequential( | ||
[ | ||
layers.Dense(5, input_shape=(3,)), | ||
layers.Softmax(), | ||
] | ||
) | ||
|
||
def test_h5_deprecation_warning(self): | ||
"""Test deprecation warning for h5 format.""" | ||
model = self.get_model() | ||
filepath = os.path.join(self.get_temp_dir(), "test_model.h5") | ||
|
||
with mock.patch.object(logging, "warning") as mock_warn: | ||
saving_api.save_model(model, filepath) | ||
mock_warn.assert_called_once_with( | ||
"You are saving your model as an HDF5 file via `model.save()`. " | ||
"This file format is considered legacy. " | ||
"We recommend using instead the native Keras format, " | ||
"e.g. `model.save('my_model.keras')`." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.