Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Export for TF backend #692

Merged
merged 35 commits into from
Aug 11, 2023
Merged

Add Export for TF backend #692

merged 35 commits into from
Aug 11, 2023

Conversation

nkovela1
Copy link
Collaborator

This PR adapts the export library from tf.keras to Keras Core, along with adapting its corresponding test suite.
Note that this PR only implements export for the TF backend and expansion to JAX is still in the planning stage.

I have included some TODOs for further test coverage or debugging for certain edge use cases.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

import tensorflow as tf


def get_tensor_spec(t, dynamic_batch=False, name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only have one tf utils file. We already have utils/tf_utils.py. We need to consolidate in one of them. I think backend/tf_utils.py is probably the best choice: 1. explicit name (in general avoid utils.py it's too generic), 2. confined to the TF backend folder.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I've merged them into tf_utils.py under the TF backend folder.

@@ -0,0 +1,570 @@
"""Library for exporting inference-only Keras models/layers."""

import tensorflow as tf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from keras_core.module_utils import tensorflow as tf

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks.



@keras_core_export("keras_core.export.ExportArchive")
class ExportArchive(tf.__internal__.tracking.AutoTrackable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is fundamentally SavedModel specific I wonder if we should rename it to something that makes it explicit.

Alternatively we can have 1. backend-agnostic ExportArchive used for configuring endpoints, a TFSavedModelExportArchive subclass that is TF specific.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can rename it once we either have a backend-agnostic abstraction or decide how we want to tie in the JAX ExportArchive. Let's discuss further on chat after this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the class even reusable with something other than SavedModel in the future? It seems pretty baked into the APIs (e.g. TensorSpecs, tf.function)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to come up with our own flow separate from this to allow export to ONNX or other formats, since these APIs are essential for the functionality of SavedModel. It's harder to create an abstraction here that could be open-ended IMO.

There are existing TF to ONNX model converters like this one: https://github.com/onnx/tensorflow-onnx
Perhaps we can integrate some of their logic later or leave it to an OSS contribution?

atol=1e-6,
)

# def test_model_with_lookup_table(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the issue?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue was calling a model with a TextVectorization Layer with string tensors (as we try to standardize them to float upon model call): https://chat.google.com/room/AAAADD9-Dbs/mZnA6hEpffQ

Since JAX and Torch do not support string tensors, we're trying a workaround of calling the layer first as Matt suggested (but there are some strange SavedModel tracked resource errors that need to be resolved with this workaround -- we can revisit this in a separate PR).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ Disregard that

Thanks for the workaround using the Input Layer, I've added the test back, changed some of the Lookup layer export logic and it works now!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, running into another one of those errors where a test fails in CI but passes using pytest directly on it.

)
"""Create a SavedModel artifact for inference (e.g. via TF-Serving).

This method lets you export a model to a lightweight SavedModel artifact
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should first mention that it only works with the TF backend.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks

export_archive.write_out(filepath)


class ReloadedLayer(Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon reflection, I think this is a pretty unfortunate name. I wonder if we should rename it to TFSMLayer or something to that extent? In the future we could have other formats as well, like ONNXLayer. It would make sense since it can work with more than what's saved with ExportArchive -- any SavedModel will do.

To note, this object doesn't seem to be exported to the public API yet.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I've changed it to TFSMLayer, thanks.



@keras_core_export("keras_core.export.ExportArchive")
class ExportArchive(tf.__internal__.tracking.AutoTrackable):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the class even reusable with something other than SavedModel in the future? It seems pretty baked into the APIs (e.g. TensorSpecs, tf.function)

"""[TF backend only]* Create a TF SavedModel artifact for inference
(e.g. via TF-Serving).

*Note: This can currently only be used with the TF backend.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**Note:**

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@fchollet
Copy link
Contributor

fchollet commented Aug 11, 2023

The lookup table test is failing, though it isn't entirely clear what the issue is.

@nkovela1
Copy link
Collaborator Author

I've decided to comment out the lookup test for now, as it's unclear how CI is affected after investigating. It's a blocker for this PR and can be addressed separately.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@fchollet fchollet merged commit 138157b into keras-team:main Aug 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants