Skip to content

Commit

Permalink
Simplify bitmasks
Browse files Browse the repository at this point in the history
  • Loading branch information
calebzulawski committed Nov 17, 2023
1 parent 6b1e7f6 commit 75943f7
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 189 deletions.
42 changes: 39 additions & 3 deletions crates/core_simd/src/masks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
)]
mod mask_impl;

mod to_bitmask;
pub use to_bitmask::{ToBitMask, ToBitMaskArray};

use crate::simd::{
cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount,
};
Expand Down Expand Up @@ -262,6 +259,45 @@ where
pub fn all(self) -> bool {
self.0.all()
}

/// Create a bitmask from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask(self) -> u64 {
self.0.to_bitmask_integer()
}

/// Create a mask from a bitmask.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
/// If the mask contains more than 64 elements, the remainder are set to `false`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask(bitmask: u64) -> Self {
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
}

/// Create a bitmask vector from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// The remaining bits are unset.
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<T, N> {
self.0.to_bitmask_vector()
}

/// Create a mask from a bitmask vector.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_vector(bitmask: Simd<T, N>) -> Self {
Self(mask_impl::Mask::from_bitmask_vector(bitmask))
}
}

// vector/array conversion
Expand Down
69 changes: 47 additions & 22 deletions crates/core_simd/src/masks/bitmask.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(unused_imports)]
use super::MaskElement;
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
use core::marker::PhantomData;

/// A mask where each lane is represented by a single bit.
Expand Down Expand Up @@ -120,39 +120,64 @@ where
}

#[inline]
#[must_use = "method returns a new array and does not mutate the original value"]
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M] {
assert!(core::mem::size_of::<Self>() == M);
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<T, N> {
let mut bitmask = Self::splat(false).to_int();

assert!(
core::mem::size_of::<Simd<T, N>>()
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
);

// Safety: converting an integer to an array of bytes of the same size is safe
unsafe { core::mem::transmute_copy(&self.0) }
// Safety: the bitmask vector is big enough to hold the bitmask
unsafe {
core::ptr::copy_nonoverlapping(
self.0.as_ref().as_ptr(),
bitmask.as_mut_array().as_mut_ptr() as _,
self.0.as_ref().len(),
);
}

bitmask
}

#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_array<const M: usize>(bitmask: [u8; M]) -> Self {
assert!(core::mem::size_of::<Self>() == M);
pub fn from_bitmask_vector(bitmask: Simd<T, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();

assert!(
core::mem::size_of::<Simd<T, N>>()
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
);

// Safety: converting an array of bytes to an integer of the same size is safe
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
// Safety: the bitmask vector is big enough to hold the bitmask
unsafe {
core::ptr::copy_nonoverlapping(
bitmask.as_array().as_ptr() as _,
bytes.as_mut().as_mut_ptr(),
bytes.as_ref().len(),
);
}

Self(bytes, PhantomData)
}

#[inline]
pub fn to_bitmask_integer<U>(self) -> U
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { core::mem::transmute_copy(&self.0) }
pub fn to_bitmask_integer(self) -> u64 {
let mut bitmask = [0u8; 8];
bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
u64::from_ne_bytes(bitmask)
}

#[inline]
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
pub fn from_bitmask_integer(bitmask: u64) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
let len = bytes.as_mut().len();
bytes
.as_mut()
.copy_from_slice(&bitmask.to_ne_bytes()[..len]);
Self(bytes, PhantomData)
}

#[inline]
Expand Down
97 changes: 53 additions & 44 deletions crates/core_simd/src/masks/full_masks.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! Masks that take up full SIMD vector registers.

use super::{to_bitmask::ToBitMaskArray, MaskElement};
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};

#[repr(transparent)]
pub struct Mask<T, const N: usize>(Simd<T, N>)
Expand Down Expand Up @@ -143,95 +142,105 @@ where
}

#[inline]
#[must_use = "method returns a new array and does not mutate the original value"]
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M]
where
super::Mask<T, N>: ToBitMaskArray,
{
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<T, N> {
let mut bitmask = Self::splat(false).to_int();

// Safety: Bytes is the right size array
unsafe {
// Compute the bitmask
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
intrinsics::simd_bitmask(self.0);

// Transmute to the return type
let mut bitmask: [u8; M] = core::mem::transmute_copy(&bitmask);

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bitmask.as_mut() {
*x = x.reverse_bits();
for x in bytes.as_mut() {
*x = x.reverse_bits()
}
};
}

bitmask
assert!(
core::mem::size_of::<Simd<T, N>>()
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
);
core::ptr::copy_nonoverlapping(
bytes.as_ref().as_ptr(),
bitmask.as_mut_array().as_mut_ptr() as _,
bytes.as_ref().len(),
);
}

bitmask
}

#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_array<const M: usize>(mut bitmask: [u8; M]) -> Self
where
super::Mask<T, N>: ToBitMaskArray,
{
pub fn from_bitmask_vector(bitmask: Simd<T, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();

// Safety: Bytes is the right size array
unsafe {
assert!(
core::mem::size_of::<Simd<T, N>>()
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
);
core::ptr::copy_nonoverlapping(
bitmask.as_array().as_ptr() as _,
bytes.as_mut().as_mut_ptr(),
bytes.as_mut().len(),
);

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bitmask.as_mut() {
for x in bytes.as_mut() {
*x = x.reverse_bits();
}
}

// Transmute to the bitmask
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
core::mem::transmute_copy(&bitmask);

// Compute the regular mask
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
bitmask,
bytes,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
}
}

#[inline]
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: U is required to be the appropriate bitmask type
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
pub(crate) fn to_bitmask_integer(self) -> u64 {
let resized = self.to_int().extend::<64>(T::FALSE);

// SAFETY: `resized` is an integer vector with length 64
let bitmask: u64 = unsafe { intrinsics::simd_bitmask(resized) };

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
bitmask.reverse_bits(N)
bitmask.reverse_bits()
} else {
bitmask
}
}

#[inline]
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
// LLVM assumes bit order should match endianness
let bitmask = if cfg!(target_endian = "big") {
bitmask.reverse_bits(N)
bitmask.reverse_bits()
} else {
bitmask
};

// Safety: U is required to be the appropriate bitmask type
unsafe {
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
// SAFETY: `mask` is the correct bitmask type for a u64 bitmask
let mask: Simd<T, 64> = unsafe {
intrinsics::simd_select_bitmask(
bitmask,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
}
Simd::<T, 64>::splat(T::TRUE),
Simd::<T, 64>::splat(T::FALSE),
)
};

// SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
unsafe { Self::from_int_unchecked(mask.extend::<N>(T::FALSE)) }
}

#[inline]
Expand Down
Loading

0 comments on commit 75943f7

Please sign in to comment.