Skip to content

Commit 8eb5627

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Fix object_registration_test.
This test was not run at all because it was missing a `main`. The tests were changed to match the current behavior of object registration. PiperOrigin-RevId: 651896455
1 parent 64108a1 commit 8eb5627

File tree

1 file changed

+28
-20
lines changed

1 file changed

+28
-20
lines changed

tf_keras/saving/object_registration_test.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,30 +58,24 @@ def __init__(self, value):
5858
def get_config(self):
5959
return {"value": self._value}
6060

61+
@classmethod
62+
def from_config(cls, config):
63+
return cls(**config)
64+
6165
serialized_name = "Custom>TestClass"
6266
inst = TestClass(value=10)
6367
class_name = object_registration._GLOBAL_CUSTOM_NAMES[TestClass]
6468
self.assertEqual(serialized_name, class_name)
6569
config = serialization_lib.serialize_keras_object(inst)
66-
self.assertEqual(class_name, config["class_name"])
70+
if tf.__internal__.tf2.enabled():
71+
self.assertEqual(class_name, config["registered_name"])
72+
else:
73+
self.assertEqual(class_name, config["class_name"])
6774
new_inst = serialization_lib.deserialize_keras_object(config)
6875
self.assertIsNot(inst, new_inst)
6976
self.assertIsInstance(new_inst, TestClass)
7077
self.assertEqual(10, new_inst._value)
7178

72-
# Make sure registering a new class with same name will fail.
73-
with self.assertRaisesRegex(
74-
ValueError, ".*has already been registered.*"
75-
):
76-
77-
@object_registration.register_keras_serializable()
78-
class TestClass:
79-
def __init__(self, value):
80-
self._value = value
81-
82-
def get_config(self):
83-
return {"value": self._value}
84-
8579
def test_serialize_custom_class_with_custom_name(self):
8680
@object_registration.register_keras_serializable(
8781
"TestPackage", "CustomName"
@@ -93,6 +87,10 @@ def __init__(self, val):
9387
def get_config(self):
9488
return {"val": self._val}
9589

90+
@classmethod
91+
def from_config(cls, config):
92+
return cls(**config)
93+
9694
serialized_name = "TestPackage>CustomName"
9795
inst = OtherTestClass(val=5)
9896
class_name = object_registration._GLOBAL_CUSTOM_NAMES[OtherTestClass]
@@ -103,9 +101,12 @@ def get_config(self):
103101
cls = object_registration.get_registered_object(fn_class_name)
104102
self.assertEqual(OtherTestClass, cls)
105103

106-
config = keras.utils.serialization.serialize_keras_object(inst)
107-
self.assertEqual(class_name, config["class_name"])
108-
new_inst = keras.utils.serialization.deserialize_keras_object(config)
104+
config = serialization_lib.serialize_keras_object(inst)
105+
if tf.__internal__.tf2.enabled():
106+
self.assertEqual(class_name, config["registered_name"])
107+
else:
108+
self.assertEqual(class_name, config["class_name"])
109+
new_inst = serialization_lib.deserialize_keras_object(config)
109110
self.assertIsNot(inst, new_inst)
110111
self.assertIsInstance(new_inst, OtherTestClass)
111112
self.assertEqual(5, new_inst._val)
@@ -121,9 +122,12 @@ def my_fn():
121122
fn_class_name = object_registration.get_registered_name(my_fn)
122123
self.assertEqual(fn_class_name, class_name)
123124

124-
config = keras.utils.serialization.serialize_keras_object(my_fn)
125-
self.assertEqual(class_name, config)
126-
fn = keras.utils.serialization.deserialize_keras_object(config)
125+
config = serialization_lib.serialize_keras_object(my_fn)
126+
if tf.__internal__.tf2.enabled():
127+
self.assertEqual("my_fn", config["config"])
128+
else:
129+
self.assertEqual(class_name, config)
130+
fn = serialization_lib.deserialize_keras_object(config)
127131
self.assertEqual(42, fn())
128132

129133
fn_2 = object_registration.get_registered_object(fn_class_name)
@@ -141,3 +145,7 @@ def test_serialize_custom_class_without_get_config_fails(self):
141145
class TestClass:
142146
def __init__(self, value):
143147
self._value = value
148+
149+
150+
if __name__ == "__main__":
151+
tf.test.main()

0 commit comments

Comments
 (0)