3
3
from typing import Union , Tuple , Optional
4
4
5
5
import numpy as np # type: ignore
6
- from scipy .stats import norm # type: ignore
7
6
8
7
from frouros .callbacks .streaming .base import StreamingCallback
9
- from frouros .utils .stats import Mean
8
+ from frouros .utils .stats import CircularMean
10
9
11
10
12
11
class mSPRT (StreamingCallback ): # noqa: N801 # pylint: disable=invalid-name
@@ -22,37 +21,34 @@ class mSPRT(StreamingCallback): # noqa: N801 # pylint: disable=invalid-name
22
21
def __init__ (
23
22
self ,
24
23
alpha : float ,
25
- sigma : float = 1.0 ,
26
- tau : Optional [ float ] = None ,
27
- truncation : int = 1 ,
24
+ sigma : Union [ int , float ] = 1.0 ,
25
+ tau : Union [ int , float ] = 1.0 ,
26
+ lambda_ : Union [ int , float ] = 1.0 ,
28
27
name : Optional [str ] = None ,
29
28
) -> None :
30
29
"""Init method.
31
30
32
31
:param alpha: alpha value
33
32
:type alpha: float
34
33
:param sigma: sigma value
35
- :type sigma: float
34
+ :type sigma: Union[int, float]
35
+ :param tau: tau value
36
+ :type tau: Union[int, float]
37
+ :param lambda_: lambda value
38
+ :type lambda_: Union[int, float]
36
39
:param name: name value
37
40
:type name: Optional[str]
38
41
"""
39
42
super ().__init__ (name = name )
40
43
self .alpha = alpha
41
44
self .sigma = sigma
42
- self .truncation = truncation
43
45
self .sigma_squared = self .sigma ** 2
44
46
self .two_sigma_squared = 2 * self .sigma_squared
45
47
self .tau = tau
46
- self .tau_squared = (
47
- self .tau ** 2
48
- if self .tau is not None
49
- else self ._calculate_tau_squared (
50
- alpha = self .alpha ,
51
- sigma_squared = self .sigma_squared ,
52
- truncation = self .truncation ,
53
- )
54
- )
55
- self .incremental_mean = Mean ()
48
+ self .tau_squared = self .tau ** 2
49
+ self .lambda_ = lambda_
50
+ self .mean = None
51
+ self .theta = None
56
52
self .p_value = 1.0
57
53
58
54
@property
@@ -78,66 +74,104 @@ def alpha(self, value: float) -> None:
78
74
self ._alpha = value
79
75
80
76
@property
81
- def tau (self ) -> Optional [float ]:
77
+ def sigma (self ) -> Optional [Union [int , float ]]:
78
+ """Sigma property.
79
+
80
+ :return: sigma value
81
+ :rtype: Optional[Union[int, float]]
82
+ """
83
+ return self ._sigma
84
+
85
+ @sigma .setter
86
+ def sigma (self , value : Optional [Union [int , float ]]) -> None :
87
+ """Sigma setter.
88
+
89
+ :param value: value to be set
90
+ :type value: Optional[float]
91
+ """
92
+ if value is not None and not isinstance (value , (int , float )):
93
+ raise TypeError ("sigma must be int, float or None" )
94
+ self ._sigma = value
95
+
96
+ @property
97
+ def tau (self ) -> Optional [Union [int , float ]]:
82
98
"""Tau property.
83
99
84
100
:return: tau squared value
85
- :rtype: Optional[float]
101
+ :rtype: Optional[Union[int, float] ]
86
102
"""
87
103
return self ._tau
88
104
89
105
@tau .setter
90
- def tau (self , value : Optional [ float ]) -> None :
106
+ def tau (self , value : Union [ int , float ]) -> None :
91
107
"""Tau setter.
92
108
93
109
:param value: value to be set
94
- :type value: Optional[ float]
110
+ :type value: Union[int, float]
95
111
"""
96
- if value is not None and not isinstance (value , float ):
97
- raise TypeError ("tau must be a float or None" )
112
+ if not isinstance (value , ( int , float ) ):
113
+ raise TypeError ("tau must be int, float or None" )
98
114
self ._tau = value
99
115
116
+ @property
117
+ def lambda_ (self ) -> Optional [Union [int , float ]]:
118
+ """Lambda property.
119
+
120
+ :return: lambda value
121
+ :rtype: Optional[Union[int, float]]
122
+ """
123
+ return self ._lambda_
124
+
125
+ @lambda_ .setter
126
+ def lambda_ (self , value : Union [int , float ]) -> None :
127
+ """Lambda setter.
128
+
129
+ :param value: value to be set
130
+ :type value: Union[int, float]
131
+ """
132
+ if not isinstance (value , (int , float )):
133
+ if value <= 0.0 :
134
+ raise ValueError ("lambda_ must be greater than 0" )
135
+ self ._lambda_ = value
136
+
137
+ def on_fit_end (self , ** kwargs ) -> None :
138
+ """On fit end method."""
139
+ self .mean = CircularMean (size = self .detector .window_size ) # type: ignore
140
+ self .theta = self .detector .compare (X = kwargs ["X" ])[0 ].distance # type: ignore
141
+
100
142
def on_update_end (self , value : Union [int , float ], ** kwargs ) -> None :
101
143
"""On update end method.
102
144
103
145
:param value: value to update detector
104
146
:type value: int
105
147
"""
106
- self .incremental_mean .update (value = value )
148
+ self .mean .update (value = value ) # type: ignore
107
149
self .p_value , likelihood = self ._calculate_p_value ()
108
150
109
151
self .logs .update (
110
152
{
111
- "distance_mean" : self .incremental_mean .get (),
153
+ "distance_mean" : self .mean .get (), # type: ignore
112
154
"likelihood" : likelihood ,
113
155
"p_value" : self .p_value ,
114
156
},
115
157
)
116
158
117
159
def reset (self ) -> None :
118
160
"""Reset method."""
119
- self .incremental_mean = Mean ()
161
+ super ().reset ()
162
+ self .mean = None
120
163
self .p_value = 1.0
121
164
122
- @staticmethod
123
- def _calculate_tau_squared (
124
- alpha : float ,
125
- sigma_squared : float ,
126
- truncation : int ,
127
- ) -> float :
128
- b = 2 * np .log (1 / alpha ) / ((truncation * sigma_squared ) ** 0.5 )
129
- minus_b_cdf = norm .cdf (- b )
130
- tau_squared = sigma_squared * minus_b_cdf / (1 / b * norm .pdf (b ) - minus_b_cdf )
131
- return tau_squared
132
-
133
165
def _calculate_p_value (self ) -> Tuple [float , float ]:
134
166
likelihood = self ._likelihood_normal_mixing_distribution (
135
- mean = self .incremental_mean .get (),
136
- sigma = self .sigma ,
167
+ mean = self .mean .get (), # type: ignore
168
+ sigma = self .sigma , # type: ignore
137
169
sigma_squared = self .sigma_squared ,
138
170
tau_squared = self .tau_squared ,
139
171
two_sigma_squared = self .two_sigma_squared ,
140
- n = self .detector .num_instances , # type: ignore
172
+ n = self .detector .window_size , # type: ignore
173
+ theta = self .theta , # type: ignore
174
+ lambda_ = self .lambda_ , # type: ignore
141
175
)
142
176
p_value = min (
143
177
self .p_value ,
@@ -146,20 +180,27 @@ def _calculate_p_value(self) -> Tuple[float, float]:
146
180
return p_value , likelihood
147
181
148
182
@staticmethod
149
- def _likelihood_normal_mixing_distribution (
183
+ def _likelihood_normal_mixing_distribution ( # pylint: disable=too-many-arguments
150
184
mean : float ,
151
185
sigma : float ,
152
186
sigma_squared : float ,
153
187
tau_squared : float ,
154
188
two_sigma_squared : float ,
155
189
n : int ,
190
+ theta : float ,
191
+ lambda_ : float ,
156
192
) -> float :
193
+ # FIXME: Explore lambda_ influence # pylint: disable=fixme
194
+ # and redesign the likelihood formula
157
195
n_tau_squared = n * tau_squared
158
196
sigma_squared_plus_n_tau_squared = sigma_squared + n_tau_squared
159
197
likelihood = (sigma / np .sqrt (sigma_squared_plus_n_tau_squared )) * np .exp (
160
198
n
161
199
* n_tau_squared
162
- * mean ** 2 # (mean - theta) ** 2, theta = 0 (H_0 value, no distance)
200
+ * lambda_ # Not present in mSPRT, added as a hyperparameter to control
201
+ # the influence of the distance difference
202
+ * (mean - theta )
203
+ ** 2 # (mean-theta) ** 2, theta=detector statistic (H_0 value, no distance)
163
204
/ (two_sigma_squared * sigma_squared_plus_n_tau_squared )
164
205
)
165
206
return likelihood
0 commit comments