Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions frouros/callbacks/batch/permutation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class PermutationTestDistanceBased(BaseCallbackBatch):
:type num_permutations: int
:param num_jobs: number of jobs, defaults to -1
:type num_jobs: int
:param conservative: conservative flag, defaults to False. If False, the p-value can be zero `(#permuted_statistics >= observed_statistic) / num_permutations`. If True, uses the conservative approach to avoid zero p-value `((#permuted_statistics >= observed_statistic) + 1) / (num_permutations + 1)`.
:type conservative: bool
:param random_state: random state, defaults to None
:type random_state: Optional[int]
:param verbose: verbose flag, defaults to False
:type verbose: bool
:param name: name value, defaults to None. If None, the name will be set to `PermutationTestDistanceBased`.
Expand Down Expand Up @@ -49,15 +53,17 @@ def __init__( # noqa: D107
self,
num_permutations: int,
num_jobs: int = -1,
conservative: bool = False,
random_state: Optional[int] = None,
verbose: bool = False,
name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(name=name)
self.num_permutations = num_permutations
self.num_jobs = num_jobs
self.conservative = conservative
self.random_state = random_state
self.verbose = verbose
self.permutation_kwargs = kwargs

@property
def num_permutations(self) -> int:
Expand Down Expand Up @@ -101,6 +107,27 @@ def num_jobs(self, value: int) -> None:
raise ValueError("value must be greater than 0 or -1.")
self._num_jobs = multiprocessing.cpu_count() if value == -1 else value

@property
def conservative(self) -> bool:
"""Conservative (avoid zero p-value) flag property.

:return: conservative flag
:rtype: bool
"""
return self._conservative

@conservative.setter
def conservative(self, value: bool) -> None:
"""Conservative (avoid zero p-value) flag setter.

:param value: value to be set
:type value: bool
:raises TypeError: Type error exception
"""
if not isinstance(value, bool):
raise TypeError("value must of type bool.")
self._conservative = value

@property
def verbose(self) -> bool:
"""Verbose flag property.
Expand Down Expand Up @@ -131,7 +158,8 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
observed_statistic: float,
num_permutations: int,
num_jobs: int,
random_state: int,
conservative: bool,
random_state: Optional[int],
verbose: bool,
) -> Tuple[List[float], float]:
permuted_statistic = permutation(
Expand All @@ -145,7 +173,12 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
verbose=verbose,
)
permuted_statistic = np.array(permuted_statistic)
p_value = (permuted_statistic >= observed_statistic).mean() # type: ignore
p_value = (
((permuted_statistic >= observed_statistic).sum() + 1) # type: ignore
/ (num_permutations + 1)
if conservative
else (permuted_statistic >= observed_statistic).mean() # type: ignore
)
return permuted_statistic, p_value

def on_compare_end(
Expand All @@ -172,8 +205,9 @@ def on_compare_end(
observed_statistic=observed_statistic,
num_permutations=self.num_permutations,
num_jobs=self.num_jobs,
conservative=self.conservative,
random_state=self.random_state,
verbose=self.verbose,
**self.permutation_kwargs,
)
self.logs.update(
{
Expand Down
36 changes: 35 additions & 1 deletion frouros/tests/integration/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_batch_permutation_test_data_univariate_different_distribution(
expected_distance: float,
expected_p_value: float,
) -> None:
"""Test batch permutation test on data callback.
"""Test batch permutation test on data drift callback.

:param X_ref_univariate: reference univariate data
:type X_ref_univariate: numpy.ndarray
Expand Down Expand Up @@ -101,6 +101,40 @@ def test_batch_permutation_test_data_univariate_different_distribution(
)


def test_batch_permutation_test_conservative(
X_ref_univariate: np.ndarray, # noqa: N803
X_test_univariate: np.ndarray,
) -> None:
"""Test batch permutation test on data drift callback using conservative flag.

:param X_ref_univariate: reference univariate data
:type X_ref_univariate: numpy.ndarray
:param X_test_univariate: test univariate data
:type X_test_univariate: numpy.ndarray
"""
np.random.seed(seed=31)

permutation_test_name = "permutation_test"
detector = MMD( # type: ignore
callbacks=[
PermutationTestDistanceBased(
num_permutations=100,
conservative=True,
random_state=31,
num_jobs=-1,
name=permutation_test_name,
)
]
)
_ = detector.fit(X=X_ref_univariate)
_, callback_logs = detector.compare(X=X_test_univariate)

assert np.isclose(
callback_logs[permutation_test_name]["p_value"],
0.00990099,
)


@pytest.mark.parametrize(
"detector_class",
[AndersonDarlingTest, CVMTest, KSTest, MannWhitneyUTest, WelchTTest],
Expand Down