Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 70 additions & 31 deletions examples/sort-axis.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use ndarray::prelude::*;
use ndarray::{Data, RemoveAxis, Zip};

use rawpointer::PointerExt;

use std::cmp::Ordering;
use std::ptr::copy_nonoverlapping;

Expand Down Expand Up @@ -97,8 +99,8 @@ where
where
D: RemoveAxis,
{
let axis = axis;
let axis_len = self.len_of(axis);
let axis_stride = self.stride_of(axis);
assert_eq!(axis_len, perm.indices.len());
debug_assert!(perm.correct());

Expand All @@ -112,26 +114,48 @@ where
// logically move ownership of all elements from self into result
// the result realizes this ownership at .assume_init() further down
let mut moved_elements = 0;

// the permutation vector is used like this:
//
// index: 0 1 2 3 (index in result)
// permut: 2 3 0 1 (index in the source)
//
// move source 2 -> result 0,
// move source 3 -> result 1,
// move source 0 -> result 2,
// move source 1 -> result 3,
// et.c.

let source_0 = self.raw_view().index_axis_move(axis, 0);

Zip::from(&perm.indices)
.and(result.axis_iter_mut(axis))
.for_each(|&perm_i, result_pane| {
// possible improvement: use unchecked indexing for `index_axis`
// Use a shortcut to avoid bounds checking in `index_axis` for the source.
//
// It works because for any given element pointer in the array we have the
// relationship:
//
// .index_axis(axis, 0) + .stride_of(axis) * j == .index_axis(axis, j)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coding style wise this is a bit ugh, since it depends on manual pointer offsetting instead of using an abstraction. Probably this would be ok as an implementation but not great as an example and that's where it is now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've never really thought this was a great example to showcase the library, although I'm not sure where else to put it. I do think it would be worth considering adding some of this functionality to ndarray itself.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, of course, it's just that it needs API design to be added to ndarray

//
// where + is pointer arithmetic on the element pointers.
//
// Here source_0 and the offset is equivalent to self.index_axis(axis, perm_i)
Zip::from(result_pane)
.and(self.index_axis(axis, perm_i))
.for_each(|to, from| {
.and(source_0.clone())
.for_each(|to, from_0| {
let from = from_0.stride_offset(axis_stride, perm_i);
copy_nonoverlapping(from, to.as_mut_ptr(), 1);
moved_elements += 1;
});
});
debug_assert_eq!(result.len(), moved_elements);
// panic-critical begin: we must not panic
// forget moved array elements but not its vec
// old_storage drops empty
// forget the old elements but not the allocation
let mut old_storage = self.into_raw_vec();
old_storage.set_len(0);

// transfer ownership of the elements into the result
result.assume_init()
// panic-critical end
}
}
}
Expand Down Expand Up @@ -179,31 +203,46 @@ mod tests {
[75600.94, 17.],
[75601.06, 18.],
];
let answer = array![
[75600.09, 10.],
[75600.21, 11.],
[75600.45, 13.],
[75600.58, 14.],
[75600.82, 16.],
[75600.94, 17.],
[75601.06, 18.],
[75601.33, 12.],
[107998.96, 1.],
[107999.08, 2.],
[107999.20, 3.],
[107999.45, 5.],
[107999.57, 6.],
[107999.81, 8.],
[107999.94, 9.],
[108000.33, 4.],
[108010.69, 7.],
[109000.70, 15.],
];

// f layout copy of a
let mut af = Array::zeros(a.dim().f());
af.assign(&a);

// transposed copy of a
let at = a.t().to_owned();

// c layout permute
let perm = a.sort_axis_by(Axis(0), |i, j| a[[i, 0]] < a[[j, 0]]);

let b = a.permute_axis(Axis(0), &perm);
assert_eq!(
b,
array![
[75600.09, 10.],
[75600.21, 11.],
[75600.45, 13.],
[75600.58, 14.],
[75600.82, 16.],
[75600.94, 17.],
[75601.06, 18.],
[75601.33, 12.],
[107998.96, 1.],
[107999.08, 2.],
[107999.20, 3.],
[107999.45, 5.],
[107999.57, 6.],
[107999.81, 8.],
[107999.94, 9.],
[108000.33, 4.],
[108010.69, 7.],
[109000.70, 15.],
]
);
assert_eq!(b, answer);

// f layout permute
let bf = af.permute_axis(Axis(0), &perm);
assert_eq!(bf, answer);

// transposed permute
let bt = at.permute_axis(Axis(1), &perm);
assert_eq!(bt, answer.t());
}
}