@@ -69,19 +69,25 @@ pub struct DBSCAN<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Dista
69
69
_phantom_y : PhantomData < Y > ,
70
70
}
71
71
72
+ #[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
72
73
#[ derive( Debug , Clone ) ]
73
74
/// DBSCAN clustering algorithm parameters
74
75
pub struct DBSCANParameters < T : Number , D : Distance < Vec < T > > > {
76
+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
75
77
/// a function that defines a distance between each pair of point in training data.
76
78
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
77
79
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
78
80
pub distance : D ,
81
+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
79
82
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
80
83
pub min_samples : usize ,
84
+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
81
85
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
82
86
pub eps : f64 ,
87
+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
83
88
/// KNN algorithm to use.
84
89
pub algorithm : KNNAlgorithmName ,
90
+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
85
91
_phantom_t : PhantomData < T > ,
86
92
}
87
93
@@ -115,6 +121,110 @@ impl<T: Number, D: Distance<Vec<T>>> DBSCANParameters<T, D> {
115
121
}
116
122
}
117
123
124
+ /// DBSCAN grid search parameters
125
+ #[ cfg_attr( feature = "serde" , derive( Serialize , Deserialize ) ) ]
126
+ #[ derive( Debug , Clone ) ]
127
+ pub struct DBSCANSearchParameters < T : Number , D : Distance < Vec < T > > > {
128
+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
129
+ /// a function that defines a distance between each pair of point in training data.
130
+ /// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
131
+ /// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
132
+ pub distance : Vec < D > ,
133
+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
134
+ /// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
135
+ pub min_samples : Vec < usize > ,
136
+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
137
+ /// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
138
+ pub eps : Vec < f64 > ,
139
+ #[ cfg_attr( feature = "serde" , serde( default ) ) ]
140
+ /// KNN algorithm to use.
141
+ pub algorithm : Vec < KNNAlgorithmName > ,
142
+ _phantom_t : PhantomData < T > ,
143
+ }
144
+
145
+ /// DBSCAN grid search iterator
146
+ pub struct DBSCANSearchParametersIterator < T : Number , D : Distance < Vec < T > > > {
147
+ dbscan_search_parameters : DBSCANSearchParameters < T , D > ,
148
+ current_distance : usize ,
149
+ current_min_samples : usize ,
150
+ current_eps : usize ,
151
+ current_algorithm : usize ,
152
+ }
153
+
154
+ impl < T : Number , D : Distance < Vec < T > > > IntoIterator for DBSCANSearchParameters < T , D > {
155
+ type Item = DBSCANParameters < T , D > ;
156
+ type IntoIter = DBSCANSearchParametersIterator < T , D > ;
157
+
158
+ fn into_iter ( self ) -> Self :: IntoIter {
159
+ DBSCANSearchParametersIterator {
160
+ dbscan_search_parameters : self ,
161
+ current_distance : 0 ,
162
+ current_min_samples : 0 ,
163
+ current_eps : 0 ,
164
+ current_algorithm : 0 ,
165
+ }
166
+ }
167
+ }
168
+
169
+ impl < T : Number , D : Distance < Vec < T > > > Iterator for DBSCANSearchParametersIterator < T , D > {
170
+ type Item = DBSCANParameters < T , D > ;
171
+
172
+ fn next ( & mut self ) -> Option < Self :: Item > {
173
+ if self . current_distance == self . dbscan_search_parameters . distance . len ( )
174
+ && self . current_min_samples == self . dbscan_search_parameters . min_samples . len ( )
175
+ && self . current_eps == self . dbscan_search_parameters . eps . len ( )
176
+ && self . current_algorithm == self . dbscan_search_parameters . algorithm . len ( )
177
+ {
178
+ return None ;
179
+ }
180
+
181
+ let next = DBSCANParameters {
182
+ distance : self . dbscan_search_parameters . distance [ self . current_distance ] . clone ( ) ,
183
+ min_samples : self . dbscan_search_parameters . min_samples [ self . current_min_samples ] ,
184
+ eps : self . dbscan_search_parameters . eps [ self . current_eps ] ,
185
+ algorithm : self . dbscan_search_parameters . algorithm [ self . current_algorithm ] . clone ( ) ,
186
+ _phantom_t : PhantomData ,
187
+ } ;
188
+
189
+ if self . current_distance + 1 < self . dbscan_search_parameters . distance . len ( ) {
190
+ self . current_distance += 1 ;
191
+ } else if self . current_min_samples + 1 < self . dbscan_search_parameters . min_samples . len ( ) {
192
+ self . current_distance = 0 ;
193
+ self . current_min_samples += 1 ;
194
+ } else if self . current_eps + 1 < self . dbscan_search_parameters . eps . len ( ) {
195
+ self . current_distance = 0 ;
196
+ self . current_min_samples = 0 ;
197
+ self . current_eps += 1 ;
198
+ } else if self . current_algorithm + 1 < self . dbscan_search_parameters . algorithm . len ( ) {
199
+ self . current_distance = 0 ;
200
+ self . current_min_samples = 0 ;
201
+ self . current_eps = 0 ;
202
+ self . current_algorithm += 1 ;
203
+ } else {
204
+ self . current_distance += 1 ;
205
+ self . current_min_samples += 1 ;
206
+ self . current_eps += 1 ;
207
+ self . current_algorithm += 1 ;
208
+ }
209
+
210
+ Some ( next)
211
+ }
212
+ }
213
+
214
+ impl < T : Number > Default for DBSCANSearchParameters < T , Euclidian < T > > {
215
+ fn default ( ) -> Self {
216
+ let default_params = DBSCANParameters :: default ( ) ;
217
+
218
+ DBSCANSearchParameters {
219
+ distance : vec ! [ default_params. distance] ,
220
+ min_samples : vec ! [ default_params. min_samples] ,
221
+ eps : vec ! [ default_params. eps] ,
222
+ algorithm : vec ! [ default_params. algorithm] ,
223
+ _phantom_t : PhantomData ,
224
+ }
225
+ }
226
+ }
227
+
118
228
impl < TX : Number , TY : Number , X : Array2 < TX > , Y : Array1 < TY > , D : Distance < Vec < TX > > > PartialEq
119
229
for DBSCAN < TX , TY , X , Y , D >
120
230
{
@@ -132,7 +242,7 @@ impl<T: Number> Default for DBSCANParameters<T, Euclidian<T>> {
132
242
distance : Distances :: euclidian ( ) ,
133
243
min_samples : 5 ,
134
244
eps : 0.5f64 ,
135
- algorithm : KNNAlgorithmName :: CoverTree ,
245
+ algorithm : KNNAlgorithmName :: default ( ) ,
136
246
_phantom_t : PhantomData ,
137
247
}
138
248
}
@@ -292,6 +402,29 @@ mod tests {
292
402
#[ cfg( feature = "serde" ) ]
293
403
use crate :: metrics:: distance:: euclidian:: Euclidian ;
294
404
405
+ #[ test]
406
+ fn search_parameters ( ) {
407
+ let parameters = DBSCANSearchParameters {
408
+ min_samples : vec ! [ 10 , 100 ] ,
409
+ eps : vec ! [ 1. , 2. ] ,
410
+ ..Default :: default ( )
411
+ } ;
412
+ let mut iter = parameters. into_iter ( ) ;
413
+ let next = iter. next ( ) . unwrap ( ) ;
414
+ assert_eq ! ( next. min_samples, 10 ) ;
415
+ assert_eq ! ( next. eps, 1. ) ;
416
+ let next = iter. next ( ) . unwrap ( ) ;
417
+ assert_eq ! ( next. min_samples, 100 ) ;
418
+ assert_eq ! ( next. eps, 1. ) ;
419
+ let next = iter. next ( ) . unwrap ( ) ;
420
+ assert_eq ! ( next. min_samples, 10 ) ;
421
+ assert_eq ! ( next. eps, 2. ) ;
422
+ let next = iter. next ( ) . unwrap ( ) ;
423
+ assert_eq ! ( next. min_samples, 100 ) ;
424
+ assert_eq ! ( next. eps, 2. ) ;
425
+ assert ! ( iter. next( ) . is_none( ) ) ;
426
+ }
427
+
295
428
#[ cfg_attr( target_arch = "wasm32" , wasm_bindgen_test:: wasm_bindgen_test) ]
296
429
#[ test]
297
430
fn fit_predict_dbscan ( ) {
0 commit comments