Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/metrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

AUCPR = classification_metrics.AUCPR
AUCROC = classification_metrics.AUCROC
Accuracy = classification_metrics.Accuracy
Average = base.Average
AveragePrecisionAtK = ranking_metrics.AveragePrecisionAtK
BLEU = nlp_metrics.BLEU
DCGAtK = ranking_metrics.DCGAtK
MAE = regression_metrics.MAE
MRR = ranking_metrics.MRR
MSE = regression_metrics.MSE
NDCGAtK = ranking_metrics.NDCGAtK
Expand All @@ -42,10 +44,12 @@
__all__ = [
"AUCPR",
"AUCROC",
"Accuracy",
"Average",
"AveragePrecisionAtK",
"BLEU",
"DCGAtK",
"MAE",
"MRR",
"MSE",
"NDCGAtK",
Expand Down
67 changes: 67 additions & 0 deletions src/metrax/classification_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,73 @@ def _default_threshold(num_thresholds: int) -> jax.Array:
return thresholds


@flax.struct.dataclass
class Accuracy(base.Average):
r"""Computes accuracy, which is the frequency with which `predictions` match `labels`.

This metric calculates the proportion of correct predictions by comparing
`predictions` and `labels` element-wise. It is the ratio of the sum of
weighted correct predictions to the sum of all corresponding weights.
If no `sample_weights` are provided, weights default to 1 for each element.

The calculation is as follows:

.. math::
\text{Accuracy} = \frac{\sum (\text{weight} \times \text{correct})}{\sum
\text{weight}}

where `correct` is 1 if `prediction == label` for an element, and 0 otherwise.
`weight` is the `sample_weight` for that element, or 1 if no weights are
given.
"""

@classmethod
def from_model_output(
cls,
predictions: jax.Array,
labels: jax.Array,
sample_weights: jax.Array | None = None,
) -> 'Accuracy':
"""Updates the metric state with new `predictions` and `labels`.

This method computes element-wise equality between `predictions` and
`labels`. The result of this comparison (a boolean array, treated as 1 for
True and 0 for False) is then used to update the metric's `total` and
`count`.

Args:
predictions: JAX array of predicted values. Expected to have a shape
compatible with `labels` for element-wise comparison (e.g.,
`(batch_size,)`, `(batch_size, num_classes)`, `(batch_size,
sequence_length, num_features)`).
labels: JAX array of true values. Expected to have a shape compatible with
`predictions` for element-wise comparison.
sample_weights: Optional JAX array of weights. If provided, it must be
broadcastable to the shape of `labels` (which should also be compatible
with `predictions`' shape).

Returns:
An updated instance of `Accuracy` metric.

Raises:
ValueError: If JAX operations (like broadcasting or arithmetic) fail due
to incompatible shapes or types among `predictions`, `labels`, and
`sample_weights`. For instance, if `predictions` and `labels` shapes
are not identical and not broadcastable to a common shape for
comparison, or if `sample_weights` cannot be broadcast to `labels`'
shape.
"""
correct = predictions == labels
count = jnp.ones_like(labels, dtype=jnp.int32)
if sample_weights is not None:
correct = correct * sample_weights
count = count * sample_weights
return cls(
total=correct.sum(),
count=count.sum(),
)


@flax.struct.dataclass
class Precision(clu_metrics.Metric):
r"""Computes precision for binary classification given `predictions` and `labels`.
Expand Down
31 changes: 31 additions & 0 deletions src/metrax/classification_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@ def test_aucroc_empty(self):
self.assertEqual(m.false_negatives, jnp.array(0, jnp.float32))
self.assertEqual(m.num_thresholds, 0)

@parameterized.named_parameters(
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS),
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS),
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS),
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
)
def test_accuracy(self, y_true, y_pred, sample_weights):
"""Test that `Accuracy` metric computes correct values."""
if sample_weights is None:
sample_weights = np.ones_like(y_true)
metrax_accuracy = metrax.Accuracy.empty()
keras_accuracy = keras.metrics.Accuracy()
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
update = metrax.Accuracy.from_model_output(
predictions=logits,
labels=labels,
sample_weights=weights,
)
metrax_accuracy = metrax_accuracy.merge(update)
keras_accuracy.update_state(labels, logits, weights)

# Use lower tolerance for lower precision dtypes.
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-5
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-5
np.testing.assert_allclose(
metrax_accuracy.compute(),
keras_accuracy.result(),
rtol=rtol,
atol=atol,
)

@parameterized.named_parameters(
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.5),
('high_threshold_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, 0.7),
Expand Down
10 changes: 10 additions & 0 deletions src/metrax/metrax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
class MetraxTest(parameterized.TestCase):

@parameterized.named_parameters(
(
'accuracy',
metrax.Accuracy,
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
),
(
'aucpr',
metrax.AUCPR,
Expand Down Expand Up @@ -88,6 +93,11 @@ class MetraxTest(parameterized.TestCase):
'ks': KS,
},
),
(
'mae',
metrax.MAE,
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
),
(
'mse',
metrax.MSE,
Expand Down
3 changes: 3 additions & 0 deletions src/metrax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

AUCPR = nnx_metrics.AUCPR
AUCROC = nnx_metrics.AUCROC
Accuracy = nnx_metrics.Accuracy
Average = nnx_metrics.Average
AveragePrecisionAtK = nnx_metrics.AveragePrecisionAtK
BLEU = nnx_metrics.BLEU
DCGAtK = nnx_metrics.DCGAtK
MAE = nnx_metrics.MAE
MRR = nnx_metrics.MRR
MSE = nnx_metrics.MSE
NDCGAtK = nnx_metrics.NDCGAtK
Expand All @@ -43,6 +45,7 @@
"BLEU",
"DCGAtK",
"MRR",
"MAE"
"MSE",
"NDCGAtK",
"Perplexity",
Expand Down
13 changes: 13 additions & 0 deletions src/metrax/nnx/nnx_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def __init__(self):
super().__init__(metrax.AUCROC)


class Accuracy(NnxWrapper):
"""An NNX class for the Metrax metric Accuracy."""

def __init__(self):
super().__init__(metrax.Accuracy)

class Average(NnxWrapper):
"""An NNX class for the Metrax metric Average."""

Expand Down Expand Up @@ -60,6 +66,13 @@ def __init__(self):
super().__init__(metrax.DCGAtK)


class MAE(NnxWrapper):
"""An NNX class for the Metrax metric MAE."""

def __init__(self):
super().__init__(metrax.MAE)


class MRR(NnxWrapper):
"""An NNX class for the Metrax metric MRR."""

Expand Down
56 changes: 56 additions & 0 deletions src/metrax/regression_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,62 @@
from metrax import base


@flax.struct.dataclass
class MAE(base.Average):
r"""Computes the mean absolute error for regression problems given `predictions` and `labels`.

The mean absolute error without sample weights is defined as:

.. math::
MAE = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i|

When sample weights :math:`w_i` are provided, the weighted mean absolute error
is:

.. math::
MAE = \frac{\sum_{i=1}^{N} w_i|y_i - \hat{y}_i|}{\sum_{i=1}^{N} w_i}

where:
- :math:`y_i` are true values
- :math:`\hat{y}_i` are predictions
- :math:`w_i` are sample weights
- :math:`N` is the number of samples
"""

@classmethod
def from_model_output(
cls,
predictions: jax.Array,
labels: jax.Array,
sample_weights: jax.Array | None = None,
) -> 'MAE':
"""Updates the metric.

Args:
predictions: A floating point 1D vector representing the prediction
generated from the model. The shape should be (batch_size,).
labels: True value. The shape should be (batch_size,).
sample_weights: An optional floating point 1D vector representing the
weight of each sample. The shape should be (batch_size,).

Returns:
Updated MAE metric. The shape should be a single scalar.

Raises:
ValueError: If type of `labels` is wrong or the shapes of `predictions`
and `labels` are incompatible.
"""
absolute_error = jnp.abs(predictions - labels)
count = jnp.ones_like(labels, dtype=jnp.int32)
if sample_weights is not None:
absolute_error = absolute_error * sample_weights
count = count * sample_weights
return cls(
total=absolute_error.sum(),
count=count.sum(),
)


@flax.struct.dataclass
class MSE(base.Average):
r"""Computes the mean squared error for regression problems given `predictions` and `labels`.
Expand Down
42 changes: 42 additions & 0 deletions src/metrax/regression_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,48 @@ def test_rsquared_empty(self):
self.assertEqual(m.sum_of_squared_error, jnp.array(0, jnp.float32))
self.assertEqual(m.sum_of_squared_label, jnp.array(0, jnp.float32))

@parameterized.named_parameters(
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, None),
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS),
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS),
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS),
)
def test_mae(self, y_true, y_pred, sample_weights):
"""Test that `MAE` Metric computes correct values."""
y_true = y_true.astype(y_pred.dtype)
y_pred = y_pred.astype(y_true.dtype)
if sample_weights is None:
sample_weights = np.ones_like(y_true)

metric = None
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
update = metrax.MAE.from_model_output(
predictions=logits,
labels=labels,
sample_weights=weights,
)
metric = update if metric is None else metric.merge(update)

# TODO(jiwonshin): Use `keras.metrics.MeanAbsoluteError` once it supports
# sample weights.
expected = sklearn_metrics.mean_absolute_error(
y_true.flatten(),
y_pred.flatten(),
sample_weight=sample_weights.flatten(),
)
# Use lower tolerance for lower precision dtypes.
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
np.testing.assert_allclose(
metric.compute(),
expected,
rtol=rtol,
atol=atol,
)

@parameterized.named_parameters(
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),
Expand Down