Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify bitmasks #375

Merged
merged 3 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions crates/core_simd/src/masks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
)]
mod mask_impl;

mod to_bitmask;
pub use to_bitmask::{ToBitMask, ToBitMaskArray};

use crate::simd::{
cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount,
};
Expand Down Expand Up @@ -262,6 +259,45 @@ where
pub fn all(self) -> bool {
self.0.all()
}

/// Create a bitmask from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask(self) -> u64 {
self.0.to_bitmask_integer()
}

/// Create a mask from a bitmask.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
/// If the mask contains more than 64 elements, the remainder are set to `false`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask(bitmask: u64) -> Self {
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
}

/// Create a bitmask vector from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// The remaining bits are unset.
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<T, N> {
self.0.to_bitmask_vector()
}

/// Create a mask from a bitmask vector.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_vector(bitmask: Simd<T, N>) -> Self {
Self(mask_impl::Mask::from_bitmask_vector(bitmask))
}
}

// vector/array conversion
Expand Down
69 changes: 47 additions & 22 deletions crates/core_simd/src/masks/bitmask.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(unused_imports)]
use super::MaskElement;
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
use core::marker::PhantomData;

/// A mask where each lane is represented by a single bit.
Expand Down Expand Up @@ -120,39 +120,64 @@ where
}

#[inline]
#[must_use = "method returns a new array and does not mutate the original value"]
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M] {
assert!(core::mem::size_of::<Self>() == M);
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<T, N> {
calebzulawski marked this conversation as resolved.
Show resolved Hide resolved
let mut bitmask = Self::splat(false).to_int();

assert!(
core::mem::size_of::<Simd<T, N>>()
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
);

// Safety: converting an integer to an array of bytes of the same size is safe
unsafe { core::mem::transmute_copy(&self.0) }
// Safety: the bitmask vector is big enough to hold the bitmask
unsafe {
core::ptr::copy_nonoverlapping(
self.0.as_ref().as_ptr(),
bitmask.as_mut_array().as_mut_ptr() as _,
self.0.as_ref().len(),
);
}

bitmask
}

#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_array<const M: usize>(bitmask: [u8; M]) -> Self {
assert!(core::mem::size_of::<Self>() == M);
pub fn from_bitmask_vector(bitmask: Simd<T, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();

assert!(
core::mem::size_of::<Simd<T, N>>()
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
);

// Safety: converting an array of bytes to an integer of the same size is safe
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
// Safety: the bitmask vector is big enough to hold the bitmask
unsafe {
core::ptr::copy_nonoverlapping(
bitmask.as_array().as_ptr() as _,
bytes.as_mut().as_mut_ptr(),
bytes.as_ref().len(),
);
}

Self(bytes, PhantomData)
}

#[inline]
pub fn to_bitmask_integer<U>(self) -> U
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { core::mem::transmute_copy(&self.0) }
pub fn to_bitmask_integer(self) -> u64 {
let mut bitmask = [0u8; 8];
bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
u64::from_ne_bytes(bitmask)
}

#[inline]
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
pub fn from_bitmask_integer(bitmask: u64) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
let len = bytes.as_mut().len();
bytes
.as_mut()
.copy_from_slice(&bitmask.to_ne_bytes()[..len]);
Self(bytes, PhantomData)
}

#[inline]
Expand Down
97 changes: 53 additions & 44 deletions crates/core_simd/src/masks/full_masks.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! Masks that take up full SIMD vector registers.

use super::{to_bitmask::ToBitMaskArray, MaskElement};
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};

#[repr(transparent)]
pub struct Mask<T, const N: usize>(Simd<T, N>)
Expand Down Expand Up @@ -143,95 +142,105 @@ where
}

#[inline]
#[must_use = "method returns a new array and does not mutate the original value"]
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M]
where
super::Mask<T, N>: ToBitMaskArray,
{
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<T, N> {
let mut bitmask = Self::splat(false).to_int();

// Safety: Bytes is the right size array
unsafe {
// Compute the bitmask
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
intrinsics::simd_bitmask(self.0);

// Transmute to the return type
let mut bitmask: [u8; M] = core::mem::transmute_copy(&bitmask);

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bitmask.as_mut() {
*x = x.reverse_bits();
for x in bytes.as_mut() {
*x = x.reverse_bits()
}
};
}

bitmask
assert!(
core::mem::size_of::<Simd<T, N>>()
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
);
core::ptr::copy_nonoverlapping(
bytes.as_ref().as_ptr(),
bitmask.as_mut_array().as_mut_ptr() as _,
bytes.as_ref().len(),
);
}

bitmask
}

#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_array<const M: usize>(mut bitmask: [u8; M]) -> Self
where
super::Mask<T, N>: ToBitMaskArray,
{
pub fn from_bitmask_vector(bitmask: Simd<T, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();

// Safety: Bytes is the right size array
unsafe {
assert!(
core::mem::size_of::<Simd<T, N>>()
>= core::mem::size_of::<<LaneCount<N> as SupportedLaneCount>::BitMask>()
);
core::ptr::copy_nonoverlapping(
bitmask.as_array().as_ptr() as _,
bytes.as_mut().as_mut_ptr(),
bytes.as_mut().len(),
);

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bitmask.as_mut() {
for x in bytes.as_mut() {
*x = x.reverse_bits();
}
}

// Transmute to the bitmask
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
core::mem::transmute_copy(&bitmask);

// Compute the regular mask
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
bitmask,
bytes,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
}
}

#[inline]
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: U is required to be the appropriate bitmask type
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
pub(crate) fn to_bitmask_integer(self) -> u64 {
let resized = self.to_int().extend::<64>(T::FALSE);

// SAFETY: `resized` is an integer vector with length 64
let bitmask: u64 = unsafe { intrinsics::simd_bitmask(resized) };

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
bitmask.reverse_bits(N)
bitmask.reverse_bits()
} else {
bitmask
}
}

#[inline]
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
// LLVM assumes bit order should match endianness
let bitmask = if cfg!(target_endian = "big") {
bitmask.reverse_bits(N)
bitmask.reverse_bits()
} else {
bitmask
};

// Safety: U is required to be the appropriate bitmask type
unsafe {
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
// SAFETY: `mask` is the correct bitmask type for a u64 bitmask
let mask: Simd<T, 64> = unsafe {
intrinsics::simd_select_bitmask(
bitmask,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
}
Simd::<T, 64>::splat(T::TRUE),
Simd::<T, 64>::splat(T::FALSE),
)
};

// SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
unsafe { Self::from_int_unchecked(mask.extend::<N>(T::FALSE)) }
}

#[inline]
Expand Down
Loading
Loading