Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tf_keras/saving/legacy/saved_model/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def get_json_type(obj):
return {
"class_name": "TypeSpec",
"type_spec": type_spec_name,
"serialized": obj._serialize(),
"serialized": _encode_tuple(obj._serialize()),
}
except ValueError:
raise ValueError(
Expand Down
22 changes: 22 additions & 0 deletions tf_keras/saving/legacy/saved_model/json_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Tests the JSON encoder and decoder."""

import enum
from typing import Mapping

import tensorflow.compat.v2 as tf

Expand All @@ -24,6 +25,16 @@
from tf_keras.testing_infra import test_utils


class _ExtensionType(tf.experimental.ExtensionType):
"""An ExtensionType with multiple tuples and mappings."""

__name__ = "tf_keras.json_utils.test._ExtensionType"

x: tf.Tensor
xy: tuple[tf.Tensor, tf.Tensor]
kv: Mapping[str, tf.Tensor]


class JsonUtilsTest(test_combinations.TestCase):
def test_encode_decode_tensor_shape(self):
metadata = {
Expand Down Expand Up @@ -64,6 +75,17 @@ def test_encode_decode_type_spec(self):
):
loaded = json_utils.decode(string)

def test_encode_decode_extensiontype_spec(self):
instance = _ExtensionType(
x=tf.constant(1),
xy=(tf.constant(2), tf.constant(True)),
kv={"a": tf.constant("foo"), "b": tf.constant("bar")},
)
spec = tf.type_spec_from_value(instance)
string = json_utils.Encoder().encode(spec)
loaded = json_utils.decode(string)
self.assertEqual(spec, loaded)

def test_encode_decode_enum(self):
class Enum(enum.Enum):
CLASS_A = "a"
Expand Down
Loading