Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/metrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
BLEU = nlp_metrics.BLEU
DCGAtK = ranking_metrics.DCGAtK
MAE = regression_metrics.MAE
MRR = ranking_metrics.MRR
MSE = regression_metrics.MSE
NDCGAtK = ranking_metrics.NDCGAtK
Expand All @@ -46,6 +47,7 @@
"AveragePrecisionAtK",
"BLEU",
"DCGAtK",
"MAE",
"MRR",
"MSE",
"NDCGAtK",
Expand Down
5 changes: 5 additions & 0 deletions src/metrax/metrax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ class MetraxTest(parameterized.TestCase):
'ks': KS,
},
),
(
'mae',
metrax.MAE,
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
),
(
'mse',
metrax.MSE,
Expand Down
2 changes: 2 additions & 0 deletions src/metrax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK
BLEU = nnx_metrics.BLEU
DCGAtK = nnx_metrics.DCGAtK
MAE = nnx_metrics.MAE
MRR = nnx_metrics.MRR
MSE = nnx_metrics.MSE
NDCGAtK = nnx_metrics.NDCGAtK
Expand All @@ -43,6 +44,7 @@
"BLEU",
"DCGAtK",
"MRR",
"MAE"
"MSE",
"NDCGAtK",
"Perplexity",
Expand Down
7 changes: 7 additions & 0 deletions src/metrax/nnx/nnx_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def __init__(self):
super().__init__(metrax.DCGAtK)


class MAE(NnxWrapper):
"""An NNX class for the Metrax metric MAE."""

def __init__(self):
super().__init__(metrax.MAE)


class MRR(NnxWrapper):
"""An NNX class for the Metrax metric MRR."""

Expand Down
56 changes: 56 additions & 0 deletions src/metrax/regression_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,62 @@
from metrax import base


@flax.struct.dataclass
class MAE(base.Average):
r"""Computes the mean absolute error for regression problems given `predictions` and `labels`.

The mean absolute error without sample weights is defined as:

.. math::
MAE = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i|

When sample weights :math:`w_i` are provided, the weighted mean absolute error
is:

.. math::
MAE = \frac{\sum_{i=1}^{N} w_i|y_i - \hat{y}_i|}{\sum_{i=1}^{N} w_i}

where:
- :math:`y_i` are true values
- :math:`\hat{y}_i` are predictions
- :math:`w_i` are sample weights
- :math:`N` is the number of samples
"""

@classmethod
def from_model_output(
cls,
predictions: jax.Array,
labels: jax.Array,
sample_weights: jax.Array | None = None,
) -> 'MAE':
"""Updates the metric.

Args:
predictions: A floating point 1D vector representing the prediction
generated from the model. The shape should be (batch_size,).
labels: True value. The shape should be (batch_size,).
sample_weights: An optional floating point 1D vector representing the
weight of each sample. The shape should be (batch_size,).

Returns:
Updated MAE metric. The shape should be a single scalar.

Raises:
ValueError: If type of `labels` is wrong or the shapes of `predictions`
and `labels` are incompatible.
"""
absolute_error = jnp.abs(predictions - labels)
count = jnp.ones_like(labels, dtype=jnp.int32)
if sample_weights is not None:
absolute_error = absolute_error * sample_weights
count = count * sample_weights
return cls(
total=absolute_error.sum(),
count=count.sum(),
)


@flax.struct.dataclass
class MSE(base.Average):
r"""Computes the mean squared error for regression problems given `predictions` and `labels`.
Expand Down
42 changes: 42 additions & 0 deletions src/metrax/regression_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,48 @@ def test_rsquared_empty(self):
self.assertEqual(m.sum_of_squared_error, jnp.array(0, jnp.float32))
self.assertEqual(m.sum_of_squared_label, jnp.array(0, jnp.float32))

@parameterized.named_parameters(
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, None),
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS),
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS),
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS),
)
def test_mae(self, y_true, y_pred, sample_weights):
"""Test that `MAE` Metric computes correct values."""
y_true = y_true.astype(y_pred.dtype)
y_pred = y_pred.astype(y_true.dtype)
if sample_weights is None:
sample_weights = np.ones_like(y_true)

metric = None
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
update = metrax.MAE.from_model_output(
predictions=logits,
labels=labels,
sample_weights=weights,
)
metric = update if metric is None else metric.merge(update)

# TODO(jiwonshin): Use `keras.metrics.MeanAbsoluteError` once it supports
# sample weights.
expected = sklearn_metrics.mean_absolute_error(
y_true.flatten(),
y_pred.flatten(),
sample_weight=sample_weights.flatten(),
)
# Use lower tolerance for lower precision dtypes.
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
np.testing.assert_allclose(
metric.compute(),
expected,
rtol=rtol,
atol=atol,
)

@parameterized.named_parameters(
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),
Expand Down