From 1ae207ca9ef5190a6cd64c1bc52ee520ed4f7096 Mon Sep 17 00:00:00 2001 From: Lucas Beyer Date: Thu, 24 Aug 2023 01:26:06 -0700 Subject: [PATCH] Make parameter_overview work on global arrays. include_stats has a new value "global" which assumes arrays may be distributed even across hosts, and uses one jit'ed call to globally compute mean/std instead of transferring parameters to host. PiperOrigin-RevId: 559677900 --- clu/parameter_overview.py | 66 ++++++++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 21 deletions(-) diff --git a/clu/parameter_overview.py b/clu/parameter_overview.py index 1ac1324..49e19eb 100644 --- a/clu/parameter_overview.py +++ b/clu/parameter_overview.py @@ -22,6 +22,7 @@ import flax import jax +import jax.numpy as jnp import numpy as np _ParamsContainer = dict[str, np.ndarray] | Mapping[str, Mapping[str, Any]] @@ -40,6 +41,15 @@ class _ParamRowWithStats(_ParamRow): std: float +@jax.jit +def _mean_std_jit(x): + return jax.tree_util.tree_map(jnp.mean, x), jax.tree_util.tree_map(jnp.std, x) + + +def _mean_std_np(x): + return jax.tree_util.tree_map(np.mean, x), jax.tree_util.tree_map(np.std, x) + + def flatten_dict( input_dict: dict[str, Any], *, prefix: str = "", delimiter: str = "/" ) -> dict[str, Any]: @@ -71,7 +81,7 @@ def count_parameters(params: _ParamsContainer) -> int: def _get_parameter_rows( params: _ParamsContainer, *, - include_stats: bool = False, + include_stats: bool | str = False, ) -> list[_ParamRow | _ParamRowWithStats]: """Returns information about parameters as a list of dictionaries. @@ -79,9 +89,9 @@ def _get_parameter_rows( params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. Alternatively a `tf.Module` can be provided, in which case the `trainable_variables` of the module will be used. - include_stats: If True add columns with mean and std for each variable. Note - that this can be considerably more compute intensive and cause a lot of - memory to be transferred to the host (with `tf.Module`). + include_stats: If True, add columns with mean and std for each variable. + If the string "global", params are sharded global arrays and this + function assumes it is called on every host, i.e. can use collectives. Returns: A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value @@ -91,27 +101,33 @@ def _get_parameter_rows( raise ValueError( f"Expected `params` to be a dictionary but got {type(params)}" ) + if include_stats not in (False, True, "global"): # Avoid typos. + raise ValueError( + f"include_stats must be False, True or 'global', got {include_stats!r}") + if params: params = flatten_dict(params) names, values = map(list, tuple(zip(*sorted(params.items())))) else: names, values = [], [] - def make_row(name, value): - if include_stats: + if include_stats: + def make_row(name, value, mean, std): return _ParamRowWithStats( name=name, shape=value.shape, size=int(np.prod(value.shape)), - mean=float(value.mean()), - std=float(value.std()), + mean=float(mean), + std=float(std), ) - else: + mean_std_fn = _mean_std_jit if include_stats == "global" else _mean_std_np + return jax.tree_util.tree_map(make_row, names, values, *mean_std_fn(values)) + else: + def make_row(name, value): return _ParamRow( name=name, shape=value.shape, size=int(np.prod(value.shape)) ) - - return [make_row(name, value) for name, value in zip(names, values)] + return jax.tree_util.tree_map(make_row, names, values) def _default_table_value_formatter(value): @@ -195,11 +211,11 @@ def __init__(self, name, values): def _get_parameter_overview( params: _ParamsContainer, *, - include_stats: bool = True, + include_stats: bool | str = True, max_lines: int | None = None, ) -> str: """See get_parameter_overview().""" - if include_stats and isinstance(params, (dict, flax.core.FrozenDict)): + if include_stats is True and isinstance(params, (dict, flax.core.FrozenDict)): # pylint: disable=g-bool-id-comparison params = jax.tree_map(np.asarray, params) rows = _get_parameter_rows(params, include_stats=include_stats) total_weights = _count_parameters(params) @@ -246,28 +262,31 @@ def get_parameter_overview( def _log_parameter_overview( params: _ParamsContainer, *, - include_stats: bool = True, + include_stats: bool | str = True, max_lines: int | None = None, msg: str | None = None, + jax_logging_process: int | None = None, ): """See log_parameter_overview().""" table = _get_parameter_overview( params, include_stats=include_stats, max_lines=max_lines ) - lines = [msg] if msg else [] - lines += table.split("\n") - # The table can be too large to fit into one log entry. - for i in range(0, len(lines), 80): - logging.info("\n%s", "\n".join(lines[i : i + 80])) + if jax_logging_process is None or jax_logging_process == jax.process_index(): + lines = [msg] if msg else [] + lines += table.split("\n") + # The table can be too large to fit into one log entry. + for i in range(0, len(lines), 80): + logging.info("\n%s", "\n".join(lines[i : i + 80])) def log_parameter_overview( params: _ParamsContainer, *, - include_stats: bool = True, + include_stats: bool | str = True, max_lines: int | None = None, msg: str | None = None, + jax_logging_process: int | None = None, ): """Writes a table with variables name and shapes to INFO log. @@ -277,10 +296,15 @@ def log_parameter_overview( params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. include_stats: If True, add columns with mean and std for each variable. + If the string "global", params are sharded global arrays and this + function assumes it is called on every host, i.e. can use collectives. max_lines: If not `None`, the maximum number of variables to include. msg: Message to be logged before the overview. + jax_logging_process: Which JAX process ID should do the logging. None = all. + Use this to avoid logspam when include_stats="global". """ _log_parameter_overview( - params, include_stats=include_stats, max_lines=max_lines, msg=msg + params, include_stats=include_stats, max_lines=max_lines, msg=msg, + jax_logging_process=jax_logging_process )