-
-
Notifications
You must be signed in to change notification settings - Fork 249
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Linear decision trees improvements (#60)
* add debug trait to structs * initial commit for random forest implementation * WIP first iter of random forest * add benches * setup bench and cleanup * cleanup * add max_n_rows for single decision tree fitting * implement random forest feature importance as collection of features from decision trees * implement random forest feature importance as collection of features from decision trees * remove unused var * remove unused var * run clippy * assert test success for feature importance * clippy and fmt * store references of nodes to queue * WIP voting classifier and predictor trait * WIP voting classifier and predictor trait * implement and test VotingClassifier hard voting * implement predict_proba for random forest and tested * documentation, examples, cleanup * cleanup * implement LinfaError for Predictor trait * fixed tests and CI/CD pipeline * renamed predict_classes to predict in logreg for consistency * implement ProbabilisticPredictor whenever needed * votingclassifier implements predictor trait * PR-43 Moss comments addressed * Switch `linfa-tree` to new infrastructure * Experiment with interface * Add argmax ensemble classifier * Run fmt * Add test with random noise * Customize decision trees with weights * use weighting of dataset * Remove unnecessary casting * Compare weight in splits with hyperparams * Rename _samples hyperparams * Fix cargo fmt lint? * Shush random forest example for time being * Added new test for perfectly separable data * Appease clippy * Fix error in test * use midpoint * skip equal values until new value is encountered * Run cargo fmt * Add max_depth function for decision trees * Add impurity decrease function * Add mean impurity decrease * Add more tests to linear decision trees * use toy test from sklearn * use four perfectly separable uniform blobs * Remove number of classes hyper-parameter This hyper-parameter can be estimated from the input data and is therefore uneccessary in the API. * Remove ensemble algorithm * Address issue with toy test * Simplify tree inspection methods * introduce node iterator * rewrite `max_depth`, `num_leaves`, `features` in iterator syntax * Add max depth testing and hyperparameter validation * Fix parameter syntax in benchmarks * Run cargo fmt * Add tikz export builder * Improve decision tree formatting * Add pruning * Adjust syntax of tikz snippet * Run cargo fmt * Run cargo fmt Co-authored-by: Francesco Gadaleta <francesco@amethix.com> Co-authored-by: francesco <francesco.gadaleta@gmail.com> Co-authored-by: moss <mossbanay@gmail.com>
- Loading branch information
1 parent
a3eede5
commit bfa5aeb
Showing
17 changed files
with
941 additions
and
413 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,122 +1,87 @@ | ||
use linfa_trees::{DecisionTree, DecisionTreeParams, SplitQuality}; | ||
use ndarray::{array, s, Array, Array2, ArrayBase, Data, Ix1, Ix2}; | ||
use ndarray_rand::rand::Rng; | ||
use std::fs::File; | ||
use std::io::Write; | ||
|
||
use ndarray::{array, stack, Array, Array1, Array2, Axis}; | ||
use ndarray_rand::rand::SeedableRng; | ||
use ndarray_rand::rand_distr::StandardNormal; | ||
use ndarray_rand::RandomExt; | ||
use rand_isaac::Isaac64Rng; | ||
use std::iter::FromIterator; | ||
|
||
/// Given an input matrix `blob_centroids`, with shape `(n_blobs, n_features)`, | ||
/// generate `blob_size` data points (a "blob") around each of the blob centroids. | ||
/// | ||
/// More specifically, each blob is formed by `blob_size` points sampled from a normal | ||
/// distribution centered in the blob centroid with unit variance. | ||
/// | ||
/// `generate_blobs` can be used to quickly assemble a synthetic dataset to test or | ||
/// benchmark various clustering algorithms on a best-case scenario input. | ||
pub fn generate_blobs( | ||
blob_size: usize, | ||
blob_centroids: &ArrayBase<impl Data<Elem = f64>, Ix2>, | ||
rng: &mut impl Rng, | ||
) -> Array2<f64> { | ||
let (n_centroids, n_features) = blob_centroids.dim(); | ||
let mut blobs: Array2<f64> = Array2::zeros((n_centroids * blob_size, n_features)); | ||
|
||
for (blob_index, blob_centroid) in blob_centroids.genrows().into_iter().enumerate() { | ||
let blob = generate_blob(blob_size, &blob_centroid, rng); | ||
|
||
let indexes = s![blob_index * blob_size..(blob_index + 1) * blob_size, ..]; | ||
blobs.slice_mut(indexes).assign(&blob); | ||
} | ||
blobs | ||
} | ||
|
||
/// Generate `blob_size` data points (a "blob") around `blob_centroid`. | ||
/// | ||
/// More specifically, the blob is formed by `blob_size` points sampled from a normal | ||
/// distribution centered in `blob_centroid` with unit variance. | ||
/// | ||
/// `generate_blob` can be used to quickly assemble a synthetic stereotypical cluster. | ||
pub fn generate_blob( | ||
blob_size: usize, | ||
blob_centroid: &ArrayBase<impl Data<Elem = f64>, Ix1>, | ||
rng: &mut impl Rng, | ||
) -> Array2<f64> { | ||
let shape = (blob_size, blob_centroid.len()); | ||
let origin_blob: Array2<f64> = Array::random_using(shape, StandardNormal, rng); | ||
origin_blob + blob_centroid | ||
} | ||
use linfa::prelude::*; | ||
use linfa_trees::{DecisionTree, SplitQuality}; | ||
|
||
fn accuracy( | ||
labels: &ArrayBase<impl Data<Elem = u64>, Ix1>, | ||
pred: &ArrayBase<impl Data<Elem = u64>, Ix1>, | ||
) -> f64 { | ||
let true_positive: f64 = labels | ||
.iter() | ||
.zip(pred.iter()) | ||
.filter(|(x, y)| x == y) | ||
.map(|_| 1.0) | ||
.sum(); | ||
true_positive / labels.len() as f64 | ||
fn generate_blobs(means: &[(f64, f64)], samples: usize, mut rng: &mut Isaac64Rng) -> Array2<f64> { | ||
let out = means | ||
.into_iter() | ||
.map(|mean| { | ||
Array::random_using((samples, 2), StandardNormal, &mut rng) + array![mean.0, mean.1] | ||
}) | ||
.collect::<Vec<_>>(); | ||
let out2 = out.iter().map(|x| x.view()).collect::<Vec<_>>(); | ||
|
||
stack(Axis(0), &out2).unwrap() | ||
} | ||
|
||
fn main() { | ||
// Our random number generator, seeded for reproducibility | ||
let mut rng = Isaac64Rng::seed_from_u64(42); | ||
|
||
// For each our expected centroids, generate `n` data points around it (a "blob") | ||
let n_classes: u64 = 4; | ||
let expected_centroids = array![[0., 0.], [1., 4.], [-5., 0.], [4., 4.]]; | ||
let n = 100; | ||
let n_classes: usize = 4; | ||
let n = 300; | ||
|
||
println!("Generating training data"); | ||
|
||
let train_x = generate_blobs(n, &expected_centroids, &mut rng); | ||
let train_y = Array::from_iter( | ||
(0..n_classes) | ||
.map(|x| std::iter::repeat(x).take(n).collect::<Vec<u64>>()) | ||
.flatten(), | ||
); | ||
let train_x = generate_blobs(&[(0., 0.), (1., 4.), (-5., 0.), (4., 4.)], n, &mut rng); | ||
let train_y = (0..n_classes) | ||
.map(|x| std::iter::repeat(x).take(n).collect::<Vec<_>>()) | ||
.flatten() | ||
.collect::<Array1<_>>(); | ||
|
||
let test_x = generate_blobs(n, &expected_centroids, &mut rng); | ||
let test_y = Array::from_iter( | ||
(0..n_classes) | ||
.map(|x| std::iter::repeat(x).take(n).collect::<Vec<u64>>()) | ||
.flatten(), | ||
); | ||
|
||
println!("Generated training data"); | ||
let dataset = Dataset::new(train_x, train_y).shuffle(&mut rng); | ||
let (train, test) = dataset.split_with_ratio(0.9); | ||
|
||
println!("Training model with Gini criterion ..."); | ||
let gini_hyperparams = DecisionTreeParams::new(n_classes) | ||
let gini_model = DecisionTree::params() | ||
.split_quality(SplitQuality::Gini) | ||
.max_depth(Some(100)) | ||
.min_samples_split(10) | ||
.min_samples_leaf(10) | ||
.build(); | ||
.min_weight_split(10.0) | ||
.min_weight_leaf(10.0) | ||
.fit(&train); | ||
|
||
let gini_pred_y = gini_model.predict(test.records().view()); | ||
let cm = gini_pred_y.confusion_matrix(&test); | ||
|
||
let gini_model = DecisionTree::fit(gini_hyperparams, &train_x, &train_y); | ||
println!("{:?}", cm); | ||
|
||
let gini_pred_y = gini_model.predict(&test_x); | ||
println!( | ||
"Test accuracy with Gini criterion: {:.2}%", | ||
100.0 * accuracy(&test_y, &gini_pred_y) | ||
100.0 * cm.accuracy() | ||
); | ||
|
||
println!("Training model with entropy criterion ..."); | ||
let entropy_hyperparams = DecisionTreeParams::new(n_classes) | ||
let entropy_model = DecisionTree::params() | ||
.split_quality(SplitQuality::Entropy) | ||
.max_depth(Some(100)) | ||
.min_samples_split(10) | ||
.min_samples_leaf(10) | ||
.build(); | ||
.min_weight_split(10.0) | ||
.min_weight_leaf(10.0) | ||
.fit(&train); | ||
|
||
let entropy_pred_y = gini_model.predict(test.records().view()); | ||
let cm = entropy_pred_y.confusion_matrix(&test); | ||
|
||
let entropy_model = DecisionTree::fit(entropy_hyperparams, &train_x, &train_y); | ||
println!("{:?}", cm); | ||
|
||
let entropy_pred_y = entropy_model.predict(&test_x); | ||
println!( | ||
"Test accuracy with Entropy criterion: {:.2}%", | ||
100.0 * accuracy(&test_y, &entropy_pred_y) | ||
100.0 * cm.accuracy() | ||
); | ||
|
||
let feats = entropy_model.features(); | ||
println!("Features trained in this tree {:?}", feats); | ||
|
||
let mut tikz = File::create("decision_tree_example.tex").unwrap(); | ||
tikz.write(gini_model.export_to_tikz().to_string().as_bytes()) | ||
.unwrap(); | ||
println!(" => generate tree description with `latex decision_tree_example.tex`!"); | ||
} |
Oops, something went wrong.