Skip to content

Commit f1c655e

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #177 from IFCA/fix-MMD-mSPRT
Fix MMD and mSPRT
2 parents 9857029 + 3b712de commit f1c655e

File tree

7 files changed

+244
-125
lines changed

7 files changed

+244
-125
lines changed

frouros/callbacks/streaming/msprt.py

Lines changed: 83 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
from typing import Union, Tuple, Optional
44

55
import numpy as np # type: ignore
6-
from scipy.stats import norm # type: ignore
76

87
from frouros.callbacks.streaming.base import StreamingCallback
9-
from frouros.utils.stats import Mean
8+
from frouros.utils.stats import CircularMean
109

1110

1211
class mSPRT(StreamingCallback): # noqa: N801 # pylint: disable=invalid-name
@@ -22,37 +21,34 @@ class mSPRT(StreamingCallback): # noqa: N801 # pylint: disable=invalid-name
2221
def __init__(
2322
self,
2423
alpha: float,
25-
sigma: float = 1.0,
26-
tau: Optional[float] = None,
27-
truncation: int = 1,
24+
sigma: Union[int, float] = 1.0,
25+
tau: Union[int, float] = 1.0,
26+
lambda_: Union[int, float] = 1.0,
2827
name: Optional[str] = None,
2928
) -> None:
3029
"""Init method.
3130
3231
:param alpha: alpha value
3332
:type alpha: float
3433
:param sigma: sigma value
35-
:type sigma: float
34+
:type sigma: Union[int, float]
35+
:param tau: tau value
36+
:type tau: Union[int, float]
37+
:param lambda_: lambda value
38+
:type lambda_: Union[int, float]
3639
:param name: name value
3740
:type name: Optional[str]
3841
"""
3942
super().__init__(name=name)
4043
self.alpha = alpha
4144
self.sigma = sigma
42-
self.truncation = truncation
4345
self.sigma_squared = self.sigma**2
4446
self.two_sigma_squared = 2 * self.sigma_squared
4547
self.tau = tau
46-
self.tau_squared = (
47-
self.tau**2
48-
if self.tau is not None
49-
else self._calculate_tau_squared(
50-
alpha=self.alpha,
51-
sigma_squared=self.sigma_squared,
52-
truncation=self.truncation,
53-
)
54-
)
55-
self.incremental_mean = Mean()
48+
self.tau_squared = self.tau**2
49+
self.lambda_ = lambda_
50+
self.mean = None
51+
self.theta = None
5652
self.p_value = 1.0
5753

5854
@property
@@ -78,66 +74,104 @@ def alpha(self, value: float) -> None:
7874
self._alpha = value
7975

8076
@property
81-
def tau(self) -> Optional[float]:
77+
def sigma(self) -> Optional[Union[int, float]]:
78+
"""Sigma property.
79+
80+
:return: sigma value
81+
:rtype: Optional[Union[int, float]]
82+
"""
83+
return self._sigma
84+
85+
@sigma.setter
86+
def sigma(self, value: Optional[Union[int, float]]) -> None:
87+
"""Sigma setter.
88+
89+
:param value: value to be set
90+
:type value: Optional[float]
91+
"""
92+
if value is not None and not isinstance(value, (int, float)):
93+
raise TypeError("sigma must be int, float or None")
94+
self._sigma = value
95+
96+
@property
97+
def tau(self) -> Optional[Union[int, float]]:
8298
"""Tau property.
8399
84100
:return: tau squared value
85-
:rtype: Optional[float]
101+
:rtype: Optional[Union[int, float]]
86102
"""
87103
return self._tau
88104

89105
@tau.setter
90-
def tau(self, value: Optional[float]) -> None:
106+
def tau(self, value: Union[int, float]) -> None:
91107
"""Tau setter.
92108
93109
:param value: value to be set
94-
:type value: Optional[float]
110+
:type value: Union[int, float]
95111
"""
96-
if value is not None and not isinstance(value, float):
97-
raise TypeError("tau must be a float or None")
112+
if not isinstance(value, (int, float)):
113+
raise TypeError("tau must be int, float or None")
98114
self._tau = value
99115

116+
@property
117+
def lambda_(self) -> Optional[Union[int, float]]:
118+
"""Lambda property.
119+
120+
:return: lambda value
121+
:rtype: Optional[Union[int, float]]
122+
"""
123+
return self._lambda_
124+
125+
@lambda_.setter
126+
def lambda_(self, value: Union[int, float]) -> None:
127+
"""Lambda setter.
128+
129+
:param value: value to be set
130+
:type value: Union[int, float]
131+
"""
132+
if not isinstance(value, (int, float)):
133+
if value <= 0.0:
134+
raise ValueError("lambda_ must be greater than 0")
135+
self._lambda_ = value
136+
137+
def on_fit_end(self, **kwargs) -> None:
138+
"""On fit end method."""
139+
self.mean = CircularMean(size=self.detector.window_size) # type: ignore
140+
self.theta = self.detector.compare(X=kwargs["X"])[0].distance # type: ignore
141+
100142
def on_update_end(self, value: Union[int, float], **kwargs) -> None:
101143
"""On update end method.
102144
103145
:param value: value to update detector
104146
:type value: int
105147
"""
106-
self.incremental_mean.update(value=value)
148+
self.mean.update(value=value) # type: ignore
107149
self.p_value, likelihood = self._calculate_p_value()
108150

109151
self.logs.update(
110152
{
111-
"distance_mean": self.incremental_mean.get(),
153+
"distance_mean": self.mean.get(), # type: ignore
112154
"likelihood": likelihood,
113155
"p_value": self.p_value,
114156
},
115157
)
116158

117159
def reset(self) -> None:
118160
"""Reset method."""
119-
self.incremental_mean = Mean()
161+
super().reset()
162+
self.mean = None
120163
self.p_value = 1.0
121164

122-
@staticmethod
123-
def _calculate_tau_squared(
124-
alpha: float,
125-
sigma_squared: float,
126-
truncation: int,
127-
) -> float:
128-
b = 2 * np.log(1 / alpha) / ((truncation * sigma_squared) ** 0.5)
129-
minus_b_cdf = norm.cdf(-b)
130-
tau_squared = sigma_squared * minus_b_cdf / (1 / b * norm.pdf(b) - minus_b_cdf)
131-
return tau_squared
132-
133165
def _calculate_p_value(self) -> Tuple[float, float]:
134166
likelihood = self._likelihood_normal_mixing_distribution(
135-
mean=self.incremental_mean.get(),
136-
sigma=self.sigma,
167+
mean=self.mean.get(), # type: ignore
168+
sigma=self.sigma, # type: ignore
137169
sigma_squared=self.sigma_squared,
138170
tau_squared=self.tau_squared,
139171
two_sigma_squared=self.two_sigma_squared,
140-
n=self.detector.num_instances, # type: ignore
172+
n=self.detector.window_size, # type: ignore
173+
theta=self.theta, # type: ignore
174+
lambda_=self.lambda_, # type: ignore
141175
)
142176
p_value = min(
143177
self.p_value,
@@ -146,20 +180,27 @@ def _calculate_p_value(self) -> Tuple[float, float]:
146180
return p_value, likelihood
147181

148182
@staticmethod
149-
def _likelihood_normal_mixing_distribution(
183+
def _likelihood_normal_mixing_distribution( # pylint: disable=too-many-arguments
150184
mean: float,
151185
sigma: float,
152186
sigma_squared: float,
153187
tau_squared: float,
154188
two_sigma_squared: float,
155189
n: int,
190+
theta: float,
191+
lambda_: float,
156192
) -> float:
193+
# FIXME: Explore lambda_ influence # pylint: disable=fixme
194+
# and redesign the likelihood formula
157195
n_tau_squared = n * tau_squared
158196
sigma_squared_plus_n_tau_squared = sigma_squared + n_tau_squared
159197
likelihood = (sigma / np.sqrt(sigma_squared_plus_n_tau_squared)) * np.exp(
160198
n
161199
* n_tau_squared
162-
* mean**2 # (mean - theta) ** 2, theta = 0 (H_0 value, no distance)
200+
* lambda_ # Not present in mSPRT, added as a hyperparameter to control
201+
# the influence of the distance difference
202+
* (mean - theta)
203+
** 2 # (mean-theta) ** 2, theta=detector statistic (H_0 value, no distance)
163204
/ (two_sigma_squared * sigma_squared_plus_n_tau_squared)
164205
)
165206
return likelihood

0 commit comments

Comments
 (0)