Skip to content

Lmm/add seeds in more algorithms #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@ name: CI

on:
push:
branches: [ main, development ]
branches: [main, development]
pull_request:
branches: [ development ]
branches: [development]

jobs:
tests:
runs-on: "${{ matrix.platform.os }}-latest"
strategy:
matrix:
platform: [
{ os: "windows", target: "x86_64-pc-windows-msvc" },
{ os: "windows", target: "i686-pc-windows-msvc" },
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
{ os: "macos", target: "aarch64-apple-darwin" },
]
platform:
[
{ os: "windows", target: "x86_64-pc-windows-msvc" },
{ os: "windows", target: "i686-pc-windows-msvc" },
{ os: "ubuntu", target: "x86_64-unknown-linux-gnu" },
{ os: "ubuntu", target: "i686-unknown-linux-gnu" },
{ os: "ubuntu", target: "wasm32-unknown-unknown" },
{ os: "macos", target: "aarch64-apple-darwin" },
]
env:
TZ: "/usr/share/zoneinfo/your/location"
steps:
Expand All @@ -40,7 +41,7 @@ jobs:
default: true
- name: Install test runner for wasm
if: matrix.platform.target == 'wasm32-unknown-unknown'
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh
- name: Stable Build
uses: actions-rs/cargo@v1
with:
Expand Down
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## Added
- Seeds to multiple algorithims that depend on random number generation.
- Added feature `js` to use WASM in browser

## BREAKING CHANGE
- Added a new parameter to `train_test_split` to define the seed.

## [0.2.1] - 2022-05-10

## Added
- L2 regularization penalty to the Logistic Regression
- Getters for the naive bayes structs
Expand Down
10 changes: 7 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,25 @@ categories = ["science"]
default = ["datasets"]
ndarray-bindings = ["ndarray"]
nalgebra-bindings = ["nalgebra"]
datasets = ["rand_distr"]
datasets = ["rand_distr", "std"]
fp_bench = ["itertools"]
std = ["rand/std", "rand/std_rng"]
# wasm32 only
js = ["getrandom/js"]

[dependencies]
ndarray = { version = "0.15", optional = true }
nalgebra = { version = "0.31", optional = true }
num-traits = "0.2"
num = "0.4"
rand = "0.8"
rand = { version = "0.8", default-features = false, features = ["small_rng"] }
rand_distr = { version = "0.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
itertools = { version = "0.10.3", optional = true }
cfg-if = "1.0.0"

[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2", features = ["js"] }
getrandom = { version = "0.2", optional = true }

[dev-dependencies]
smartcore = { path = ".", features = ["fp_bench"] }
Expand Down
13 changes: 9 additions & 4 deletions src/cluster/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
//! * ["An Introduction to Statistical Learning", James G., Witten D., Hastie T., Tibshirani R., 10.3.1 K-Means Clustering](http://faculty.marshall.usc.edu/gareth-james/ISL/)
//! * ["k-means++: The Advantages of Careful Seeding", Arthur D., Vassilvitskii S.](http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf)

use rand::Rng;
use std::fmt::Debug;
use std::iter::Sum;

use ::rand::Rng;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand All @@ -65,6 +65,7 @@ use crate::error::Failed;
use crate::linalg::Matrix;
use crate::math::distance::euclidian::*;
use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;

/// K-Means clustering algorithm
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down Expand Up @@ -108,6 +109,9 @@ pub struct KMeansParameters {
pub k: usize,
/// Maximum number of iterations of the k-means algorithm for a single run.
pub max_iter: usize,
/// Determines random number generation for centroid initialization.
/// Use an int to make the randomness deterministic
pub seed: Option<u64>,
}

impl KMeansParameters {
Expand All @@ -128,6 +132,7 @@ impl Default for KMeansParameters {
KMeansParameters {
k: 2,
max_iter: 100,
seed: None,
}
}
}
Expand Down Expand Up @@ -168,7 +173,7 @@ impl<T: RealNumber + Sum> KMeans<T> {
let (n, d) = data.shape();

let mut distortion = T::max_value();
let mut y = KMeans::kmeans_plus_plus(data, parameters.k);
let mut y = KMeans::kmeans_plus_plus(data, parameters.k, parameters.seed);
let mut size = vec![0; parameters.k];
let mut centroids = vec![vec![T::zero(); d]; parameters.k];

Expand Down Expand Up @@ -241,8 +246,8 @@ impl<T: RealNumber + Sum> KMeans<T> {
Ok(result.to_row_vector())
}

fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize) -> Vec<usize> {
let mut rng = rand::thread_rng();
fn kmeans_plus_plus<M: Matrix<T>>(data: &M, k: usize, seed: Option<u64>) -> Vec<usize> {
let mut rng = get_rng_impl(seed);
let (n, m) = data.shape();
let mut y = vec![0; n];
let mut centroid = data.get_row_as_vec(rng.gen_range(0..n));
Expand Down
11 changes: 6 additions & 5 deletions src/ensemble/random_forest_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
//!
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand::Rng;

use std::default::Default;
use std::fmt::Debug;

Expand All @@ -57,6 +57,7 @@ use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError};
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
use crate::tree::decision_tree_classifier::{
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
};
Expand Down Expand Up @@ -221,7 +222,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
.unwrap()
});

let mut rng = StdRng::seed_from_u64(parameters.seed);
let mut rng = get_rng_impl(Some(parameters.seed));
let classes = y_m.unique();
let k = classes.len();
let mut trees: Vec<DecisionTreeClassifier<T>> = Vec::new();
Expand All @@ -242,9 +243,9 @@ impl<T: RealNumber> RandomForestClassifier<T> {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
};
let tree =
DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
let tree = DecisionTreeClassifier::fit_weak_learner(x, y, samples, mtry, params)?;
trees.push(tree);
}

Expand Down
11 changes: 6 additions & 5 deletions src/ensemble/random_forest_regressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>

use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use rand::Rng;

use std::default::Default;
use std::fmt::Debug;

Expand All @@ -55,6 +55,7 @@ use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError};
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::rand::get_rng_impl;
use crate::tree::decision_tree_regressor::{
DecisionTreeRegressor, DecisionTreeRegressorParameters,
};
Expand Down Expand Up @@ -191,7 +192,7 @@ impl<T: RealNumber> RandomForestRegressor<T> {
.m
.unwrap_or((num_attributes as f64).sqrt().floor() as usize);

let mut rng = StdRng::seed_from_u64(parameters.seed);
let mut rng = get_rng_impl(Some(parameters.seed));
let mut trees: Vec<DecisionTreeRegressor<T>> = Vec::new();

let mut maybe_all_samples: Option<Vec<Vec<bool>>> = Option::None;
Expand All @@ -208,9 +209,9 @@ impl<T: RealNumber> RandomForestRegressor<T> {
max_depth: parameters.max_depth,
min_samples_leaf: parameters.min_samples_leaf,
min_samples_split: parameters.min_samples_split,
seed: Some(parameters.seed),
};
let tree =
DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params, &mut rng)?;
let tree = DecisionTreeRegressor::fit_weak_learner(x, y, samples, mtry, params)?;
trees.push(tree);
}

Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,5 @@ pub mod readers;
pub mod svm;
/// Supervised tree-based learning methods
pub mod tree;

pub(crate) mod rand;
6 changes: 4 additions & 2 deletions src/math/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use std::iter::{Product, Sum};
use std::ops::{AddAssign, DivAssign, MulAssign, SubAssign};
use std::str::FromStr;

use crate::rand::get_rng_impl;

/// Defines real number
/// <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS_CHTML"></script>
pub trait RealNumber:
Expand Down Expand Up @@ -79,7 +81,7 @@ impl RealNumber for f64 {
}

fn rand() -> f64 {
let mut rng = rand::thread_rng();
let mut rng = get_rng_impl(None);
rng.gen()
}

Expand Down Expand Up @@ -124,7 +126,7 @@ impl RealNumber for f32 {
}

fn rand() -> f32 {
let mut rng = rand::thread_rng();
let mut rng = get_rng_impl(None);
rng.gen()
}

Expand Down
21 changes: 19 additions & 2 deletions src/model_selection/kfold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
use crate::linalg::Matrix;
use crate::math::num::RealNumber;
use crate::model_selection::BaseKFold;
use crate::rand::get_rng_impl;
use rand::seq::SliceRandom;
use rand::thread_rng;

/// K-Folds cross-validator
pub struct KFold {
/// Number of folds. Must be at least 2.
pub n_splits: usize, // cannot exceed std::usize::MAX
/// Whether to shuffle the data before splitting into batches
pub shuffle: bool,
/// When shuffle is True, seed affects the ordering of the indices.
/// Which controls the randomness of each fold
pub seed: Option<u64>,
}

impl KFold {
Expand All @@ -23,8 +26,10 @@ impl KFold {

// initialise indices
let mut indices: Vec<usize> = (0..n_samples).collect();
let mut rng = get_rng_impl(self.seed);

if self.shuffle {
indices.shuffle(&mut thread_rng());
indices.shuffle(&mut rng);
}
// return a new array of given shape n_split, filled with each element of n_samples divided by n_splits.
let mut fold_sizes = vec![n_samples / self.n_splits; self.n_splits];
Expand Down Expand Up @@ -66,6 +71,7 @@ impl Default for KFold {
KFold {
n_splits: 3,
shuffle: true,
seed: None,
}
}
}
Expand All @@ -81,6 +87,12 @@ impl KFold {
self.shuffle = shuffle;
self
}

/// When shuffle is True, random_state affects the ordering of the indices.
pub fn with_seed(mut self, seed: Option<u64>) -> Self {
self.seed = seed;
self
}
}

/// An iterator over indices that split data into training and test set.
Expand Down Expand Up @@ -150,6 +162,7 @@ mod tests {
let k = KFold {
n_splits: 3,
shuffle: false,
seed: None,
};
let x: DenseMatrix<f64> = DenseMatrix::rand(33, 100);
let test_indices = k.test_indices(&x);
Expand All @@ -165,6 +178,7 @@ mod tests {
let k = KFold {
n_splits: 3,
shuffle: false,
seed: None,
};
let x: DenseMatrix<f64> = DenseMatrix::rand(34, 100);
let test_indices = k.test_indices(&x);
Expand All @@ -180,6 +194,7 @@ mod tests {
let k = KFold {
n_splits: 2,
shuffle: false,
seed: None,
};
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
let test_masks = k.test_masks(&x);
Expand All @@ -206,6 +221,7 @@ mod tests {
let k = KFold {
n_splits: 2,
shuffle: false,
seed: None,
};
let x: DenseMatrix<f64> = DenseMatrix::rand(22, 100);
let train_test_splits: Vec<(Vec<usize>, Vec<usize>)> = k.split(&x).collect();
Expand Down Expand Up @@ -238,6 +254,7 @@ mod tests {
let k = KFold {
n_splits: 3,
shuffle: false,
seed: None,
};
let x: DenseMatrix<f64> = DenseMatrix::rand(10, 4);
let expected: Vec<(Vec<usize>, Vec<usize>)> = vec![
Expand Down
Loading