Skip to content

Commit

Permalink
Add unit test passing ShapeDtypeStruct to get_parameter_overview.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595321929
  • Loading branch information
jpuigcerver authored and copybara-github committed Jan 3, 2024
1 parent 574d4c9 commit f30bc44
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions clu/parameter_overview_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def test_get_parameter_overview(self):
FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS,
parameter_overview.get_parameter_overview(variables))

def test_get_parameter_overview_shape_dtype_struct(self):
variables_shape_dtype_struct = jax.eval_shape(
lambda: CNN().init(jax.random.PRNGKey(42), jnp.zeros((2, 5, 5, 3))))
self.assertEqual(
FLAX_CONV2D_PARAMETER_OVERVIEW,
parameter_overview.get_parameter_overview(
variables_shape_dtype_struct["params"], include_stats=False))

def test_printing_bool(self):
self.assertEqual(
parameter_overview._default_table_value_formatter(True), "True")
Expand Down

0 comments on commit f30bc44

Please sign in to comment.