@@ -58,30 +58,24 @@ def __init__(self, value):
5858 def get_config (self ):
5959 return {"value" : self ._value }
6060
61+ @classmethod
62+ def from_config (cls , config ):
63+ return cls (** config )
64+
6165 serialized_name = "Custom>TestClass"
6266 inst = TestClass (value = 10 )
6367 class_name = object_registration ._GLOBAL_CUSTOM_NAMES [TestClass ]
6468 self .assertEqual (serialized_name , class_name )
6569 config = serialization_lib .serialize_keras_object (inst )
66- self .assertEqual (class_name , config ["class_name" ])
70+ if tf .__internal__ .tf2 .enabled ():
71+ self .assertEqual (class_name , config ["registered_name" ])
72+ else :
73+ self .assertEqual (class_name , config ["class_name" ])
6774 new_inst = serialization_lib .deserialize_keras_object (config )
6875 self .assertIsNot (inst , new_inst )
6976 self .assertIsInstance (new_inst , TestClass )
7077 self .assertEqual (10 , new_inst ._value )
7178
72- # Make sure registering a new class with same name will fail.
73- with self .assertRaisesRegex (
74- ValueError , ".*has already been registered.*"
75- ):
76-
77- @object_registration .register_keras_serializable ()
78- class TestClass :
79- def __init__ (self , value ):
80- self ._value = value
81-
82- def get_config (self ):
83- return {"value" : self ._value }
84-
8579 def test_serialize_custom_class_with_custom_name (self ):
8680 @object_registration .register_keras_serializable (
8781 "TestPackage" , "CustomName"
@@ -93,6 +87,10 @@ def __init__(self, val):
9387 def get_config (self ):
9488 return {"val" : self ._val }
9589
90+ @classmethod
91+ def from_config (cls , config ):
92+ return cls (** config )
93+
9694 serialized_name = "TestPackage>CustomName"
9795 inst = OtherTestClass (val = 5 )
9896 class_name = object_registration ._GLOBAL_CUSTOM_NAMES [OtherTestClass ]
@@ -103,9 +101,12 @@ def get_config(self):
103101 cls = object_registration .get_registered_object (fn_class_name )
104102 self .assertEqual (OtherTestClass , cls )
105103
106- config = keras .utils .serialization .serialize_keras_object (inst )
107- self .assertEqual (class_name , config ["class_name" ])
108- new_inst = keras .utils .serialization .deserialize_keras_object (config )
104+ config = serialization_lib .serialize_keras_object (inst )
105+ if tf .__internal__ .tf2 .enabled ():
106+ self .assertEqual (class_name , config ["registered_name" ])
107+ else :
108+ self .assertEqual (class_name , config ["class_name" ])
109+ new_inst = serialization_lib .deserialize_keras_object (config )
109110 self .assertIsNot (inst , new_inst )
110111 self .assertIsInstance (new_inst , OtherTestClass )
111112 self .assertEqual (5 , new_inst ._val )
@@ -121,9 +122,12 @@ def my_fn():
121122 fn_class_name = object_registration .get_registered_name (my_fn )
122123 self .assertEqual (fn_class_name , class_name )
123124
124- config = keras .utils .serialization .serialize_keras_object (my_fn )
125- self .assertEqual (class_name , config )
126- fn = keras .utils .serialization .deserialize_keras_object (config )
125+ config = serialization_lib .serialize_keras_object (my_fn )
126+ if tf .__internal__ .tf2 .enabled ():
127+ self .assertEqual ("my_fn" , config ["config" ])
128+ else :
129+ self .assertEqual (class_name , config )
130+ fn = serialization_lib .deserialize_keras_object (config )
127131 self .assertEqual (42 , fn ())
128132
129133 fn_2 = object_registration .get_registered_object (fn_class_name )
@@ -141,3 +145,7 @@ def test_serialize_custom_class_without_get_config_fails(self):
141145 class TestClass :
142146 def __init__ (self , value ):
143147 self ._value = value
148+
149+
150+ if __name__ == "__main__" :
151+ tf .test .main ()
0 commit comments