Skip to content

Commit 18e9c67

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #172 from IFCA/fix-msprt-fit
Add missing on_fit_end in mSPRT
2 parents bdc4fa0 + f90c517 commit 18e9c67

File tree

5 files changed

+11
-7
lines changed

5 files changed

+11
-7
lines changed

frouros/callbacks/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ def set_detector(self, detector) -> None:
5555
# )
5656
# self._detector = value
5757

58-
def on_fit_start(self) -> None:
58+
def on_fit_start(self, **kwargs) -> None:
5959
"""On fit start method."""
6060

61-
def on_fit_end(self) -> None:
61+
def on_fit_end(self, **kwargs) -> None:
6262
"""On fit end method."""
6363

64-
def on_drift_detected(self) -> None:
64+
def on_drift_detected(self, **kwargs) -> None:
6565
"""On drift detected method."""
6666

6767
@abc.abstractmethod

frouros/callbacks/streaming/msprt.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class mSPRT(StreamingCallback): # noqa: N801 # pylint: disable=invalid-name
2121

2222
def __init__(
2323
self,
24-
alpha: float = 0.05,
24+
alpha: float,
2525
sigma: float = 1.0,
2626
tau: Optional[float] = None,
2727
truncation: int = 1,
@@ -97,6 +97,10 @@ def tau(self, value: Optional[float]) -> None:
9797
raise TypeError("tau must be a float or None")
9898
self._tau = value
9999

100+
def on_fit_end(self, **kwargs) -> None:
101+
"""On fit end method."""
102+
self.incremental_mean.num_values = len(kwargs["X"])
103+
100104
def on_update_end(self, value: Union[int, float], **kwargs) -> None:
101105
"""On update end method.
102106

frouros/detectors/data_drift/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def fit(self, X: np.ndarray, **kwargs) -> Dict[str, Any]: # noqa: N803
188188
callback.on_fit_start()
189189
self._fit(X=X, **kwargs)
190190
for callback in self.callbacks: # type: ignore
191-
callback.on_fit_end()
191+
callback.on_fit_end(X=X, **kwargs)
192192

193193
logs = self._get_callbacks_logs()
194194
return logs

frouros/tests/integration/test_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def test_streaming_warning_samples_buffer_on_concept_drift(
242242
" expected_p_value,"
243243
" expected_likelihood",
244244
[
245-
(MMDStreaming, 70, 0.3622854, 0.0324733, 30.79452443),
245+
(MMDStreaming, 40, 0.08821576, 0.00494882, 202.06836342),
246246
],
247247
)
248248
def test_streaming_msprt_multivariate_different_distribution(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "frouros"
3-
version = "0.2.4"
3+
version = "0.2.5"
44
description = "A Python library for drift detection in Machine Learning problems"
55
authors = [
66
{name = "Jaime Céspedes Sisniega", email = "cespedes@ifca.unican.es"}

0 commit comments

Comments
 (0)