Skip to content

Commit 1e85bee

Browse files
author
Montana Low
committed
create defaults for enums
1 parent f34d2dc commit 1e85bee

File tree

12 files changed

+49
-11
lines changed

12 files changed

+49
-11
lines changed

src/algorithm/neighbour/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ pub enum KNNAlgorithmName {
5959
CoverTree,
6060
}
6161

62+
impl Default for KNNAlgorithmName {
63+
fn default() -> Self {
64+
KNNAlgorithmName::CoverTree
65+
}
66+
}
67+
6268
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6369
#[derive(Debug)]
6470
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {

src/cluster/dbscan.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub struct DBSCAN<T: RealNumber, D: Distance<Vec<T>, T>> {
6565
eps: T,
6666
}
6767

68+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
6869
#[derive(Debug, Clone)]
6970
/// DBSCAN clustering algorithm parameters
7071
pub struct DBSCANParameters<T: RealNumber, D: Distance<Vec<T>, T>> {
@@ -229,7 +230,7 @@ impl<T: RealNumber> Default for DBSCANParameters<T, Euclidian> {
229230
distance: Distances::euclidian(),
230231
min_samples: 5,
231232
eps: T::half(),
232-
algorithm: KNNAlgorithmName::CoverTree,
233+
algorithm: KNNAlgorithmName::default(),
233234
}
234235
}
235236
}

src/cluster/kmeans.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ impl<T: RealNumber> PartialEq for KMeans<T> {
102102
}
103103
}
104104

105+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
105106
#[derive(Debug, Clone)]
106107
/// K-Means clustering algorithm parameters
107108
pub struct KMeansParameters {

src/decomposition/pca.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for PCA<T, M> {
8383
}
8484
}
8585

86+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8687
#[derive(Debug, Clone)]
8788
/// PCA parameters
8889
pub struct PCAParameters {

src/decomposition/svd.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ impl<T: RealNumber, M: Matrix<T>> PartialEq for SVD<T, M> {
6969
}
7070
}
7171

72+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7273
#[derive(Debug, Clone)]
7374
/// SVD parameters
7475
pub struct SVDParameters {

src/linear/linear_regression.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ pub enum LinearRegressionSolverName {
8080
SVD,
8181
}
8282

83+
impl Default for LinearRegressionSolverName {
84+
fn default() -> Self {
85+
LinearRegressionSolverName::SVD
86+
}
87+
}
88+
8389
/// Linear Regression parameters
8490
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8591
#[derive(Debug, Clone)]
@@ -109,7 +115,7 @@ impl LinearRegressionParameters {
109115
impl Default for LinearRegressionParameters {
110116
fn default() -> Self {
111117
LinearRegressionParameters {
112-
solver: LinearRegressionSolverName::SVD,
118+
solver: LinearRegressionSolverName::default(),
113119
}
114120
}
115121
}

src/linear/logistic_regression.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ pub enum LogisticRegressionSolverName {
7575
LBFGS,
7676
}
7777

78+
impl Default for LogisticRegressionSolverName {
79+
fn default() -> Self {
80+
LogisticRegressionSolverName::LBFGS
81+
}
82+
}
83+
7884
/// Logistic Regression parameters
7985
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8086
#[derive(Debug, Clone)]
@@ -208,7 +214,7 @@ impl<T: RealNumber> LogisticRegressionParameters<T> {
208214
impl<T: RealNumber> Default for LogisticRegressionParameters<T> {
209215
fn default() -> Self {
210216
LogisticRegressionParameters {
211-
solver: LogisticRegressionSolverName::LBFGS,
217+
solver: LogisticRegressionSolverName::default(),
212218
alpha: T::zero(),
213219
}
214220
}

src/linear/ridge_regression.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,18 @@ use crate::math::num::RealNumber;
7171
#[derive(Debug, Clone, Eq, PartialEq)]
7272
/// Approach to use for estimation of regression coefficients. Cholesky is more efficient but SVD is more stable.
7373
pub enum RidgeRegressionSolverName {
74-
#[cfg_attr(feature = "serde", serde(default))]
7574
/// Cholesky decomposition, see [Cholesky](../../linalg/cholesky/index.html)
7675
Cholesky,
77-
#[cfg_attr(feature = "serde", serde(default))]
7876
/// SVD decomposition, see [SVD](../../linalg/svd/index.html)
7977
SVD,
8078
}
8179

80+
impl Default for RidgeRegressionSolverName {
81+
fn default() -> Self {
82+
RidgeRegressionSolverName::Cholesky
83+
}
84+
}
85+
8286
/// Ridge Regression parameters
8387
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
8488
#[derive(Debug, Clone)]
@@ -209,7 +213,7 @@ impl<T: RealNumber> RidgeRegressionParameters<T> {
209213
impl<T: RealNumber> Default for RidgeRegressionParameters<T> {
210214
fn default() -> Self {
211215
RidgeRegressionParameters {
212-
solver: RidgeRegressionSolverName::Cholesky,
216+
solver: RidgeRegressionSolverName::default(),
213217
alpha: T::one(),
214218
normalize: true,
215219
}

src/neighbors/knn_classifier.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ impl<T: RealNumber> Default for KNNClassifierParameters<T, Euclidian> {
116116
fn default() -> Self {
117117
KNNClassifierParameters {
118118
distance: Distances::euclidian(),
119-
algorithm: KNNAlgorithmName::CoverTree,
120-
weight: KNNWeightFunction::Uniform,
119+
algorithm: KNNAlgorithmName::default(),
120+
weight: KNNWeightFunction::default(),
121121
k: 3,
122122
t: PhantomData,
123123
}

src/neighbors/knn_regressor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ impl<T: RealNumber> Default for KNNRegressorParameters<T, Euclidian> {
118118
fn default() -> Self {
119119
KNNRegressorParameters {
120120
distance: Distances::euclidian(),
121-
algorithm: KNNAlgorithmName::CoverTree,
122-
weight: KNNWeightFunction::Uniform,
121+
algorithm: KNNAlgorithmName::default(),
122+
weight: KNNWeightFunction::default(),
123123
k: 3,
124124
t: PhantomData,
125125
}

src/neighbors/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ pub enum KNNWeightFunction {
5858
Distance,
5959
}
6060

61+
impl Default for KNNWeightFunction {
62+
fn default() -> Self {
63+
KNNWeightFunction::Uniform
64+
}
65+
}
66+
6167
impl KNNWeightFunction {
6268
fn calc_weights<T: RealNumber>(&self, distances: Vec<T>) -> std::vec::Vec<T> {
6369
match *self {

src/tree/decision_tree_classifier.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ pub enum SplitCriterion {
123123
ClassificationError,
124124
}
125125

126+
impl Default for SplitCriterion {
127+
fn default() -> Self {
128+
SplitCriterion::Gini
129+
}
130+
}
131+
126132
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
127133
#[derive(Debug)]
128134
struct Node<T: RealNumber> {
@@ -201,7 +207,7 @@ impl DecisionTreeClassifierParameters {
201207
impl Default for DecisionTreeClassifierParameters {
202208
fn default() -> Self {
203209
DecisionTreeClassifierParameters {
204-
criterion: SplitCriterion::Gini,
210+
criterion: SplitCriterion::default(),
205211
max_depth: None,
206212
min_samples_leaf: 1,
207213
min_samples_split: 2,

0 commit comments

Comments
 (0)