Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 85 additions & 10 deletions library/kani/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,92 @@ mod intrinsics {
T: MaskElement,
{
let mut mask_array = [0; mask_len(LANES)];
for lane in (0..input.len()).rev() {
let byte = lane / 8;
let mask = &mut mask_array[byte];
let shift_mask = *mask << 1;
*mask = if input[lane] == T::TRUE {
shift_mask | 0x1
} else {
assert_eq!(input[lane], T::FALSE, "Masks values should either be 0 or -1");
shift_mask
};

// The implementation below is the equivalent of the following:
// ```rust
// for lane in (0..input.len()).rev() {
// let byte = lane / 8;
// let mask = &mut mask_array[byte];
// let shift_mask = *mask << 1;
// *mask = if input[lane] == T::TRUE {
// shift_mask | 0x1
// } else {
// assert_eq!(input[lane], T::FALSE, "Masks values should either be 0 or -1");
// shift_mask
// };
// }
// ```
// but is intentionally written in a way that minimizes the number of
// loop iterations. In particular, it's implemented as a nested loop
// where the outer loop iterates over bytes and the inner "loop" (which
// is manually unwound) iterates over bits in a byte. This is to avoid
// needing a high unwind value for harnesses that invoke this code (e.g.
// through the `HashSet` data structure).
for (byte_idx, byte) in mask_array.iter_mut().enumerate() {
// Calculate the starting lane for this byte
let start_lane = byte_idx << 3;
// Calculate how many bits to process (handle the last byte which might be partial)
let bits_to_process = (LANES - start_lane).min(8);

*byte = if bits_to_process > 0 && input[start_lane] == T::TRUE { 1 << 0 } else { 0 }
| if bits_to_process > 1 && input[start_lane + 1] == T::TRUE { 1 << 1 } else { 0 }
| if bits_to_process > 2 && input[start_lane + 2] == T::TRUE { 1 << 2 } else { 0 }
| if bits_to_process > 3 && input[start_lane + 3] == T::TRUE { 1 << 3 } else { 0 }
| if bits_to_process > 4 && input[start_lane + 4] == T::TRUE { 1 << 4 } else { 0 }
| if bits_to_process > 5 && input[start_lane + 5] == T::TRUE { 1 << 5 } else { 0 }
| if bits_to_process > 6 && input[start_lane + 6] == T::TRUE { 1 << 6 } else { 0 }
| if bits_to_process > 7 && input[start_lane + 7] == T::TRUE { 1 << 7 } else { 0 };

assert!(
bits_to_process < 1
|| input[start_lane] == T::TRUE
|| input[start_lane] == T::FALSE,
"Masks values should either be 0 or -1"
);
assert!(
bits_to_process < 2
|| input[start_lane + 1] == T::TRUE
|| input[start_lane + 1] == T::FALSE,
"Masks values should either be 0 or -1"
);
assert!(
bits_to_process < 3
|| input[start_lane + 2] == T::TRUE
|| input[start_lane + 2] == T::FALSE,
"Masks values should either be 0 or -1"
);
assert!(
bits_to_process < 4
|| input[start_lane + 3] == T::TRUE
|| input[start_lane + 3] == T::FALSE,
"Masks values should either be 0 or -1"
);
assert!(
bits_to_process < 5
|| input[start_lane + 4] == T::TRUE
|| input[start_lane + 4] == T::FALSE,
"Masks values should either be 0 or -1"
);
assert!(
bits_to_process < 6
|| input[start_lane + 5] == T::TRUE
|| input[start_lane + 5] == T::FALSE,
"Masks values should either be 0 or -1"
);
assert!(
bits_to_process < 7
|| input[start_lane + 6] == T::TRUE
|| input[start_lane + 6] == T::FALSE,
"Masks values should either be 0 or -1"
);
assert!(
bits_to_process < 8
|| input[start_lane + 7] == T::TRUE
|| input[start_lane + 7] == T::FALSE,
"Masks values should either be 0 or -1"
);
}

mask_array
}

Expand Down
80 changes: 80 additions & 0 deletions tests/kani/SIMD/simd_bitmask_equiv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright Kani Contributors
// SPDX-License-Identifier: Apache-2.0 OR MIT
#![feature(repr_simd, core_intrinsics)]
#![feature(generic_const_exprs)]
#![feature(portable_simd)]

// This test checks the equivalence of Kani's old and new implementations of the
// `simd_bitmask` intrinsic

use std::fmt::Debug;

pub trait MaskElement: PartialEq + Debug {
const TRUE: Self;
const FALSE: Self;
}

impl MaskElement for i32 {
const TRUE: Self = -1;
const FALSE: Self = 0;
}

/// Calculate the minimum number of lanes to represent a mask
/// Logic similar to `bitmask_len` from `portable_simd`.
/// <https://github.com/rust-lang/portable-simd/blob/490b5cf/crates/core_simd/src/masks/to_bitmask.rs#L75-L79>
const fn mask_len(len: usize) -> usize {
len.div_ceil(8)
}

fn simd_bitmask_impl_old<T, const LANES: usize>(input: &[T; LANES]) -> [u8; mask_len(LANES)]
where
T: MaskElement,
{
let mut mask_array = [0; mask_len(LANES)];
for lane in (0..input.len()).rev() {
let byte = lane / 8;
let mask = &mut mask_array[byte];
let shift_mask = *mask << 1;
*mask = if input[lane] == T::TRUE {
shift_mask | 0x1
} else {
assert_eq!(input[lane], T::FALSE, "Masks values should either be 0 or -1");
shift_mask
};
}
mask_array
}

unsafe fn simd_bitmask<T, U, E, const LANES: usize>(input: T) -> U
where
[u8; mask_len(LANES)]: Sized,
E: MaskElement,
{
let data = &*(&input as *const T as *const [E; LANES]);
let mask = simd_bitmask_impl_old(data);
(&mask as *const [u8; mask_len(LANES)] as *const U).read()
}

#[repr(simd)]
#[derive(Clone, Debug)]
struct CustomMask<const LANES: usize>([i32; LANES]);

impl<const LANES: usize> kani::Arbitrary for CustomMask<LANES>
where
[bool; LANES]: Sized + kani::Arbitrary,
{
fn any() -> Self {
CustomMask(kani::any::<[bool; LANES]>().map(|v| if v { i32::FALSE } else { i32::TRUE }))
}
}

#[kani::proof]
#[kani::solver(kissat)]
fn check_equiv() {
let mask = kani::any::<CustomMask<8>>();
unsafe {
let result1 = simd_bitmask::<_, u8, i32, 8>(mask.clone());
let result2 = std::intrinsics::simd::simd_bitmask::<_, u8>(mask);
assert_eq!(result1, result2);
}
}
Loading