Skip to content

Commit

Permalink
Unify the handling of tensor-valued metrics of Std with Average. In p…
Browse files Browse the repository at this point in the history
…articular, this removes the ndim restriction on the Std metric.

- Average does not have such a restriction. This is not needed.
- Inside pmap, if per device the output is per sample loss values (ndim=1), then the result of lax.all_gather is tensor with ndim=2 with the added dimension from the devices. So `Average` would work but `Std` would fail.

PiperOrigin-RevId: 603275875
  • Loading branch information
lzlarryli authored and copybara-github committed Feb 1, 2024
1 parent f30bc44 commit 1368e52
Show file tree
Hide file tree
Showing 43 changed files with 117 additions and 71 deletions.
2 changes: 1 addition & 1 deletion clu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/asynclib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/asynclib_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/checkpoint_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/data/dataset_iterator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/data/dataset_iterator_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/deterministic_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/deterministic_data_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/internal/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/internal/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/async_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/async_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/logging_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/logging_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/multi_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/multi_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/summary_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/tf/summary_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/tf/summary_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/torch_tensorboard_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/torch_tensorboard_writer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/metric_writers/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
66 changes: 37 additions & 29 deletions clu/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -742,6 +742,27 @@ def compute(self) -> Any:
return self.value


def _broadcast_masks(values: jnp.ndarray, mask: jnp.ndarray | None):
"""Checks and broadcasts mask for aggregating values."""
if values.ndim == 0:
values = values[None]
if mask is None:
mask = jnp.ones_like(values)
# Leading dimensions of mask and values must match.
if mask.shape[0] != values.shape[0]:
raise ValueError(
"Argument `mask` must have the same leading dimension as `values`. "
f"Received mask of dimension {mask.shape} "
f"and values of dimension {values.shape}."
)
# Broadcast mask to the same number of dimensions as values.
if mask.ndim < values.ndim:
mask = jnp.expand_dims(mask, axis=tuple(np.arange(mask.ndim, values.ndim)))
mask = mask.astype(bool)
utils.check_param(mask, dtype=bool, ndim=values.ndim)
return values, mask


@flax.struct.dataclass
class Average(Metric):
"""Computes the average of a scalar or a batch of tensors.
Expand Down Expand Up @@ -769,26 +790,14 @@ def empty(cls) -> Average:
def from_model_output(
cls, values: jnp.ndarray, mask: jnp.ndarray | None = None, **_
) -> Average:
if values.ndim == 0:
values = values[None]
if mask is None:
mask = jnp.ones_like(values)
# Leading dimensions of mask and values must match.
if mask.shape[0] != values.shape[0]:
raise ValueError(
f"Argument `mask` must have the same leading dimension as `values`. "
f"Received mask of dimension {mask.shape} "
f"and values of dimension {values.shape}.")
# Broadcast mask to the same number of dimensions as values.
if mask.ndim < values.ndim:
mask = jnp.expand_dims(
mask, axis=tuple(np.arange(mask.ndim, values.ndim)))
mask = mask.astype(bool)
utils.check_param(mask, dtype=bool, ndim=values.ndim)
values, mask = _broadcast_masks(values, mask)
return cls(
total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
count=jnp.where(mask, jnp.ones_like(values, dtype=jnp.int32),
jnp.zeros_like(values, dtype=jnp.int32)).sum(),
count=jnp.where(
mask,
jnp.ones_like(values, dtype=jnp.int32),
jnp.zeros_like(values, dtype=jnp.int32),
).sum(),
)

def merge(self, other: Average) -> Average:
Expand All @@ -804,9 +813,10 @@ def compute(self) -> Any:

@flax.struct.dataclass
class Std(Metric):
"""Computes the standard deviation of a scalar or a batch of scalars.
"""Computes the standard deviation of a scalar or a batch of tensors.
See also documentation of `Metric`.
The result is always a single scalar. See also the documentation of `Average`
for the mask handling.
"""

total: jnp.ndarray
Expand All @@ -824,17 +834,15 @@ def empty(cls) -> Std:
def from_model_output(
cls, values: jnp.ndarray, mask: jnp.ndarray | None = None, **_
) -> Std:
if values.ndim == 0:
values = values[None]
utils.check_param(values, ndim=1)
if mask is None:
mask = jnp.ones(values.shape[0], dtype=jnp.int32)
values, mask = _broadcast_masks(values, mask)
return cls(
total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
sum_of_squares=jnp.where(
mask, values**2, jnp.zeros_like(values)
sum_of_squares=jnp.where(mask, values**2, jnp.zeros_like(values)).sum(),
count=jnp.where(
mask,
jnp.ones_like(values, dtype=jnp.int32),
jnp.zeros_like(values, dtype=jnp.int32),
).sum(),
count=mask.sum(),
)

def merge(self, other: Std) -> Std:
Expand Down
40 changes: 39 additions & 1 deletion clu/metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -526,6 +526,44 @@ def merge_collection(model_output, collection):
# If it does have a weak type the second call will cause a re-trace
collection = merge_collection(model_output, collection)

@parameterized.product(
value_mask_pair=[
(1, None),
([1, 2, 3], None),
([1, 2, 3], [True, True, False]),
([[1, 2], [2, 3], [3, 4]], None),
([[1, 2], [2, 3], [3, 4]], [False, True, True]),
(
[[1, 2], [2, 3], [3, 4]],
[[False, True], [True, True], [True, True]],
),
([[[1, 2], [2, 3]], [[2, 1], [3, 4]], [[3, 1], [4, 1]]], None),
(
[[[1, 2], [2, 3]], [[2, 1], [3, 4]], [[3, 1], [4, 1]]],
[False, True, True],
),
],
metric_np_equivalent_pair=[
(metrics.Average, jnp.mean),
(metrics.Std, jnp.std),
],
)
def test_tensor_aggregation_metrics_with_masks(
self, value_mask_pair, metric_np_equivalent_pair
):
values, mask = value_mask_pair
metric, np_equivalent = metric_np_equivalent_pair
values = jnp.asarray(values)
masked = values
if mask is not None:
mask = jnp.asarray(mask)
masked = values[mask]
expected = np_equivalent(masked)

result = metric.from_model_output(values, mask=mask).compute()
# The lower precision is needed for the lower precision jitted version.
chex.assert_trees_all_close(result, expected, atol=1e-4, rtol=1e-4)


if __name__ == "__main__":
absltest.main()
2 changes: 1 addition & 1 deletion clu/parameter_overview.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/parameter_overview_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/periodic_actions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/periodic_actions_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/platform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/platform/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/platform/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/preprocess_spec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/preprocess_spec_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/profiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion clu/values.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 The CLU Authors.
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit 1368e52

Please sign in to comment.