Skip to content

Commit 43acbbd

Browse files
CLU Authorscopybara-github
authored andcommitted
Fix when variable is None for include_stats is True or False
PiperOrigin-RevId: 721354960
1 parent 77b0602 commit 43acbbd

File tree

1 file changed

+31
-23
lines changed

1 file changed

+31
-23
lines changed

clu/parameter_overview.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,17 @@ def flatten_dict(
8282
def _count_parameters(params: _ParamsContainer) -> int:
8383
"""Returns the count of variables for the module or parameter dictionary."""
8484
params = flatten_dict(params)
85-
return sum(np.prod(v.shape) for v in params.values())
85+
return sum(np.prod(v.shape) for v in params.values() if v is not None)
8686

8787

8888
def _parameters_size(params: _ParamsContainer) -> int:
8989
"""Returns total size (bytes) for the module or parameter dictionary."""
9090
params = flatten_dict(params)
91-
return sum(np.prod(v.shape) * v.dtype.itemsize for v in params.values())
91+
return sum(
92+
np.prod(v.shape) * v.dtype.itemsize
93+
for v in params.values()
94+
if v is not None
95+
)
9296

9397

9498
def count_parameters(params: _ParamsContainer) -> int:
@@ -127,6 +131,8 @@ def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:
127131

128132
def _make_row_with_stats(name, value, mean, std) -> _ParamRowWithStats:
129133
row = _make_row(name, value)
134+
mean = mean or 0.0
135+
std = std or 0.0
130136
return _ParamRowWithStats(
131137
**dataclasses.asdict(row),
132138
mean=float(jax.device_get(mean)),
@@ -156,12 +162,11 @@ def _get_parameter_rows(
156162
params: Dictionary with parameters as NumPy arrays. The dictionary can be
157163
nested. Alternatively a `tf.Module` can be provided, in which case the
158164
`trainable_variables` of the module will be used.
159-
include_stats: If True, add columns with mean and std for each variable.
160-
If the string "sharding", add column a column with the sharding of the
161-
variable.
162-
If the string "global", params are sharded global arrays and this
163-
function assumes it is called on every host, i.e. can use collectives.
164-
The sharding of the variables is also added as a column.
165+
include_stats: If True, add columns with mean and std for each variable. If
166+
the string "sharding", add column a column with the sharding of the
167+
variable. If the string "global", params are sharded global arrays and
168+
this function assumes it is called on every host, i.e. can use
169+
collectives. The sharding of the variables is also added as a column.
165170
166171
Returns:
167172
A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
@@ -185,12 +190,14 @@ def _get_parameter_rows(
185190
case True:
186191
mean_and_std = _mean_std(values)
187192
return jax.tree_util.tree_map(
188-
_make_row_with_stats, names, values, *mean_and_std)
193+
_make_row_with_stats, names, values, *mean_and_std
194+
)
189195

190196
case "global":
191197
mean_and_std = _mean_std_jit(values)
192198
return jax.tree_util.tree_map(
193-
_make_row_with_stats_and_sharding, names, values, *mean_and_std)
199+
_make_row_with_stats_and_sharding, names, values, *mean_and_std
200+
)
194201

195202
case "sharding":
196203
return jax.tree_util.tree_map(_make_row_with_sharding, names, values)
@@ -256,8 +263,7 @@ def __init__(self, name, values):
256263
column_names = [field.name for field in dataclasses.fields(rows[0])]
257264

258265
columns = [
259-
Column(name, [value_formatter(getattr(row, name))
260-
for row in rows])
266+
Column(name, [value_formatter(getattr(row, name)) for row in rows])
261267
for name in column_names
262268
]
263269

@@ -312,12 +318,11 @@ def get_parameter_overview(
312318
Args:
313319
params: Dictionary with parameters as NumPy arrays. The dictionary can be
314320
nested.
315-
include_stats: If True, add columns with mean and std for each variable.
316-
If the string "sharding", add column a column with the sharding of the
317-
variable.
318-
If the string "global", params are sharded global arrays and this
319-
function assumes it is called on every host, i.e. can use collectives.
320-
The sharding of the variables is also added as a column.
321+
include_stats: If True, add columns with mean and std for each variable. If
322+
the string "sharding", add column a column with the sharding of the
323+
variable. If the string "global", params are sharded global arrays and
324+
this function assumes it is called on every host, i.e. can use
325+
collectives. The sharding of the variables is also added as a column.
321326
max_lines: If not `None`, the maximum number of variables to include.
322327
323328
Returns:
@@ -375,16 +380,19 @@ def log_parameter_overview(
375380
Args:
376381
params: Dictionary with parameters as NumPy arrays. The dictionary can be
377382
nested.
378-
include_stats: If True, add columns with mean and std for each variable.
379-
If the string "global", params are sharded global arrays and this
380-
function assumes it is called on every host, i.e. can use collectives.
383+
include_stats: If True, add columns with mean and std for each variable. If
384+
the string "global", params are sharded global arrays and this function
385+
assumes it is called on every host, i.e. can use collectives.
381386
max_lines: If not `None`, the maximum number of variables to include.
382387
msg: Message to be logged before the overview.
383388
jax_logging_process: Which JAX process ID should do the logging. None = all.
384389
Use this to avoid logspam when include_stats="global".
385390
"""
386391

387392
_log_parameter_overview(
388-
params, include_stats=include_stats, max_lines=max_lines, msg=msg,
389-
jax_logging_process=jax_logging_process
393+
params,
394+
include_stats=include_stats,
395+
max_lines=max_lines,
396+
msg=msg,
397+
jax_logging_process=jax_logging_process,
390398
)

0 commit comments

Comments
 (0)