Skip to content

Commit cfac12b

Browse files
CLU Authorscopybara-github
authored andcommitted
Change the count type from int32 to float32 to avoid overflows.
PiperOrigin-RevId: 723026512
1 parent 43acbbd commit cfac12b

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

clu/metrics.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -784,20 +784,21 @@ class Average(Metric):
784784

785785
@classmethod
786786
def empty(cls) -> Average:
787-
return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.int32))
787+
return cls(total=jnp.array(0, jnp.float32), count=jnp.array(0, jnp.float32))
788788

789789
@classmethod
790790
def from_model_output(
791791
cls, values: jnp.ndarray, mask: jnp.ndarray | None = None, **_
792792
) -> Average:
793793
values, mask = _broadcast_masks(values, mask)
794794
return cls(
795-
total=jnp.where(mask, values, jnp.zeros_like(values)).sum(),
795+
total=jnp.where(mask, values, jnp.zeros_like(values)).sum().astype(
796+
jnp.float32),
796797
count=jnp.where(
797798
mask,
798799
jnp.ones_like(values, dtype=jnp.int32),
799800
jnp.zeros_like(values, dtype=jnp.int32),
800-
).sum(),
801+
).sum().astype(jnp.float32),
801802
)
802803

803804
def merge(self, other: Average) -> Average:

0 commit comments

Comments
 (0)