Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions frouros/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
History,
PermutationTestOnBatchData,
ResetOnBatchDataDrift,
WarningSamplesBuffer,
)

__all__ = [
"Callback",
"History",
"PermutationTestOnBatchData",
"ResetOnBatchDataDrift",
"WarningSamplesBuffer",
]
140 changes: 137 additions & 3 deletions frouros/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
"""Callback module."""

import abc
import copy
import multiprocessing
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np # type: ignore

# FIXME: set_detector method as a workaround to # pylint: disable=fixme
# avoid circular import problem. Make it an abstract method and
# uncomment commented code when it is solved

# from frouros.detectors.concept_drift.base import ConceptDriftBase
# from frouros.detectors.concept_drift.ddm_based.base import DDMBased
# from frouros.detectors.data_drift.batch.base import DataDriftBatchBase
from frouros.utils.stats import permutation_test, Stat


Expand Down Expand Up @@ -42,7 +51,6 @@ def name(self, value: Optional[str]) -> None:
raise TypeError("name must be of type str or None.")
self._name = self.__class__.__name__ if value is None else value

# FIXME: Workaround to avoid circular import problem # pylint: disable=fixme
def set_detector(self, detector) -> None:
"""Set detector method."""
self.detector = detector
Expand All @@ -69,16 +77,28 @@ def on_fit_end(self) -> None:
def on_drift_detected(self) -> None:
"""On drift detected method."""

@abc.abstractmethod
def reset(self) -> None:
"""Reset method."""


class StreamingCallback(Callback):
"""Streaming callback class."""

def on_update_start(self) -> None:
def on_update_start(self, value: Union[int, float], **kwargs) -> None:
"""On update start method."""

def on_update_end(self, value: Union[int, float], **kwargs) -> None:
"""On update end method."""

# @abc.abstractmethod
# def set_detector(self, detector) -> None:
# """Set detector method."""

@abc.abstractmethod
def reset(self) -> None:
"""Reset method."""


class BatchCallback(Callback):
"""Batch callback class."""
Expand All @@ -89,6 +109,14 @@ def on_compare_start(self) -> None:
def on_compare_end(self, **kwargs) -> None:
"""On compare end method."""

# @abc.abstractmethod
# def set_detector(self, detector) -> None:
# """Set detector method."""

@abc.abstractmethod
def reset(self) -> None:
"""Reset method."""


class History(StreamingCallback):
"""History callback class."""
Expand Down Expand Up @@ -139,6 +167,24 @@ def on_update_end(self, value: Union[int, float], **kwargs) -> None:

self.logs.update(**self.history)

# def set_detector(self, detector) -> None:
# """Set detector method.
#
# :raises TypeError: Type error exception
# """
# if not isinstance(detector, ConceptDriftBase):
# raise TypeError(
# f"callback {self.__class__.name} cannot be used with detector"
# f" {detector.__class__name}. Must be used with a detector of "
# f"type ConceptDriftBase."
# )
# self.detector = detector

def reset(self) -> None:
"""Reset method."""
for key in self.history.keys():
self.history[key].clear()


class PermutationTestOnBatchData(BatchCallback):
"""Permutation test on batch data callback class."""
Expand All @@ -149,7 +195,7 @@ def __init__(
num_jobs: int = -1,
name: Optional[str] = None,
verbose: bool = False,
**kwargs
**kwargs,
) -> None:
"""Init method.

Expand Down Expand Up @@ -277,6 +323,22 @@ def on_compare_end(self, **kwargs) -> None:
},
)

# def set_detector(self, detector) -> None:
# """Set detector method.
#
# :raises TypeError: Type error exception
# """
# if not isinstance(detector, DataDriftBatchBase):
# raise TypeError(
# f"callback {self.__class__.name} cannot be used with detector"
# f" {detector.__class__name}. Must be used with a detector of "
# f"type DataDriftBatchBase."
# )
# self.detector = detector

def reset(self) -> None:
"""Reset method."""


class ResetOnBatchDataDrift(BatchCallback):
"""Reset on batch data drift callback class."""
Expand Down Expand Up @@ -319,3 +381,75 @@ def on_compare_end(self, **kwargs) -> None:
if p_value < self.alpha:
print("Drift detected. Resetting detector.")
self.detector.reset() # type: ignore

# def set_detector(self, detector) -> None:
# """Set detector method.
#
# :raises TypeError: Type error exception
# """
# if not isinstance(detector, DataDriftBatchBase):
# raise TypeError(
# f"callback {self.__class__.name} cannot be used with detector"
# f" {detector.__class__name}. Must be used with a detector of "
# f"type DataDriftBatchBase."
# )
# self.detector = detector

def reset(self) -> None:
"""Reset method."""


class WarningSamplesBuffer(StreamingCallback):
"""Store warning samples as a buffer callback class."""

def __init__(self, name: Optional[str] = None) -> None:
"""Init method.

:param name: name to be use
:type name: Optional[str]
"""
super().__init__(name=name)
self.X: List[Any] = []
self.y: List[Any] = []
self._start_warning = False

def on_update_start(self, value: Union[int, float], **kwargs) -> None:
"""On update start method."""
self._start_warning = not self.detector.warning # type: ignore

def on_update_end(self, value: Union[int, float], **kwargs) -> None:
"""On update end method.

:param value: value to update detector
:type value: int
"""
self.logs = {
"X": copy.deepcopy(self.X),
"y": copy.deepcopy(self.y),
}

def on_warning_detected(self, **kwargs) -> None:
"""On warning detected method."""
if self._start_warning:
map(lambda x: x.clear(), [self.X, self.y])
self.X.append(kwargs["X"])
self.y.append(kwargs["y"])

# def set_detector(self, detector) -> None:
# """Set detector method.
#
# :raises TypeError: Type error exception
# """
# if not isinstance(detector, DDMBased):
# raise TypeError(
# f"callback {self.__class__.name} cannot be used with detector"
# f" {detector.__class__name}. Must be used with a detector of "
# f"type DDMBased."
# )
# self.detector = detector

def reset(self) -> None:
"""Reset method."""
self.X.clear()
self.y.clear()
self._start_warning = False
10 changes: 7 additions & 3 deletions frouros/detectors/concept_drift/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,13 @@ def __init__(

self.num_instances = 0
self.drift = False
for callback in self.callbacks: # type: ignore
callback.set_detector(detector=self)

def _set_additional_vars_callback(self) -> None:
for callback in self.callbacks: # type: ignore
if isinstance(callback, History):
callback.set_detector(detector=self)
# callback.set_detector(detector=self)
callback.add_additional_vars(
vars_=self.additional_vars.keys(), # type: ignore
)
Expand Down Expand Up @@ -145,6 +147,8 @@ def reset(self) -> None:
"""Reset method."""
self.num_instances = 0
self.drift = False
for callback in self.callbacks: # type: ignore
callback.reset()

@property
def status(self) -> Dict[str, bool]:
Expand All @@ -162,10 +166,10 @@ def update(self, value: Union[int, float], **kwargs) -> Dict[str, Any]:
:type value: Union[int, float]
"""
for callback in self.callbacks: # type: ignore
callback.on_update_start() # type: ignore
callback.on_update_start(value=value, **kwargs) # type: ignore
self._update(value=value, **kwargs)
for callback in self.callbacks: # type: ignore
callback.on_update_end(value=value) # type: ignore
callback.on_update_end(value=value, **kwargs) # type: ignore

callbacks_logs = self._get_callbacks_logs()
return callbacks_logs
Expand Down
4 changes: 4 additions & 0 deletions frouros/detectors/concept_drift/ddm_based/ddm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""DDM (Drift detection method) module."""

from contextlib import suppress
from typing import Union

from frouros.detectors.concept_drift.ddm_based.base import DDMBaseConfig, DDMErrorBased
Expand Down Expand Up @@ -43,6 +44,9 @@ def _update(self, value: Union[int, float], **kwargs) -> None:
if warning_flag:
# Warning
self.warning = True
for callback in self.callbacks: # type: ignore
with suppress(AttributeError):
callback.on_warning_detected(**kwargs) # type: ignore
else:
# In-Control
self.warning = False
Expand Down
41 changes: 30 additions & 11 deletions frouros/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,19 +400,18 @@ def concept_drift_dataset_simple() -> Tuple[
return (X_ref, y_ref), (X_test, y_test)


@pytest.fixture(scope="module", name="model_errors")
def concept_drift_model_errors_simple(
dataset_simple: Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
) -> List[int]:
"""Compute model errors given a dataset with concept drift.
@pytest.fixture(scope="module", name="model")
def concept_drift_model(
dataset_simple: Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]],
) -> sklearn.pipeline.Pipeline:
"""Model used for concept drift.

:param dataset_simple: Dataset with concept drift
:param dataset_simple: dataset with concept drift
:type dataset_simple: Tuple[Tuple[numpy.ndarray, numpy.ndarray],
Tuple[numpy.ndarray, numpy.ndarray]]
:return: model errors
:rtype: List[int]
:return: trained model
:rtype: sklearn.pipeline.Pipeline
"""
(X_ref, y_ref), (X_test, y_test) = dataset_simple # noqa: N806
(X_ref, y_ref), _ = dataset_simple # noqa: N806

pipeline = sklearn.pipeline.Pipeline(
[
Expand All @@ -422,7 +421,27 @@ def concept_drift_model_errors_simple(
)
pipeline.fit(X=X_ref, y=y_ref)

y_test_pred = pipeline.predict(X_test)
return pipeline


@pytest.fixture(scope="module", name="model_errors")
def concept_drift_model_errors_simple(
dataset_simple: Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]],
model: sklearn.pipeline.Pipeline,
) -> List[int]:
"""Compute model errors given a dataset with concept drift.

:param dataset_simple: dataset with concept drift
:type dataset_simple: Tuple[Tuple[numpy.ndarray, numpy.ndarray],
Tuple[numpy.ndarray, numpy.ndarray]]
:param model: trained model
:type model: sklearn.pipeline.Pipeline
:return: model errors
:rtype: List[int]
"""
_, (X_test, y_test) = dataset_simple # noqa: N806

y_test_pred = model.predict(X_test)
error = (1 - y_test_pred == y_test).astype(int).tolist()

return error
Loading