Skip to content

Commit

Permalink
Make parameter_overview work on global arrays.
Browse files Browse the repository at this point in the history
include_stats has a new value "global" which assumes arrays may be distributed even across hosts, and uses one jit'ed call to globally compute mean/std instead of transferring parameters to host.

PiperOrigin-RevId: 559677900
  • Loading branch information
Lucas Beyer authored and copybara-github committed Aug 24, 2023
1 parent efa68db commit 1ae207c
Showing 1 changed file with 45 additions and 21 deletions.
66 changes: 45 additions & 21 deletions clu/parameter_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import flax
import jax
import jax.numpy as jnp
import numpy as np

_ParamsContainer = dict[str, np.ndarray] | Mapping[str, Mapping[str, Any]]
Expand All @@ -40,6 +41,15 @@ class _ParamRowWithStats(_ParamRow):
std: float


@jax.jit
def _mean_std_jit(x):
return jax.tree_util.tree_map(jnp.mean, x), jax.tree_util.tree_map(jnp.std, x)


def _mean_std_np(x):
return jax.tree_util.tree_map(np.mean, x), jax.tree_util.tree_map(np.std, x)


def flatten_dict(
input_dict: dict[str, Any], *, prefix: str = "", delimiter: str = "/"
) -> dict[str, Any]:
Expand Down Expand Up @@ -71,17 +81,17 @@ def count_parameters(params: _ParamsContainer) -> int:
def _get_parameter_rows(
params: _ParamsContainer,
*,
include_stats: bool = False,
include_stats: bool | str = False,
) -> list[_ParamRow | _ParamRowWithStats]:
"""Returns information about parameters as a list of dictionaries.
Args:
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested. Alternatively a `tf.Module` can be provided, in which case the
`trainable_variables` of the module will be used.
include_stats: If True add columns with mean and std for each variable. Note
that this can be considerably more compute intensive and cause a lot of
memory to be transferred to the host (with `tf.Module`).
include_stats: If True, add columns with mean and std for each variable.
If the string "global", params are sharded global arrays and this
function assumes it is called on every host, i.e. can use collectives.
Returns:
A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
Expand All @@ -91,27 +101,33 @@ def _get_parameter_rows(
raise ValueError(
f"Expected `params` to be a dictionary but got {type(params)}"
)
if include_stats not in (False, True, "global"): # Avoid typos.
raise ValueError(
f"include_stats must be False, True or 'global', got {include_stats!r}")

if params:
params = flatten_dict(params)
names, values = map(list, tuple(zip(*sorted(params.items()))))
else:
names, values = [], []

def make_row(name, value):
if include_stats:
if include_stats:
def make_row(name, value, mean, std):
return _ParamRowWithStats(
name=name,
shape=value.shape,
size=int(np.prod(value.shape)),
mean=float(value.mean()),
std=float(value.std()),
mean=float(mean),
std=float(std),
)
else:
mean_std_fn = _mean_std_jit if include_stats == "global" else _mean_std_np
return jax.tree_util.tree_map(make_row, names, values, *mean_std_fn(values))
else:
def make_row(name, value):
return _ParamRow(
name=name, shape=value.shape, size=int(np.prod(value.shape))
)

return [make_row(name, value) for name, value in zip(names, values)]
return jax.tree_util.tree_map(make_row, names, values)


def _default_table_value_formatter(value):
Expand Down Expand Up @@ -195,11 +211,11 @@ def __init__(self, name, values):
def _get_parameter_overview(
params: _ParamsContainer,
*,
include_stats: bool = True,
include_stats: bool | str = True,
max_lines: int | None = None,
) -> str:
"""See get_parameter_overview()."""
if include_stats and isinstance(params, (dict, flax.core.FrozenDict)):
if include_stats is True and isinstance(params, (dict, flax.core.FrozenDict)): # pylint: disable=g-bool-id-comparison
params = jax.tree_map(np.asarray, params)
rows = _get_parameter_rows(params, include_stats=include_stats)
total_weights = _count_parameters(params)
Expand Down Expand Up @@ -246,28 +262,31 @@ def get_parameter_overview(
def _log_parameter_overview(
params: _ParamsContainer,
*,
include_stats: bool = True,
include_stats: bool | str = True,
max_lines: int | None = None,
msg: str | None = None,
jax_logging_process: int | None = None,
):
"""See log_parameter_overview()."""

table = _get_parameter_overview(
params, include_stats=include_stats, max_lines=max_lines
)
lines = [msg] if msg else []
lines += table.split("\n")
# The table can be too large to fit into one log entry.
for i in range(0, len(lines), 80):
logging.info("\n%s", "\n".join(lines[i : i + 80]))
if jax_logging_process is None or jax_logging_process == jax.process_index():
lines = [msg] if msg else []
lines += table.split("\n")
# The table can be too large to fit into one log entry.
for i in range(0, len(lines), 80):
logging.info("\n%s", "\n".join(lines[i : i + 80]))


def log_parameter_overview(
params: _ParamsContainer,
*,
include_stats: bool = True,
include_stats: bool | str = True,
max_lines: int | None = None,
msg: str | None = None,
jax_logging_process: int | None = None,
):
"""Writes a table with variables name and shapes to INFO log.
Expand All @@ -277,10 +296,15 @@ def log_parameter_overview(
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested.
include_stats: If True, add columns with mean and std for each variable.
If the string "global", params are sharded global arrays and this
function assumes it is called on every host, i.e. can use collectives.
max_lines: If not `None`, the maximum number of variables to include.
msg: Message to be logged before the overview.
jax_logging_process: Which JAX process ID should do the logging. None = all.
Use this to avoid logspam when include_stats="global".
"""

_log_parameter_overview(
params, include_stats=include_stats, max_lines=max_lines, msg=msg
params, include_stats=include_stats, max_lines=max_lines, msg=msg,
jax_logging_process=jax_logging_process
)

0 comments on commit 1ae207c

Please sign in to comment.