Skip to content

Commit 088f7e3

Browse files
author
Montana Low
committed
merge development
2 parents ba9398a + 764309e commit 088f7e3

33 files changed

+2362
-67
lines changed

.github/workflows/ci.yml

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,24 @@ name: CI
22

33
on:
44
push:
5-
branches: [ main, development ]
5+
branches: [main, development]
66
pull_request:
7-
branches: [ development ]
7+
branches: [development]
88

99
jobs:
1010
tests:
1111
runs-on: "${{ matrix.platform.os }}-latest"
1212
strategy:
1313
matrix:
14-
platform: [
15-
{ os: "windows", target: "x86_64-pc-windows-msvc" },
16-
{ os: "windows", target: "i686-pc-windows-msvc" },
17-
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
18-
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
19-
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
20-
{ os: "macos", target: "aarch64-apple-darwin" },
21-
]
14+
platform:
15+
[
16+
{ os: "windows", target: "x86_64-pc-windows-msvc" },
17+
{ os: "windows", target: "i686-pc-windows-msvc" },
18+
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
19+
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
20+
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
21+
{ os: "macos", target: "aarch64-apple-darwin" },
22+
]
2223
env:
2324
TZ: "/usr/share/zoneinfo/your/location"
2425
steps:
@@ -40,7 +41,7 @@ jobs:
4041
default: true
4142
- name: Install test runner for wasm
4243
if: matrix.platform.target == 'wasm32-unknown-unknown'
43-
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
44+
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
4445
- name: Stable Build
4546
uses: actions-rs/cargo@v1
4647
with:

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88

9+
## Added
10+
- Seeds to multiple algorithims that depend on random number generation.
11+
- Added feature `js` to use WASM in browser
12+
13+
## BREAKING CHANGE
14+
- Added a new parameter to `train_test_split` to define the seed.
15+
16+
## [0.2.1] - 2022-05-10
17+
918
## Added
1019
- L2 regularization penalty to the Logistic Regression
1120
- Getters for the naive bayes structs

Cargo.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,26 @@ categories = ["science"]
1616
default = ["datasets", "serde"]
1717
ndarray-bindings = ["ndarray"]
1818
nalgebra-bindings = ["nalgebra"]
19-
datasets = ["rand_distr"]
19+
datasets = ["rand_distr", "std"]
2020
fp_bench = ["itertools"]
21+
std = ["rand/std", "rand/std_rng"]
22+
# wasm32 only
23+
js = ["getrandom/js"]
2124

2225
[dependencies]
2326
approx = "0.5.1"
27+
cfg-if = "1.0.0"
2428
itertools = { version = "0.10.3", optional = true }
2529
ndarray = { version = "0.15", optional = true }
2630
nalgebra = { version = "0.31", optional = true }
2731
num-traits = "0.2.12"
2832
num = "0.4"
29-
rand = "0.8.3"
33+
rand = { version = "0.8", default-features = false, features = ["small_rng"] }
3034
rand_distr = { version = "0.4", optional = true }
3135
serde = { version = "1", features = ["derive"], optional = true }
36+
3237
[target.'cfg(target_arch = "wasm32")'.dependencies]
33-
getrandom = { version = "0.2", features = ["js"] }
38+
getrandom = { version = "0.2", optional = true }
3439

3540
[dev-dependencies]
3641
smartcore = { path = ".", features = ["fp_bench"] }

src/algorithm/neighbour/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ pub enum KNNAlgorithmName {
5959
CoverTree,
6060
}
6161

62+
impl Default for KNNAlgorithmName {
63+
fn default() -> Self {
64+
KNNAlgorithmName::CoverTree
65+
}
66+
}
67+
6268
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6369
#[derive(Debug)]
6470
pub(crate) enum KNNAlgorithm<T: Number, D: Distance<Vec<T>>> {

src/cluster/dbscan.rs

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,25 @@ pub struct DBSCAN<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Dista
6969
_phantom_y: PhantomData<Y>,
7070
}
7171

72+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7273
#[derive(Debug, Clone)]
7374
/// DBSCAN clustering algorithm parameters
7475
pub struct DBSCANParameters<T: Number, D: Distance<Vec<T>>> {
76+
#[cfg_attr(feature = "serde", serde(default))]
7577
/// a function that defines a distance between each pair of point in training data.
7678
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
7779
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
7880
pub distance: D,
81+
#[cfg_attr(feature = "serde", serde(default))]
7982
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
8083
pub min_samples: usize,
84+
#[cfg_attr(feature = "serde", serde(default))]
8185
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
8286
pub eps: f64,
87+
#[cfg_attr(feature = "serde", serde(default))]
8388
/// KNN algorithm to use.
8489
pub algorithm: KNNAlgorithmName,
90+
#[cfg_attr(feature = "serde", serde(default))]
8591
_phantom_t: PhantomData<T>,
8692
}
8793

@@ -115,6 +121,110 @@ impl<T: Number, D: Distance<Vec<T>>> DBSCANParameters<T, D> {
115121
}
116122
}
117123

124+
/// DBSCAN grid search parameters
125+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
126+
#[derive(Debug, Clone)]
127+
pub struct DBSCANSearchParameters<T: Number, D: Distance<Vec<T>>> {
128+
#[cfg_attr(feature = "serde", serde(default))]
129+
/// a function that defines a distance between each pair of point in training data.
130+
/// This function should extend [`Distance`](../../math/distance/trait.Distance.html) trait.
131+
/// See [`Distances`](../../math/distance/struct.Distances.html) for a list of available functions.
132+
pub distance: Vec<D>,
133+
#[cfg_attr(feature = "serde", serde(default))]
134+
/// The number of samples (or total weight) in a neighborhood for a point to be considered as a core point.
135+
pub min_samples: Vec<usize>,
136+
#[cfg_attr(feature = "serde", serde(default))]
137+
/// The maximum distance between two samples for one to be considered as in the neighborhood of the other.
138+
pub eps: Vec<f64>,
139+
#[cfg_attr(feature = "serde", serde(default))]
140+
/// KNN algorithm to use.
141+
pub algorithm: Vec<KNNAlgorithmName>,
142+
_phantom_t: PhantomData<T>,
143+
}
144+
145+
/// DBSCAN grid search iterator
146+
pub struct DBSCANSearchParametersIterator<T: Number, D: Distance<Vec<T>>> {
147+
dbscan_search_parameters: DBSCANSearchParameters<T, D>,
148+
current_distance: usize,
149+
current_min_samples: usize,
150+
current_eps: usize,
151+
current_algorithm: usize,
152+
}
153+
154+
impl<T: Number, D: Distance<Vec<T>>> IntoIterator for DBSCANSearchParameters<T, D> {
155+
type Item = DBSCANParameters<T, D>;
156+
type IntoIter = DBSCANSearchParametersIterator<T, D>;
157+
158+
fn into_iter(self) -> Self::IntoIter {
159+
DBSCANSearchParametersIterator {
160+
dbscan_search_parameters: self,
161+
current_distance: 0,
162+
current_min_samples: 0,
163+
current_eps: 0,
164+
current_algorithm: 0,
165+
}
166+
}
167+
}
168+
169+
impl<T: Number, D: Distance<Vec<T>>> Iterator for DBSCANSearchParametersIterator<T, D> {
170+
type Item = DBSCANParameters<T, D>;
171+
172+
fn next(&mut self) -> Option<Self::Item> {
173+
if self.current_distance == self.dbscan_search_parameters.distance.len()
174+
&& self.current_min_samples == self.dbscan_search_parameters.min_samples.len()
175+
&& self.current_eps == self.dbscan_search_parameters.eps.len()
176+
&& self.current_algorithm == self.dbscan_search_parameters.algorithm.len()
177+
{
178+
return None;
179+
}
180+
181+
let next = DBSCANParameters {
182+
distance: self.dbscan_search_parameters.distance[self.current_distance].clone(),
183+
min_samples: self.dbscan_search_parameters.min_samples[self.current_min_samples],
184+
eps: self.dbscan_search_parameters.eps[self.current_eps],
185+
algorithm: self.dbscan_search_parameters.algorithm[self.current_algorithm].clone(),
186+
_phantom_t: PhantomData,
187+
};
188+
189+
if self.current_distance + 1 < self.dbscan_search_parameters.distance.len() {
190+
self.current_distance += 1;
191+
} else if self.current_min_samples + 1 < self.dbscan_search_parameters.min_samples.len() {
192+
self.current_distance = 0;
193+
self.current_min_samples += 1;
194+
} else if self.current_eps + 1 < self.dbscan_search_parameters.eps.len() {
195+
self.current_distance = 0;
196+
self.current_min_samples = 0;
197+
self.current_eps += 1;
198+
} else if self.current_algorithm + 1 < self.dbscan_search_parameters.algorithm.len() {
199+
self.current_distance = 0;
200+
self.current_min_samples = 0;
201+
self.current_eps = 0;
202+
self.current_algorithm += 1;
203+
} else {
204+
self.current_distance += 1;
205+
self.current_min_samples += 1;
206+
self.current_eps += 1;
207+
self.current_algorithm += 1;
208+
}
209+
210+
Some(next)
211+
}
212+
}
213+
214+
impl<T: Number> Default for DBSCANSearchParameters<T, Euclidian<T>> {
215+
fn default() -> Self {
216+
let default_params = DBSCANParameters::default();
217+
218+
DBSCANSearchParameters {
219+
distance: vec![default_params.distance],
220+
min_samples: vec![default_params.min_samples],
221+
eps: vec![default_params.eps],
222+
algorithm: vec![default_params.algorithm],
223+
_phantom_t: PhantomData,
224+
}
225+
}
226+
}
227+
118228
impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: Distance<Vec<TX>>> PartialEq
119229
for DBSCAN<TX, TY, X, Y, D>
120230
{
@@ -132,7 +242,7 @@ impl<T: Number> Default for DBSCANParameters<T, Euclidian<T>> {
132242
distance: Distances::euclidian(),
133243
min_samples: 5,
134244
eps: 0.5f64,
135-
algorithm: KNNAlgorithmName::CoverTree,
245+
algorithm: KNNAlgorithmName::default(),
136246
_phantom_t: PhantomData,
137247
}
138248
}
@@ -292,6 +402,29 @@ mod tests {
292402
#[cfg(feature = "serde")]
293403
use crate::metrics::distance::euclidian::Euclidian;
294404

405+
#[test]
406+
fn search_parameters() {
407+
let parameters = DBSCANSearchParameters {
408+
min_samples: vec![10, 100],
409+
eps: vec![1., 2.],
410+
..Default::default()
411+
};
412+
let mut iter = parameters.into_iter();
413+
let next = iter.next().unwrap();
414+
assert_eq!(next.min_samples, 10);
415+
assert_eq!(next.eps, 1.);
416+
let next = iter.next().unwrap();
417+
assert_eq!(next.min_samples, 100);
418+
assert_eq!(next.eps, 1.);
419+
let next = iter.next().unwrap();
420+
assert_eq!(next.min_samples, 10);
421+
assert_eq!(next.eps, 2.);
422+
let next = iter.next().unwrap();
423+
assert_eq!(next.min_samples, 100);
424+
assert_eq!(next.eps, 2.);
425+
assert!(iter.next().is_none());
426+
}
427+
295428
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
296429
#[test]
297430
fn fit_predict_dbscan() {

0 commit comments

Comments
 (0)