File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -784,20 +784,21 @@ class Average(Metric):
784
784
785
785
@classmethod
786
786
def empty (cls ) -> Average :
787
- return cls (total = jnp .array (0 , jnp .float32 ), count = jnp .array (0 , jnp .int32 ))
787
+ return cls (total = jnp .array (0 , jnp .float32 ), count = jnp .array (0 , jnp .float32 ))
788
788
789
789
@classmethod
790
790
def from_model_output (
791
791
cls , values : jnp .ndarray , mask : jnp .ndarray | None = None , ** _
792
792
) -> Average :
793
793
values , mask = _broadcast_masks (values , mask )
794
794
return cls (
795
- total = jnp .where (mask , values , jnp .zeros_like (values )).sum (),
795
+ total = jnp .where (mask , values , jnp .zeros_like (values )).sum ().astype (
796
+ jnp .float32 ),
796
797
count = jnp .where (
797
798
mask ,
798
799
jnp .ones_like (values , dtype = jnp .int32 ),
799
800
jnp .zeros_like (values , dtype = jnp .int32 ),
800
- ).sum (),
801
+ ).sum (). astype ( jnp . float32 ) ,
801
802
)
802
803
803
804
def merge (self , other : Average ) -> Average :
You can’t perform that action at this time.
0 commit comments