-
-
Notifications
You must be signed in to change notification settings - Fork 249
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat: Implement Random Projections (#332)
* Add random projections algorithms for dimensionality reduction. Contains two algorithms based on variants on the Johnson-lindenstrauss lemma: - Random projections with Gaussian coefficients - Sparse random projections with +/- 1 coefficients (multiplied by a scaling factor). * Update readme * Add RNG to random projection structs RNG defaults to Xoshiro256Plus if not provided by user. Also added tests for minimum dimension using values from scikit-learn. * Check that random projections actually reduce the dimension of the data. * Use fixed dimension in error tests * Refactor random projections code
- Loading branch information
Showing
11 changed files
with
726 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
algorithms/linfa-reduction/examples/gaussian_projection.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
use std::{error::Error, time::Instant}; | ||
|
||
use linfa::prelude::*; | ||
use linfa_reduction::random_projection::GaussianRandomProjection; | ||
use linfa_trees::{DecisionTree, SplitQuality}; | ||
|
||
use mnist::{MnistBuilder, NormalizedMnist}; | ||
use ndarray::{Array1, Array2}; | ||
use rand::SeedableRng; | ||
use rand_xoshiro::Xoshiro256Plus; | ||
|
||
/// Train a Decision tree on the MNIST data set, with and without dimensionality reduction. | ||
fn main() -> Result<(), Box<dyn Error>> { | ||
// Parameters | ||
let train_sz = 10_000usize; | ||
let test_sz = 1_000usize; | ||
let reduced_dim = 100; | ||
let rng = Xoshiro256Plus::seed_from_u64(42); | ||
|
||
let NormalizedMnist { | ||
trn_img, | ||
trn_lbl, | ||
tst_img, | ||
tst_lbl, | ||
.. | ||
} = MnistBuilder::new() | ||
.label_format_digit() | ||
.training_set_length(train_sz as u32) | ||
.test_set_length(test_sz as u32) | ||
.download_and_extract() | ||
.finalize() | ||
.normalize(); | ||
|
||
let train_data = Array2::from_shape_vec((train_sz, 28 * 28), trn_img)?; | ||
let train_labels: Array1<usize> = | ||
Array1::from_shape_vec(train_sz, trn_lbl)?.map(|x| *x as usize); | ||
let train_dataset = Dataset::new(train_data, train_labels); | ||
|
||
let test_data = Array2::from_shape_vec((test_sz, 28 * 28), tst_img)?; | ||
let test_labels: Array1<usize> = Array1::from_shape_vec(test_sz, tst_lbl)?.map(|x| *x as usize); | ||
|
||
let params = DecisionTree::params() | ||
.split_quality(SplitQuality::Gini) | ||
.max_depth(Some(10)); | ||
|
||
println!("Training non-reduced model..."); | ||
let start = Instant::now(); | ||
let model: DecisionTree<f32, usize> = params.fit(&train_dataset)?; | ||
|
||
let end = start.elapsed(); | ||
let pred_y = model.predict(&test_data); | ||
let cm = pred_y.confusion_matrix(&test_labels)?; | ||
println!("Non-reduced model precision: {}%", 100.0 * cm.accuracy()); | ||
println!("Training time: {:.2}s\n", end.as_secs_f32()); | ||
|
||
println!("Training reduced model..."); | ||
let start = Instant::now(); | ||
// Compute the random projection and train the model on the reduced dataset. | ||
let proj = GaussianRandomProjection::<f32>::params_with_rng(rng) | ||
.target_dim(reduced_dim) | ||
.fit(&train_dataset)?; | ||
let reduced_train_ds = proj.transform(&train_dataset); | ||
let reduced_test_data = proj.transform(&test_data); | ||
let model_reduced: DecisionTree<f32, usize> = params.fit(&reduced_train_ds)?; | ||
|
||
let end = start.elapsed(); | ||
let pred_reduced = model_reduced.predict(&reduced_test_data); | ||
let cm_reduced = pred_reduced.confusion_matrix(&test_labels)?; | ||
println!( | ||
"Reduced model precision: {}%", | ||
100.0 * cm_reduced.accuracy() | ||
); | ||
println!("Reduction + training time: {:.2}s", end.as_secs_f32()); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
use std::{error::Error, time::Instant}; | ||
|
||
use linfa::prelude::*; | ||
use linfa_reduction::random_projection::SparseRandomProjection; | ||
use linfa_trees::{DecisionTree, SplitQuality}; | ||
|
||
use mnist::{MnistBuilder, NormalizedMnist}; | ||
use ndarray::{Array1, Array2}; | ||
use rand::SeedableRng; | ||
use rand_xoshiro::Xoshiro256Plus; | ||
|
||
/// Train a Decision tree on the MNIST data set, with and without dimensionality reduction. | ||
fn main() -> Result<(), Box<dyn Error>> { | ||
// Parameters | ||
let train_sz = 10_000usize; | ||
let test_sz = 1_000usize; | ||
let reduced_dim = 100; | ||
let rng = Xoshiro256Plus::seed_from_u64(42); | ||
|
||
let NormalizedMnist { | ||
trn_img, | ||
trn_lbl, | ||
tst_img, | ||
tst_lbl, | ||
.. | ||
} = MnistBuilder::new() | ||
.label_format_digit() | ||
.training_set_length(train_sz as u32) | ||
.test_set_length(test_sz as u32) | ||
.download_and_extract() | ||
.finalize() | ||
.normalize(); | ||
|
||
let train_data = Array2::from_shape_vec((train_sz, 28 * 28), trn_img)?; | ||
let train_labels: Array1<usize> = | ||
Array1::from_shape_vec(train_sz, trn_lbl)?.map(|x| *x as usize); | ||
let train_dataset = Dataset::new(train_data, train_labels); | ||
|
||
let test_data = Array2::from_shape_vec((test_sz, 28 * 28), tst_img)?; | ||
let test_labels: Array1<usize> = Array1::from_shape_vec(test_sz, tst_lbl)?.map(|x| *x as usize); | ||
|
||
let params = DecisionTree::params() | ||
.split_quality(SplitQuality::Gini) | ||
.max_depth(Some(10)); | ||
|
||
println!("Training non-reduced model..."); | ||
let start = Instant::now(); | ||
let model: DecisionTree<f32, usize> = params.fit(&train_dataset)?; | ||
|
||
let end = start.elapsed(); | ||
let pred_y = model.predict(&test_data); | ||
let cm = pred_y.confusion_matrix(&test_labels)?; | ||
println!("Non-reduced model precision: {}%", 100.0 * cm.accuracy()); | ||
println!("Training time: {:.2}s\n", end.as_secs_f32()); | ||
|
||
println!("Training reduced model..."); | ||
let start = Instant::now(); | ||
// Compute the random projection and train the model on the reduced dataset. | ||
let proj = SparseRandomProjection::<f32>::params_with_rng(rng) | ||
.target_dim(reduced_dim) | ||
.fit(&train_dataset)?; | ||
let reduced_train_ds = proj.transform(&train_dataset); | ||
let reduced_test_data = proj.transform(&test_data); | ||
let model_reduced: DecisionTree<f32, usize> = params.fit(&reduced_train_ds)?; | ||
|
||
let end = start.elapsed(); | ||
let pred_reduced = model_reduced.predict(&reduced_test_data); | ||
let cm_reduced = pred_reduced.confusion_matrix(&test_labels)?; | ||
println!( | ||
"Reduced model precision: {}%", | ||
100.0 * cm_reduced.accuracy() | ||
); | ||
println!("Reduction + training time: {:.2}s", end.as_secs_f32()); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
169 changes: 169 additions & 0 deletions
169
algorithms/linfa-reduction/src/random_projection/algorithms.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
use std::marker::PhantomData; | ||
|
||
use linfa::{ | ||
dataset::{AsTargets, FromTargetArray}, | ||
prelude::Records, | ||
traits::{Fit, Transformer}, | ||
DatasetBase, Float, | ||
}; | ||
use ndarray::{linalg::Dot, Array2, ArrayBase, Data, Ix2}; | ||
|
||
use rand::{prelude::Distribution, Rng, SeedableRng}; | ||
use rand_xoshiro::Xoshiro256Plus; | ||
|
||
use super::hyperparams::RandomProjectionParamsInner; | ||
use super::{common::johnson_lindenstrauss_min_dim, methods::ProjectionMethod}; | ||
use super::{RandomProjectionParams, RandomProjectionValidParams}; | ||
use crate::ReductionError; | ||
|
||
/// Embedding via random projection | ||
pub struct RandomProjection<Proj: ProjectionMethod, F: Float> | ||
where | ||
Proj::RandomDistribution: Distribution<F>, | ||
{ | ||
projection: Proj::ProjectionMatrix<F>, | ||
} | ||
|
||
impl<F, Proj, Rec, T, R> Fit<Rec, T, ReductionError> for RandomProjectionValidParams<Proj, R> | ||
where | ||
F: Float, | ||
Proj: ProjectionMethod, | ||
Rec: Records<Elem = F>, | ||
R: Rng + Clone, | ||
Proj::RandomDistribution: Distribution<F>, | ||
{ | ||
type Object = RandomProjection<Proj, F>; | ||
|
||
fn fit(&self, dataset: &linfa::DatasetBase<Rec, T>) -> Result<Self::Object, ReductionError> { | ||
let n_samples = dataset.nsamples(); | ||
let n_features = dataset.nfeatures(); | ||
let mut rng = self.rng.clone(); | ||
|
||
let n_dims = match &self.params { | ||
RandomProjectionParamsInner::Dimension { target_dim } => *target_dim, | ||
RandomProjectionParamsInner::Epsilon { eps } => { | ||
johnson_lindenstrauss_min_dim(n_samples, *eps) | ||
} | ||
}; | ||
|
||
if n_dims > n_features { | ||
return Err(ReductionError::DimensionIncrease(n_dims, n_features)); | ||
} | ||
|
||
let projection = Proj::generate_matrix(n_features, n_dims, &mut rng)?; | ||
|
||
Ok(RandomProjection { projection }) | ||
} | ||
} | ||
|
||
impl<Proj: ProjectionMethod, F: Float> RandomProjection<Proj, F> | ||
where | ||
Proj::RandomDistribution: Distribution<F>, | ||
{ | ||
/// Create new parameters for a [`RandomProjection`] with default value | ||
/// `eps = 0.1` and a [`Xoshiro256Plus`] RNG. | ||
pub fn params() -> RandomProjectionParams<Proj, Xoshiro256Plus> { | ||
RandomProjectionParams(RandomProjectionValidParams { | ||
params: RandomProjectionParamsInner::Epsilon { eps: 0.1 }, | ||
rng: Xoshiro256Plus::seed_from_u64(42), | ||
marker: PhantomData, | ||
}) | ||
} | ||
|
||
/// Create new parameters for a [`RandomProjection`] with default values | ||
/// `eps = 0.1` and the provided [`Rng`]. | ||
pub fn params_with_rng<R>(rng: R) -> RandomProjectionParams<Proj, R> | ||
where | ||
R: Rng + Clone, | ||
{ | ||
RandomProjectionParams(RandomProjectionValidParams { | ||
params: RandomProjectionParamsInner::Epsilon { eps: 0.1 }, | ||
rng, | ||
marker: PhantomData, | ||
}) | ||
} | ||
} | ||
|
||
impl<Proj, F, D> Transformer<&ArrayBase<D, Ix2>, Array2<F>> for RandomProjection<Proj, F> | ||
where | ||
Proj: ProjectionMethod, | ||
F: Float, | ||
D: Data<Elem = F>, | ||
ArrayBase<D, Ix2>: Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>, | ||
Proj::RandomDistribution: Distribution<F>, | ||
{ | ||
/// Compute the embedding of a two-dimensional array | ||
fn transform(&self, x: &ArrayBase<D, Ix2>) -> Array2<F> { | ||
x.dot(&self.projection) | ||
} | ||
} | ||
|
||
impl<Proj, F, D> Transformer<ArrayBase<D, Ix2>, Array2<F>> for RandomProjection<Proj, F> | ||
where | ||
Proj: ProjectionMethod, | ||
F: Float, | ||
D: Data<Elem = F>, | ||
ArrayBase<D, Ix2>: Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>, | ||
Proj::RandomDistribution: Distribution<F>, | ||
{ | ||
/// Compute the embedding of a two-dimensional array | ||
fn transform(&self, x: ArrayBase<D, Ix2>) -> Array2<F> { | ||
self.transform(&x) | ||
} | ||
} | ||
|
||
impl<Proj, F, T> Transformer<DatasetBase<Array2<F>, T>, DatasetBase<Array2<F>, T>> | ||
for RandomProjection<Proj, F> | ||
where | ||
Proj: ProjectionMethod, | ||
F: Float, | ||
T: AsTargets, | ||
for<'a> ArrayBase<ndarray::ViewRepr<&'a F>, Ix2>: | ||
Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>, | ||
Proj::RandomDistribution: Distribution<F>, | ||
{ | ||
/// Compute the embedding of a dataset | ||
/// | ||
/// # Parameter | ||
/// | ||
/// * `data`: a dataset | ||
/// | ||
/// # Returns | ||
/// | ||
/// New dataset, with data equal to the embedding of the input data | ||
fn transform(&self, data: DatasetBase<Array2<F>, T>) -> DatasetBase<Array2<F>, T> { | ||
let new_records = self.transform(data.records().view()); | ||
|
||
DatasetBase::new(new_records, data.targets) | ||
} | ||
} | ||
|
||
impl<'a, Proj, F, L, T> Transformer<&'a DatasetBase<Array2<F>, T>, DatasetBase<Array2<F>, T::View>> | ||
for RandomProjection<Proj, F> | ||
where | ||
Proj: ProjectionMethod, | ||
F: Float, | ||
L: 'a, | ||
T: AsTargets<Elem = L> + FromTargetArray<'a>, | ||
for<'b> ArrayBase<ndarray::ViewRepr<&'b F>, Ix2>: | ||
Dot<Proj::ProjectionMatrix<F>, Output = Array2<F>>, | ||
Proj::RandomDistribution: Distribution<F>, | ||
{ | ||
/// Compute the embedding of a dataset | ||
/// | ||
/// # Parameter | ||
/// | ||
/// * `data`: a dataset | ||
/// | ||
/// # Returns | ||
/// | ||
/// New dataset, with data equal to the embedding of the input data | ||
fn transform(&self, data: &'a DatasetBase<Array2<F>, T>) -> DatasetBase<Array2<F>, T::View> { | ||
let new_records = self.transform(data.records().view()); | ||
|
||
DatasetBase::new( | ||
new_records, | ||
T::new_targets_view(AsTargets::as_targets(data)), | ||
) | ||
} | ||
} |
Oops, something went wrong.