Skip to content

Fix callbacks classes #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 7, 2023
4 changes: 2 additions & 2 deletions docs/source/api_reference/callbacks/batch.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
:toctree: auto_generated/
:template: class.md

PermutationTestOnBatchData
ResetOnBatchDataDrift
PermutationTestDistanceBased
ResetStatisticalTest
```
2 changes: 1 addition & 1 deletion docs/source/api_reference/callbacks/streaming.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
:toctree: auto_generated/
:template: class.md

History
HistoryConceptDrift
mSPRT
WarningSamplesBuffer
```
4 changes: 2 additions & 2 deletions docs/source/examples/data_drift/MMD_advance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"from torch.utils.data import Dataset\n",
"import torchvision\n",
"\n",
"from frouros.callbacks import PermutationTestOnBatchData\n",
"from frouros.callbacks import PermutationTestDistanceBased\n",
"from frouros.detectors.data_drift import MMD"
],
"metadata": {
Expand Down Expand Up @@ -877,7 +877,7 @@
"\n",
"detector = MMD(\n",
" callbacks=[\n",
" PermutationTestOnBatchData(\n",
" PermutationTestDistanceBased(\n",
" num_permutations=1000,\n",
" random_state=seed,\n",
" num_jobs=-1,\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/data_drift/MMD_simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"from scipy.spatial.distance import pdist\n",
"from scipy.stats import multivariate_normal\n",
"\n",
"from frouros.callbacks import PermutationTestOnBatchData\n",
"from frouros.callbacks import PermutationTestDistanceBased\n",
"from frouros.detectors.data_drift import MMD"
],
"metadata": {
Expand Down Expand Up @@ -264,7 +264,7 @@
"source": [
"detector = MMD(\n",
" callbacks=[\n",
" PermutationTestOnBatchData(\n",
" PermutationTestDistanceBased(\n",
" num_permutations=1000,\n",
" random_state=seed,\n",
" num_jobs=-1,\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/data_drift/multivariate_detector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"from sklearn.metrics import accuracy_score\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"\n",
"from frouros.callbacks.batch import PermutationTestOnBatchData\n",
"from frouros.callbacks.batch import PermutationTestDistanceBased\n",
"from frouros.detectors.data_drift import MMD"
]
},
Expand Down Expand Up @@ -163,7 +163,7 @@
"source": [
"detector = MMD(\n",
" callbacks=[\n",
" PermutationTestOnBatchData(\n",
" PermutationTestDistanceBased(\n",
" num_permutations=1000,\n",
" random_state=31,\n",
" num_jobs=-1,\n",
Expand Down
10 changes: 5 additions & 5 deletions frouros/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Callbacks init."""

from .batch import PermutationTestOnBatchData, ResetOnBatchDataDrift
from .streaming import History, mSPRT, WarningSamplesBuffer
from .batch import PermutationTestDistanceBased, ResetStatisticalTest
from .streaming import HistoryConceptDrift, mSPRT, WarningSamplesBuffer

__all__ = [
"History",
"HistoryConceptDrift",
"mSPRT",
"PermutationTestOnBatchData",
"ResetOnBatchDataDrift",
"PermutationTestDistanceBased",
"ResetStatisticalTest",
"WarningSamplesBuffer",
]
3 changes: 0 additions & 3 deletions frouros/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ def on_fit_start(self, **kwargs) -> None:
def on_fit_end(self, **kwargs) -> None:
"""On fit end method."""

def on_drift_detected(self, **kwargs) -> None:
"""On drift detected method."""

@abc.abstractmethod
def reset(self) -> None:
"""Reset method."""
Expand Down
8 changes: 4 additions & 4 deletions frouros/callbacks/batch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Batch callbacks init."""

from .permutation_test import PermutationTestOnBatchData
from .reset_drift import ResetOnBatchDataDrift
from .permutation_test import PermutationTestDistanceBased
from .reset import ResetStatisticalTest

__all__ = [
"PermutationTestOnBatchData",
"ResetOnBatchDataDrift",
"PermutationTestDistanceBased",
"ResetStatisticalTest",
]
6 changes: 3 additions & 3 deletions frouros/callbacks/batch/permutation_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Permutation test on batch data callback module."""
"""Permutation test batch callback module."""

import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple
Expand All @@ -10,8 +10,8 @@
from frouros.utils.stats import permutation, z_score


class PermutationTestOnBatchData(BaseCallbackBatch):
"""Permutation test on batch data callback class."""
class PermutationTestDistanceBased(BaseCallbackBatch):
"""Permutation test on distance based batch callback class."""

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Reset on batch data drift callback module."""
"""Reset batch callback module."""

from typing import Optional

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.utils.logger import logger


class ResetOnBatchDataDrift(BaseCallbackBatch):
"""Reset on batch data drift callback class."""
class ResetStatisticalTest(BaseCallbackBatch):
"""Reset on statistical test batch callback class."""

def __init__(self, alpha: float, name: Optional[str] = None) -> None:
"""Init method.
Expand Down Expand Up @@ -44,7 +45,7 @@ def on_compare_end(self, **kwargs) -> None:
"""On compare end method."""
p_value = kwargs["result"].p_value
if p_value < self.alpha:
print("Drift detected. Resetting detector.")
logger.info("Drift detected. Resetting detector...")
self.detector.reset() # type: ignore

# FIXME: set_detector method as a workaround to # pylint: disable=fixme
Expand Down
4 changes: 2 additions & 2 deletions frouros/callbacks/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Streaming callbacks init."""

from .history import History
from .history import HistoryConceptDrift
from .msprt import mSPRT
from .warning_samples import WarningSamplesBuffer

__all__ = [
"History",
"HistoryConceptDrift",
"mSPRT",
"WarningSamplesBuffer",
]
4 changes: 2 additions & 2 deletions frouros/callbacks/streaming/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from frouros.utils.stats import BaseStat


class History(BaseCallbackStreaming):
"""History callback class."""
class HistoryConceptDrift(BaseCallbackStreaming):
"""HistoryConceptDrift callback class."""

def __init__(self, name: Optional[str] = None) -> None:
"""Init method.
Expand Down
14 changes: 10 additions & 4 deletions frouros/detectors/concept_drift/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
from typing import Any, Dict, List, Optional, Union

from frouros.callbacks import History
from frouros.callbacks import HistoryConceptDrift
from frouros.callbacks.streaming.base import BaseCallbackStreaming
from frouros.detectors.base import BaseDetector
from frouros.utils.checks import check_callbacks
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(

def _set_additional_vars_callback(self) -> None:
for callback in self.callbacks: # type: ignore
if isinstance(callback, History):
if isinstance(callback, HistoryConceptDrift):
# callback.set_detector(detector=self)
callback.add_additional_vars(
vars_=self.additional_vars.keys(), # type: ignore
Expand Down Expand Up @@ -186,10 +186,16 @@ def update(self, value: Union[int, float], **kwargs) -> Dict[str, Any]:
:type value: Union[int, float]
"""
for callback in self.callbacks: # type: ignore
callback.on_update_start(value=value, **kwargs) # type: ignore
callback.on_update_start( # type: ignore
value=value,
**kwargs,
)
self._update(value=value, **kwargs)
for callback in self.callbacks: # type: ignore
callback.on_update_end(value=value, **kwargs) # type: ignore
callback.on_update_end( # type: ignore
value=value,
**kwargs,
)

callbacks_logs = self._get_callbacks_logs()
return callbacks_logs
Expand Down
10 changes: 8 additions & 2 deletions frouros/detectors/data_drift/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,16 @@ def fit(self, X: np.ndarray, **kwargs) -> Dict[str, Any]: # noqa: N803
"""
self._check_fit_dimensions(X=X)
for callback in self.callbacks: # type: ignore
callback.on_fit_start()
callback.on_fit_start(
X=X,
**kwargs,
)
self._fit(X=X, **kwargs)
for callback in self.callbacks: # type: ignore
callback.on_fit_end(X=X, **kwargs)
callback.on_fit_end(
X=X,
**kwargs,
)

logs = self._get_callbacks_logs()
return logs
Expand Down
7 changes: 5 additions & 2 deletions frouros/detectors/data_drift/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ def compare(
:param X: feature data
:type X: numpy.ndarray
:return: compare result and callbacks logs
:rtype: Tuple[np.ndarray, Dict[str, Any]]
:rtype: Tuple[numpy.ndarray, Dict[str, Any]]
"""
for callback in self.callbacks: # type: ignore
callback.on_compare_start() # type: ignore
callback.on_compare_start( # type: ignore
X_ref=self.X_ref,
X_test=X,
)
result = self._compare(X=X, **kwargs)
for callback in self.callbacks: # type: ignore
callback.on_compare_end( # type: ignore
Expand Down
16 changes: 8 additions & 8 deletions frouros/tests/integration/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import sklearn # type: ignore # pylint: disable=import-error

from frouros.callbacks.batch import (
PermutationTestOnBatchData,
ResetOnBatchDataDrift,
PermutationTestDistanceBased,
ResetStatisticalTest,
)
from frouros.callbacks.streaming import (
History,
HistoryConceptDrift,
mSPRT,
WarningSamplesBuffer,
)
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_batch_permutation_test_data_univariate_different_distribution(
permutation_test_name = "permutation_test"
detector = detector_class( # type: ignore
callbacks=[
PermutationTestOnBatchData(
PermutationTestDistanceBased(
num_permutations=100,
random_state=31,
num_jobs=-1,
Expand All @@ -111,13 +111,13 @@ def test_batch_permutation_test_data_univariate_different_distribution(
"detector_class",
[CVMTest, KSTest, WelchTTest],
)
def test_batch_reset_on_data_drift(
def test_batch_reset_on_statistical_test_data_drift(
X_ref_univariate, # noqa: N803
X_test_univariate,
detector_class: BaseDataDriftBatch,
mocker,
) -> None:
"""Test batch reset on data drift callback.
"""Test batch reset on statistical test data drift callback.

:param X_ref_univariate: reference univariate data
:type X_ref_univariate: numpy.ndarray
Expand All @@ -130,7 +130,7 @@ def test_batch_reset_on_data_drift(

detector = detector_class( # type: ignore
callbacks=[
ResetOnBatchDataDrift(
ResetStatisticalTest(
alpha=0.01,
),
],
Expand Down Expand Up @@ -169,7 +169,7 @@ def test_streaming_history_on_concept_drift(
:type detector_class: BaseConceptDrift
"""
name = "history"
detector = detector_class(callbacks=History(name=name)) # type: ignore
detector = detector_class(callbacks=HistoryConceptDrift(name=name)) # type: ignore

for error in model_errors:
history = detector.update(value=error)
Expand Down
24 changes: 12 additions & 12 deletions frouros/tests/unit/utils/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pytest # type: ignore

from frouros.callbacks.base import BaseCallback
from frouros.callbacks.batch import PermutationTestOnBatchData, ResetOnBatchDataDrift
from frouros.callbacks.batch import PermutationTestDistanceBased, ResetStatisticalTest
from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.callbacks.streaming import History, WarningSamplesBuffer
from frouros.callbacks.streaming import HistoryConceptDrift, WarningSamplesBuffer
from frouros.callbacks.streaming.base import BaseCallbackStreaming
from frouros.utils.checks import check_callbacks

Expand All @@ -20,7 +20,7 @@
BaseCallbackBatch,
),
(
PermutationTestOnBatchData(
PermutationTestDistanceBased(
num_permutations=10,
),
BaseCallbackBatch,
Expand All @@ -31,21 +31,21 @@
),
(
[
PermutationTestOnBatchData(
PermutationTestDistanceBased(
num_permutations=10,
),
ResetOnBatchDataDrift(
ResetStatisticalTest(
alpha=0.05,
),
],
BaseCallbackBatch,
),
(
History(),
HistoryConceptDrift(),
BaseCallbackStreaming,
),
(
[History(), WarningSamplesBuffer()],
[HistoryConceptDrift(), WarningSamplesBuffer()],
BaseCallbackStreaming,
),
],
Expand All @@ -71,28 +71,28 @@ def test_check_callbacks(
"callbacks, expected_cls",
[
(
PermutationTestOnBatchData(
PermutationTestDistanceBased(
num_permutations=10,
),
BaseCallbackStreaming,
),
(
[
PermutationTestOnBatchData(
PermutationTestDistanceBased(
num_permutations=10,
),
ResetOnBatchDataDrift(
ResetStatisticalTest(
alpha=0.05,
),
],
BaseCallbackStreaming,
),
(
History(),
HistoryConceptDrift(),
BaseCallbackBatch,
),
(
[History(), WarningSamplesBuffer()],
[HistoryConceptDrift(), WarningSamplesBuffer()],
BaseCallbackBatch,
),
],
Expand Down