Skip to content

Support Keras 3 models in from_keras #2398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion tf2onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tf2onnx import constants, logging, utils, optimizer
from tf2onnx import tf_loader
from tf2onnx.graph import ExternalTensorStorage
from tf2onnx.tf_utils import compress_graph_def, get_tf_version
from tf2onnx.tf_utils import compress_graph_def, get_tf_version, get_keras_version



Expand Down Expand Up @@ -408,6 +408,106 @@ def _from_keras_tf1(model, opset=None, custom_ops=None, custom_op_handlers=None,

return model_proto, external_tensor_storage

def from_keras3(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None,
target=None, large_model=False, output_path=None, optimizers=None):
"""
Convert a Keras 3 model to ONNX using tf2onnx.

Args:
model: Keras 3 Functional or Sequential model
name: Name for the converted model
input_signature: Optional list of tf.TensorSpec
opset: ONNX opset version
custom_ops: Dictionary of custom ops
custom_op_handlers: Dictionary of custom op handlers
custom_rewriter: List of graph rewriters
inputs_as_nchw: List of input names to convert to NCHW
extra_opset: Additional opset imports
shape_override: Dictionary to override input shapes
target: Target platforms (for workarounds)
large_model: Whether to use external tensor storage
output_path: Optional path to write ONNX model to file

Returns:
A tuple (model_proto, external_tensor_storage_dict)
"""


if not input_signature:

input_signature = [
tf.TensorSpec(tensor.shape, tensor.dtype, name=tensor.name.split(":")[0])
for tensor in model.inputs
]

# Trace model
function = tf.function(model)
concrete_func = function.get_concrete_function(*input_signature)

# These inputs will be removed during freezing (includes resources, etc.)
if hasattr(concrete_func.graph, '_captures'):
graph_captures = concrete_func.graph._captures # pylint: disable=protected-access
captured_inputs = [t_name.name for _, t_name in graph_captures.values()]
else:
graph_captures = concrete_func.graph.function_captures.by_val_internal
captured_inputs = [t.name for t in graph_captures.values()]
input_names = [input_tensor.name for input_tensor in concrete_func.inputs
if input_tensor.name not in captured_inputs]
output_names = [output_tensor.name for output_tensor in concrete_func.outputs
if output_tensor.dtype != tf.dtypes.resource]


tensors_to_rename = tensor_names_from_structed(concrete_func, input_names, output_names)
reverse_lookup = {v: k for k, v in tensors_to_rename.items()}



valid_names = []
for out in model.output_names:
if out in reverse_lookup:
valid_names.append(reverse_lookup[out])
else:
print(f"Warning: Output name '{out}' not found in reverse_lookup.")
# Fallback: verwende TensorFlow-Ausgangsnamen direkt
valid_names = [t.name for t in concrete_func.outputs if t.dtype != tf.dtypes.resource]
break
output_names = valid_names


#if old_out_names is not None:
#model.output_names = old_out_names

with tf.device("/cpu:0"):
frozen_graph, initialized_tables = \
tf_loader.from_trackable(model, concrete_func, input_names, output_names, large_model)

for node in frozen_graph.node:
print(node.name, node.op)
model_proto, external_tensor_storage = _convert_common(
frozen_graph,
name=model.name,
continue_on_error=True,
target=target,
opset=opset,
custom_ops=custom_ops,
custom_op_handlers=custom_op_handlers,
optimizers=optimizers,
custom_rewriter=custom_rewriter,
extra_opset=extra_opset,
shape_override=shape_override,
input_names=input_names,
output_names=output_names,
inputs_as_nchw=inputs_as_nchw,
outputs_as_nchw=outputs_as_nchw,
large_model=large_model,
tensors_to_rename=tensors_to_rename,
initialized_tables=initialized_tables,
output_path=output_path)

#print(model_proto)

return model_proto, external_tensor_storage

def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_op_handlers=None,
custom_rewriter=None, inputs_as_nchw=None, outputs_as_nchw=None, extra_opset=None, shape_override=None,
Expand Down Expand Up @@ -438,6 +538,10 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
if get_tf_version() < Version("2.0"):
return _from_keras_tf1(model, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw,
outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)
if get_keras_version() > Version("3.0"):
return from_keras3(model, input_signature, opset, custom_ops, custom_op_handlers,
custom_rewriter, inputs_as_nchw, outputs_as_nchw, extra_opset, shape_override,
target, large_model, output_path, optimizers)

old_out_names = _rename_duplicate_keras_model_names(model)
from tensorflow.python.keras.saving import saving_utils as _saving_utils # pylint: disable=import-outside-toplevel
Expand Down
4 changes: 4 additions & 0 deletions tf2onnx/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import tensorflow as tf
import keras

from tensorflow.core.framework import types_pb2, tensor_pb2, graph_pb2
from tensorflow.python.framework import tensor_util
Expand Down Expand Up @@ -124,6 +125,9 @@ def get_tf_node_attr(node, name):
def get_tf_version():
return Version(tf.__version__)

def get_keras_version():
return Version(keras.__version__)

def compress_graph_def(graph_def):
"""
Remove large const values from graph. This lets us import the graph and run shape inference without TF crashing.
Expand Down