diff --git a/clu/metrics.py b/clu/metrics.py index 613c284..5c15ba3 100644 --- a/clu/metrics.py +++ b/clu/metrics.py @@ -403,7 +403,7 @@ def reduce(self) -> CollectingMetric: # Note that this is usually called from inside a `pmap()` via # `Collection.gather_from_model_output()` so we concatenate using jnp. return type(self)( - {name: jnp.concatenate(values) for name, values in self.values.items()}) + {name: jnp.concatenate(values) for name, values in self.values.items()}) # pytype: disable=wrong-arg-types # jnp-types def compute(self): # No return type annotation, so subclasses can override return {k: np.concatenate(v) for k, v in self.values.items()}