Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 63 additions & 63 deletions docs/source/examples/data_drift/MMD_advance.ipynb

Large diffs are not rendered by default.

12 changes: 10 additions & 2 deletions frouros/callbacks/batch/permutation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np # type: ignore
from scipy.stats import norm # type: ignore

from frouros.callbacks.batch.base import BatchCallback
from frouros.utils.stats import permutation
from frouros.utils.stats import permutation, z_score


class PermutationTestOnBatchData(BatchCallback):
Expand Down Expand Up @@ -120,7 +121,14 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
random_state=random_state,
verbose=verbose,
)
p_value = (observed_statistic < permuted_statistic).mean() # type: ignore
permuted_statistic = np.array(permuted_statistic)
# Use z-score to calculate p-value
observed_z_score = z_score(
value=observed_statistic,
mean=permuted_statistic.mean(), # type: ignore
std=permuted_statistic.std(), # type: ignore
)
p_value = norm.sf(np.abs(observed_z_score)) * 2
return permuted_statistic, p_value

def on_compare_end(self, **kwargs) -> None:
Expand Down
20 changes: 12 additions & 8 deletions frouros/tests/integration/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@
"detector_class, expected_distance, expected_p_value",
[
(BhattacharyyaDistance, 0.55516059, 0.0),
(EMD, 3.85346006, 0.0),
(HellingerDistance, 0.74509099, 0.0),
(HINormalizedComplement, 0.78, 0.0),
(JS, 0.67010107, 0.0),
(KL, np.inf, 0.0),
(MMD, 0.69509004, 0.0),
(PSI, 461.20379435, 0.0),
(EMD, 3.85346006, 9.21632493e-101),
(HellingerDistance, 0.74509099, 3.13808126e-50),
(HINormalizedComplement, 0.78, 1.31340683e-55),
(JS, 0.67010107, 2.30485343e-63),
(KL, np.inf, np.nan),
(MMD, 0.69509004, 2.53277069e-137),
(PSI, 461.20379435, 4.45088795e-238),
],
)
def test_batch_permutation_test_data_univariate_different_distribution(
Expand Down Expand Up @@ -100,7 +100,11 @@ def test_batch_permutation_test_data_univariate_different_distribution(
distance, callback_logs = detector.compare(X=X_test_univariate)

assert np.isclose(distance, expected_distance)
assert np.isclose(callback_logs[permutation_test_name]["p_value"], expected_p_value)
assert np.isclose(
callback_logs[permutation_test_name]["p_value"],
expected_p_value,
equal_nan=True,
)


@pytest.mark.parametrize(
Expand Down
20 changes: 19 additions & 1 deletion frouros/utils/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,5 +262,23 @@ def permutation( # pylint: disable=too-many-arguments,too-many-locals
iterable=tqdm(permuted_data) if verbose else permuted_data,
).get()

# FIXME: explore if abs must be used in permuted_statistic # pylint: disable=fixme
return permuted_statistics


def z_score(
value: np.ndarray,
mean: float,
std: float,
) -> np.ndarray:
"""Z-score method.

:param value: value to use to compute the z-score
:type value: np.ndarray
:param mean: mean value
:type mean: float
:param std: standard deviation value
:type std: float
:return: z-score
:rtype: np.ndarray
"""
return (value - mean) / std