Skip to content

Commit

Permalink
Merge pull request #375 from rust-lang/bitmask
Browse files Browse the repository at this point in the history
Simplify bitmasks
  • Loading branch information
calebzulawski authored Nov 19, 2023
2 parents 8d9bcda + 0ad68db commit 7e5c03a
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 185 deletions.
42 changes: 39 additions & 3 deletions crates/core_simd/src/masks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
)]
mod mask_impl;

mod to_bitmask;
pub use to_bitmask::{ToBitMask, ToBitMaskArray};

use crate::simd::{
cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount,
};
Expand Down Expand Up @@ -262,6 +259,45 @@ where
pub fn all(self) -> bool {
self.0.all()
}

/// Create a bitmask from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask(self) -> u64 {
self.0.to_bitmask_integer()
}

/// Create a mask from a bitmask.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
/// If the mask contains more than 64 elements, the remainder are set to `false`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask(bitmask: u64) -> Self {
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
}

/// Create a bitmask vector from a mask.
///
/// Each bit is set if the corresponding element in the mask is `true`.
/// The remaining bits are unset.
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
self.0.to_bitmask_vector()
}

/// Create a mask from a bitmask vector.
///
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
Self(mask_impl::Mask::from_bitmask_vector(bitmask))
}
}

// vector/array conversion
Expand Down
46 changes: 22 additions & 24 deletions crates/core_simd/src/masks/bitmask.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(unused_imports)]
use super::MaskElement;
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
use core::marker::PhantomData;

/// A mask where each lane is represented by a single bit.
Expand Down Expand Up @@ -120,39 +120,37 @@ where
}

#[inline]
#[must_use = "method returns a new array and does not mutate the original value"]
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M] {
assert!(core::mem::size_of::<Self>() == M);

// Safety: converting an integer to an array of bytes of the same size is safe
unsafe { core::mem::transmute_copy(&self.0) }
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
let mut bitmask = Simd::splat(0);
bitmask.as_mut_array()[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
bitmask
}

#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_array<const M: usize>(bitmask: [u8; M]) -> Self {
assert!(core::mem::size_of::<Self>() == M);

// Safety: converting an array of bytes to an integer of the same size is safe
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
let len = bytes.as_ref().len();
bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
Self(bytes, PhantomData)
}

#[inline]
pub fn to_bitmask_integer<U>(self) -> U
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { core::mem::transmute_copy(&self.0) }
pub fn to_bitmask_integer(self) -> u64 {
let mut bitmask = [0u8; 8];
bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
u64::from_ne_bytes(bitmask)
}

#[inline]
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
pub fn from_bitmask_integer(bitmask: u64) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
let len = bytes.as_mut().len();
bytes
.as_mut()
.copy_from_slice(&bitmask.to_ne_bytes()[..len]);
Self(bytes, PhantomData)
}

#[inline]
Expand Down
140 changes: 101 additions & 39 deletions crates/core_simd/src/masks/full_masks.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! Masks that take up full SIMD vector registers.

use super::{to_bitmask::ToBitMaskArray, MaskElement};
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};

#[repr(transparent)]
pub struct Mask<T, const N: usize>(Simd<T, N>)
Expand Down Expand Up @@ -143,94 +142,157 @@ where
}

#[inline]
#[must_use = "method returns a new array and does not mutate the original value"]
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M]
where
super::Mask<T, N>: ToBitMaskArray,
{
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
let mut bitmask = Simd::splat(0);

// Safety: Bytes is the right size array
unsafe {
// Compute the bitmask
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
intrinsics::simd_bitmask(self.0);

// Transmute to the return type
let mut bitmask: [u8; M] = core::mem::transmute_copy(&bitmask);

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bitmask.as_mut() {
*x = x.reverse_bits();
for x in bytes.as_mut() {
*x = x.reverse_bits()
}
};
}

bitmask
bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref());
}

bitmask
}

#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask_array<const M: usize>(mut bitmask: [u8; M]) -> Self
where
super::Mask<T, N>: ToBitMaskArray,
{
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();

// Safety: Bytes is the right size array
unsafe {
let len = bytes.as_ref().len();
bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
for x in bitmask.as_mut() {
for x in bytes.as_mut() {
*x = x.reverse_bits();
}
}

// Transmute to the bitmask
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
core::mem::transmute_copy(&bitmask);

// Compute the regular mask
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
bitmask,
bytes,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
}
}

#[inline]
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
LaneCount<M>: SupportedLaneCount,
{
// Safety: U is required to be the appropriate bitmask type
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
let resized = self.to_int().resize::<M>(T::FALSE);

// Safety: `resized` is an integer vector with length M, which must match T
let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) };

// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
bitmask.reverse_bits(N)
bitmask.reverse_bits(M)
} else {
bitmask
}
}

#[inline]
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
where
super::Mask<T, N>: ToBitMask<BitMask = U>,
LaneCount<M>: SupportedLaneCount,
{
// LLVM assumes bit order should match endianness
let bitmask = if cfg!(target_endian = "big") {
bitmask.reverse_bits(N)
bitmask.reverse_bits(M)
} else {
bitmask
};

// Safety: U is required to be the appropriate bitmask type
unsafe {
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
// SAFETY: `mask` is the correct bitmask type for a u64 bitmask
let mask: Simd<T, M> = unsafe {
intrinsics::simd_select_bitmask(
bitmask,
Self::splat(true).to_int(),
Self::splat(false).to_int(),
))
Simd::<T, M>::splat(T::TRUE),
Simd::<T, M>::splat(T::FALSE),
)
};

// SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
}

#[inline]
pub(crate) fn to_bitmask_integer(self) -> u64 {
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
macro_rules! bitmask {
{ $($ty:ty: $($len:literal),*;)* } => {
match N {
$($(
// Safety: bitmask matches length
$len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 },
)*)*
// Safety: bitmask matches length
_ => unsafe { self.to_bitmask_impl::<u64, 64>() },
}
}
}
#[cfg(all_lane_counts)]
bitmask! {
u8: 1, 2, 3, 4, 5, 6, 7, 8;
u16: 9, 10, 11, 12, 13, 14, 15, 16;
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
}
#[cfg(not(all_lane_counts))]
bitmask! {
u8: 1, 2, 4, 8;
u16: 16;
u32: 32;
u64: 64;
}
}

#[inline]
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
// TODO modify simd_bitmask_select to truncate input, making this unnecessary
macro_rules! bitmask {
{ $($ty:ty: $($len:literal),*;)* } => {
match N {
$($(
// Safety: bitmask matches length
$len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) },
)*)*
// Safety: bitmask matches length
_ => unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) },
}
}
}
#[cfg(all_lane_counts)]
bitmask! {
u8: 1, 2, 3, 4, 5, 6, 7, 8;
u16: 9, 10, 11, 12, 13, 14, 15, 16;
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
}
#[cfg(not(all_lane_counts))]
bitmask! {
u8: 1, 2, 4, 8;
u16: 16;
u32: 32;
u64: 64;
}
}

Expand Down
Loading

0 comments on commit 7e5c03a

Please sign in to comment.