|
16 | 16 | """Tests the JSON encoder and decoder.""" |
17 | 17 |
|
18 | 18 | import enum |
| 19 | +from typing import Mapping |
19 | 20 |
|
20 | 21 | import tensorflow.compat.v2 as tf |
21 | 22 |
|
|
24 | 25 | from tf_keras.testing_infra import test_utils |
25 | 26 |
|
26 | 27 |
|
| 28 | +class _ExtensionType(tf.experimental.ExtensionType): |
| 29 | + """An ExtensionType with multiple tuples and mappings.""" |
| 30 | + |
| 31 | + __name__ = "tf_keras.json_utils.test._ExtensionType" |
| 32 | + |
| 33 | + x: tf.Tensor |
| 34 | + xy: tuple[tf.Tensor, tf.Tensor] |
| 35 | + kv: Mapping[str, tf.Tensor] |
| 36 | + |
| 37 | + |
27 | 38 | class JsonUtilsTest(test_combinations.TestCase): |
28 | 39 | def test_encode_decode_tensor_shape(self): |
29 | 40 | metadata = { |
@@ -64,6 +75,17 @@ def test_encode_decode_type_spec(self): |
64 | 75 | ): |
65 | 76 | loaded = json_utils.decode(string) |
66 | 77 |
|
| 78 | + def test_encode_decode_extensiontype_spec(self): |
| 79 | + instance = _ExtensionType( |
| 80 | + x=tf.constant(1), |
| 81 | + xy=(tf.constant(2), tf.constant(True)), |
| 82 | + kv={"a": tf.constant("foo"), "b": tf.constant("bar")}, |
| 83 | + ) |
| 84 | + spec = tf.type_spec_from_value(instance) |
| 85 | + string = json_utils.Encoder().encode(spec) |
| 86 | + loaded = json_utils.decode(string) |
| 87 | + self.assertEqual(spec, loaded) |
| 88 | + |
67 | 89 | def test_encode_decode_enum(self): |
68 | 90 | class Enum(enum.Enum): |
69 | 91 | CLASS_A = "a" |
|
0 commit comments