Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Export for TF backend #692

Merged
merged 35 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
88e9b30
Add saved model test
nkovela1 Jul 17, 2023
dfd404f
Merge remote-tracking branch 'refs/remotes/origin/main'
nkovela1 Jul 17, 2023
19b0e39
Add TF tracking attribute
nkovela1 Jul 17, 2023
0be8fcc
Add tests for functional and subclassed
nkovela1 Jul 17, 2023
8908273
Fix saving trackables
nkovela1 Jul 18, 2023
0418c60
Fix test assertions
nkovela1 Jul 18, 2023
82c3af1
Fix formatting
nkovela1 Jul 18, 2023
6c8731d
Add comments for attribute tracking
nkovela1 Jul 18, 2023
ac35c30
Merge branch 'keras-team:main' into main
nkovela1 Jul 18, 2023
d341600
Merge remote-tracking branch 'refs/remotes/origin/main'
nkovela1 Jul 18, 2023
8c1d954
Change saved model test description
nkovela1 Jul 18, 2023
9751d02
Add backend conditional for attribute
nkovela1 Jul 18, 2023
c1391cb
Change package name
nkovela1 Jul 18, 2023
1e7df16
Change epoch nums
nkovela1 Jul 18, 2023
51410fe
Revert epochs
nkovela1 Jul 18, 2023
1f11c1a
Add set verbose logging utility and debug callback tests
nkovela1 Jul 18, 2023
e93a4a6
Fix formatting
nkovela1 Jul 18, 2023
99301e1
Sync with main repo
nkovela1 Jul 26, 2023
a6eda55
Merge remote-tracking branch 'refs/remotes/origin/main'
nkovela1 Aug 7, 2023
2902b72
Merge remote-tracking branch 'refs/remotes/origin/main'
nkovela1 Aug 7, 2023
49b74b8
Merge remote-tracking branch 'refs/remotes/origin/main'
nkovela1 Aug 9, 2023
1f446c0
Initial port of model export
nkovela1 Aug 9, 2023
adbc885
Fix imports
nkovela1 Aug 9, 2023
ede9bd7
Add save spec methods to TF layer
nkovela1 Aug 9, 2023
d4f2b1a
Add export function to Keras Core base model
nkovela1 Aug 9, 2023
766b9f6
Downgrade naming error to warning and debug TF variable collections c…
nkovela1 Aug 9, 2023
b6990f8
Simplify weight reloading
nkovela1 Aug 10, 2023
316fdc7
Fix formatting, add TODOs
nkovela1 Aug 10, 2023
02df6af
Unify tf_utils under backend/tensorflow
nkovela1 Aug 10, 2023
8f3f3c9
Fix docstring and import
nkovela1 Aug 10, 2023
82bf3a3
Fix module utils import
nkovela1 Aug 10, 2023
9fdb0dc
Fix lookup layers export and add test
nkovela1 Aug 10, 2023
796b466
Change naming to TFSMLayer
nkovela1 Aug 10, 2023
db80dc9
Remove parameterized
nkovela1 Aug 11, 2023
4861175
Comment out failing test
nkovela1 Aug 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix imports
  • Loading branch information
nkovela1 committed Aug 9, 2023
commit adbc88507528df03fd577bb18e93f64439f4b534
10 changes: 5 additions & 5 deletions keras_core/export/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def add_endpoint(self, name, fn, input_signature=None):
a Functional model with 2 inputs):

```python
model = keras.Model(inputs=[x1, x2], outputs=outputs)
model = keras_core.Model(inputs=[x1, x2], outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
Expand All @@ -192,7 +192,7 @@ def add_endpoint(self, name, fn, input_signature=None):
This also works with dictionary inputs:

```python
model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)
model = keras_core.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)

export_archive = ExportArchive()
export_archive.track(model)
Expand Down Expand Up @@ -370,20 +370,20 @@ def _filter_and_track_resources(self):
# Next, track lookup tables.
# Hopefully, one day this will be automated at the tf.function level.
self._misc_assets = []
from keras.layers.preprocessing.index_lookup import IndexLookup
from keras_core.layers import IntegerLookup

if hasattr(self, "_tracked"):
for root in self._tracked:
descendants = tf.train.TrackableView(root).descendants()
for trackable in descendants:
if isinstance(trackable, IndexLookup):
if isinstance(trackable, IntegerLookup):
self._misc_assets.append(trackable)


def export_model(model, filepath):
export_archive = ExportArchive()
export_archive.track(model)
if isinstance(model, (functional.Functional, sequential.Sequential)):
if isinstance(model, (Functional, Sequential)):
input_signature = tf.nest.map_structure(_make_tensor_spec, model.inputs)
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
Expand Down
84 changes: 40 additions & 44 deletions keras_core/export/export_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,23 @@
from absl.testing import parameterized

from keras_core import testing
from keras_core import layers
from keras_core import models
from keras_core import utils
from keras_core.export import export_lib


def get_model():
layers = [
keras_core.layers.Dense(10, activation="relu"),
keras_core.layers.BatchNormalization(),
keras_core.layers.Dense(1, activation="sigmoid"),
layer_list = [
layers.Dense(10, activation="relu"),
layers.BatchNormalization(),
layers.Dense(1, activation="sigmoid"),
]
model = keras_core.Sequential(layers, input_shape=(10,))
model = models.Sequential(layer_list)
return model


class ExportArchiveTest(testing.TestCase, parameterized.TestCase):
@test_combinations.run_with_all_model_types
def test_standard_model_export(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
Expand All @@ -34,7 +36,6 @@ def test_standard_model_export(self):
ref_output, revived_model.serve(ref_input).numpy(), atol=1e-6
)

@test_combinations.run_with_all_model_types
def test_low_level_model_export(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")

Expand Down Expand Up @@ -93,7 +94,7 @@ def my_endpoint(x):
def test_layer_export(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer")

layer = keras_core.layers.BatchNormalization()
layer = layers.BatchNormalization()
ref_input = tf.random.normal((3, 10))
ref_output = layer(ref_input).numpy() # Build layer (important)

Expand All @@ -117,11 +118,11 @@ def test_layer_export(self):

def test_multi_input_output_functional_model(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
x1 = keras_core.Input((2,))
x2 = keras_core.Input((2,))
y1 = keras_core.layers.Dense(3)(x1)
y2 = keras_core.layers.Dense(3)(x2)
model = keras_core.Model([x1, x2], [y1, y2])
x1 = layers.Input((2,))
x2 = layers.Input((2,))
y1 = layers.Dense(3)(x1)
y2 = layers.Dense(3)(x2)
model = models.Model([x1, x2], [y1, y2])

ref_inputs = [tf.random.normal((3, 2)), tf.random.normal((3, 2))]
ref_outputs = model(ref_inputs)
Expand Down Expand Up @@ -158,7 +159,7 @@ def test_multi_input_output_functional_model(self):
)

# Now test dict inputs
model = keras_core.Model({"x1": x1, "x2": x2}, [y1, y2])
model = models.Model({"x1": x1, "x2": x2}, [y1, y2])

ref_inputs = {
"x1": tf.random.normal((3, 2)),
Expand Down Expand Up @@ -199,13 +200,13 @@ def test_multi_input_output_functional_model(self):

def test_model_with_lookup_table(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
text_vectorization = keras_core.layers.TextVectorization()
text_vectorization = layers.TextVectorization()
text_vectorization.adapt(["one two", "three four", "five six"])
model = keras_core.Sequential(
model = models.Sequential(
[
text_vectorization,
keras_core.layers.Embedding(10, 32),
keras_core.layers.Dense(1),
layers.Embedding(10, 32),
layers.Dense(1),
]
)
ref_input = tf.convert_to_tensor(["one two three four"])
Expand All @@ -219,10 +220,10 @@ def test_model_with_lookup_table(self):

def test_track_multiple_layers(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
layer_1 = keras_core.layers.Dense(2)
layer_1 = layers.Dense(2)
ref_input_1 = tf.random.normal((3, 4))
ref_output_1 = layer_1(ref_input_1).numpy()
layer_2 = keras_core.layers.Dense(3)
layer_2 = layers.Dense(3)
ref_input_2 = tf.random.normal((3, 5))
ref_output_2 = layer_2(ref_input_2).numpy()

Expand Down Expand Up @@ -263,7 +264,7 @@ def test_track_multiple_layers(self):
def test_non_standard_layer_signature(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_layer")

layer = keras_core.layers.MultiHeadAttention(2, 2)
layer = layers.MultiHeadAttention(2, 2)
x1 = tf.random.normal((3, 2, 2))
x2 = tf.random.normal((3, 2, 2))
ref_output = layer(x1, x2).numpy() # Build layer (important)
Expand Down Expand Up @@ -294,11 +295,11 @@ def test_non_standard_layer_signature(self):
def test_variable_collection(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")

model = keras_core.Sequential(
model = models.Sequential(
[
keras_core.Input((10,)),
keras_core.layers.Dense(2),
keras_core.layers.Dense(2),
layers.Input((10,)),
layers.Dense(2),
layers.Dense(2),
]
)

Expand Down Expand Up @@ -327,15 +328,15 @@ def test_export_model_errors(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")

# Model has not been built
model = keras_core.Sequential([keras_core.layers.Dense(2)])
model = models.Sequential([layers.Dense(2)])
with self.assertRaisesRegex(ValueError, "It must be built"):
export_lib.export_model(model, temp_filepath)

# Subclassed model has not been called
class MyModel(keras_core.Model):
class MyModel(models.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense = keras_core.layers.Dense(2)
self.dense = layers.Dense(2)

def build(self, input_shape):
self.dense.build(input_shape)
Expand All @@ -351,7 +352,7 @@ def call(self, x):

def test_export_archive_errors(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = keras_core.Sequential([keras_core.layers.Dense(2)])
model = models.Sequential([layers.Dense(2)])
model(tf.random.normal((2, 3)))

# Endpoint name reuse
Expand Down Expand Up @@ -423,7 +424,7 @@ def test_export_no_assets(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")

# Case where there are legitimately no assets.
model = keras_core.Sequential([keras_core.layers.Flatten()])
model = models.Sequential([layers.Flatten()])
model(tf.random.normal((2, 3)))
export_archive = export_lib.ExportArchive()
export_archive.add_endpoint(
Expand All @@ -438,7 +439,6 @@ def test_export_no_assets(self):
)
export_archive.write_out(temp_filepath)

@test_combinations.run_with_all_model_types
def test_model_export_method(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
Expand All @@ -452,9 +452,7 @@ def test_model_export_method(self):
)


@test_utils.run_v2_only
class TestReloadedLayer(tf.test.TestCase, parameterized.TestCase):
@test_combinations.run_with_all_model_types
def test_reloading_export_archive(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
Expand All @@ -476,7 +474,7 @@ def test_reloading_export_archive(self):
)

# Test fine-tuning
new_model = keras_core.Sequential([reloaded_layer])
new_model = models.Sequential([reloaded_layer])
new_model.compile(optimizer="rmsprop", loss="mse")
x = tf.random.normal((32, 10))
y = tf.random.normal((32, 1))
Expand All @@ -495,7 +493,6 @@ def test_reloading_export_archive(self):
reloaded_layer(ref_input).numpy(), new_output, atol=1e-7
)

@test_combinations.run_with_all_model_types
def test_reloading_default_saved_model(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
Expand Down Expand Up @@ -524,12 +521,12 @@ def test_reloading_default_saved_model(self):

def test_call_training(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
keras_core.utils.set_random_seed(1337)
model = keras_core.Sequential(
utils.set_random_seed(1337)
model = models.Sequential(
[
keras_core.Input((10,)),
keras_core.layers.Dense(10),
keras_core.layers.Dropout(0.99999),
layers.Input((10,)),
layers.Dense(10),
layers.Dropout(0.99999),
]
)
export_archive = export_lib.ExportArchive()
Expand Down Expand Up @@ -559,7 +556,6 @@ def test_call_training(self):
self.assertAllClose(np.mean(training_output), 0.0, atol=1e-7)
self.assertNotAllClose(np.mean(inference_output), 0.0, atol=1e-7)

@test_combinations.run_with_all_model_types
def test_serialization(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = get_model()
Expand All @@ -577,10 +573,10 @@ def test_serialization(self):
)

# Test whole model saving with reloaded layer inside
model = keras_core.Sequential([reloaded_layer])
model = models.Sequential([reloaded_layer])
temp_model_filepath = os.path.join(self.get_temp_dir(), "m.keras")
model.save(temp_model_filepath, save_format="keras_v3")
reloaded_model = keras_core.models.load_model(
reloaded_model = models.Models.load_model(
temp_model_filepath,
custom_objects={"ReloadedLayer": export_lib.ReloadedLayer},
)
Expand All @@ -591,7 +587,7 @@ def test_serialization(self):
def test_errors(self):
# Test missing call endpoint
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
model = keras_core.Sequential([keras_core.Input((2,)), keras_core.layers.Dense(3)])
model = models.Sequential([layers.Input((2,)), layers.Dense(3)])
export_lib.export_model(model, temp_filepath)
with self.assertRaisesRegex(ValueError, "The endpoint 'wrong'"):
export_lib.ReloadedLayer(temp_filepath, call_endpoint="wrong")
Expand Down