Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions frouros/detectors/data_drift/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def _specific_checks(self, X: np.ndarray) -> None: # noqa: N803

@abc.abstractmethod
def _apply_method(
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
self,
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
**kwargs,
) -> Any:
pass

Expand All @@ -110,9 +113,13 @@ def _compare(
pass

def _get_result(
self, X: np.ndarray, **kwargs # noqa: N803
self,
X: np.ndarray, # noqa: N803
**kwargs,
) -> Union[List[float], List[Tuple[float, float]], Tuple[float, float]]:
result = self._apply_method( # type: ignore # pylint: disable=not-callable
X_ref=self.X_ref, X=X, **kwargs
X_ref=self.X_ref,
X=X,
**kwargs,
)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@ class AndersonDarlingTest(BaseStatisticalTest):

:param callbacks: callbacks, defaults to None
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
:param kwargs: additional keyword arguments to pass to scipy.stats.anderson_ksamp
:type kwargs: Dict[str, Any]

:Note:
p-values are bounded between 0.001 and 0.25 according to scipy documentation [1]_.
- Passing additional arguments to `scipy.stats.anderson_ksamp <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.anderson_ksamp.html>`__ can be done using :func:`compare` kwargs.
- p-values are bounded between 0.001 and 0.25 according to `scipy.stats.anderson_ksamp <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.anderson_ksamp.html>`__.

:References:

.. [scholz1987k] Scholz, Fritz W., and Michael A. Stephens.
"K-sample Anderson–Darling tests."
Journal of the American Statistical Association 82.399 (1987): 918-924.
[1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.anderson_ksamp.html # noqa: E501 # pylint: disable=line-too-long

:Example:

Expand All @@ -42,29 +40,30 @@ class AndersonDarlingTest(BaseStatisticalTest):
>>> _ = detector.fit(X=X)
>>> detector.compare(X=Y)[0]
StatisticalResult(statistic=32.40316586267425, p_value=0.001)
"""
""" # noqa: E501 # pylint: disable=line-too-long

def __init__( # noqa: D107
self,
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
**kwargs,
) -> None:
super().__init__(
data_type=NumericalData(),
statistical_type=UnivariateData(),
callbacks=callbacks,
)
self.kwargs = kwargs

@staticmethod
def _statistical_test(
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
**kwargs,
) -> StatisticalResult:
test = anderson_ksamp(
samples=[
X_ref,
X,
],
**self.kwargs,
**kwargs,
)
test = StatisticalResult(
statistic=test.statistic,
Expand Down
18 changes: 14 additions & 4 deletions frouros/detectors/data_drift/batch/statistical_test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@ class BaseStatisticalTest(BaseDataDriftBatch):
"""Abstract class representing a statistical test."""

def _apply_method(
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
self,
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
**kwargs,
) -> Tuple[float, float]:
statistical_test = self._statistical_test(X_ref=X_ref, X=X, **kwargs)
statistical_test = self._statistical_test(
X_ref=X_ref,
X=X,
**kwargs,
)
return statistical_test

def _compare(
Expand All @@ -30,8 +37,11 @@ def _compare(
result = self._get_result(X=X, **kwargs)
return result # type: ignore

@staticmethod
@abc.abstractmethod
def _statistical_test(
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
) -> Tuple[float, float]:
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
**kwargs,
) -> StatisticalResult:
pass
30 changes: 20 additions & 10 deletions frouros/detectors/data_drift/batch/statistical_test/chisquare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ class ChiSquareTest(BaseStatisticalTest):

:param callbacks: callbacks, defaults to None
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
:param kwargs: additional keyword arguments to pass to scipy.stats.chi2_contingency
:type kwargs: Dict[str, Any]

:Note:
- Passing additional arguments to `scipy.stats.chi2_contingency <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chi2_contingency.html>`__ can be done using :func:`compare` kwargs.

:References:

Expand All @@ -42,34 +43,43 @@ class ChiSquareTest(BaseStatisticalTest):
>>> _ = detector.fit(X=X)
>>> detector.compare(X=Y)[0]
StatisticalResult(statistic=9.81474665685192, p_value=0.0017311812135839511)
"""
""" # noqa: E501 # pylint: disable=line-too-long

def __init__( # noqa: D107
self,
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
**kwargs,
) -> None:
super().__init__(
data_type=CategoricalData(),
statistical_type=UnivariateData(),
callbacks=callbacks,
)
self.kwargs = kwargs

@staticmethod
def _statistical_test(
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
**kwargs,
) -> StatisticalResult:
f_exp, f_obs = self._calculate_frequencies(X_ref=X_ref, X=X)
f_exp, f_obs = ChiSquareTest._calculate_frequencies(
X_ref=X_ref,
X=X,
)
statistic, p_value, _, _ = chi2_contingency(
observed=np.array([f_obs, f_exp]), **self.kwargs
observed=np.array([f_obs, f_exp]),
**kwargs,
)

test = StatisticalResult(statistic=statistic, p_value=p_value)
test = StatisticalResult(
statistic=statistic,
p_value=p_value,
)
return test

@staticmethod
def _calculate_frequencies(
X_ref: np.ndarray, X: np.ndarray # noqa: N803
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
) -> Tuple[List[int], List[int]]:
X_ref_counter, X_counter = [ # noqa: N806
*map(collections.Counter, [X_ref, X]) # noqa: N806
Expand Down
21 changes: 13 additions & 8 deletions frouros/detectors/data_drift/batch/statistical_test/cvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ class CVMTest(BaseStatisticalTest):

:param callbacks: callbacks, defaults to None
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
:param kwargs: additional keyword arguments to pass to scipy.stats.cramervonmises_2samp
:type kwargs: Dict[str, Any]

:Note:
- Passing additional arguments to `scipy.stats.cramervonmises_2samp <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.cramervonmises_2samp.html>`__ can be done using :func:`compare` kwargs.

:References:

Expand All @@ -39,19 +40,17 @@ class CVMTest(BaseStatisticalTest):
>>> _ = detector.fit(X=X)
>>> detector.compare(X=Y)[0]
StatisticalResult(statistic=5.331699999999998, p_value=1.7705426014202885e-10)
""" # noqa: E501
""" # noqa: E501 # pylint: disable=line-too-long

def __init__( # noqa: D107
self,
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
**kwargs,
) -> None:
super().__init__(
data_type=NumericalData(),
statistical_type=UnivariateData(),
callbacks=callbacks,
)
self.kwargs = kwargs

@BaseStatisticalTest.X_ref.setter # type: ignore[attr-defined]
def X_ref(self, value: Optional[np.ndarray]) -> None: # noqa: N802
Expand All @@ -75,13 +74,19 @@ def _check_sufficient_samples(X: np.ndarray) -> None: # noqa: N803
if X.shape[0] < 2:
raise InsufficientSamplesError("Number of samples must be at least 2.")

@staticmethod
def _statistical_test(
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
**kwargs,
) -> StatisticalResult:
test = cramervonmises_2samp(
x=X_ref,
y=X,
method=self.kwargs.get("method", "auto"),
**kwargs,
)
test = StatisticalResult(
statistic=test.statistic,
p_value=test.pvalue,
)
test = StatisticalResult(statistic=test.statistic, p_value=test.pvalue)
return test
23 changes: 14 additions & 9 deletions frouros/detectors/data_drift/batch/statistical_test/ks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class KSTest(BaseStatisticalTest):

:param callbacks: callbacks, defaults to None
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
:param kwargs: additional keyword arguments to pass to scipy.stats.ks_2samp
:type kwargs: Dict[str, Any]

:Note:
- Passing additional arguments to `scipy.stats.ks_2samp <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ks_2samp.html>`__ can be done using :func:`compare` kwargs.

:References:

Expand All @@ -38,28 +39,32 @@ class KSTest(BaseStatisticalTest):
>>> _ = detector.fit(X=X)
>>> detector.compare(X=Y)[0]
StatisticalResult(statistic=0.55, p_value=3.0406585087050305e-14)
"""
""" # noqa: E501 # pylint: disable=line-too-long

def __init__( # noqa: D107
self,
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
**kwargs,
) -> None:
super().__init__(
data_type=NumericalData(),
statistical_type=UnivariateData(),
callbacks=callbacks,
)
self.kwargs = kwargs

@staticmethod
def _statistical_test(
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
**kwargs,
) -> StatisticalResult:
test = ks_2samp(
data1=X_ref,
data2=X,
alternative=self.kwargs.get("alternative", "two-sided"),
method=self.kwargs.get("method", "auto"),
alternative=kwargs.get("alternative", "two-sided"),
method=kwargs.get("method", "auto"),
)
test = StatisticalResult(
statistic=test.statistic,
p_value=test.pvalue,
)
test = StatisticalResult(statistic=test.statistic, p_value=test.pvalue)
return test
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class MannWhitneyUTest(BaseStatisticalTest):

:param callbacks: callbacks, defaults to None
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
:param kwargs: additional keyword arguments to pass to scipy.stats.mannwhitneyu
:type kwargs: Dict[str, Any]

:Note:
- Passing additional arguments to `scipy.stats.mannwhitneyu <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.mannwhitneyu.html>`__ can be done using :func:`compare` kwargs.

:References:

Expand All @@ -39,29 +40,30 @@ class MannWhitneyUTest(BaseStatisticalTest):
>>> _ = detector.fit(X=X)
>>> detector.compare(X=Y)[0]
StatisticalResult(statistic=2139.0, p_value=2.7623373527697943e-12)
"""
""" # noqa: E501 # pylint: disable=line-too-long

def __init__( # noqa: D107
self,
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
**kwargs,
) -> None:
super().__init__(
data_type=NumericalData(),
statistical_type=UnivariateData(),
callbacks=callbacks,
)
self.kwargs = kwargs

@staticmethod
def _statistical_test(
self, X_ref: np.ndarray, X: np.ndarray, **kwargs # noqa: N803
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
**kwargs,
) -> StatisticalResult:
test = mannwhitneyu( # pylint: disable=unexpected-keyword-arg
x=X_ref,
y=X,
alternative="two-sided",
nan_policy="raise",
**self.kwargs,
alternative=kwargs.get("alternative", "two-sided"),
nan_policy=kwargs.get("nan_policy", "raise"),
**kwargs,
)
test = StatisticalResult(
statistic=test.statistic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class WelchTTest(BaseStatisticalTest):

:param callbacks: callbacks, defaults to None
:type callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]]
:param kwargs: additional keyword arguments to pass to scipy.stats.ttest_ind
:type kwargs: Dict[str, Any]

:Note:
- Passing additional arguments to `scipy.stats.ttest_ind <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ttest_ind.html>`__ can be done using :func:`compare` kwargs.

:References:

Expand All @@ -39,22 +40,20 @@ class WelchTTest(BaseStatisticalTest):
>>> _ = detector.fit(X=X)
>>> detector.compare(X=Y)[0]
StatisticalResult(statistic=-7.651304662806378, p_value=8.685225410826823e-13)
"""
""" # noqa: E501 # pylint: disable=line-too-long

def __init__( # noqa: D107
self,
callbacks: Optional[Union[BaseCallbackBatch, List[BaseCallbackBatch]]] = None,
**kwargs,
) -> None:
super().__init__(
data_type=NumericalData(),
statistical_type=UnivariateData(),
callbacks=callbacks,
)
self.kwargs = kwargs

@staticmethod
def _statistical_test(
self,
X_ref: np.ndarray, # noqa: N803
X: np.ndarray,
**kwargs,
Expand All @@ -63,8 +62,8 @@ def _statistical_test(
a=X_ref,
b=X,
equal_var=False,
alternative="two-sided",
**self.kwargs,
alternative=kwargs.get("alternative", "two-sided"),
**kwargs,
)
test = StatisticalResult(
statistic=test.statistic,
Expand Down