From f30bc441a14f0ccf8eaff79800f486a846613a8c Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Wed, 3 Jan 2024 00:53:22 -0800 Subject: [PATCH] Add unit test passing ShapeDtypeStruct to get_parameter_overview. PiperOrigin-RevId: 595321929 --- clu/parameter_overview_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/clu/parameter_overview_test.py b/clu/parameter_overview_test.py index 6cd73b2..670ddf7 100644 --- a/clu/parameter_overview_test.py +++ b/clu/parameter_overview_test.py @@ -92,6 +92,14 @@ def test_get_parameter_overview(self): FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS, parameter_overview.get_parameter_overview(variables)) + def test_get_parameter_overview_shape_dtype_struct(self): + variables_shape_dtype_struct = jax.eval_shape( + lambda: CNN().init(jax.random.PRNGKey(42), jnp.zeros((2, 5, 5, 3)))) + self.assertEqual( + FLAX_CONV2D_PARAMETER_OVERVIEW, + parameter_overview.get_parameter_overview( + variables_shape_dtype_struct["params"], include_stats=False)) + def test_printing_bool(self): self.assertEqual( parameter_overview._default_table_value_formatter(True), "True")