Skip to content

Commit

Permalink
renamed Dataset to DatasetBase in naive bayes
Browse files Browse the repository at this point in the history
  • Loading branch information
Sauro98 committed Jan 2, 2021
1 parent 6e91574 commit 9f0839f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
20 changes: 10 additions & 10 deletions linfa-bayes/src/gaussian_nb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ndarray_stats::QuantileExt;
use std::collections::HashMap;

use crate::error::Result;
use linfa::dataset::{Dataset, Labels};
use linfa::dataset::{DatasetBase, Labels};
use linfa::traits::{Fit, IncrementalFit, Predict};
use linfa::Float;

Expand Down Expand Up @@ -55,7 +55,7 @@ where
///
/// ```no_run
/// # use ndarray::array;
/// # use linfa::Dataset;
/// # use linfa::DatasetBase;
/// # use linfa_bayes::GaussianNbParams;
/// # use linfa::traits::{Fit, Predict};
/// # use std::error::Error;
Expand All @@ -70,15 +70,15 @@ where
/// ];
/// let y = vec![1, 1, 1, 2, 2, 2];
///
/// let data = Dataset::new(x.view(), &y);
/// let data = DatasetBase::new(x.view(), &y);
/// let model = GaussianNbParams::params().fit(&data)?;
/// let pred = model.predict(x.view());
///
/// assert_eq!(pred.to_vec(), y);
/// # Ok(())
/// # }
/// ```
fn fit(&self, dataset: &'a Dataset<ArrayView2<A>, L>) -> Self::Object {
fn fit(&self, dataset: &'a DatasetBase<ArrayView2<A>, L>) -> Self::Object {
// We extract the unique classes in sorted order
let mut unique_classes = dataset.targets.labels();
unique_classes.sort_unstable();
Expand Down Expand Up @@ -106,7 +106,7 @@ where
///
/// ```no_run
/// # use ndarray::{array, Axis};
/// # use linfa::Dataset;
/// # use linfa::DatasetBase;
/// # use linfa_bayes::GaussianNbParams;
/// # use linfa::traits::{Predict, IncrementalFit};
/// # use std::error::Error;
Expand All @@ -128,7 +128,7 @@ where
/// .axis_chunks_iter(Axis(0), 2)
/// .zip(y.axis_chunks_iter(Axis(0), 2))
/// {
/// model = clf.fit_with(model, &Dataset::new(x, y))?;
/// model = clf.fit_with(model, &DatasetBase::new(x, y))?;
/// }
///
/// let pred = model.as_ref().unwrap().predict(x.view());
Expand All @@ -140,7 +140,7 @@ where
fn fit_with(
&self,
model_in: Self::ObjectIn,
dataset: &Dataset<ArrayView2<A>, L>,
dataset: &DatasetBase<ArrayView2<A>, L>,
) -> Self::ObjectOut {
let x = dataset.records();
let y = dataset.targets();
Expand Down Expand Up @@ -358,7 +358,7 @@ impl<A: Float> GaussianNb<A> {
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use linfa::Dataset;
use linfa::DatasetBase;
use ndarray::array;

#[test]
Expand All @@ -374,7 +374,7 @@ mod tests {
let y = array![1, 1, 1, 2, 2, 2];

let clf = GaussianNbParams::params();
let data = Dataset::new(x.view(), y.view());
let data = DatasetBase::new(x.view(), y.view());
let fitted_clf = clf.fit(&data).unwrap();
let pred = fitted_clf.predict(x.view());
assert_eq!(pred, y);
Expand Down Expand Up @@ -424,7 +424,7 @@ mod tests {
let model = x
.axis_chunks_iter(Axis(0), 2)
.zip(y.axis_chunks_iter(Axis(0), 2))
.map(|(a, b)| Dataset::new(a, b))
.map(|(a, b)| DatasetBase::new(a, b))
.fold(None, |current, d| clf.fit_with(current, &d).unwrap())
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion src/dataset/impl_dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ impl<'a, F: Float, E: Copy> DatasetView<'a, F, E> {
/// ```
///
pub fn fold(&self, k: usize) -> Vec<(Dataset<F, E>, Dataset<F, E>)> {
let fold_size = self.targets().dim() / k;
let fold_size = self.targets().len() / k;
let mut res = Vec::new();

// Generates all k folds of records and targets
Expand Down

0 comments on commit 9f0839f

Please sign in to comment.