Skip to content

Commit 325e16e

Browse files
committed
add mae to metrax (#79)
* add mae to metrax * modify regression_metrics * revert ranking metrics changes
1 parent e81881d commit 325e16e

File tree

6 files changed

+114
-0
lines changed

6 files changed

+114
-0
lines changed

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
2626
BLEU = nlp_metrics.BLEU
2727
DCGAtK = ranking_metrics.DCGAtK
28+
MAE = regression_metrics.MAE
2829
MRR = ranking_metrics.MRR
2930
MSE = regression_metrics.MSE
3031
NDCGAtK = ranking_metrics.NDCGAtK
@@ -48,6 +49,7 @@
4849
"AveragePrecisionAtK",
4950
"BLEU",
5051
"DCGAtK",
52+
"MAE",
5153
"MRR",
5254
"MSE",
5355
"NDCGAtK",

src/metrax/metrax_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ class MetraxTest(parameterized.TestCase):
9393
'ks': KS,
9494
},
9595
),
96+
(
97+
'mae',
98+
metrax.MAE,
99+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
100+
),
96101
(
97102
'mse',
98103
metrax.MSE,

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK
2222
BLEU = nnx_metrics.BLEU
2323
DCGAtK = nnx_metrics.DCGAtK
24+
MAE = nnx_metrics.MAE
2425
MRR = nnx_metrics.MRR
2526
MSE = nnx_metrics.MSE
2627
NDCGAtK = nnx_metrics.NDCGAtK
@@ -44,6 +45,7 @@
4445
"BLEU",
4546
"DCGAtK",
4647
"MRR",
48+
"MAE"
4749
"MSE",
4850
"NDCGAtK",
4951
"Perplexity",

src/metrax/nnx/nnx_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def __init__(self):
6666
super().__init__(metrax.DCGAtK)
6767

6868

69+
class MAE(NnxWrapper):
70+
"""An NNX class for the Metrax metric MAE."""
71+
72+
def __init__(self):
73+
super().__init__(metrax.MAE)
74+
75+
6976
class MRR(NnxWrapper):
7077
"""An NNX class for the Metrax metric MRR."""
7178

src/metrax/regression_metrics.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,62 @@
2121
from metrax import base
2222

2323

24+
@flax.struct.dataclass
25+
class MAE(base.Average):
26+
r"""Computes the mean absolute error for regression problems given `predictions` and `labels`.
27+
28+
The mean absolute error without sample weights is defined as:
29+
30+
.. math::
31+
MAE = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i|
32+
33+
When sample weights :math:`w_i` are provided, the weighted mean absolute error
34+
is:
35+
36+
.. math::
37+
MAE = \frac{\sum_{i=1}^{N} w_i|y_i - \hat{y}_i|}{\sum_{i=1}^{N} w_i}
38+
39+
where:
40+
- :math:`y_i` are true values
41+
- :math:`\hat{y}_i` are predictions
42+
- :math:`w_i` are sample weights
43+
- :math:`N` is the number of samples
44+
"""
45+
46+
@classmethod
47+
def from_model_output(
48+
cls,
49+
predictions: jax.Array,
50+
labels: jax.Array,
51+
sample_weights: jax.Array | None = None,
52+
) -> 'MAE':
53+
"""Updates the metric.
54+
55+
Args:
56+
predictions: A floating point 1D vector representing the prediction
57+
generated from the model. The shape should be (batch_size,).
58+
labels: True value. The shape should be (batch_size,).
59+
sample_weights: An optional floating point 1D vector representing the
60+
weight of each sample. The shape should be (batch_size,).
61+
62+
Returns:
63+
Updated MAE metric. The shape should be a single scalar.
64+
65+
Raises:
66+
ValueError: If type of `labels` is wrong or the shapes of `predictions`
67+
and `labels` are incompatible.
68+
"""
69+
absolute_error = jnp.abs(predictions - labels)
70+
count = jnp.ones_like(labels, dtype=jnp.int32)
71+
if sample_weights is not None:
72+
absolute_error = absolute_error * sample_weights
73+
count = count * sample_weights
74+
return cls(
75+
total=absolute_error.sum(),
76+
count=count.sum(),
77+
)
78+
79+
2480
@flax.struct.dataclass
2581
class MSE(base.Average):
2682
r"""Computes the mean squared error for regression problems given `predictions` and `labels`.

src/metrax/regression_metrics_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,48 @@ def test_rsquared_empty(self):
109109
self.assertEqual(m.sum_of_squared_error, jnp.array(0, jnp.float32))
110110
self.assertEqual(m.sum_of_squared_label, jnp.array(0, jnp.float32))
111111

112+
@parameterized.named_parameters(
113+
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
114+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),
115+
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, None),
116+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
117+
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS),
118+
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS),
119+
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS),
120+
)
121+
def test_mae(self, y_true, y_pred, sample_weights):
122+
"""Test that `MAE` Metric computes correct values."""
123+
y_true = y_true.astype(y_pred.dtype)
124+
y_pred = y_pred.astype(y_true.dtype)
125+
if sample_weights is None:
126+
sample_weights = np.ones_like(y_true)
127+
128+
metric = None
129+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
130+
update = metrax.MAE.from_model_output(
131+
predictions=logits,
132+
labels=labels,
133+
sample_weights=weights,
134+
)
135+
metric = update if metric is None else metric.merge(update)
136+
137+
# TODO(jiwonshin): Use `keras.metrics.MeanAbsoluteError` once it supports
138+
# sample weights.
139+
expected = sklearn_metrics.mean_absolute_error(
140+
y_true.flatten(),
141+
y_pred.flatten(),
142+
sample_weight=sample_weights.flatten(),
143+
)
144+
# Use lower tolerance for lower precision dtypes.
145+
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
146+
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
147+
np.testing.assert_allclose(
148+
metric.compute(),
149+
expected,
150+
rtol=rtol,
151+
atol=atol,
152+
)
153+
112154
@parameterized.named_parameters(
113155
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
114156
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),

0 commit comments

Comments
 (0)