@@ -109,6 +109,48 @@ def test_rsquared_empty(self):
109109 self .assertEqual (m .sum_of_squared_error , jnp .array (0 , jnp .float32 ))
110110 self .assertEqual (m .sum_of_squared_label , jnp .array (0 , jnp .float32 ))
111111
112+ @parameterized .named_parameters (
113+ ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 , None ),
114+ ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 , None ),
115+ ('basic_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 , None ),
116+ ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 , None ),
117+ ('weighted_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 , SAMPLE_WEIGHTS ),
118+ ('weighted_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 , SAMPLE_WEIGHTS ),
119+ ('weighted_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 , SAMPLE_WEIGHTS ),
120+ )
121+ def test_mae (self , y_true , y_pred , sample_weights ):
122+ """Test that `MAE` Metric computes correct values."""
123+ y_true = y_true .astype (y_pred .dtype )
124+ y_pred = y_pred .astype (y_true .dtype )
125+ if sample_weights is None :
126+ sample_weights = np .ones_like (y_true )
127+
128+ metric = None
129+ for labels , logits , weights in zip (y_true , y_pred , sample_weights ):
130+ update = metrax .MAE .from_model_output (
131+ predictions = logits ,
132+ labels = labels ,
133+ sample_weights = weights ,
134+ )
135+ metric = update if metric is None else metric .merge (update )
136+
137+ # TODO(jiwonshin): Use `keras.metrics.MeanAbsoluteError` once it supports
138+ # sample weights.
139+ expected = sklearn_metrics .mean_absolute_error (
140+ y_true .flatten (),
141+ y_pred .flatten (),
142+ sample_weight = sample_weights .flatten (),
143+ )
144+ # Use lower tolerance for lower precision dtypes.
145+ rtol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
146+ atol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
147+ np .testing .assert_allclose (
148+ metric .compute (),
149+ expected ,
150+ rtol = rtol ,
151+ atol = atol ,
152+ )
153+
112154 @parameterized .named_parameters (
113155 ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 , None ),
114156 ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 , None ),
0 commit comments