Skip to content

Commit

Permalink
Dataset improvements (#74)
Browse files Browse the repository at this point in the history
* Renamed Dataset to DatasetBase in all subcrates. Added new types Dataset and DatasetView

* impl all method except k-fold, update dataset crate

* Dropped vec implementation. Adjusted svm accordingly

* k-folding with axis_chunks_iter

* renamed Dataset to DatasetBase in naive bayes

* swapping for k-fold, dropped vec impl
  • Loading branch information
Sauro98 authored Jan 5, 2021
1 parent ab55fb9 commit 21dd579
Show file tree
Hide file tree
Showing 28 changed files with 690 additions and 316 deletions.
6 changes: 3 additions & 3 deletions datasets/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn array_from_buf(buf: &[u8]) -> Array2<f64> {
#[cfg(feature = "iris")]
/// Read in the iris-flower dataset from dataset path
/// The `.csv` data is two dimensional: Axis(0) denotes y-axis (rows), Axis(1) denotes x-axis (columns)
pub fn iris() -> Dataset<Array2<f64>, Vec<usize>> {
pub fn iris() -> Dataset<f64, usize> {
let data = include_bytes!("../data/iris.csv.gz");
let array = array_from_buf(&data[..]);

Expand All @@ -34,7 +34,7 @@ pub fn iris() -> Dataset<Array2<f64>, Vec<usize>> {
}

#[cfg(feature = "diabetes")]
pub fn diabetes() -> Dataset<Array2<f64>, Array1<f64>> {
pub fn diabetes() -> Dataset<f64, f64> {
let data = include_bytes!("../data/diabetes_data.csv.gz");
let data = array_from_buf(&data[..]);

Expand All @@ -45,7 +45,7 @@ pub fn diabetes() -> Dataset<Array2<f64>, Array1<f64>> {
}

#[cfg(feature = "winequality")]
pub fn winequality() -> Dataset<Array2<f64>, Vec<usize>> {
pub fn winequality() -> Dataset<f64, usize> {
let data = include_bytes!("../data/winequality-red.csv.gz");
let array = array_from_buf(&data[..]);

Expand Down
20 changes: 10 additions & 10 deletions linfa-bayes/src/gaussian_nb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ndarray_stats::QuantileExt;
use std::collections::HashMap;

use crate::error::Result;
use linfa::dataset::{Dataset, Labels};
use linfa::dataset::{DatasetBase, Labels};
use linfa::traits::{Fit, IncrementalFit, Predict};
use linfa::Float;

Expand Down Expand Up @@ -55,7 +55,7 @@ where
///
/// ```no_run
/// # use ndarray::array;
/// # use linfa::Dataset;
/// # use linfa::DatasetBase;
/// # use linfa_bayes::GaussianNbParams;
/// # use linfa::traits::{Fit, Predict};
/// # use std::error::Error;
Expand All @@ -70,15 +70,15 @@ where
/// ];
/// let y = vec![1, 1, 1, 2, 2, 2];
///
/// let data = Dataset::new(x.view(), &y);
/// let data = DatasetBase::new(x.view(), &y);
/// let model = GaussianNbParams::params().fit(&data)?;
/// let pred = model.predict(x.view());
///
/// assert_eq!(pred.to_vec(), y);
/// # Ok(())
/// # }
/// ```
fn fit(&self, dataset: &'a Dataset<ArrayView2<A>, L>) -> Self::Object {
fn fit(&self, dataset: &'a DatasetBase<ArrayView2<A>, L>) -> Self::Object {
// We extract the unique classes in sorted order
let mut unique_classes = dataset.targets.labels();
unique_classes.sort_unstable();
Expand Down Expand Up @@ -106,7 +106,7 @@ where
///
/// ```no_run
/// # use ndarray::{array, Axis};
/// # use linfa::Dataset;
/// # use linfa::DatasetBase;
/// # use linfa_bayes::GaussianNbParams;
/// # use linfa::traits::{Predict, IncrementalFit};
/// # use std::error::Error;
Expand All @@ -128,7 +128,7 @@ where
/// .axis_chunks_iter(Axis(0), 2)
/// .zip(y.axis_chunks_iter(Axis(0), 2))
/// {
/// model = clf.fit_with(model, &Dataset::new(x, y))?;
/// model = clf.fit_with(model, &DatasetBase::new(x, y))?;
/// }
///
/// let pred = model.as_ref().unwrap().predict(x.view());
Expand All @@ -140,7 +140,7 @@ where
fn fit_with(
&self,
model_in: Self::ObjectIn,
dataset: &Dataset<ArrayView2<A>, L>,
dataset: &DatasetBase<ArrayView2<A>, L>,
) -> Self::ObjectOut {
let x = dataset.records();
let y = dataset.targets();
Expand Down Expand Up @@ -358,7 +358,7 @@ impl<A: Float> GaussianNb<A> {
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use linfa::Dataset;
use linfa::DatasetBase;
use ndarray::array;

#[test]
Expand All @@ -374,7 +374,7 @@ mod tests {
let y = array![1, 1, 1, 2, 2, 2];

let clf = GaussianNbParams::params();
let data = Dataset::new(x.view(), y.view());
let data = DatasetBase::new(x.view(), y.view());
let fitted_clf = clf.fit(&data).unwrap();
let pred = fitted_clf.predict(x.view());
assert_eq!(pred, y);
Expand Down Expand Up @@ -424,7 +424,7 @@ mod tests {
let model = x
.axis_chunks_iter(Axis(0), 2)
.zip(y.axis_chunks_iter(Axis(0), 2))
.map(|(a, b)| Dataset::new(a, b))
.map(|(a, b)| DatasetBase::new(a, b))
.fold(None, |current, d| clf.fit_with(current, &d).unwrap())
.unwrap();

Expand Down
5 changes: 3 additions & 2 deletions linfa-clustering/benches/gaussian_mixture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use criterion::{
PlotConfiguration,
};
use linfa::traits::Fit;
use linfa::Dataset;
use linfa::DatasetBase;
use linfa_clustering::{generate_blobs, GaussianMixtureModel};
use ndarray::Array2;
use ndarray_rand::rand::SeedableRng;
Expand All @@ -22,7 +22,8 @@ fn gaussian_mixture_bench(c: &mut Criterion) {
let n_features = 3;
let centroids =
Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), &mut rng);
let dataset = Dataset::from(generate_blobs(cluster_size, &centroids, &mut rng));
let dataset: DatasetBase<_, _> =
(generate_blobs(cluster_size, &centroids, &mut rng), ()).into();
bencher.iter(|| {
black_box(
GaussianMixtureModel::params(n_clusters)
Expand Down
4 changes: 2 additions & 2 deletions linfa-clustering/benches/k_means.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use criterion::{
PlotConfiguration,
};
use linfa::traits::Fit;
use linfa::Dataset;
use linfa::DatasetBase;
use linfa_clustering::{generate_blobs, KMeans};
use ndarray::Array2;
use ndarray_rand::rand::SeedableRng;
Expand All @@ -22,7 +22,7 @@ fn k_means_bench(c: &mut Criterion) {
let n_features = 3;
let centroids =
Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), &mut rng);
let dataset = Dataset::from(generate_blobs(cluster_size, &centroids, &mut rng));
let dataset = DatasetBase::from(generate_blobs(cluster_size, &centroids, &mut rng));
bencher.iter(|| {
black_box(
KMeans::params_with_rng(n_clusters, rng.clone())
Expand Down
6 changes: 3 additions & 3 deletions linfa-clustering/examples/kmeans.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use linfa::traits::{Fit, Predict};
use linfa::Dataset;
use linfa::DatasetBase;
use linfa_clustering::{generate_blobs, KMeans};
use ndarray::{array, Axis};
use ndarray_npy::write_npy;
Expand All @@ -15,7 +15,7 @@ fn main() {
// For each our expected centroids, generate `n` data points around it (a "blob")
let expected_centroids = array![[10., 10.], [1., 12.], [20., 30.], [-20., 30.],];
let n = 10000;
let dataset = Dataset::from(generate_blobs(n, &expected_centroids, &mut rng));
let dataset = DatasetBase::from(generate_blobs(n, &expected_centroids, &mut rng));

// Configure our training algorithm
let n_clusters = expected_centroids.len_of(Axis(0));
Expand All @@ -27,7 +27,7 @@ fn main() {

// Assign each point to a cluster using the set of centroids found using `fit`
let dataset = model.predict(dataset);
let Dataset {
let DatasetBase {
records, targets, ..
} = dataset;

Expand Down
22 changes: 13 additions & 9 deletions linfa-clustering/src/appx_dbscan/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::appx_dbscan::clustering::AppxDbscanLabeler;
use crate::appx_dbscan::hyperparameters::{AppxDbscanHyperParams, AppxDbscanHyperParamsBuilder};
use linfa::dataset::Targets;
use linfa::traits::Transformer;
use linfa::{Dataset, Float};
use linfa::{DatasetBase, Float};
use ndarray::{Array1, ArrayBase, Data, Ix2};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
Expand Down Expand Up @@ -107,13 +107,15 @@ impl<F: Float, D: Data<Elem = F>> Transformer<&ArrayBase<D, Ix2>, Array1<Option<
}

impl<F: Float, D: Data<Elem = F>, T: Targets>
Transformer<Dataset<ArrayBase<D, Ix2>, T>, Dataset<ArrayBase<D, Ix2>, Array1<Option<usize>>>>
for AppxDbscanHyperParams<F>
Transformer<
DatasetBase<ArrayBase<D, Ix2>, T>,
DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>>,
> for AppxDbscanHyperParams<F>
{
fn transform(
&self,
dataset: Dataset<ArrayBase<D, Ix2>, T>,
) -> Dataset<ArrayBase<D, Ix2>, Array1<Option<usize>>> {
dataset: DatasetBase<ArrayBase<D, Ix2>, T>,
) -> DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>> {
let predicted = self.transform(dataset.records());
dataset.with_targets(predicted)
}
Expand All @@ -128,13 +130,15 @@ impl<F: Float, D: Data<Elem = F>> Transformer<&ArrayBase<D, Ix2>, Array1<Option<
}

impl<F: Float, D: Data<Elem = F>, T: Targets>
Transformer<Dataset<ArrayBase<D, Ix2>, T>, Dataset<ArrayBase<D, Ix2>, Array1<Option<usize>>>>
for AppxDbscanHyperParamsBuilder<F>
Transformer<
DatasetBase<ArrayBase<D, Ix2>, T>,
DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>>,
> for AppxDbscanHyperParamsBuilder<F>
{
fn transform(
&self,
dataset: Dataset<ArrayBase<D, Ix2>, T>,
) -> Dataset<ArrayBase<D, Ix2>, Array1<Option<usize>>> {
dataset: DatasetBase<ArrayBase<D, Ix2>, T>,
) -> DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>> {
self.build().transform(dataset)
}
}
22 changes: 13 additions & 9 deletions linfa-clustering/src/dbscan/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use ndarray_stats::DeviationExt;

use linfa::dataset::Targets;
use linfa::traits::Transformer;
use linfa::{Dataset, Float};
use linfa::{DatasetBase, Float};

#[derive(Clone, Debug, PartialEq)]
/// DBSCAN (Density-based Spatial Clustering of Applications with Noise)
Expand Down Expand Up @@ -115,13 +115,15 @@ impl<F: Float, D: Data<Elem = F>> Transformer<&ArrayBase<D, Ix2>, Array1<Option<
}

impl<F: Float, D: Data<Elem = F>, T: Targets>
Transformer<Dataset<ArrayBase<D, Ix2>, T>, Dataset<ArrayBase<D, Ix2>, Array1<Option<usize>>>>
for DbscanHyperParams<F>
Transformer<
DatasetBase<ArrayBase<D, Ix2>, T>,
DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>>,
> for DbscanHyperParams<F>
{
fn transform(
&self,
dataset: Dataset<ArrayBase<D, Ix2>, T>,
) -> Dataset<ArrayBase<D, Ix2>, Array1<Option<usize>>> {
dataset: DatasetBase<ArrayBase<D, Ix2>, T>,
) -> DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>> {
let predicted = self.transform(dataset.records());
dataset.with_targets(predicted)
}
Expand All @@ -136,13 +138,15 @@ impl<F: Float, D: Data<Elem = F>> Transformer<&ArrayBase<D, Ix2>, Array1<Option<
}

impl<F: Float, D: Data<Elem = F>, T: Targets>
Transformer<Dataset<ArrayBase<D, Ix2>, T>, Dataset<ArrayBase<D, Ix2>, Array1<Option<usize>>>>
for DbscanHyperParamsBuilder<F>
Transformer<
DatasetBase<ArrayBase<D, Ix2>, T>,
DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>>,
> for DbscanHyperParamsBuilder<F>
{
fn transform(
&self,
dataset: Dataset<ArrayBase<D, Ix2>, T>,
) -> Dataset<ArrayBase<D, Ix2>, Array1<Option<usize>>> {
dataset: DatasetBase<ArrayBase<D, Ix2>, T>,
) -> DatasetBase<ArrayBase<D, Ix2>, Array1<Option<usize>>> {
self.build().transform(dataset)
}
}
Expand Down
Loading

0 comments on commit 21dd579

Please sign in to comment.