Skip to content

Commit 2a2287c

Browse files
In-place axis permutation with cycle detection (#1505)
* draft for inplace reverse, permute * add test cases * formatter * cycle detection logic with bitmask * formatter * satisfying CI(cargo doc, etc.) * add comments from doc, to describe how the logic works
1 parent 0a8498a commit 2a2287c

File tree

2 files changed

+154
-0
lines changed

2 files changed

+154
-0
lines changed

src/impl_methods.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2542,6 +2542,72 @@ where
25422542
unsafe { self.with_strides_dim(new_strides, new_dim) }
25432543
}
25442544

2545+
/// Permute the axes in-place.
2546+
///
2547+
/// This does not move any data, it just adjusts the array's dimensions
2548+
/// and strides.
2549+
///
2550+
/// *i* in the *j*-th place in the axes sequence means `self`'s *i*-th axis
2551+
/// becomes `self`'s *j*-th axis
2552+
///
2553+
/// **Panics** if any of the axes are out of bounds, if an axis is missing,
2554+
/// or if an axis is repeated more than once.
2555+
///
2556+
/// # Example
2557+
/// ```rust
2558+
/// use ndarray::{arr2, Array3};
2559+
///
2560+
/// let mut a = arr2(&[[0, 1], [2, 3]]);
2561+
/// a.permute_axes([1, 0]);
2562+
/// assert_eq!(a, arr2(&[[0, 2], [1, 3]]));
2563+
///
2564+
/// let mut b = Array3::<u8>::zeros((1, 2, 3));
2565+
/// b.permute_axes([1, 0, 2]);
2566+
/// assert_eq!(b.shape(), &[2, 1, 3]);
2567+
/// ```
2568+
#[track_caller]
2569+
pub fn permute_axes<T>(&mut self, axes: T)
2570+
where T: IntoDimension<Dim = D>
2571+
{
2572+
let axes = axes.into_dimension();
2573+
// Ensure that each axis is used exactly once.
2574+
let mut usage_counts = D::zeros(self.ndim());
2575+
for axis in axes.slice() {
2576+
usage_counts[*axis] += 1;
2577+
}
2578+
for count in usage_counts.slice() {
2579+
assert_eq!(*count, 1, "each axis must be listed exactly once");
2580+
}
2581+
2582+
let dim = self.layout.dim.slice_mut();
2583+
let strides = self.layout.strides.slice_mut();
2584+
let axes = axes.slice();
2585+
2586+
// The cycle detection is done using a bitmask to track visited positions.
2587+
// For example, axes from [0,1,2] to [2, 0, 1]
2588+
// For axis values [1, 0, 2]:
2589+
// 1 << 1 // 0b0001 << 1 = 0b0010 (decimal 2)
2590+
// 1 << 0 // 0b0001 << 0 = 0b0001 (decimal 1)
2591+
// 1 << 2 // 0b0001 << 2 = 0b0100 (decimal 4)
2592+
//
2593+
// Each axis gets its own unique bit position in the bitmask:
2594+
// - Axis 0: bit 0 (rightmost)
2595+
// - Axis 1: bit 1
2596+
// - Axis 2: bit 2
2597+
//
2598+
let mut visited = 0usize;
2599+
for (new_axis, &axis) in axes.iter().enumerate() {
2600+
if (visited & (1 << axis)) != 0 {
2601+
continue;
2602+
}
2603+
2604+
dim.swap(axis, new_axis);
2605+
strides.swap(axis, new_axis);
2606+
2607+
visited |= (1 << axis) | (1 << new_axis);
2608+
}
2609+
}
2610+
25452611
/// Transpose the array by reversing axes.
25462612
///
25472613
/// Transposition reverses the order of the axes (dimensions and strides)
@@ -2552,6 +2618,16 @@ where
25522618
self.layout.strides.slice_mut().reverse();
25532619
self
25542620
}
2621+
2622+
/// Reverse the axes of the array in-place.
2623+
///
2624+
/// This does not move any data, it just adjusts the array's dimensions
2625+
/// and strides.
2626+
pub fn reverse_axes(&mut self)
2627+
{
2628+
self.layout.dim.slice_mut().reverse();
2629+
self.layout.strides.slice_mut().reverse();
2630+
}
25552631
}
25562632

25572633
impl<A, D: Dimension> ArrayRef<A, D>

tests/array.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,3 +2828,81 @@ fn test_slice_assign()
28282828
*a.slice_mut(s![1..3]) += 1;
28292829
assert_eq!(a, array![0, 2, 3, 3, 4]);
28302830
}
2831+
2832+
#[test]
2833+
fn reverse_axes()
2834+
{
2835+
let mut a = arr2(&[[1, 2], [3, 4]]);
2836+
a.reverse_axes();
2837+
assert_eq!(a, arr2(&[[1, 3], [2, 4]]));
2838+
2839+
let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]);
2840+
a.reverse_axes();
2841+
assert_eq!(a, arr2(&[[1, 4], [2, 5], [3, 6]]));
2842+
2843+
let mut a = Array::from_iter(0..24)
2844+
.into_shape_with_order((2, 3, 4))
2845+
.unwrap();
2846+
let original = a.clone();
2847+
a.reverse_axes();
2848+
for ((i0, i1, i2), elem) in original.indexed_iter() {
2849+
assert_eq!(*elem, a[(i2, i1, i0)]);
2850+
}
2851+
}
2852+
2853+
#[test]
2854+
fn permute_axes()
2855+
{
2856+
let mut a = arr2(&[[1, 2], [3, 4]]);
2857+
a.permute_axes([1, 0]);
2858+
assert_eq!(a, arr2(&[[1, 3], [2, 4]]));
2859+
2860+
let mut a = Array::from_iter(0..24)
2861+
.into_shape_with_order((2, 3, 4))
2862+
.unwrap();
2863+
let original = a.clone();
2864+
a.permute_axes([2, 1, 0]);
2865+
for ((i0, i1, i2), elem) in original.indexed_iter() {
2866+
assert_eq!(*elem, a[(i2, i1, i0)]);
2867+
}
2868+
2869+
let mut a = Array::from_iter(0..120)
2870+
.into_shape_with_order((2, 3, 4, 5))
2871+
.unwrap();
2872+
let original = a.clone();
2873+
a.permute_axes([1, 0, 3, 2]);
2874+
for ((i0, i1, i2, i3), elem) in original.indexed_iter() {
2875+
assert_eq!(*elem, a[(i1, i0, i3, i2)]);
2876+
}
2877+
}
2878+
2879+
#[should_panic]
2880+
#[test]
2881+
fn permute_axes_repeated_axis()
2882+
{
2883+
let mut a = Array::from_iter(0..24)
2884+
.into_shape_with_order((2, 3, 4))
2885+
.unwrap();
2886+
a.permute_axes([1, 0, 1]);
2887+
}
2888+
2889+
#[should_panic]
2890+
#[test]
2891+
fn permute_axes_missing_axis()
2892+
{
2893+
let mut a = Array::from_iter(0..24)
2894+
.into_shape_with_order((2, 3, 4))
2895+
.unwrap()
2896+
.into_dyn();
2897+
a.permute_axes(&[2, 0][..]);
2898+
}
2899+
2900+
#[should_panic]
2901+
#[test]
2902+
fn permute_axes_oob()
2903+
{
2904+
let mut a = Array::from_iter(0..24)
2905+
.into_shape_with_order((2, 3, 4))
2906+
.unwrap();
2907+
a.permute_axes([1, 0, 3]);
2908+
}

0 commit comments

Comments
 (0)