Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Export for TF backend #692

Merged
merged 35 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
88e9b30
Add saved model test
nkovela1 Jul 17, 2023
dfd404f
Merge remote-tracking branch 'refs/remotes/origin/main'
nkovela1 Jul 17, 2023
19b0e39
Add TF tracking attribute
nkovela1 Jul 17, 2023
0be8fcc
Add tests for functional and subclassed
nkovela1 Jul 17, 2023
8908273
Fix saving trackables
nkovela1 Jul 18, 2023
0418c60
Fix test assertions
nkovela1 Jul 18, 2023
82c3af1
Fix formatting
nkovela1 Jul 18, 2023
6c8731d
Add comments for attribute tracking
nkovela1 Jul 18, 2023
ac35c30
Merge branch 'keras-team:main' into main
nkovela1 Jul 18, 2023
d341600
Merge remote-tracking branch 'refs/remotes/origin/main'
nkovela1 Jul 18, 2023
8c1d954
Change saved model test description
nkovela1 Jul 18, 2023
9751d02
Add backend conditional for attribute
nkovela1 Jul 18, 2023
c1391cb
Change package name
nkovela1 Jul 18, 2023
1e7df16
Change epoch nums
nkovela1 Jul 18, 2023
51410fe
Revert epochs
nkovela1 Jul 18, 2023
1f11c1a
Add set verbose logging utility and debug callback tests
nkovela1 Jul 18, 2023
e93a4a6
Fix formatting
nkovela1 Jul 18, 2023
99301e1
Sync with main repo
nkovela1 Jul 26, 2023
a6eda55
Merge remote-tracking branch 'refs/remotes/origin/main'
nkovela1 Aug 7, 2023
2902b72
Merge remote-tracking branch 'refs/remotes/origin/main'
nkovela1 Aug 7, 2023
49b74b8
Merge remote-tracking branch 'refs/remotes/origin/main'
nkovela1 Aug 9, 2023
1f446c0
Initial port of model export
nkovela1 Aug 9, 2023
adbc885
Fix imports
nkovela1 Aug 9, 2023
ede9bd7
Add save spec methods to TF layer
nkovela1 Aug 9, 2023
d4f2b1a
Add export function to Keras Core base model
nkovela1 Aug 9, 2023
766b9f6
Downgrade naming error to warning and debug TF variable collections c…
nkovela1 Aug 9, 2023
b6990f8
Simplify weight reloading
nkovela1 Aug 10, 2023
316fdc7
Fix formatting, add TODOs
nkovela1 Aug 10, 2023
02df6af
Unify tf_utils under backend/tensorflow
nkovela1 Aug 10, 2023
8f3f3c9
Fix docstring and import
nkovela1 Aug 10, 2023
82bf3a3
Fix module utils import
nkovela1 Aug 10, 2023
9fdb0dc
Fix lookup layers export and add test
nkovela1 Aug 10, 2023
796b466
Change naming to TFSMLayer
nkovela1 Aug 10, 2023
db80dc9
Remove parameterized
nkovela1 Aug 11, 2023
4861175
Comment out failing test
nkovela1 Aug 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add save spec methods to TF layer
  • Loading branch information
nkovela1 committed Aug 9, 2023
commit ede9bd71a77e40c334f44e6fc047c4d19dd3a14e
50 changes: 49 additions & 1 deletion keras_core/backend/tensorflow/layer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,59 @@
import tensorflow as tf

from keras_core.backend.tensorflow import utils as tf_utils

class TFLayer(tf.__internal__.tracking.AutoTrackable):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Export-related attributes
self._saved_model_inputs_spec = None
self._saved_model_arg_spec = None

def _post_build(self):
"""Can be overriden to perform post-build actions."""
pass

@tf.__internal__.tracking.no_automatic_dependency_tracking
def _set_save_spec(self, inputs, args=None, kwargs=None):
"""Defines the save spec so that serialization can trace layer calls.

The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are
saved into a tuple of `([inputs] + args, kwargs)`.

Args:
inputs: possibly nested inputs passed into the call function.
args: a list of positional arguments passed into call.
kwargs: a dictionary of keyword arguments passed into call.
"""
if self._saved_model_inputs_spec is not None:
return # Already set.

inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs)
args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or [])
kwargs_spec = {}
# Filter out non-tensor arguments from kwargs.
for key, kwarg in kwargs.items():
flat_kwarg = tf.nest.flatten(kwarg)
flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg]
if any(s is None for s in flat_specs):
continue
kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs)

self._saved_model_inputs_spec = inputs_spec
self._saved_model_arg_spec = (
[inputs_spec] + list(args_spec),
kwargs_spec,
)

def _get_save_spec(self, dynamic_batch=True, inputs_only=True):
if self._saved_model_inputs_spec is None:
return None

spec = tf.nest.map_structure(
lambda t: tf_utils.get_tensor_spec(t, dynamic_batch=dynamic_batch),
self._saved_model_arg_spec,
)
return spec[0][0] if inputs_only else spec

def _trackable_children(self, save_type="checkpoint", **kwargs):
if save_type == "savedmodel":
# SavedModel needs to ignore the execution functions.
Expand Down
26 changes: 26 additions & 0 deletions keras_core/backend/tensorflow/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import tensorflow as tf

def get_tensor_spec(t, dynamic_batch=False, name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only have one tf utils file. We already have utils/tf_utils.py. We need to consolidate in one of them. I think backend/tf_utils.py is probably the best choice: 1. explicit name (in general avoid utils.py it's too generic), 2. confined to the TF backend folder.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I've merged them into tf_utils.py under the TF backend folder.

"""Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`."""
if isinstance(t, tf.TypeSpec):
spec = t
elif isinstance(tensor, tf.__internal__.CompositeTensor):
# Check for ExtensionTypes
spec = t._type_spec
elif hasattr(t, "shape") and hasattr(t, "dtype"):
spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)
else:
return None # Allow non-Tensors to pass through.

if not dynamic_batch:
return spec

shape = spec.shape
if shape.rank is None or shape.rank == 0:
return spec

shape_list = shape.as_list()
shape_list[0] = None
shape = tf.TensorShape(shape_list)
spec._shape = shape
return spec
9 changes: 9 additions & 0 deletions keras_core/export/export_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from absl.testing import parameterized

from keras_core import backend
from keras_core import testing
from keras_core import layers
from keras_core import models
Expand All @@ -23,6 +24,10 @@ def get_model():
return model


@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="Export only currently supports the TF backend.",
)
class ExportArchiveTest(testing.TestCase, parameterized.TestCase):
def test_standard_model_export(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
Expand Down Expand Up @@ -452,6 +457,10 @@ def test_model_export_method(self):
)


@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="Export only currently supports the TF backend.",
)
class TestReloadedLayer(tf.test.TestCase, parameterized.TestCase):
def test_reloading_export_archive(self):
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
Expand Down