Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions frouros/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import abc
from typing import Optional

import numpy as np # type: ignore


class BaseCallback(abc.ABC):
"""Abstract class representing a callback."""
Expand Down Expand Up @@ -55,11 +57,19 @@ def set_detector(self, detector) -> None:
# )
# self._detector = value

def on_fit_start(self, **kwargs) -> None:
"""On fit start method."""
def on_fit_start(self, X: np.ndarray) -> None: # noqa: N803
"""On fit start method.

:param X: reference data
:type X: numpy.ndarray
"""

def on_fit_end(self, **kwargs) -> None:
"""On fit end method."""
def on_fit_end(self, X: np.ndarray) -> None: # noqa: N803
"""On fit end method.

:param X: reference data
:type X: numpy.ndarray
"""

@abc.abstractmethod
def reset(self) -> None:
Expand Down
36 changes: 31 additions & 5 deletions frouros/callbacks/batch/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
"""Base callback batch module."""

import abc
from typing import Any

import numpy as np # type: ignore

from frouros.callbacks.base import BaseCallback


class BaseCallbackBatch(BaseCallback):
"""Callback batch class."""

def on_compare_start(self, **kwargs) -> None:
"""On compare start method."""

def on_compare_end(self, **kwargs) -> None:
"""On compare end method."""
def on_compare_start(
self,
X_ref: np.ndarray, # noqa: N803
X_test: np.ndarray,
) -> None:
"""On compare start method.

:param X_ref: reference data
:type X_ref: numpy.ndarray
:param X_test: test data
:type X_test: numpy.ndarray
"""

def on_compare_end(
self,
result: Any,
X_ref: np.ndarray, # noqa: N803
X_test: np.ndarray,
) -> None:
"""On compare end method.

:param result: result obtained from the `compare` method
:type result: Any
:param X_ref: reference data
:type X_ref: numpy.ndarray
:param X_test: test data
:type X_test: numpy.ndarray
"""

# FIXME: set_detector method as a workaround to # pylint: disable=fixme
# avoid circular import problem. Make it an abstract method and
Expand Down
20 changes: 16 additions & 4 deletions frouros/callbacks/batch/permutation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,22 @@ def _calculate_p_value( # pylint: disable=too-many-arguments
p_value = (permuted_statistic >= observed_statistic).mean() # type: ignore
return permuted_statistic, p_value

def on_compare_end(self, **kwargs) -> None:
"""On compare end method."""
X_ref, X_test = kwargs["X_ref"], kwargs["X_test"] # noqa: N806
observed_statistic = kwargs["result"][0]
def on_compare_end(
self,
result: Any,
X_ref: np.ndarray, # noqa: N803
X_test: np.ndarray,
) -> None:
"""On compare end method.

:param result: result obtained from the `compare` method
:type result: Any
:param X_ref: reference data
:type X_ref: numpy.ndarray
:param X_test: test data
:type X_test: numpy.ndarray
"""
observed_statistic = result.distance
permuted_statistics, p_value = self._calculate_p_value(
X_ref=X_ref,
X_test=X_test,
Expand Down
25 changes: 20 additions & 5 deletions frouros/callbacks/batch/reset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Reset batch callback module."""

from typing import Optional
from typing import Any, Optional

import numpy as np # type: ignore

from frouros.callbacks.batch.base import BaseCallbackBatch
from frouros.utils.logger import logger
Expand Down Expand Up @@ -58,10 +60,23 @@ def alpha(self, value: float) -> None:
raise ValueError("value must be greater than 0.")
self._alpha = value

def on_compare_end(self, **kwargs) -> None:
"""On compare end method."""
p_value = kwargs["result"].p_value
if p_value < self.alpha:
def on_compare_end(
self,
result: Any,
X_ref: np.ndarray, # noqa: N803
X_test: np.ndarray,
) -> None:
"""On compare end method.

:param result: result obtained from the `compare` method
:type result: Any
:param X_ref: reference data
:type X_ref: numpy.ndarray
:param X_test: test data
:type X_test: numpy.ndarray
"""
p_value = result.p_value
if p_value <= self.alpha:
logger.info("Drift detected. Resetting detector...")
self.detector.reset() # type: ignore

Expand Down
17 changes: 13 additions & 4 deletions frouros/callbacks/streaming/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
"""Base callback streaming module."""

import abc
from typing import Union

from frouros.callbacks.base import BaseCallback


class BaseCallbackStreaming(BaseCallback):
"""Callback streaming class."""

def on_update_start(self, **kwargs) -> None:
"""On update start method."""
def on_update_start(self, value: Union[int, float]) -> None:
"""On update start method.

def on_update_end(self, **kwargs) -> None:
"""On update end method."""
:param value: value used to update the detector
:type value: Union[int, float]
"""

def on_update_end(self, value: Union[int, float]) -> None:
"""On update end method.

:param value: value used to update the detector
:type value: Union[int, float]
"""

# FIXME: set_detector method as a workaround to # pylint: disable=fixme
# avoid circular import problem. Make it an abstract method and
Expand Down
12 changes: 8 additions & 4 deletions frouros/callbacks/streaming/history.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""History callback module."""

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from frouros.callbacks.streaming.base import BaseCallbackStreaming
from frouros.utils.stats import BaseStat
Expand Down Expand Up @@ -62,9 +62,13 @@ def add_additional_vars(self, vars_: List[str]) -> None:
self.additional_vars.extend(vars_)
self.history = {**self.history, **{var: [] for var in self.additional_vars}}

def on_update_end(self, **kwargs) -> None:
"""On update end method."""
self.history["value"].append(kwargs["value"])
def on_update_end(self, value: Union[int, float]) -> None:
"""On update end method.

:param value: value used to update the detector
:type value: Union[int, float]
"""
self.history["value"].append(value)
self.history["num_instances"].append(
self.detector.num_instances # type: ignore
)
Expand Down
4 changes: 2 additions & 2 deletions frouros/detectors/concept_drift/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,17 @@ def update(self, value: Union[int, float], **kwargs) -> Dict[str, Any]:

:param value: value to update detector
:type value: Union[int, float]
:return: callbacks logs
:rtype: Dict[str, Any]]
"""
for callback in self.callbacks: # type: ignore
callback.on_update_start( # type: ignore
value=value,
**kwargs,
)
self._update(value=value, **kwargs)
for callback in self.callbacks: # type: ignore
callback.on_update_end( # type: ignore
value=value,
**kwargs,
)

callbacks_logs = self._get_callbacks_logs()
Expand Down
6 changes: 2 additions & 4 deletions frouros/detectors/data_drift/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,15 @@ def fit(self, X: np.ndarray, **kwargs) -> Dict[str, Any]: # noqa: N803
for callback in self.callbacks: # type: ignore
callback.on_fit_start(
X=X,
**kwargs,
)
self._fit(X=X, **kwargs)
for callback in self.callbacks: # type: ignore
callback.on_fit_end(
X=X,
**kwargs,
)

logs = self._get_callbacks_logs()
return logs
callbacks_logs = self._get_callbacks_logs()
return callbacks_logs

def reset(self) -> None:
"""Reset method."""
Expand Down
2 changes: 1 addition & 1 deletion frouros/detectors/data_drift/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def compare(
) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Compare values.

:param X: feature data
:param X: test data
:type X: numpy.ndarray
:return: compare result and callbacks logs
:rtype: Tuple[numpy.ndarray, Dict[str, Any]]
Expand Down
20 changes: 11 additions & 9 deletions frouros/detectors/data_drift/streaming/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,27 +56,29 @@ def reset(self) -> None:
self._reset()

def update(
self, value: Union[int, float]
self,
value: Union[int, float],
) -> Tuple[Optional[BaseResult], Dict[str, Any]]:
"""Update detector.

:param value: value to use to update the detector
:type value: Union[int, float]
:return: update result
:rtype: Optional[BaseResult]
:return: update result and callbacks logs
:rtype: Tuple[Optional[BaseResult], Dict[str, Any]]
"""
self._common_checks() # noqa: N806
self._specific_checks(X=value) # noqa: N806
self.num_instances += 1

for callback in self.callbacks: # type: ignore
callback.on_update_start(value=value) # type: ignore
callback.on_update_start( # type: ignore
value=value, # type: ignore
)
result = self._update(value=value)
if result is not None:
for callback in self.callbacks: # type: ignore
callback.on_update_end( # type: ignore
value=result.distance, # type: ignore
)
for callback in self.callbacks: # type: ignore
callback.on_update_end( # type: ignore
value=result, # type: ignore
)

callbacks_logs = self._get_callbacks_logs()
return result, callbacks_logs
Expand Down