Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ jobs:
fi
pip install -r $REQUIREMENTS_FILE --progress-bar off --upgrade
pip uninstall -y keras keras-nightly
pip install tf_keras==2.16.0 --progress-bar off --upgrade
pip install -e "." --progress-bar off --upgrade
- name: Test applications with pytest
if: ${{ steps.filter.outputs.applications == 'true' }}
Expand Down
2 changes: 2 additions & 0 deletions .kokoro/github/ubuntu/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ then
# Raise error if GPU is not detected.
python3 -c 'import torch;assert torch.cuda.is_available()'

# TODO: keras/src/export/export_lib_test.py update LD_LIBRARY_PATH
pytest keras --ignore keras/src/applications \
--ignore keras/src/export/export_lib_test.py \
--cov=keras \
--cov-config=pyproject.toml

Expand Down
157 changes: 132 additions & 25 deletions keras/src/backend/torch/export.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,142 @@
from keras.src import layers
import copy
import warnings

import torch

from keras.src import backend
from keras.src import ops
from keras.src import tree
from keras.src.utils.module_utils import tensorflow as tf
from keras.src.utils.module_utils import torch_xla


class TorchExportArchive:
def track(self, resource):
if not isinstance(resource, layers.Layer):
raise ValueError(
"Invalid resource type. Expected an instance of a "
"JAX-based Keras `Layer` or `Model`. "
f"Received instead an object of type '{type(resource)}'. "
f"Object received: {resource}"
)
raise NotImplementedError(
"`track` is not implemented in the torch backend. Use"
"`track_and_add_endpoint` instead."
)

if isinstance(resource, layers.Layer):
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
variables = resource.variables
trainable_variables = resource.trainable_variables
non_trainable_variables = resource.non_trainable_variables
self._tf_trackable.variables += tree.map_structure(
self._convert_to_tf_variable, variables
)
self._tf_trackable.trainable_variables += tree.map_structure(
self._convert_to_tf_variable, trainable_variables
def add_endpoint(self, name, fn, input_signature, **kwargs):
raise NotImplementedError(
"`add_endpoint` is not implemented in the torch backend. Use"
"`track_and_add_endpoint` instead."
)

def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was the new API needed? Can't we just make track and add_endpoint pure config ops, and make write do all the actual export logic?

Copy link
Contributor Author

@james77777778 james77777778 Dec 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_xla relies on ExportedProgram generated by torch.export.export which requires a torch.nn.Module. As a result, it is difficult to decouple the logic into track and add_endpoint for torch_xla.

Additionally, the fn in add_endpoint is not applicable for torch_xla.

EDITED:
We already have a common API for all backends (Model.export) and ExportArchive is a low-level interface for advance users. It should be acceptable to have some discrepancies across backends in ExportArchive.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's certainly a bit confusing that track/add_endpoint only work in TF/JAX while track_and_add_endpoint only works with torch. But I see your point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

track_and_add_endpoint works across all backends. I have combined the logic of track and add_endpoint in the API for TF/JAX.

# Disable false alarms related to lifting parameters.
warnings.filterwarnings("ignore", message=".*created when tracing.*")
warnings.filterwarnings(
"ignore", message=".*Unable to find the path of the module.*"
)

if not isinstance(resource, torch.nn.Module):
raise TypeError(
"`resource` must be an instance of `torch.nn.Module`. "
f"Received: resource={resource} (of type {type(resource)})"
)
self._tf_trackable.non_trainable_variables += tree.map_structure(
self._convert_to_tf_variable, non_trainable_variables

def _check_input_signature(input_spec):
for s in tree.flatten(input_spec.shape):
if s is None:
raise ValueError(
"The shape in the `input_spec` must be fully "
f"specified. Received: input_spec={input_spec}"
)

def _to_torch_tensor(x, replace_none_number=1):
shape = backend.standardize_shape(x.shape)
shape = tuple(
s if s is not None else replace_none_number for s in shape
)
return ops.ones(shape, x.dtype)

def add_endpoint(self, name, fn, input_signature=None, **kwargs):
# TODO: torch-xla?
raise NotImplementedError(
"`add_endpoint` is not implemented in the torch backend."
tree.map_structure(_check_input_signature, input_signature)
sample_inputs = tree.map_structure(_to_torch_tensor, input_signature)
sample_inputs = tuple(sample_inputs)

# Ref: torch_xla.tf_saved_model_integration
# TODO: Utilize `dynamic_shapes`
exported = torch.export.export(
resource, sample_inputs, dynamic_shapes=None, strict=False
)
options = torch_xla.stablehlo.StableHLOExportOptions(
override_tracing_arguments=sample_inputs
)
stablehlo_model = torch_xla.stablehlo.exported_program_to_stablehlo(
exported, options
)
state_dict_keys = list(stablehlo_model._bundle.state_dict.keys())

# Remove unused variables.
for k in state_dict_keys:
if "lifted" not in k:
stablehlo_model._bundle.state_dict.pop(k)

bundle = copy.deepcopy(stablehlo_model._bundle)
bundle.state_dict = {
k: tf.Variable(v, trainable=False, name=k)
for k, v in bundle.state_dict.items()
}
bundle.additional_constants = [
tf.Variable(v, trainable=False) for v in bundle.additional_constants
]

# Track variables in `bundle` for `write_out`.
self._tf_trackable.variables += (
list(bundle.state_dict.values()) + bundle.additional_constants
)

# Ref: torch_xla.tf_saved_model_integration.save_stablehlo_graph_as_tf
def make_tf_function(func, bundle):
from tensorflow.compiler.tf2xla.python import xla as tfxla

def _get_shape_with_dynamic(signature):
shape = copy.copy(signature.shape)
for i in signature.dynamic_dims:
shape[i] = None
return shape

def _extract_call_parameters(args, meta, bundle):
call_args = []
if meta.input_pytree_spec is not None:
args = tree.flatten(args)
for loc in meta.input_locations:
if loc.type_ == torch_xla.stablehlo.VariableType.PARAMETER:
call_args.append(bundle.state_dict[loc.name])
elif loc.type_ == torch_xla.stablehlo.VariableType.CONSTANT:
call_args.append(
bundle.additional_constants[loc.position]
)
else:
call_args.append(args[loc.position])
return call_args

def inner(*args):
Touts = [sig.dtype for sig in func.meta.output_signature]
Souts = [
_get_shape_with_dynamic(sig)
for sig in func.meta.output_signature
]
call_args = _extract_call_parameters(args, func.meta, bundle)
results = tfxla.call_module(
tuple(call_args),
version=5,
Tout=Touts, # dtype information
Sout=Souts, # Shape information
function_list=[],
module=func.bytecode,
)
if len(Souts) == 1:
results = results[0]
return results

return inner

decorated_fn = tf.function(
make_tf_function(
stablehlo_model._bundle.stablehlo_funcs[0], bundle
),
input_signature=input_signature,
)
return decorated_fn
127 changes: 100 additions & 27 deletions keras/src/export/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class ExportArchive(BackendExportArchive):

**Note on resource tracking:**

`ExportArchive` is able to automatically track all `tf.Variables` used
`ExportArchive` is able to automatically track all `keras.Variables` used
by its endpoints, so most of the time calling `.track(model)`
is not strictly required. However, if your model uses lookup layers such
as `IntegerLookup`, `StringLookup`, or `TextVectorization`,
Expand All @@ -104,9 +104,10 @@ class ExportArchive(BackendExportArchive):

def __init__(self):
super().__init__()
if backend.backend() not in ("tensorflow", "jax"):
if backend.backend() not in ("tensorflow", "jax", "torch"):
raise NotImplementedError(
"The export API is only compatible with JAX and TF backends."
"`ExportArchive` is only compatible with TensorFlow, JAX and "
"Torch backends."
)

self._endpoint_names = []
Expand Down Expand Up @@ -141,8 +142,8 @@ def track(self, resource):
(`TextVectorization`, `IntegerLookup`, `StringLookup`)
are automatically tracked in `add_endpoint()`.

Arguments:
resource: A trackable TensorFlow resource.
Args:
resource: A trackable Keras resource, such as a layer or model.
"""
if isinstance(resource, layers.Layer) and not resource.built:
raise ValueError(
Expand Down Expand Up @@ -334,12 +335,78 @@ def serving_fn(x):
self._endpoint_names.append(name)
return decorated_fn

def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):
"""Track the variables and register a new serving endpoint.

This function combines the functionality of `track` and `add_endpoint`.
It tracks the variables of the `resource` (either a layer or a model)
and registers a serving endpoint using `resource.__call__`.

Args:
name: `str`. The name of the endpoint.
resource: A trackable Keras resource, such as a layer or model.
input_signature: Optional. Specifies the shape and dtype of `fn`.
Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,
`backend.KerasTensor`, or backend tensor (see below for an
example showing a `Functional` model with 2 input arguments). If
not provided, `fn` must be a `tf.function` that has been called
at least once. Defaults to `None`.
**kwargs: Additional keyword arguments:
- Specific to the JAX backend:
- `is_static`: Optional `bool`. Indicates whether `fn` is
static. Set to `False` if `fn` involves state updates
(e.g., RNG seeds).
- `jax2tf_kwargs`: Optional `dict`. Arguments for
`jax2tf.convert`. See [`jax2tf.convert`](
https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
If `native_serialization` and `polymorphic_shapes` are
not provided, they are automatically computed.

"""
if name in self._endpoint_names:
raise ValueError(f"Endpoint name '{name}' is already taken.")
if not isinstance(resource, layers.Layer):
raise ValueError(
"Invalid resource type. Expected an instance of a Keras "
"`Layer` or `Model`. "
f"Received: resource={resource} (of type {type(resource)})"
)
if not resource.built:
raise ValueError(
"The layer provided has not yet been built. "
"It must be built before export."
)
if backend.backend() != "jax":
if "jax2tf_kwargs" in kwargs or "is_static" in kwargs:
raise ValueError(
"'jax2tf_kwargs' and 'is_static' are only supported with "
f"the jax backend. Current backend: {backend.backend()}"
)

input_signature = tree.map_structure(_make_tensor_spec, input_signature)

if not hasattr(BackendExportArchive, "track_and_add_endpoint"):
# Default behavior.
self.track(resource)
return self.add_endpoint(
name, resource.__call__, input_signature, **kwargs
)
else:
# Special case for the torch backend.
decorated_fn = BackendExportArchive.track_and_add_endpoint(
self, name, resource, input_signature, **kwargs
)
self._endpoint_signatures[name] = input_signature
setattr(self._tf_trackable, name, decorated_fn)
self._endpoint_names.append(name)
return decorated_fn

def add_variable_collection(self, name, variables):
"""Register a set of variables to be retrieved after reloading.

Arguments:
name: The string name for the collection.
variables: A tuple/list/set of `tf.Variable` instances.
variables: A tuple/list/set of `keras.Variable` instances.

Example:

Expand Down Expand Up @@ -496,9 +563,6 @@ def export_saved_model(
):
"""Export the model as a TensorFlow SavedModel artifact for inference.

**Note:** This feature is currently supported only with TensorFlow and
JAX backends.

This method lets you export a model to a lightweight SavedModel artifact
that contains the model's forward pass only (its `call()` method)
and can be served via e.g. TensorFlow Serving. The forward pass is
Expand Down Expand Up @@ -527,6 +591,14 @@ def export_saved_model(
If `native_serialization` and `polymorphic_shapes` are not
provided, they are automatically computed.

**Note:** This feature is currently supported only with TensorFlow, JAX and
Torch backends. Support for the Torch backend is experimental.

**Note:** The dynamic shape feature is not yet supported with Torch
backend. As a result, you must fully define the shapes of the inputs using
`input_signature`. If `input_signature` is not provided, all instances of
`None` (such as the batch size) will be replaced with `1`.

Example:

```python
Expand All @@ -543,28 +615,29 @@ def export_saved_model(
`export()` method relies on `ExportArchive` internally.
"""
export_archive = ExportArchive()
export_archive.track(model)
if isinstance(model, (Functional, Sequential)):
if input_signature is None:
if input_signature is None:
if not model.built:
raise ValueError(
"The layer provided has not yet been built. "
"It must be built before export."
)
if isinstance(model, (Functional, Sequential)):
input_signature = tree.map_structure(
_make_tensor_spec, model.inputs
)
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
export_archive.add_endpoint(
"serve", model.__call__, input_signature, **kwargs
)
else:
if input_signature is None:
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
else:
input_signature = _get_input_signature(model)
if not input_signature or not model._called:
raise ValueError(
"The model provided has never called. "
"It must be called at least once before export."
)
export_archive.add_endpoint(
"serve", model.__call__, input_signature, **kwargs
)
if not input_signature or not model._called:
raise ValueError(
"The model provided has never called. "
"It must be called at least once before export."
)

export_archive.track_and_add_endpoint(
"serve", model, input_signature, **kwargs
)
export_archive.write_out(filepath, verbose=verbose)


Expand Down
Loading
Loading