Skip to content

Commit

Permalink
Increase test coverage + Fix save_model_to_hdf5 + Improve `is_remot…
Browse files Browse the repository at this point in the history
…e_path` + Fix `is_remote_path` (#900)

* Increase test coverage in `saving`

* Add FAILED tests TODO

* Add tests for `LambdaCallback`

* Add tests for `LambdaCallback`

* Add test for saving_api.py#L96

* Increase test coverage in `saving`

* Increase test coverage

* refines the logic `os.makedirs` +Increase tests

* Increase test coverage

* Increase test coverage

* More tests file_utils_test.py+fix bug `rmtree`

* More tests `file_utils_test` + fix bug `rmtree`

* More tests file_utils_test + fix bug rmtree

* Increase test coverage

* add tests to `lambda_callback_test`

* Add tests in file_utils_test.py

* Add tests in file_utils_test.py

* Add more tests `file_utils_test`

* add class TestValidateFile

* Add tests for `TestIsRemotePath`

* Add tests in file_utils_test.py

* Add tests in file_utils_test.py

* Add tests in file_utils_test.py

* Add tests in `file_utils_test.py`

* fix `is_remote_path`

* improve `is_remote_path`

* Add test for `raise_if_no_gfile_raises`

* Add tests for file_utils.py

* Add tests in `saving_api_test.py`

* Add tests `saving_api_test.py`

* Add tests saving_api_test.py

* Add tests in `saving_api_test.py`

* Add test `test_directory_creation_on_save`

* Add test `legacy_h5_format_test.py`

* Flake8  for `LambdaCallbackTest`

* use `get_model` and `self.get_temp_dir`

* Fix format

* Improve `is_remote_path` + Add tests

* Fix `is_remote_path`
  • Loading branch information
Faisal-Alsrheed authored Sep 20, 2023
1 parent 5af4344 commit a354797
Show file tree
Hide file tree
Showing 6 changed files with 947 additions and 104 deletions.
127 changes: 123 additions & 4 deletions keras_core/callbacks/lambda_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@

class LambdaCallbackTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_LambdaCallback(self):
BATCH_SIZE = 4
def test_lambda_callback(self):
"""Test standard LambdaCallback functionalities with training."""
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=BATCH_SIZE), layers.Dense(1)]
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
Expand All @@ -34,7 +35,7 @@ def test_LambdaCallback(self):
model.fit(
x,
y,
batch_size=BATCH_SIZE,
batch_size=batch_size,
validation_split=0.2,
callbacks=[lambda_log_callback],
epochs=5,
Expand All @@ -44,3 +45,121 @@ def test_LambdaCallback(self):
self.assertTrue(any("on_epoch_begin" in log for log in logs.output))
self.assertTrue(any("on_epoch_end" in log for log in logs.output))
self.assertTrue(any("on_train_end" in log for log in logs.output))

@pytest.mark.requires_trainable_backend
def test_lambda_callback_with_batches(self):
"""Test LambdaCallback's behavior with batch-level callbacks."""
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
)
x = np.random.randn(16, 2)
y = np.random.randn(16, 1)
lambda_log_callback = callbacks.LambdaCallback(
on_train_batch_begin=lambda batch, logs: logging.warning(
"on_train_batch_begin"
),
on_train_batch_end=lambda batch, logs: logging.warning(
"on_train_batch_end"
),
)
with self.assertLogs(level="WARNING") as logs:
model.fit(
x,
y,
batch_size=batch_size,
validation_split=0.2,
callbacks=[lambda_log_callback],
epochs=5,
verbose=0,
)
self.assertTrue(
any("on_train_batch_begin" in log for log in logs.output)
)
self.assertTrue(
any("on_train_batch_end" in log for log in logs.output)
)

@pytest.mark.requires_trainable_backend
def test_lambda_callback_with_kwargs(self):
"""Test LambdaCallback's behavior with custom defined callback."""
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
)
x = np.random.randn(16, 2)
y = np.random.randn(16, 1)
model.fit(
x, y, batch_size=batch_size, epochs=1, verbose=0
) # Train briefly for evaluation to work.

def custom_on_test_begin(logs):
logging.warning("custom_on_test_begin_executed")

lambda_log_callback = callbacks.LambdaCallback(
on_test_begin=custom_on_test_begin
)
with self.assertLogs(level="WARNING") as logs:
model.evaluate(
x,
y,
batch_size=batch_size,
callbacks=[lambda_log_callback],
verbose=0,
)
self.assertTrue(
any(
"custom_on_test_begin_executed" in log
for log in logs.output
)
)

@pytest.mark.requires_trainable_backend
def test_lambda_callback_no_args(self):
"""Test initializing LambdaCallback without any arguments."""
lambda_callback = callbacks.LambdaCallback()
self.assertIsInstance(lambda_callback, callbacks.LambdaCallback)

@pytest.mark.requires_trainable_backend
def test_lambda_callback_with_additional_kwargs(self):
"""Test initializing LambdaCallback with non-predefined kwargs."""

def custom_callback(logs):
pass

lambda_callback = callbacks.LambdaCallback(
custom_method=custom_callback
)
self.assertTrue(hasattr(lambda_callback, "custom_method"))

@pytest.mark.requires_trainable_backend
def test_lambda_callback_during_prediction(self):
"""Test LambdaCallback's functionality during model prediction."""
batch_size = 4
model = Sequential(
[layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]
)
model.compile(
optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()
)
x = np.random.randn(16, 2)

def custom_on_predict_begin(logs):
logging.warning("on_predict_begin_executed")

lambda_callback = callbacks.LambdaCallback(
on_predict_begin=custom_on_predict_begin
)
with self.assertLogs(level="WARNING") as logs:
model.predict(
x, batch_size=batch_size, callbacks=[lambda_callback], verbose=0
)
self.assertTrue(
any("on_predict_begin_executed" in log for log in logs.output)
)
5 changes: 2 additions & 3 deletions keras_core/legacy/saving/legacy_h5_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
if not proceed:
return

# Try creating dir if not exist
dirpath = os.path.dirname(filepath)
if not os.path.exists(dirpath):
os.path.makedirs(dirpath)
if dirpath and not os.path.exists(dirpath):
os.makedirs(dirpath, exist_ok=True)

f = h5py.File(filepath, mode="w")
opened_new_file = True
Expand Down
16 changes: 16 additions & 0 deletions keras_core/legacy/saving/legacy_h5_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,19 @@ def call(self, x):

# Compare output
self.assertAllClose(ref_output, output, atol=1e-5)


@pytest.mark.requires_trainable_backend
class DirectoryCreationTest(testing.TestCase):
def test_directory_creation_on_save(self):
"""Test if directory is created on model save."""
model = get_sequential_model(keras_core)
nested_dirpath = os.path.join(
self.get_temp_dir(), "dir1", "dir2", "dir3"
)
filepath = os.path.join(nested_dirpath, "model.h5")
self.assertFalse(os.path.exists(nested_dirpath))
legacy_h5_format.save_model_to_hdf5(model, filepath)
self.assertTrue(os.path.exists(nested_dirpath))
loaded_model = legacy_h5_format.load_model_from_hdf5(filepath)
self.assertEqual(model.to_json(), loaded_model.to_json())
178 changes: 178 additions & 0 deletions keras_core/saving/saving_api_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import unittest.mock as mock

import numpy as np
from absl import logging

from keras_core import layers
from keras_core.models import Sequential
from keras_core.saving import saving_api
from keras_core.testing import test_case


class SaveModelTests(test_case.TestCase):
def get_model(self):
return Sequential(
[
layers.Dense(5, input_shape=(3,)),
layers.Softmax(),
]
)

def test_basic_saving(self):
"""Test basic model saving and loading."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
saving_api.save_model(model, filepath)

loaded_model = saving_api.load_model(filepath)
x = np.random.uniform(size=(10, 3))
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))

def test_invalid_save_format(self):
"""Test deprecated save_format argument."""
model = self.get_model()
with self.assertRaisesRegex(
ValueError, "The `save_format` argument is deprecated"
):
saving_api.save_model(model, "model.txt", save_format=True)

def test_unsupported_arguments(self):
"""Test unsupported argument during model save."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
with self.assertRaisesRegex(
ValueError, r"The following argument\(s\) are not supported"
):
saving_api.save_model(model, filepath, random_arg=True)

def test_save_h5_format(self):
"""Test saving model in h5 format."""
model = self.get_model()
filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5")
saving_api.save_model(model, filepath_h5)
self.assertTrue(os.path.exists(filepath_h5))
os.remove(filepath_h5)

def test_save_unsupported_extension(self):
"""Test saving model with unsupported extension."""
model = self.get_model()
with self.assertRaisesRegex(
ValueError, "Invalid filepath extension for saving"
):
saving_api.save_model(model, "model.png")


class LoadModelTests(test_case.TestCase):
def get_model(self):
return Sequential(
[
layers.Dense(5, input_shape=(3,)),
layers.Softmax(),
]
)

def test_basic_load(self):
"""Test basic model loading."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
saving_api.save_model(model, filepath)

loaded_model = saving_api.load_model(filepath)
x = np.random.uniform(size=(10, 3))
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))

def test_load_unsupported_format(self):
"""Test loading model with unsupported format."""
with self.assertRaisesRegex(ValueError, "File format not supported"):
saving_api.load_model("model.pkl")

def test_load_keras_not_zip(self):
"""Test loading keras file that's not a zip."""
with self.assertRaisesRegex(ValueError, "File not found"):
saving_api.load_model("not_a_zip.keras")

def test_load_h5_format(self):
"""Test loading model in h5 format."""
model = self.get_model()
filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5")
saving_api.save_model(model, filepath_h5)
loaded_model = saving_api.load_model(filepath_h5)
x = np.random.uniform(size=(10, 3))
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))
os.remove(filepath_h5)

def test_load_model_with_custom_objects(self):
"""Test loading model with custom objects."""

class CustomLayer(layers.Layer):
def call(self, inputs):
return inputs

model = Sequential([CustomLayer(input_shape=(3,))])
filepath = os.path.join(self.get_temp_dir(), "custom_model.keras")
model.save(filepath)
loaded_model = saving_api.load_model(
filepath, custom_objects={"CustomLayer": CustomLayer}
)
self.assertIsInstance(loaded_model.layers[0], CustomLayer)
os.remove(filepath)


class LoadWeightsTests(test_case.TestCase):
def get_model(self):
return Sequential(
[
layers.Dense(5, input_shape=(3,)),
layers.Softmax(),
]
)

def test_load_keras_weights(self):
"""Test loading keras weights."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
model.save_weights(filepath)
original_weights = model.get_weights()
model.load_weights(filepath)
loaded_weights = model.get_weights()
for orig, loaded in zip(original_weights, loaded_weights):
self.assertTrue(np.array_equal(orig, loaded))

def test_load_h5_weights_by_name(self):
"""Test loading h5 weights by name."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
model.save_weights(filepath)
with self.assertRaisesRegex(ValueError, "Invalid keyword arguments"):
model.load_weights(filepath, by_name=True)

def test_load_weights_invalid_extension(self):
"""Test loading weights with unsupported extension."""
model = self.get_model()
with self.assertRaisesRegex(ValueError, "File format not supported"):
model.load_weights("invalid_extension.pkl")


class SaveModelTestsWarning(test_case.TestCase):
def get_model(self):
return Sequential(
[
layers.Dense(5, input_shape=(3,)),
layers.Softmax(),
]
)

def test_h5_deprecation_warning(self):
"""Test deprecation warning for h5 format."""
model = self.get_model()
filepath = os.path.join(self.get_temp_dir(), "test_model.h5")

with mock.patch.object(logging, "warning") as mock_warn:
saving_api.save_model(model, filepath)
mock_warn.assert_called_once_with(
"You are saving your model as an HDF5 file via `model.save()`. "
"This file format is considered legacy. "
"We recommend using instead the native Keras format, "
"e.g. `model.save('my_model.keras')`."
)
18 changes: 14 additions & 4 deletions keras_core/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,19 @@ def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535):


def is_remote_path(filepath):
"""Returns `True` for paths that represent a remote GCS location."""
# TODO: improve generality.
if re.match(r"^(/cns|/cfs|/gcs|.*://).*$", str(filepath)):
"""
Determines if a given filepath indicates a remote location.
This function checks if the filepath represents a known remote pattern
such as GCS (`/gcs`), CNS (`/cns`), CFS (`/cfs`), HDFS (`/hdfs`)
Args:
filepath (str): The path to be checked.
Returns:
bool: True if the filepath is a recognized remote path, otherwise False
"""
if re.match(r"^(/cns|/cfs|/gcs|/hdfs|.*://).*$", str(filepath)):
return True
return False

Expand Down Expand Up @@ -445,7 +455,7 @@ def rmtree(path):
return gfile.rmtree(path)
else:
_raise_if_no_gfile(path)
return shutil.rmtree
return shutil.rmtree(path)


def listdir(path):
Expand Down
Loading

0 comments on commit a354797

Please sign in to comment.