Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ax/adapter/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,9 @@ def test_compute_diagnostics(self) -> None:
self.assertAlmostEqual(diag["Log likelihood"]["m2"], -25.82334285505847)
self.assertEqual(diag["MSE"]["m1"], 18.75)
self.assertEqual(diag["MSE"]["m2"], 38.75)
# Kendall tau rank correlation (NaN because y_pred is constant)
self.assertTrue(np.isnan(diag["Kendall tau rank correlation"]["m1"]))
self.assertTrue(np.isnan(diag["Kendall tau rank correlation"]["m2"]))

def test_assess_model_fit(self) -> None:
# Construct diagnostics
Expand Down
14 changes: 13 additions & 1 deletion ax/utils/stats/model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy.typing as npt

from ax.utils.common.logger import get_logger
from scipy.stats import fisher_exact, norm, pearsonr, spearmanr
from scipy.stats import fisher_exact, kendalltau, norm, pearsonr, spearmanr
from sklearn.neighbors import KernelDensity


Expand All @@ -31,6 +31,7 @@
FISHER_EXACT_TEST_P = "Fisher exact test p"
LOG_LIKELIHOOD = "Log likelihood"
MSE = "MSE"
KENDALL_TAU_RANK_CORRELATION = "Kendall tau rank correlation"


class ModelFitMetricDirection(Enum):
Expand Down Expand Up @@ -277,6 +278,16 @@ def _rank_correlation(
return float(rho)


def _kendall_tau_rank_correlation(
y_obs: npt.NDArray,
y_pred: npt.NDArray,
se_pred: npt.NDArray,
) -> float:
with np.errstate(invalid="ignore"):
rho, _ = kendalltau(x=y_pred, y=y_obs)
return float(rho)


def _fisher_exact_test_p(
y_obs: npt.NDArray,
y_pred: npt.NDArray,
Expand Down Expand Up @@ -325,6 +336,7 @@ def _fisher_exact_test_p(
TOTAL_RAW_EFFECT: _total_raw_effect,
CORRELATION_COEFFICIENT: _correlation_coefficient,
RANK_CORRELATION: _rank_correlation,
KENDALL_TAU_RANK_CORRELATION: _kendall_tau_rank_correlation,
FISHER_EXACT_TEST_P: _fisher_exact_test_p,
LOG_LIKELIHOOD: _log_likelihood,
MSE: _mse,
Expand Down
43 changes: 41 additions & 2 deletions ax/utils/stats/tests/test_model_fit_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,50 @@

import numpy as np
from ax.utils.common.testutils import TestCase
from ax.utils.stats.model_fit_stats import _fisher_exact_test_p, entropy_of_observations
from scipy.stats import fisher_exact
from ax.utils.stats.model_fit_stats import (
_fisher_exact_test_p,
_kendall_tau_rank_correlation,
entropy_of_observations,
)
from scipy.stats import fisher_exact, kendalltau


class TestModelFitStats(TestCase):
def test_kendall_tau_rank_correlation(self) -> None:
# Create a dummy set of observations and predictions
y_obs = np.array([1.0, 3.0, 2.0, 5.0, 7.0, 3.0])
y_pred = np.array([2.0, 4.0, 1.0, 6.0, 8.0, 2.5])
se_pred = np.full(len(y_obs), np.nan) # not used for kendall tau

# Compute expected result using scipy
expected_tau, _ = kendalltau(x=y_pred, y=y_obs)

# Compute result using ax
ax_result = _kendall_tau_rank_correlation(y_obs, y_pred, se_pred)

self.assertEqual(expected_tau, ax_result)

def test_kendall_tau_rank_correlation_with_ties(self) -> None:
# Test with tied values
y_obs = np.array([1.0, 2.0, 2.0, 3.0, 3.0, 3.0])
y_pred = np.array([1.0, 2.0, 2.0, 3.0, 3.0, 3.0])
se_pred = np.full(len(y_obs), np.nan)

expected_tau, _ = kendalltau(x=y_pred, y=y_obs)
ax_result = _kendall_tau_rank_correlation(y_obs, y_pred, se_pred)

self.assertEqual(expected_tau, ax_result)

def test_kendall_tau_rank_correlation_perfect_negative(self) -> None:
# Test with perfectly negatively correlated data
y_obs = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
y_pred = np.array([5.0, 4.0, 3.0, 2.0, 1.0])
se_pred = np.full(len(y_obs), np.nan)

ax_result = _kendall_tau_rank_correlation(y_obs, y_pred, se_pred)

self.assertAlmostEqual(ax_result, -1.0)

def test_entropy_of_observations(self) -> None:
np.random.seed(1234)
n = 16
Expand Down