Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615750878
  • Loading branch information
jpuigcerver authored and copybara-github committed Mar 14, 2024
1 parent 0b961e5 commit c50acb7
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 36 deletions.
109 changes: 75 additions & 34 deletions clu/parameter_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class _ParamRow:
size: int


@dataclasses.dataclass
class _ParamRowWithSharding(_ParamRow):
sharding: tuple[int | None, ...] | str


@dataclasses.dataclass
class _ParamRowWithStats(_ParamRow):
mean: float
Expand Down Expand Up @@ -92,6 +97,47 @@ def count_parameters(params: _ParamsContainer) -> int:
return _count_parameters(params)


def _make_row(name, value) -> _ParamRow:
return _ParamRow(
name=name,
shape=value.shape,
dtype=str(value.dtype),
size=int(np.prod(value.shape)),
)


def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:
row = _make_row(name, value)
if hasattr(value, "sharding"):
if hasattr(value.sharding, "spec"):
sharding = tuple(value.sharding.spec)
else:
sharding = str(value.sharding)
else:
sharding = ()
return _ParamRowWithSharding(**dataclasses.asdict(row), sharding=sharding)


def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats:
row = _make_row(name, value)
return _ParamRowWithStats(
**dataclasses.asdict(row),
mean=float(jax.device_get(mean)),
std=float(jax.device_get(std)),
)


def _make_row_with_stats_and_sharding(
name, value, mean, std
) -> _ParamRowWithStatsAndSharding:
row = _make_row_with_sharding(name, value)
return _ParamRowWithStatsAndSharding(
**dataclasses.asdict(row),
mean=float(jax.device_get(mean)),
std=float(jax.device_get(std)),
)


def _get_parameter_rows(
params: _ParamsContainer,
*,
Expand All @@ -104,8 +150,11 @@ def _get_parameter_rows(
nested. Alternatively a `tf.Module` can be provided, in which case the
`trainable_variables` of the module will be used.
include_stats: If True, add columns with mean and std for each variable.
If the string "sharding", add column a column with the sharding of the
variable.
If the string "global", params are sharded global arrays and this
function assumes it is called on every host, i.e. can use collectives.
The sharding of the variables is also added as a column.
Returns:
A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
Expand All @@ -122,37 +171,25 @@ def _get_parameter_rows(
else:
names, values = [], []

if include_stats:
def make_row(name, value, mean, std):
kw = dict(
name=name,
shape=value.shape,
dtype=str(value.dtype),
size=int(np.prod(value.shape)),
mean=float(jax.device_get(mean)),
std=float(jax.device_get(std)),
)
if include_stats == "global" and hasattr(value, "sharding"):
if hasattr(value.sharding, "spec"):
return _ParamRowWithStatsAndSharding(
sharding=tuple(value.sharding.spec), **kw
)
else:
return _ParamRowWithStatsAndSharding(
sharding=str(value.sharding), **kw
)
return _ParamRowWithStats(**kw)
mean_std_fn = _mean_std_jit if include_stats == "global" else _mean_std
return jax.tree_util.tree_map(make_row, names, values, *mean_std_fn(values))
else:
def make_row(name, value):
return _ParamRow(
name=name,
shape=value.shape,
dtype=str(value.dtype),
size=int(np.prod(value.shape)),
)
return jax.tree_util.tree_map(make_row, names, values)
match include_stats:
case False:
return jax.tree_util.tree_map(_make_row, names, values)

case True:
mean_and_std = _mean_std(values)
return jax.tree_util.tree_map(
_make_row_with_stats, names, values, *mean_and_std)

case "global":
mean_and_std = _mean_std_jit(values)
return jax.tree_util.tree_map(
_make_row_with_stats_and_sharding, names, values, *mean_and_std)

case "sharding":
return jax.tree_util.tree_map(_make_row_with_sharding, names, values)

case _:
raise ValueError(f"Unknown `include_stats`: {include_stats}")


def _default_table_value_formatter(value):
Expand Down Expand Up @@ -247,6 +284,7 @@ def _get_parameter_overview(
False: _ParamRow,
True: _ParamRowWithStats,
"global": _ParamRowWithStatsAndSharding,
"sharding": _ParamRowWithSharding,
}[include_stats]
# Pass in `column_names` to enable rendering empty tables.
column_names = [field.name for field in dataclasses.fields(RowType)]
Expand All @@ -267,9 +305,12 @@ def get_parameter_overview(
Args:
params: Dictionary with parameters as NumPy arrays. The dictionary can be
nested.
include_stats: If True, add columns with mean and std for each variable. If
the string "global", params are sharded global arrays and this function
assumes it is called on every host, i.e. can use collectives.
include_stats: If True, add columns with mean and std for each variable.
If the string "sharding", add column a column with the sharding of the
variable.
If the string "global", params are sharded global arrays and this
function assumes it is called on every host, i.e. can use collectives.
The sharding of the variables is also added as a column.
max_lines: If not `None`, the maximum number of variables to include.
Returns:
Expand Down
33 changes: 31 additions & 2 deletions clu/parameter_overview_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from flax import linen as nn
import jax
import jax.numpy as jnp
import numpy as np


EMPTY_PARAMETER_OVERVIEW = """+------+-------+-------+------+------+-----+
Expand All @@ -35,6 +36,14 @@
+-------------+--------------+---------+------+
Total: 56 -- 224 bytes"""

FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING = """+-------------+--------------+---------+------+----------+
| Name | Shape | Dtype | Size | Sharding |
+-------------+--------------+---------+------+----------+
| conv/bias | (2,) | float32 | 2 | () |
| conv/kernel | (3, 3, 3, 2) | float32 | 54 | () |
+-------------+--------------+---------+------+----------+
Total: 56 -- 224 bytes"""

FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+---------+------+------+-----+
| Name | Shape | Dtype | Size | Mean | Std |
+-------------+--------------+---------+------+------+-----+
Expand All @@ -43,6 +52,14 @@
+-------------+--------------+---------+------+------+-----+
Total: 56 -- 224 bytes"""

FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING = """+-------------+--------------+---------+------+------+-----+----------+
| Name | Shape | Dtype | Size | Mean | Std | Sharding |
+-------------+--------------+---------+------+------+-----+----------+
| conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 | () |
| conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 | () |
+-------------+--------------+---------+------+------+-----+----------+
Total: 56 -- 224 bytes"""

FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = """+--------------------+--------------+---------+------+------+-----+
| Name | Shape | Dtype | Size | Mean | Std |
+--------------------+--------------+---------+------+------+-----+
Expand All @@ -66,7 +83,7 @@ def test_count_parameters_empty(self):

def test_count_parameters(self):
rng = jax.random.PRNGKey(42)
# Weights of a 2D convolution with 2 filters..
# Weights of a 2D convolution with 2 filters.
variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
# 3 * 3*3 * 2 + 2 (bias) = 56 parameters
self.assertEqual(56,
Expand All @@ -78,7 +95,7 @@ def test_get_parameter_overview_empty(self):

def test_get_parameter_overview(self):
rng = jax.random.PRNGKey(42)
# Weights of a 2D convolution with 2 filters..
# Weights of a 2D convolution with 2 filters.
variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
variables = jax.tree_map(jnp.ones_like, variables)
self.assertEqual(
Expand All @@ -91,6 +108,18 @@ def test_get_parameter_overview(self):
self.assertEqual(
FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS,
parameter_overview.get_parameter_overview(variables))
# Add sharding with PartitionSpecs.
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), "d")
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
variables = jax.jit(lambda x: x, out_shardings=sharding)(variables)
self.assertEqual(
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_SHARDING,
parameter_overview.get_parameter_overview(
variables["params"], include_stats="sharding"))
self.assertEqual(
FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS_AND_SHARDING,
parameter_overview.get_parameter_overview(
variables["params"], include_stats="global"))

def test_get_parameter_overview_shape_dtype_struct(self):
variables_shape_dtype_struct = jax.eval_shape(
Expand Down

0 comments on commit c50acb7

Please sign in to comment.