From c50acb760902c94a89ad3f605edc2d094bc2a7a1 Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Thu, 14 Mar 2024 06:03:35 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 615750878 --- clu/parameter_overview.py | 109 +++++++++++++++++++++++---------- clu/parameter_overview_test.py | 33 +++++++++- 2 files changed, 106 insertions(+), 36 deletions(-) diff --git a/clu/parameter_overview.py b/clu/parameter_overview.py index ae2fd5c..0b60f76 100644 --- a/clu/parameter_overview.py +++ b/clu/parameter_overview.py @@ -36,6 +36,11 @@ class _ParamRow: size: int +@dataclasses.dataclass +class _ParamRowWithSharding(_ParamRow): + sharding: tuple[int | None, ...] | str + + @dataclasses.dataclass class _ParamRowWithStats(_ParamRow): mean: float @@ -92,6 +97,47 @@ def count_parameters(params: _ParamsContainer) -> int: return _count_parameters(params) +def _make_row(name, value) -> _ParamRow: + return _ParamRow( + name=name, + shape=value.shape, + dtype=str(value.dtype), + size=int(np.prod(value.shape)), + ) + + +def _make_row_with_sharding(name, value) -> _ParamRowWithSharding: + row = _make_row(name, value) + if hasattr(value, "sharding"): + if hasattr(value.sharding, "spec"): + sharding = tuple(value.sharding.spec) + else: + sharding = str(value.sharding) + else: + sharding = () + return _ParamRowWithSharding(**dataclasses.asdict(row), sharding=sharding) + + +def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats: + row = _make_row(name, value) + return _ParamRowWithStats( + **dataclasses.asdict(row), + mean=float(jax.device_get(mean)), + std=float(jax.device_get(std)), + ) + + +def _make_row_with_stats_and_sharding( + name, value, mean, std +) -> _ParamRowWithStatsAndSharding: + row = _make_row_with_sharding(name, value) + return _ParamRowWithStatsAndSharding( + **dataclasses.asdict(row), + mean=float(jax.device_get(mean)), + std=float(jax.device_get(std)), + ) + + def _get_parameter_rows( params: _ParamsContainer, *, @@ -104,8 +150,11 @@ def _get_parameter_rows( nested. Alternatively a `tf.Module` can be provided, in which case the `trainable_variables` of the module will be used. include_stats: If True, add columns with mean and std for each variable. + If the string "sharding", add column a column with the sharding of the + variable. If the string "global", params are sharded global arrays and this function assumes it is called on every host, i.e. can use collectives. + The sharding of the variables is also added as a column. Returns: A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value @@ -122,37 +171,25 @@ def _get_parameter_rows( else: names, values = [], [] - if include_stats: - def make_row(name, value, mean, std): - kw = dict( - name=name, - shape=value.shape, - dtype=str(value.dtype), - size=int(np.prod(value.shape)), - mean=float(jax.device_get(mean)), - std=float(jax.device_get(std)), - ) - if include_stats == "global" and hasattr(value, "sharding"): - if hasattr(value.sharding, "spec"): - return _ParamRowWithStatsAndSharding( - sharding=tuple(value.sharding.spec), **kw - ) - else: - return _ParamRowWithStatsAndSharding( - sharding=str(value.sharding), **kw - ) - return _ParamRowWithStats(**kw) - mean_std_fn = _mean_std_jit if include_stats == "global" else _mean_std - return jax.tree_util.tree_map(make_row, names, values, *mean_std_fn(values)) - else: - def make_row(name, value): - return _ParamRow( - name=name, - shape=value.shape, - dtype=str(value.dtype), - size=int(np.prod(value.shape)), - ) - return jax.tree_util.tree_map(make_row, names, values) + match include_stats: + case False: + return jax.tree_util.tree_map(_make_row, names, values) + + case True: + mean_and_std = _mean_std(values) + return jax.tree_util.tree_map( + _make_row_with_stats, names, values, *mean_and_std) + + case "global": + mean_and_std = _mean_std_jit(values) + return jax.tree_util.tree_map( + _make_row_with_stats_and_sharding, names, values, *mean_and_std) + + case "sharding": + return jax.tree_util.tree_map(_make_row_with_sharding, names, values) + + case _: + raise ValueError(f"Unknown `include_stats`: {include_stats}") def _default_table_value_formatter(value): @@ -247,6 +284,7 @@ def _get_parameter_overview( False: _ParamRow, True: _ParamRowWithStats, "global": _ParamRowWithStatsAndSharding, + "sharding": _ParamRowWithSharding, }[include_stats] # Pass in `column_names` to enable rendering empty tables. column_names = [field.name for field in dataclasses.fields(RowType)] @@ -267,9 +305,12 @@ def get_parameter_overview( Args: params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. - include_stats: If True, add columns with mean and std for each variable. If - the string "global", params are sharded global arrays and this function - assumes it is called on every host, i.e. can use collectives. + include_stats: If True, add columns with mean and std for each variable. + If the string "sharding", add column a column with the sharding of the + variable. + If the string "global", params are sharded global arrays and this + function assumes it is called on every host, i.e. can use collectives. + The sharding of the variables is also added as a column. max_lines: If not `None`, the maximum number of variables to include. Returns: diff --git a/clu/parameter_overview_test.py b/clu/parameter_overview_test.py index 0cf9068..72bf59c 100644 --- a/clu/parameter_overview_test.py +++ b/clu/parameter_overview_test.py @@ -19,6 +19,7 @@ from flax import linen as nn import jax import jax.numpy as jnp +import numpy as np EMPTY_PARAMETER_OVERVIEW = """+------+-------+-------+------+------+-----+ @@ -35,6 +36,14 @@ +-------------+--------------+---------+------+ Total: 56 -- 224 bytes""" +FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING = """+-------------+--------------+---------+------+----------+ +| Name | Shape | Dtype | Size | Sharding | ++-------------+--------------+---------+------+----------+ +| conv/bias | (2,) | float32 | 2 | () | +| conv/kernel | (3, 3, 3, 2) | float32 | 54 | () | ++-------------+--------------+---------+------+----------+ +Total: 56 -- 224 bytes""" + FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+---------+------+------+-----+ | Name | Shape | Dtype | Size | Mean | Std | +-------------+--------------+---------+------+------+-----+ @@ -43,6 +52,14 @@ +-------------+--------------+---------+------+------+-----+ Total: 56 -- 224 bytes""" +FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING = """+-------------+--------------+---------+------+------+-----+----------+ +| Name | Shape | Dtype | Size | Mean | Std | Sharding | ++-------------+--------------+---------+------+------+-----+----------+ +| conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 | () | +| conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 | () | ++-------------+--------------+---------+------+------+-----+----------+ +Total: 56 -- 224 bytes""" + FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = """+--------------------+--------------+---------+------+------+-----+ | Name | Shape | Dtype | Size | Mean | Std | +--------------------+--------------+---------+------+------+-----+ @@ -66,7 +83,7 @@ def test_count_parameters_empty(self): def test_count_parameters(self): rng = jax.random.PRNGKey(42) - # Weights of a 2D convolution with 2 filters.. + # Weights of a 2D convolution with 2 filters. variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3))) # 3 * 3*3 * 2 + 2 (bias) = 56 parameters self.assertEqual(56, @@ -78,7 +95,7 @@ def test_get_parameter_overview_empty(self): def test_get_parameter_overview(self): rng = jax.random.PRNGKey(42) - # Weights of a 2D convolution with 2 filters.. + # Weights of a 2D convolution with 2 filters. variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3))) variables = jax.tree_map(jnp.ones_like, variables) self.assertEqual( @@ -91,6 +108,18 @@ def test_get_parameter_overview(self): self.assertEqual( FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS, parameter_overview.get_parameter_overview(variables)) + # Add sharding with PartitionSpecs. + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), "d") + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + variables = jax.jit(lambda x: x, out_shardings=sharding)(variables) + self.assertEqual( + FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING, + parameter_overview.get_parameter_overview( + variables["params"], include_stats="sharding")) + self.assertEqual( + FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING, + parameter_overview.get_parameter_overview( + variables["params"], include_stats="global")) def test_get_parameter_overview_shape_dtype_struct(self): variables_shape_dtype_struct = jax.eval_shape(