Skip to content

Commit c297916

Browse files
committed
Fix axes computation in reorder_v2 wrapper code
1 parent b8b15e7 commit c297916

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

src/data/mod.rs

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,16 @@ where
529529
///# Return Values
530530
///
531531
/// Array with data reordered as per the new axes order
532+
///
533+
///# Examples
534+
///
535+
/// ```rust
536+
/// use arrayfire::{Array, Dim4, print, randu, reorder_v2};
537+
/// let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
538+
/// let b = reorder_v2(&a, 1, 0, None);
539+
/// print(&a);
540+
/// print(&b);
541+
/// ```
532542
pub fn reorder_v2<T>(
533543
input: &Array<T>,
534544
new_axis0: u64,
@@ -538,16 +548,25 @@ pub fn reorder_v2<T>(
538548
where
539549
T: HasAfEnum,
540550
{
541-
let mut new_axes = vec![new_axis0, new_axis1];
551+
let mut new_axes = vec![0, 1, 2, 3];
552+
new_axes[0] = new_axis0;
553+
new_axes[1] = new_axis1;
542554
match next_axes {
543-
Some(v) => {
544-
for axis in v {
545-
new_axes.push(axis);
555+
Some(left_over_new_axes) => {
556+
// At the moment of writing this comment, ArrayFire could
557+
// handle only a maximum of 4 dimensions. Hence, excluding
558+
// the two explicit axes arguments to this function, a maximum
559+
// of only two more axes can be provided. Hence the below condition.
560+
assert!(left_over_new_axes.len() <= 2);
561+
562+
for a_idx in 0..left_over_new_axes.len() {
563+
new_axes[2 + a_idx] = left_over_new_axes[a_idx];
546564
}
547565
}
548566
None => {
549-
new_axes.push(2);
550-
new_axes.push(3);
567+
for a_idx in 2..4 {
568+
new_axes[a_idx] = a_idx as u64;
569+
}
551570
}
552571
};
553572

tests/data.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
use ::arrayfire::*;
2+
use float_cmp::approx_eq;
3+
4+
#[test]
5+
fn check_reorder_api() {
6+
let dims = Dim4::new(&[4, 5, 2, 3]);
7+
let A = randu::<f32>(dims);
8+
9+
let transposedA = reorder_v2(&A, 1, 0, None);
10+
let swap_0_2 = reorder_v2(&A, 2, 1, Some(vec![0]));
11+
let swap_1_2 = reorder_v2(&A, 0, 2, Some(vec![1]));
12+
let swap_0_3 = reorder_v2(&A, 3, 1, Some(vec![2, 0]));
13+
}

0 commit comments

Comments
 (0)