Skip to content

Commit

Permalink
Allow paramter_overview to work with JAX ShapeDtypeStruct.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595317017
  • Loading branch information
jpuigcerver authored and copybara-github committed Jan 3, 2024
1 parent 8a01f21 commit 574d4c9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion clu/parameter_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _count_parameters(params: _ParamsContainer) -> int:
def _parameters_size(params: _ParamsContainer) -> int:
"""Returns total size (bytes) for the module or parameter dictionary."""
params = flatten_dict(params)
return sum(v.nbytes for v in params.values())
return sum(np.prod(v.shape) * v.dtype.itemsize for v in params.values())


def count_parameters(params: _ParamsContainer) -> int:
Expand Down

0 comments on commit 574d4c9

Please sign in to comment.