@@ -1272,6 +1272,57 @@ def call(self, inputs):
12721272 self .assertAllEqual (out_false , sample_input )
12731273
12741274
1275+ @test_utils .run_v2_only
1276+ class BaseRandomLayerTest (test_combinations .TestCase ):
1277+ def teardown (self ):
1278+ backend .disable_tf_random_generator ()
1279+
1280+ def test_rng_type_is_saved_in_config (self ):
1281+ backend .disable_tf_random_generator ()
1282+
1283+ layer = base_layer .BaseRandomLayer (rng_type = "stateful" )
1284+ config = layer .get_config ()
1285+ self .assertEqual (config ["rng_type" ], "stateful" )
1286+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1287+ self .assertEqual (reloaded_layer ._random_generator ._rng_type , "stateful" )
1288+
1289+ layer = base_layer .BaseRandomLayer (rng_type = "stateless" )
1290+ config = layer .get_config ()
1291+ self .assertEqual (config ["rng_type" ], "stateless" )
1292+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1293+ self .assertEqual (
1294+ reloaded_layer ._random_generator ._rng_type , "stateless"
1295+ )
1296+
1297+ layer = base_layer .BaseRandomLayer ()
1298+ config = layer .get_config ()
1299+ self .assertNotIn ("rng_type" , config )
1300+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1301+ self .assertEqual (
1302+ reloaded_layer ._random_generator ._rng_type , "legacy_stateful"
1303+ )
1304+
1305+ layer = base_layer .BaseRandomLayer (rng_type = "legacy_stateful" )
1306+ config = layer .get_config ()
1307+ self .assertNotIn ("rng_type" , config )
1308+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1309+ self .assertEqual (
1310+ reloaded_layer ._random_generator ._rng_type , "legacy_stateful"
1311+ )
1312+
1313+ def test_rng_type_with_tf_random_generator (self ):
1314+ # Test `rng_type` is still serialized when global stateful mode is on.
1315+ backend .enable_tf_random_generator ()
1316+
1317+ layer = base_layer .BaseRandomLayer ()
1318+ config = layer .get_config ()
1319+ self .assertEqual (config ["rng_type" ], "stateful" )
1320+
1321+ backend .disable_tf_random_generator ()
1322+ reloaded_layer = base_layer .BaseRandomLayer .from_config (config )
1323+ self .assertEqual (reloaded_layer ._random_generator ._rng_type , "stateful" )
1324+
1325+
12751326@test_utils .run_v2_only
12761327class SymbolicSupportTest (test_combinations .TestCase ):
12771328 def test_using_symbolic_tensors_with_tf_ops (self ):
0 commit comments