Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add features required by rusty-machine. #499

Merged
merged 14 commits into from
Feb 3, 2019
Merged
Prev Previous commit
Next Next commit
Add select_rows and select_columns.
  • Loading branch information
sebcrozet committed Feb 3, 2019
commit bba1f48e810c4e2391ec82b88d7f948961ad5463
6 changes: 3 additions & 3 deletions src/base/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ where
///
/// let m = Matrix3::from_diagonal(&Vector3::new(1.0, 2.0, 3.0));
/// // The two additional arguments represent the matrix dimensions.
/// let dm = DMatrix::from_diagonal(&DVector::from_row_slice(3, &[1.0, 2.0, 3.0]));
/// let dm = DMatrix::from_diagonal(&DVector::from_row_slice(&[1.0, 2.0, 3.0]));
///
/// assert!(m.m11 == 1.0 && m.m12 == 0.0 && m.m13 == 0.0 &&
/// m.m21 == 0.0 && m.m22 == 2.0 && m.m23 == 0.0 &&
Expand Down Expand Up @@ -616,7 +616,7 @@ macro_rules! impl_constructors_from_data(
///
/// let v = Vector3::from_row_slice(&[0, 1, 2]);
/// // The additional argument represents the vector dimension.
/// let dv = DVector::from_row_slice(3, &[0, 1, 2]);
/// let dv = DVector::from_row_slice(&[0, 1, 2]);
/// let m = Matrix2x3::from_row_slice(&[0, 1, 2, 3, 4, 5]);
/// // The two additional arguments represent the matrix dimensions.
/// let dm = DMatrix::from_row_slice(2, 3, &[0, 1, 2, 3, 4, 5]);
Expand All @@ -643,7 +643,7 @@ macro_rules! impl_constructors_from_data(
///
/// let v = Vector3::from_column_slice(&[0, 1, 2]);
/// // The additional argument represents the vector dimension.
/// let dv = DVector::from_column_slice(3, &[0, 1, 2]);
/// let dv = DVector::from_column_slice(&[0, 1, 2]);
/// let m = Matrix2x3::from_column_slice(&[0, 1, 2, 3, 4, 5]);
/// // The two additional arguments represent the matrix dimensions.
/// let dm = DMatrix::from_column_slice(2, 3, &[0, 1, 2, 3, 4, 5]);
Expand Down
45 changes: 43 additions & 2 deletions src/base/edition.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use num::{One, Zero};
use std::cmp;
use std::ptr;
use std::iter::ExactSizeIterator;

use base::allocator::{Allocator, Reallocator};
use base::constraint::{DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint};
Expand Down Expand Up @@ -32,6 +33,46 @@ impl<N: Scalar + Zero, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {

res
}

/// Creates a new matrix by extracting the given set of rows from `self`.
pub fn select_rows(&self, irows: impl ExactSizeIterator<Item = usize> + Clone) -> MatrixMN<N, Dynamic, C>
where DefaultAllocator: Allocator<N, Dynamic, C> {
let ncols = self.data.shape().1;
let mut res = unsafe { MatrixMN::new_uninitialized_generic(Dynamic::new(irows.len()), ncols) };

// First, check that all the indices from irows are valid.
// This will allow us to use unchecked access in the inner loop.
for i in irows.clone() {
assert!(i < self.nrows(), "Row index out of bounds.")
}

for j in 0..ncols.value() {
// FIXME: use unchecked column indexing
let mut res = res.column_mut(j);
let mut src = self.column(j);

for (destination, source) in irows.clone().enumerate() {
unsafe {
*res.vget_unchecked_mut(destination) = *src.vget_unchecked(source)
}
}
}

res
}

/// Creates a new matrix by extracting the given set of columns from `self`.
pub fn select_columns(&self, icols: impl ExactSizeIterator<Item = usize>) -> MatrixMN<N, R, Dynamic>
where DefaultAllocator: Allocator<N, R, Dynamic> {
let nrows = self.data.shape().0;
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, Dynamic::new(icols.len())) };

for (destination, source) in icols.enumerate() {
res.column_mut(destination).copy_from(&self.column(source))
}

res
}
}

impl<N: Scalar, R: Dim, C: Dim, S: StorageMut<N, R, C>> Matrix<N, R, C, S> {
Expand Down Expand Up @@ -764,9 +805,9 @@ where
/// # Example
/// ```
/// # use nalgebra::DVector;
/// let mut vector = DVector::from_vec(3, vec![0, 1, 2]);
/// let mut vector = DVector::from_vec(vec![0, 1, 2]);
/// vector.extend(vec![3, 4, 5]);
/// assert!(vector.eq(&DVector::from_vec(6, vec![0, 1, 2, 3, 4, 5])));
/// assert!(vector.eq(&DVector::from_vec(vec![0, 1, 2, 3, 4, 5])));
/// ```
fn extend<I: IntoIterator<Item=N>>(&mut self, iter: I) {
self.data.extend(iter);
Expand Down