Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify bitmasks #375

Merged
merged 3 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions crates/core_simd/src/masks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
)]
mod mask_impl;

mod to_bitmask;
pub use to_bitmask::{ToBitMask, ToBitMaskArray};

use crate::simd::{
cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount,
};
Expand Down Expand Up @@ -262,6 +259,45 @@ where
pub fn all(self) -> bool {
self.0.all()
}

/// Create a bitmask from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask(self) -> u64 {
self.0.to_bitmask_integer()
}

/// Create a mask from a bitmask.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
/// If the mask contains more than 64 elements, the remainder are set to `false`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask(bitmask: u64) -> Self {
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
}

/// Create a bitmask vector from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// The remaining bits are unset.
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
self.0.to_bitmask_vector()
}

/// Create a mask from a bitmask vector.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
Self(mask_impl::Mask::from_bitmask_vector(bitmask))
}
}

// vector/array conversion
Expand Down
46 changes: 22 additions & 24 deletions crates/core_simd/src/masks/bitmask.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(unused_imports)]
use super::MaskElement;
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
use core::marker::PhantomData;

/// A mask where each lane is represented by a single bit.
Expand Down Expand Up @@ -120,39 +120,37 @@ where
}

#[inline]
#[must_use = "method returns a new array and does not mutate the original value"]
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M] {
assert!(core::mem::size_of::<Self>() == M);

// Safety: converting an integer to an array of bytes of the same size is safe
unsafe { core::mem::transmute_copy(&self.0) }
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
let mut bitmask = Simd::splat(0);
bitmask.as_mut_array()[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
bitmask
}

#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_array<const M: usize>(bitmask: [u8; M]) -> Self {
assert!(core::mem::size_of::<Self>() == M);

// Safety: converting an array of bytes to an integer of the same size is safe
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
let len = bytes.as_ref().len();
bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
Self(bytes, PhantomData)
}

#[inline]
pub fn to_bitmask_integer<U>(self) -> U
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { core::mem::transmute_copy(&self.0) }
pub fn to_bitmask_integer(self) -> u64 {
let mut bitmask = [0u8; 8];
bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
u64::from_ne_bytes(bitmask)
}

#[inline]
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
pub fn from_bitmask_integer(bitmask: u64) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
let len = bytes.as_mut().len();
bytes
.as_mut()
.copy_from_slice(&bitmask.to_ne_bytes()[..len]);
Self(bytes, PhantomData)
}

#[inline]
Expand Down
140 changes: 101 additions & 39 deletions crates/core_simd/src/masks/full_masks.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! Masks that take up full SIMD vector registers.

use super::{to_bitmask::ToBitMaskArray, MaskElement};
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};

#[repr(transparent)]
pub struct Mask<T, const N: usize>(Simd<T, N>)
Expand Down Expand Up @@ -143,94 +142,157 @@ where
}

#[inline]
#[must_use = "method returns a new array and does not mutate the original value"]
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M]
where
super::Mask<T, N>: ToBitMaskArray,
{
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
let mut bitmask = Simd::splat(0);

// Safety: Bytes is the right size array
unsafe {
// Compute the bitmask
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
intrinsics::simd_bitmask(self.0);

// Transmute to the return type
let mut bitmask: [u8; M] = core::mem::transmute_copy(&bitmask);

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bitmask.as_mut() {
*x = x.reverse_bits();
for x in bytes.as_mut() {
*x = x.reverse_bits()
}
};
}

bitmask
bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref());
}

bitmask
}

#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_array<const M: usize>(mut bitmask: [u8; M]) -> Self
where
super::Mask<T, N>: ToBitMaskArray,
{
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();

// Safety: Bytes is the right size array
unsafe {
let len = bytes.as_ref().len();
bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bitmask.as_mut() {
for x in bytes.as_mut() {
*x = x.reverse_bits();
}
}

// Transmute to the bitmask
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
core::mem::transmute_copy(&bitmask);

// Compute the regular mask
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
bitmask,
bytes,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
}
}

#[inline]
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
LaneCount<M>: SupportedLaneCount,
{
// Safety: U is required to be the appropriate bitmask type
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
let resized = self.to_int().resize::<M>(T::FALSE);

// Safety: `resized` is an integer vector with length M, which must match T
let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) };

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
bitmask.reverse_bits(N)
bitmask.reverse_bits(M)
} else {
bitmask
}
}

#[inline]
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
LaneCount<M>: SupportedLaneCount,
{
// LLVM assumes bit order should match endianness
let bitmask = if cfg!(target_endian = "big") {
bitmask.reverse_bits(N)
bitmask.reverse_bits(M)
} else {
bitmask
};

// Safety: U is required to be the appropriate bitmask type
unsafe {
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
// SAFETY: `mask` is the correct bitmask type for a u64 bitmask
let mask: Simd<T, M> = unsafe {
intrinsics::simd_select_bitmask(
bitmask,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
Simd::<T, M>::splat(T::TRUE),
Simd::<T, M>::splat(T::FALSE),
)
};

// SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
}

#[inline]
pub(crate) fn to_bitmask_integer(self) -> u64 {
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
macro_rules! bitmask {
{ $($ty:ty: $($len:literal),*;)* } => {
match N {
$($(
// Safety: bitmask matches length
$len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 },
)*)*
// Safety: bitmask matches length
_ => unsafe { self.to_bitmask_impl::<u64, 64>() },
}
}
}
#[cfg(all_lane_counts)]
bitmask! {
u8: 1, 2, 3, 4, 5, 6, 7, 8;
u16: 9, 10, 11, 12, 13, 14, 15, 16;
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
}
#[cfg(not(all_lane_counts))]
bitmask! {
u8: 1, 2, 4, 8;
u16: 16;
u32: 32;
u64: 64;
}
}

#[inline]
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
// TODO modify simd_bitmask_select to truncate input, making this unnecessary
macro_rules! bitmask {
{ $($ty:ty: $($len:literal),*;)* } => {
match N {
$($(
// Safety: bitmask matches length
$len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) },
)*)*
// Safety: bitmask matches length
_ => unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) },
}
}
}
#[cfg(all_lane_counts)]
bitmask! {
u8: 1, 2, 3, 4, 5, 6, 7, 8;
u16: 9, 10, 11, 12, 13, 14, 15, 16;
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
}
#[cfg(not(all_lane_counts))]
bitmask! {
u8: 1, 2, 4, 8;
u16: 16;
u32: 32;
u64: 64;
}
}

Expand Down
Loading
Loading