Skip to content

Commit 0773892

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #306 from IFCA-Advanced-Computing/feature-bws-test
Add Baumgartner-Weiss-Schindler test data drift method
2 parents 622949f + 786a01c commit 0773892

File tree

9 files changed

+128
-20
lines changed

9 files changed

+128
-20
lines changed

README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,8 @@ The currently implemented detectors are listed in the following table.
335335
<td style="text-align: center; border: 1px solid grey; padding: 8px;"><a href="https://doi.org/10.1007/978-3-540-75488-6_27">Nishida and Yamauchi (2007)</a></td>
336336
</tr>
337337
<tr>
338-
<td rowspan="17" style="text-align: center; border: 1px solid grey; padding: 8px;">Data drift</td>
339-
<td rowspan="15" style="text-align: center; border: 1px solid grey; padding: 8px;">Batch</td>
338+
<td rowspan="18" style="text-align: center; border: 1px solid grey; padding: 8px;">Data drift</td>
339+
<td rowspan="16" style="text-align: center; border: 1px solid grey; padding: 8px;">Batch</td>
340340
<td rowspan="10" style="text-align: center; border: 1px solid grey; padding: 8px;">Distance based</td>
341341
<td style="text-align: center; border: 1px solid grey; padding: 8px;">U</td>
342342
<td style="text-align: center; border: 1px solid grey; padding: 8px;">N</td>
@@ -398,12 +398,18 @@ The currently implemented detectors are listed in the following table.
398398
<td style="text-align: center; border: 1px solid grey; padding: 8px;"><a href="https://doi.org/10.1057/jors.2008.144">Wu and Olson (2010)</a></td>
399399
</tr>
400400
<tr>
401-
<td rowspan="5" style="text-align: center; border: 1px solid grey; padding: 8px;">Statistical test</td>
401+
<td rowspan="6" style="text-align: center; border: 1px solid grey; padding: 8px;">Statistical test</td>
402402
<td style="text-align: center; border: 1px solid grey; padding: 8px;">U</td>
403403
<td style="text-align: center; border: 1px solid grey; padding: 8px;">C</td>
404404
<td style="text-align: center; border: 1px solid grey; padding: 8px;">Chi-square test</td>
405405
<td style="text-align: center; border: 1px solid grey; padding: 8px;"><a href="https://doi.org/10.1080/14786440009463897">Pearson (1900)</a></td>
406406
</tr>
407+
<tr>
408+
<td style="text-align: center; border: 1px solid grey; padding: 8px;">U</td>
409+
<td style="text-align: center; border: 1px solid grey; padding: 8px;">N</td>
410+
<td style="text-align: center; border: 1px solid grey; padding: 8px;">Baumgartner-Weiss-Schindler test</td>
411+
<td style="text-align: center; border: 1px solid grey; padding: 8px;"><a href="https://doi.org/10.2307/2533862">Baumgartner et al. (1998)</a></td>
412+
</tr>
407413
<tr>
408414
<td style="text-align: center; border: 1px solid grey; padding: 8px;">U</td>
409415
<td style="text-align: center; border: 1px solid grey; padding: 8px;">N</td>

docs/source/api_reference/detectors/data_drift/batch.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ The {mod}`frouros.detectors.data_drift.batch` module contains batch data drift d
4949
:template: class.md
5050
5151
AndersonDarlingTest
52+
BWSTest
5253
ChiSquareTest
5354
CVMTest
5455
KSTest

frouros/detectors/data_drift/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .batch import ( # noqa: F401
44
AndersonDarlingTest,
55
BhattacharyyaDistance,
6+
BWSTest,
67
ChiSquareTest,
78
CVMTest,
89
EMD,

frouros/detectors/data_drift/batch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from .statistical_test import (
1515
AndersonDarlingTest,
16+
BWSTest,
1617
ChiSquareTest,
1718
CVMTest,
1819
KSTest,
@@ -22,6 +23,7 @@
2223

2324
__all__ = [
2425
"AndersonDarlingTest",
26+
"BWSTest",
2527
"BhattacharyyaDistance",
2628
"ChiSquareTest",
2729
"CVMTest",

frouros/detectors/data_drift/batch/statistical_test/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Data drift batch statistical test detection methods' init."""
22

33
from .anderson_darling import AndersonDarlingTest
4+
from .bws import BWSTest
45
from .chisquare import ChiSquareTest
56
from .cvm import CVMTest
67
from .ks import KSTest
@@ -9,6 +10,7 @@
910

1011
__all__ = [
1112
"AndersonDarlingTest",
13+
"BWSTest",
1214
"ChiSquareTest",
1315
"CVMTest",
1416
"KSTest",
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""BWSTest (Baumgartner-Weiss-Schindler test) module."""
2+
3+
from typing import Optional, Union
4+
5+
import numpy as np # type: ignore
6+
from scipy.stats import bws_test # type: ignore
7+
8+
from frouros.callbacks.batch.base import BaseCallbackBatch
9+
from frouros.detectors.data_drift.base import NumericalData, UnivariateData
10+
from frouros.detectors.data_drift.batch.statistical_test.base import (
11+
BaseStatisticalTest,
12+
StatisticalResult,
13+
)
14+
15+
16+
class BWSTest(BaseStatisticalTest):
17+
"""BWSTest (Baumgartner-Weiss-Schindler test) [baumgartner1998nonparametric]_ detector.
18+
19+
:param callbacks: callbacks, defaults to None
20+
:type callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]]
21+
22+
:Note:
23+
- Passing additional arguments to `scipy.stats.bws_test <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.bws_test.html>`__ can be done using :func:`compare` kwargs.
24+
25+
:References:
26+
27+
.. [baumgartner1998nonparametric] Baumgartner, W., P. Weiß, and H. Schindler.
28+
"A nonparametric test for the general two-sample problem."
29+
Biometrics (1998): 1129-1135.
30+
31+
:Example:
32+
33+
>>> from frouros.detectors.data_drift import BWSTest
34+
>>> import numpy as np
35+
>>> np.random.seed(seed=31)
36+
>>> X = np.random.normal(loc=0, scale=1, size=100)
37+
>>> Y = np.random.normal(loc=1, scale=1, size=100)
38+
>>> detector = BWSTest()
39+
>>> _ = detector.fit(X=X)
40+
>>> detector.compare(X=Y)[0]
41+
StatisticalResult(statistic=29.942072035675395, p_value=0.0001)
42+
""" # noqa: E501 # pylint: disable=line-too-long
43+
44+
def __init__( # noqa: D107
45+
self,
46+
callbacks: Optional[Union[BaseCallbackBatch, list[BaseCallbackBatch]]] = None,
47+
) -> None:
48+
super().__init__(
49+
data_type=NumericalData(),
50+
statistical_type=UnivariateData(),
51+
callbacks=callbacks,
52+
)
53+
54+
@staticmethod
55+
def _statistical_test(
56+
X_ref: np.ndarray, # noqa: N803
57+
X: np.ndarray,
58+
**kwargs,
59+
) -> StatisticalResult:
60+
test = bws_test(
61+
x=X_ref,
62+
y=X,
63+
alternative=kwargs.get("alternative", "two-sided"),
64+
method=kwargs.get("method", None),
65+
)
66+
test = StatisticalResult(
67+
statistic=test.statistic,
68+
p_value=test.pvalue,
69+
)
70+
return test

frouros/tests/integration/test_callback.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from frouros.detectors.concept_drift.base import BaseConceptDrift
2828
from frouros.detectors.data_drift.batch import (
2929
AndersonDarlingTest,
30+
BWSTest,
3031
BhattacharyyaDistance,
3132
CVMTest,
3233
EMD,
@@ -137,7 +138,14 @@ def test_batch_permutation_test_conservative(
137138

138139
@pytest.mark.parametrize(
139140
"detector_class",
140-
[AndersonDarlingTest, CVMTest, KSTest, MannWhitneyUTest, WelchTTest],
141+
[
142+
AndersonDarlingTest,
143+
BWSTest,
144+
CVMTest,
145+
KSTest,
146+
MannWhitneyUTest,
147+
WelchTTest,
148+
],
141149
)
142150
def test_batch_reset_on_statistical_test_data_drift(
143151
X_ref_univariate, # noqa: N803
@@ -153,6 +161,8 @@ def test_batch_reset_on_statistical_test_data_drift(
153161
:type X_test_univariate: numpy.ndarray
154162
:param detector_class: detector distance
155163
:type detector_class: BaseDataDriftBatch
164+
:param mocker: mocker
165+
:type mocker: pytest_mock.mocker
156166
"""
157167
mocker.patch("frouros.detectors.data_drift.batch.base.BaseDataDriftBatch.reset")
158168

frouros/tests/integration/test_data_drift.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,17 @@
44

55
import numpy as np # type: ignore
66
import pytest # type: ignore
7+
from scipy.stats import PermutationMethod # type: ignore
78

9+
from frouros.detectors.data_drift.batch import (
10+
AndersonDarlingTest,
11+
BWSTest,
12+
ChiSquareTest,
13+
CVMTest,
14+
KSTest,
15+
MannWhitneyUTest,
16+
WelchTTest,
17+
)
818
from frouros.detectors.data_drift.batch import (
919
BhattacharyyaDistance,
1020
EMD,
@@ -16,14 +26,6 @@
1626
KL,
1727
MMD,
1828
)
19-
from frouros.detectors.data_drift.batch import (
20-
AndersonDarlingTest,
21-
ChiSquareTest,
22-
CVMTest,
23-
KSTest,
24-
MannWhitneyUTest,
25-
WelchTTest,
26-
)
2729
from frouros.detectors.data_drift.batch.base import BaseDataDriftBatch
2830
from frouros.detectors.data_drift.streaming import ( # noqa: N811
2931
IncrementalKSTest,
@@ -161,20 +163,32 @@ def test_batch_distance_bins_based_univariate_same_distribution(
161163

162164

163165
@pytest.mark.parametrize(
164-
"detector, expected_statistic, expected_p_value",
166+
"detector, expected_statistic, expected_p_value, kwargs",
165167
[
166-
(AndersonDarlingTest(), 23171.19994366, 0.001),
167-
(CVMTest(), 3776.09848103, 5.38105056e-07),
168-
(KSTest(), 0.99576271, 0.0),
169-
(MannWhitneyUTest(), 6912.0, 0.0),
170-
(WelchTTest(), -287.92032554, 0.0),
168+
(AndersonDarlingTest(), 23171.19994366, 0.001, {}),
169+
(
170+
BWSTest(),
171+
108757.63520694,
172+
0.00990099,
173+
{
174+
"method": PermutationMethod(
175+
n_resamples=100,
176+
random_state=31,
177+
),
178+
},
179+
),
180+
(CVMTest(), 3776.09848103, 5.38105056e-07, {}),
181+
(KSTest(), 0.99576271, 0.0, {}),
182+
(MannWhitneyUTest(), 6912.0, 0.0, {}),
183+
(WelchTTest(), -287.92032554, 0.0, {}),
171184
],
172185
)
173186
def test_batch_statistical_univariate(
174187
elec2_dataset: Tuple[np.ndarray, np.ndarray, np.ndarray],
175188
detector: BaseDataDriftBatch,
176189
expected_statistic: float,
177190
expected_p_value: float,
191+
kwargs: dict,
178192
) -> None:
179193
"""Test statistical univariate method.
180194
@@ -186,11 +200,13 @@ def test_batch_statistical_univariate(
186200
:type expected_statistic: float
187201
:param expected_p_value: expected p-value value
188202
:type expected_p_value: float
203+
:param kwargs: additional arguments
204+
:type kwargs: dict
189205
"""
190206
X_ref, _, X_test = elec2_dataset # noqa: N806
191207

192208
_ = detector.fit(X=X_ref[:, 0])
193-
(statistic, p_value), _ = detector.compare(X=X_test[:, 0])
209+
(statistic, p_value), _ = detector.compare(X=X_test[:, 0], **kwargs)
194210

195211
assert np.isclose(statistic, expected_statistic)
196212
assert np.isclose(p_value, expected_p_value)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ dependencies = [
3939
"matplotlib>=3.8.2,<3.9",
4040
"numpy>=1.26.3,<1.27",
4141
"requests>=2.31.0,<2.32",
42-
"scipy>=1.11.4,<1.13",
42+
"scipy>=1.12.0,<1.13",
4343
"tqdm>=4.66.1,<5.0",
4444
]
4545

0 commit comments

Comments
 (0)