@@ -82,13 +82,17 @@ def flatten_dict(
82
82
def _count_parameters (params : _ParamsContainer ) -> int :
83
83
"""Returns the count of variables for the module or parameter dictionary."""
84
84
params = flatten_dict (params )
85
- return sum (np .prod (v .shape ) for v in params .values ())
85
+ return sum (np .prod (v .shape ) for v in params .values () if v is not None )
86
86
87
87
88
88
def _parameters_size (params : _ParamsContainer ) -> int :
89
89
"""Returns total size (bytes) for the module or parameter dictionary."""
90
90
params = flatten_dict (params )
91
- return sum (np .prod (v .shape ) * v .dtype .itemsize for v in params .values ())
91
+ return sum (
92
+ np .prod (v .shape ) * v .dtype .itemsize
93
+ for v in params .values ()
94
+ if v is not None
95
+ )
92
96
93
97
94
98
def count_parameters (params : _ParamsContainer ) -> int :
@@ -127,6 +131,8 @@ def _make_row_with_sharding(name, value) -> _ParamRowWithSharding:
127
131
128
132
def _make_row_with_stats (name , value , mean , std ) -> _ParamRowWithStats :
129
133
row = _make_row (name , value )
134
+ mean = mean or 0.0
135
+ std = std or 0.0
130
136
return _ParamRowWithStats (
131
137
** dataclasses .asdict (row ),
132
138
mean = float (jax .device_get (mean )),
@@ -156,12 +162,11 @@ def _get_parameter_rows(
156
162
params: Dictionary with parameters as NumPy arrays. The dictionary can be
157
163
nested. Alternatively a `tf.Module` can be provided, in which case the
158
164
`trainable_variables` of the module will be used.
159
- include_stats: If True, add columns with mean and std for each variable.
160
- If the string "sharding", add column a column with the sharding of the
161
- variable.
162
- If the string "global", params are sharded global arrays and this
163
- function assumes it is called on every host, i.e. can use collectives.
164
- The sharding of the variables is also added as a column.
165
+ include_stats: If True, add columns with mean and std for each variable. If
166
+ the string "sharding", add column a column with the sharding of the
167
+ variable. If the string "global", params are sharded global arrays and
168
+ this function assumes it is called on every host, i.e. can use
169
+ collectives. The sharding of the variables is also added as a column.
165
170
166
171
Returns:
167
172
A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value
@@ -185,12 +190,14 @@ def _get_parameter_rows(
185
190
case True :
186
191
mean_and_std = _mean_std (values )
187
192
return jax .tree_util .tree_map (
188
- _make_row_with_stats , names , values , * mean_and_std )
193
+ _make_row_with_stats , names , values , * mean_and_std
194
+ )
189
195
190
196
case "global" :
191
197
mean_and_std = _mean_std_jit (values )
192
198
return jax .tree_util .tree_map (
193
- _make_row_with_stats_and_sharding , names , values , * mean_and_std )
199
+ _make_row_with_stats_and_sharding , names , values , * mean_and_std
200
+ )
194
201
195
202
case "sharding" :
196
203
return jax .tree_util .tree_map (_make_row_with_sharding , names , values )
@@ -256,8 +263,7 @@ def __init__(self, name, values):
256
263
column_names = [field .name for field in dataclasses .fields (rows [0 ])]
257
264
258
265
columns = [
259
- Column (name , [value_formatter (getattr (row , name ))
260
- for row in rows ])
266
+ Column (name , [value_formatter (getattr (row , name )) for row in rows ])
261
267
for name in column_names
262
268
]
263
269
@@ -312,12 +318,11 @@ def get_parameter_overview(
312
318
Args:
313
319
params: Dictionary with parameters as NumPy arrays. The dictionary can be
314
320
nested.
315
- include_stats: If True, add columns with mean and std for each variable.
316
- If the string "sharding", add column a column with the sharding of the
317
- variable.
318
- If the string "global", params are sharded global arrays and this
319
- function assumes it is called on every host, i.e. can use collectives.
320
- The sharding of the variables is also added as a column.
321
+ include_stats: If True, add columns with mean and std for each variable. If
322
+ the string "sharding", add column a column with the sharding of the
323
+ variable. If the string "global", params are sharded global arrays and
324
+ this function assumes it is called on every host, i.e. can use
325
+ collectives. The sharding of the variables is also added as a column.
321
326
max_lines: If not `None`, the maximum number of variables to include.
322
327
323
328
Returns:
@@ -375,16 +380,19 @@ def log_parameter_overview(
375
380
Args:
376
381
params: Dictionary with parameters as NumPy arrays. The dictionary can be
377
382
nested.
378
- include_stats: If True, add columns with mean and std for each variable.
379
- If the string "global", params are sharded global arrays and this
380
- function assumes it is called on every host, i.e. can use collectives.
383
+ include_stats: If True, add columns with mean and std for each variable. If
384
+ the string "global", params are sharded global arrays and this function
385
+ assumes it is called on every host, i.e. can use collectives.
381
386
max_lines: If not `None`, the maximum number of variables to include.
382
387
msg: Message to be logged before the overview.
383
388
jax_logging_process: Which JAX process ID should do the logging. None = all.
384
389
Use this to avoid logspam when include_stats="global".
385
390
"""
386
391
387
392
_log_parameter_overview (
388
- params , include_stats = include_stats , max_lines = max_lines , msg = msg ,
389
- jax_logging_process = jax_logging_process
393
+ params ,
394
+ include_stats = include_stats ,
395
+ max_lines = max_lines ,
396
+ msg = msg ,
397
+ jax_logging_process = jax_logging_process ,
390
398
)
0 commit comments