Skip to content

grid search #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions src/linear/elastic_net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,121 @@ impl<T: RealNumber> Default for ElasticNetParameters<T> {
}
}

/// ElasticNet grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct ElasticNetSearchParameters<T: RealNumber> {
/// Regularization parameter.
pub alpha: Vec<T>,
/// The elastic net mixing parameter, with 0 <= l1_ratio <= 1.
/// For l1_ratio = 0 the penalty is an L2 penalty.
/// For l1_ratio = 1 it is an L1 penalty. For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
pub l1_ratio: Vec<T>,
/// If True, the regressors X will be normalized before regression by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
/// The tolerance for the optimization
pub tol: Vec<T>,
/// The maximum number of iterations
pub max_iter: Vec<usize>,
}

/// ElasticNet grid search iterator
pub struct ElasticNetSearchParametersIterator<T: RealNumber> {
lasso_regression_search_parameters: ElasticNetSearchParameters<T>,
current_alpha: usize,
current_l1_ratio: usize,
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
}

impl<T: RealNumber> IntoIterator for ElasticNetSearchParameters<T> {
type Item = ElasticNetParameters<T>;
type IntoIter = ElasticNetSearchParametersIterator<T>;

fn into_iter(self) -> Self::IntoIter {
ElasticNetSearchParametersIterator {
lasso_regression_search_parameters: self,
current_alpha: 0,
current_l1_ratio: 0,
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
}
}
}

impl<T: RealNumber> Iterator for ElasticNetSearchParametersIterator<T> {
type Item = ElasticNetParameters<T>;

fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
&& self.current_l1_ratio == self.lasso_regression_search_parameters.l1_ratio.len()
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
{
return None;
}

let next = ElasticNetParameters {
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
l1_ratio: self.lasso_regression_search_parameters.alpha[self.current_l1_ratio],
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
};

if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_l1_ratio + 1 < self.lasso_regression_search_parameters.l1_ratio.len()
{
self.current_alpha = 0;
self.current_l1_ratio += 1;
} else if self.current_normalize + 1
< self.lasso_regression_search_parameters.normalize.len()
{
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize += 1;
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize = 0;
self.current_tol += 1;
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
{
self.current_alpha = 0;
self.current_l1_ratio = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else {
self.current_alpha += 1;
self.current_l1_ratio += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
}

Some(next)
}
}

impl<T: RealNumber> Default for ElasticNetSearchParameters<T> {
fn default() -> Self {
let default_params = ElasticNetParameters::default();

ElasticNetSearchParameters {
alpha: vec![default_params.alpha],
l1_ratio: vec![default_params.l1_ratio],
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
}
}
}

impl<T: RealNumber, M: Matrix<T>> PartialEq for ElasticNet<T, M> {
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients
Expand Down Expand Up @@ -291,6 +406,29 @@ mod tests {
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error;

#[test]
fn search_parameters() {
let parameters = ElasticNetSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 100);
assert!(iter.next().is_none());
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn elasticnet_longley() {
Expand Down
122 changes: 122 additions & 0 deletions src/linear/lasso.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,105 @@ impl<T: RealNumber, M: Matrix<T>> Predictor<M, M::RowVector> for Lasso<T, M> {
}
}

/// Lasso grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LassoSearchParameters<T: RealNumber> {
/// Controls the strength of the penalty to the loss function.
pub alpha: Vec<T>,
/// If true the regressors X will be normalized before regression
/// by subtracting the mean and dividing by the standard deviation.
pub normalize: Vec<bool>,
/// The tolerance for the optimization
pub tol: Vec<T>,
/// The maximum number of iterations
pub max_iter: Vec<usize>,
}

/// Lasso grid search iterator
pub struct LassoSearchParametersIterator<T: RealNumber> {
lasso_regression_search_parameters: LassoSearchParameters<T>,
current_alpha: usize,
current_normalize: usize,
current_tol: usize,
current_max_iter: usize,
}

impl<T: RealNumber> IntoIterator for LassoSearchParameters<T> {
type Item = LassoParameters<T>;
type IntoIter = LassoSearchParametersIterator<T>;

fn into_iter(self) -> Self::IntoIter {
LassoSearchParametersIterator {
lasso_regression_search_parameters: self,
current_alpha: 0,
current_normalize: 0,
current_tol: 0,
current_max_iter: 0,
}
}
}

impl<T: RealNumber> Iterator for LassoSearchParametersIterator<T> {
type Item = LassoParameters<T>;

fn next(&mut self) -> Option<Self::Item> {
if self.current_alpha == self.lasso_regression_search_parameters.alpha.len()
&& self.current_normalize == self.lasso_regression_search_parameters.normalize.len()
&& self.current_tol == self.lasso_regression_search_parameters.tol.len()
&& self.current_max_iter == self.lasso_regression_search_parameters.max_iter.len()
{
return None;
}

let next = LassoParameters {
alpha: self.lasso_regression_search_parameters.alpha[self.current_alpha],
normalize: self.lasso_regression_search_parameters.normalize[self.current_normalize],
tol: self.lasso_regression_search_parameters.tol[self.current_tol],
max_iter: self.lasso_regression_search_parameters.max_iter[self.current_max_iter],
};

if self.current_alpha + 1 < self.lasso_regression_search_parameters.alpha.len() {
self.current_alpha += 1;
} else if self.current_normalize + 1
< self.lasso_regression_search_parameters.normalize.len()
{
self.current_alpha = 0;
self.current_normalize += 1;
} else if self.current_tol + 1 < self.lasso_regression_search_parameters.tol.len() {
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol += 1;
} else if self.current_max_iter + 1 < self.lasso_regression_search_parameters.max_iter.len()
{
self.current_alpha = 0;
self.current_normalize = 0;
self.current_tol = 0;
self.current_max_iter += 1;
} else {
self.current_alpha += 1;
self.current_normalize += 1;
self.current_tol += 1;
self.current_max_iter += 1;
}

Some(next)
}
}

impl<T: RealNumber> Default for LassoSearchParameters<T> {
fn default() -> Self {
let default_params = LassoParameters::default();

LassoSearchParameters {
alpha: vec![default_params.alpha],
normalize: vec![default_params.normalize],
tol: vec![default_params.tol],
max_iter: vec![default_params.max_iter],
}
}
}

impl<T: RealNumber, M: Matrix<T>> Lasso<T, M> {
/// Fits Lasso regression to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
Expand Down Expand Up @@ -226,6 +325,29 @@ mod tests {
use crate::linalg::naive::dense_matrix::*;
use crate::metrics::mean_absolute_error;

#[test]
fn search_parameters() {
let parameters = LassoSearchParameters {
alpha: vec![0., 1.],
max_iter: vec![10, 100],
..Default::default()
};
let mut iter = parameters.into_iter();
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 10);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 0.);
assert_eq!(next.max_iter, 100);
let next = iter.next().unwrap();
assert_eq!(next.alpha, 1.);
assert_eq!(next.max_iter, 100);
assert!(iter.next().is_none());
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn lasso_fit_predict() {
Expand Down
70 changes: 69 additions & 1 deletion src/linear/linear_regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ use crate::linalg::Matrix;
use crate::math::num::RealNumber;

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Eq, PartialEq)]
/// Approach to use for estimation of regression coefficients. QR is more efficient but SVD is more stable.
pub enum LinearRegressionSolverName {
/// QR decomposition, see [QR](../../linalg/qr/index.html)
Expand Down Expand Up @@ -113,6 +113,60 @@ impl Default for LinearRegressionParameters {
}
}

/// Linear Regression grid search parameters
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct LinearRegressionSearchParameters {
/// Solver to use for estimation of regression coefficients.
pub solver: Vec<LinearRegressionSolverName>,
}

/// Linear Regression grid search iterator
pub struct LinearRegressionSearchParametersIterator {
linear_regression_search_parameters: LinearRegressionSearchParameters,
current_solver: usize,
}

impl IntoIterator for LinearRegressionSearchParameters {
type Item = LinearRegressionParameters;
type IntoIter = LinearRegressionSearchParametersIterator;

fn into_iter(self) -> Self::IntoIter {
LinearRegressionSearchParametersIterator {
linear_regression_search_parameters: self,
current_solver: 0,
}
}
}

impl Iterator for LinearRegressionSearchParametersIterator {
type Item = LinearRegressionParameters;

fn next(&mut self) -> Option<Self::Item> {
if self.current_solver == self.linear_regression_search_parameters.solver.len() {
return None;
}

let next = LinearRegressionParameters {
solver: self.linear_regression_search_parameters.solver[self.current_solver].clone(),
};

self.current_solver += 1;

Some(next)
}
}

impl Default for LinearRegressionSearchParameters {
fn default() -> Self {
let default_params = LinearRegressionParameters::default();

LinearRegressionSearchParameters {
solver: vec![default_params.solver],
}
}
}

impl<T: RealNumber, M: Matrix<T>> PartialEq for LinearRegression<T, M> {
fn eq(&self, other: &Self) -> bool {
self.coefficients == other.coefficients
Expand Down Expand Up @@ -200,6 +254,20 @@ mod tests {
use super::*;
use crate::linalg::naive::dense_matrix::*;

#[test]
fn search_parameters() {
let parameters = LinearRegressionSearchParameters {
solver: vec![
LinearRegressionSolverName::QR,
LinearRegressionSolverName::SVD,
],
};
let mut iter = parameters.into_iter();
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::QR);
assert_eq!(iter.next().unwrap().solver, LinearRegressionSolverName::SVD);
assert!(iter.next().is_none());
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn ols_fit_predict() {
Expand Down
Loading