Skip to content

Commit

Permalink
Adds Dtype column and total size (in bytes).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 593053061
  • Loading branch information
andsteing authored and copybara-github committed Dec 22, 2023
1 parent d61576a commit 8a01f21
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 32 deletions.
18 changes: 15 additions & 3 deletions clu/parameter_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
class _ParamRow:
name: str
shape: tuple[int, ...]
dtype: str
size: int


Expand Down Expand Up @@ -79,6 +80,12 @@ def _count_parameters(params: _ParamsContainer) -> int:
return sum(np.prod(v.shape) for v in params.values())


def _parameters_size(params: _ParamsContainer) -> int:
"""Returns total size (bytes) for the module or parameter dictionary."""
params = flatten_dict(params)
return sum(v.nbytes for v in params.values())


def count_parameters(params: _ParamsContainer) -> int:
"""Returns the count of variables for the module or parameter dictionary."""

Expand Down Expand Up @@ -120,6 +127,7 @@ def make_row(name, value, mean, std):
kw = dict(
name=name,
shape=value.shape,
dtype=str(value.dtype),
size=int(np.prod(value.shape)),
mean=float(jax.device_get(mean)),
std=float(jax.device_get(std)),
Expand All @@ -139,7 +147,10 @@ def make_row(name, value, mean, std):
else:
def make_row(name, value):
return _ParamRow(
name=name, shape=value.shape, size=int(np.prod(value.shape))
name=name,
shape=value.shape,
dtype=str(value.dtype),
size=int(np.prod(value.shape)),
)
return jax.tree_util.tree_map(make_row, names, values)

Expand Down Expand Up @@ -232,7 +243,6 @@ def _get_parameter_overview(
if include_stats is True and isinstance(params, (dict, flax.core.FrozenDict)): # pylint: disable=g-bool-id-comparison
params = jax.device_get(params) # A no-op if already numpy array.
rows = _get_parameter_rows(params, include_stats=include_stats)
total_weights = _count_parameters(params)
RowType = { # pylint: disable=invalid-name
False: _ParamRow,
True: _ParamRowWithStats,
Expand All @@ -241,7 +251,9 @@ def _get_parameter_overview(
# Pass in `column_names` to enable rendering empty tables.
column_names = [field.name for field in dataclasses.fields(RowType)]
table = make_table(rows, max_lines=max_lines, column_names=column_names)
return table + f"\nTotal: {total_weights:,}"
total_weights = _count_parameters(params)
total_size = _parameters_size(params)
return table + f"\nTotal: {total_weights:,} -- {total_size:,} bytes"


def get_parameter_overview(
Expand Down
58 changes: 29 additions & 29 deletions clu/parameter_overview_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,35 +21,35 @@
import jax.numpy as jnp


EMPTY_PARAMETER_OVERVIEW = """+------+-------+------+------+-----+
| Name | Shape | Size | Mean | Std |
+------+-------+------+------+-----+
+------+-------+------+------+-----+
Total: 0"""

FLAX_CONV2D_PARAMETER_OVERVIEW = """+-------------+--------------+------+
| Name | Shape | Size |
+-------------+--------------+------+
| conv/bias | (2,) | 2 |
| conv/kernel | (3, 3, 3, 2) | 54 |
+-------------+--------------+------+
Total: 56"""

FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+------+------+-----+
| Name | Shape | Size | Mean | Std |
+-------------+--------------+------+------+-----+
| conv/bias | (2,) | 2 | 1.0 | 0.0 |
| conv/kernel | (3, 3, 3, 2) | 54 | 1.0 | 0.0 |
+-------------+--------------+------+------+-----+
Total: 56"""

FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = """+--------------------+--------------+------+------+-----+
| Name | Shape | Size | Mean | Std |
+--------------------+--------------+------+------+-----+
| params/conv/bias | (2,) | 2 | 1.0 | 0.0 |
| params/conv/kernel | (3, 3, 3, 2) | 54 | 1.0 | 0.0 |
+--------------------+--------------+------+------+-----+
Total: 56"""
EMPTY_PARAMETER_OVERVIEW = """+------+-------+-------+------+------+-----+
| Name | Shape | Dtype | Size | Mean | Std |
+------+-------+-------+------+------+-----+
+------+-------+-------+------+------+-----+
Total: 0 -- 0 bytes"""

FLAX_CONV2D_PARAMETER_OVERVIEW = """+-------------+--------------+---------+------+
| Name | Shape | Dtype | Size |
+-------------+--------------+---------+------+
| conv/bias | (2,) | float32 | 2 |
| conv/kernel | (3, 3, 3, 2) | float32 | 54 |
+-------------+--------------+---------+------+
Total: 56 -- 224 bytes"""

FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS = """+-------------+--------------+---------+------+------+-----+
| Name | Shape | Dtype | Size | Mean | Std |
+-------------+--------------+---------+------+------+-----+
| conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 |
| conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 |
+-------------+--------------+---------+------+------+-----+
Total: 56 -- 224 bytes"""

FLAX_CONV2D_MAPPING_PARAMETER_OVERVIEW_WITH_STATS = """+--------------------+--------------+---------+------+------+-----+
| Name | Shape | Dtype | Size | Mean | Std |
+--------------------+--------------+---------+------+------+-----+
| params/conv/bias | (2,) | float32 | 2 | 1.0 | 0.0 |
| params/conv/kernel | (3, 3, 3, 2) | float32 | 54 | 1.0 | 0.0 |
+--------------------+--------------+---------+------+------+-----+
Total: 56 -- 224 bytes"""


class CNN(nn.Module):
Expand Down

0 comments on commit 8a01f21

Please sign in to comment.