diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index 1199153a5bd..63731342423 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -12,9 +12,6 @@ )] mod mask_impl; -mod to_bitmask; -pub use to_bitmask::{ToBitMask, ToBitMaskArray}; - use crate::simd::{ cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount, }; @@ -262,6 +259,45 @@ where pub fn all(self) -> bool { self.0.all() } + + /// Create a bitmask from a mask. + /// + /// Each bit is set if the corresponding element in the mask is `true`. + /// If the mask contains more than 64 elements, the bitmask is truncated to the first 64. + #[inline] + #[must_use = "method returns a new integer and does not mutate the original value"] + pub fn to_bitmask(self) -> u64 { + self.0.to_bitmask_integer() + } + + /// Create a mask from a bitmask. + /// + /// For each bit, if it is set, the corresponding element in the mask is set to `true`. + /// If the mask contains more than 64 elements, the remainder are set to `false`. + #[inline] + #[must_use = "method returns a new mask and does not mutate the original value"] + pub fn from_bitmask(bitmask: u64) -> Self { + Self(mask_impl::Mask::from_bitmask_integer(bitmask)) + } + + /// Create a bitmask vector from a mask. + /// + /// Each bit is set if the corresponding element in the mask is `true`. + /// The remaining bits are unset. + #[inline] + #[must_use = "method returns a new integer and does not mutate the original value"] + pub fn to_bitmask_vector(self) -> Simd { + self.0.to_bitmask_vector() + } + + /// Create a mask from a bitmask vector. + /// + /// For each bit, if it is set, the corresponding element in the mask is set to `true`. + #[inline] + #[must_use = "method returns a new mask and does not mutate the original value"] + pub fn from_bitmask_vector(bitmask: Simd) -> Self { + Self(mask_impl::Mask::from_bitmask_vector(bitmask)) + } } // vector/array conversion diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs index aaae28a07be..6ddff07fea2 100644 --- a/crates/core_simd/src/masks/bitmask.rs +++ b/crates/core_simd/src/masks/bitmask.rs @@ -1,7 +1,7 @@ #![allow(unused_imports)] use super::MaskElement; use crate::simd::intrinsics; -use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask}; +use crate::simd::{LaneCount, Simd, SupportedLaneCount}; use core::marker::PhantomData; /// A mask where each lane is represented by a single bit. @@ -120,39 +120,37 @@ where } #[inline] - #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_bitmask_array(self) -> [u8; M] { - assert!(core::mem::size_of::() == M); - - // Safety: converting an integer to an array of bytes of the same size is safe - unsafe { core::mem::transmute_copy(&self.0) } + #[must_use = "method returns a new vector and does not mutate the original value"] + pub fn to_bitmask_vector(self) -> Simd { + let mut bitmask = Simd::splat(0); + bitmask.as_mut_array()[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref()); + bitmask } #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn from_bitmask_array(bitmask: [u8; M]) -> Self { - assert!(core::mem::size_of::() == M); - - // Safety: converting an array of bytes to an integer of the same size is safe - Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData) + pub fn from_bitmask_vector(bitmask: Simd) -> Self { + let mut bytes = as SupportedLaneCount>::BitMask::default(); + let len = bytes.as_ref().len(); + bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]); + Self(bytes, PhantomData) } #[inline] - pub fn to_bitmask_integer(self) -> U - where - super::Mask: ToBitMask, - { - // Safety: these are the same types - unsafe { core::mem::transmute_copy(&self.0) } + pub fn to_bitmask_integer(self) -> u64 { + let mut bitmask = [0u8; 8]; + bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref()); + u64::from_ne_bytes(bitmask) } #[inline] - pub fn from_bitmask_integer(bitmask: U) -> Self - where - super::Mask: ToBitMask, - { - // Safety: these are the same types - unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) } + pub fn from_bitmask_integer(bitmask: u64) -> Self { + let mut bytes = as SupportedLaneCount>::BitMask::default(); + let len = bytes.as_mut().len(); + bytes + .as_mut() + .copy_from_slice(&bitmask.to_ne_bytes()[..len]); + Self(bytes, PhantomData) } #[inline] diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs index 2aa9272ab46..0d17e90c128 100644 --- a/crates/core_simd/src/masks/full_masks.rs +++ b/crates/core_simd/src/masks/full_masks.rs @@ -1,8 +1,7 @@ //! Masks that take up full SIMD vector registers. -use super::{to_bitmask::ToBitMaskArray, MaskElement}; use crate::simd::intrinsics; -use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask}; +use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount}; #[repr(transparent)] pub struct Mask(Simd) @@ -143,53 +142,49 @@ where } #[inline] - #[must_use = "method returns a new array and does not mutate the original value"] - pub fn to_bitmask_array(self) -> [u8; M] - where - super::Mask: ToBitMaskArray, - { + #[must_use = "method returns a new vector and does not mutate the original value"] + pub fn to_bitmask_vector(self) -> Simd { + let mut bitmask = Simd::splat(0); + // Safety: Bytes is the right size array unsafe { // Compute the bitmask - let bitmask: as ToBitMaskArray>::BitMaskArray = + let mut bytes: as SupportedLaneCount>::BitMask = intrinsics::simd_bitmask(self.0); - // Transmute to the return type - let mut bitmask: [u8; M] = core::mem::transmute_copy(&bitmask); - // LLVM assumes bit order should match endianness if cfg!(target_endian = "big") { - for x in bitmask.as_mut() { - *x = x.reverse_bits(); + for x in bytes.as_mut() { + *x = x.reverse_bits() } - }; + } - bitmask + bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref()); } + + bitmask } #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] - pub fn from_bitmask_array(mut bitmask: [u8; M]) -> Self - where - super::Mask: ToBitMaskArray, - { + pub fn from_bitmask_vector(bitmask: Simd) -> Self { + let mut bytes = as SupportedLaneCount>::BitMask::default(); + // Safety: Bytes is the right size array unsafe { + let len = bytes.as_ref().len(); + bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]); + // LLVM assumes bit order should match endianness if cfg!(target_endian = "big") { - for x in bitmask.as_mut() { + for x in bytes.as_mut() { *x = x.reverse_bits(); } } - // Transmute to the bitmask - let bitmask: as ToBitMaskArray>::BitMaskArray = - core::mem::transmute_copy(&bitmask); - // Compute the regular mask Self::from_int_unchecked(intrinsics::simd_select_bitmask( - bitmask, + bytes, Self::splat(true).to_int(), Self::splat(false).to_int(), )) @@ -197,40 +192,107 @@ where } #[inline] - pub(crate) fn to_bitmask_integer(self) -> U + unsafe fn to_bitmask_impl(self) -> U where - super::Mask: ToBitMask, + LaneCount: SupportedLaneCount, { - // Safety: U is required to be the appropriate bitmask type - let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) }; + let resized = self.to_int().resize::(T::FALSE); + + // Safety: `resized` is an integer vector with length M, which must match T + let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) }; // LLVM assumes bit order should match endianness if cfg!(target_endian = "big") { - bitmask.reverse_bits(N) + bitmask.reverse_bits(M) } else { bitmask } } #[inline] - pub(crate) fn from_bitmask_integer(bitmask: U) -> Self + unsafe fn from_bitmask_impl(bitmask: U) -> Self where - super::Mask: ToBitMask, + LaneCount: SupportedLaneCount, { // LLVM assumes bit order should match endianness let bitmask = if cfg!(target_endian = "big") { - bitmask.reverse_bits(N) + bitmask.reverse_bits(M) } else { bitmask }; - // Safety: U is required to be the appropriate bitmask type - unsafe { - Self::from_int_unchecked(intrinsics::simd_select_bitmask( + // SAFETY: `mask` is the correct bitmask type for a u64 bitmask + let mask: Simd = unsafe { + intrinsics::simd_select_bitmask( bitmask, - Self::splat(true).to_int(), - Self::splat(false).to_int(), - )) + Simd::::splat(T::TRUE), + Simd::::splat(T::FALSE), + ) + }; + + // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE` + unsafe { Self::from_int_unchecked(mask.resize::(T::FALSE)) } + } + + #[inline] + pub(crate) fn to_bitmask_integer(self) -> u64 { + // TODO modify simd_bitmask to zero-extend output, making this unnecessary + macro_rules! bitmask { + { $($ty:ty: $($len:literal),*;)* } => { + match N { + $($( + // Safety: bitmask matches length + $len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 }, + )*)* + // Safety: bitmask matches length + _ => unsafe { self.to_bitmask_impl::() }, + } + } + } + #[cfg(all_lane_counts)] + bitmask! { + u8: 1, 2, 3, 4, 5, 6, 7, 8; + u16: 9, 10, 11, 12, 13, 14, 15, 16; + u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32; + u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64; + } + #[cfg(not(all_lane_counts))] + bitmask! { + u8: 1, 2, 4, 8; + u16: 16; + u32: 32; + u64: 64; + } + } + + #[inline] + pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { + // TODO modify simd_bitmask_select to truncate input, making this unnecessary + macro_rules! bitmask { + { $($ty:ty: $($len:literal),*;)* } => { + match N { + $($( + // Safety: bitmask matches length + $len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) }, + )*)* + // Safety: bitmask matches length + _ => unsafe { Self::from_bitmask_impl::(bitmask) }, + } + } + } + #[cfg(all_lane_counts)] + bitmask! { + u8: 1, 2, 3, 4, 5, 6, 7, 8; + u16: 9, 10, 11, 12, 13, 14, 15, 16; + u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32; + u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64; + } + #[cfg(not(all_lane_counts))] + bitmask! { + u8: 1, 2, 4, 8; + u16: 16; + u32: 32; + u64: 64; } } diff --git a/crates/core_simd/src/masks/to_bitmask.rs b/crates/core_simd/src/masks/to_bitmask.rs deleted file mode 100644 index 06f09c65aca..00000000000 --- a/crates/core_simd/src/masks/to_bitmask.rs +++ /dev/null @@ -1,111 +0,0 @@ -use super::{mask_impl, Mask, MaskElement}; -use crate::simd::{LaneCount, SupportedLaneCount}; -use core::borrow::{Borrow, BorrowMut}; - -mod sealed { - pub trait Sealed {} -} -pub use sealed::Sealed; - -impl Sealed for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ -} - -/// Converts masks to and from integer bitmasks. -/// -/// Each bit of the bitmask corresponds to a mask element, starting with the LSB. -pub trait ToBitMask: Sealed { - /// The integer bitmask type. - type BitMask; - - /// Converts a mask to a bitmask. - fn to_bitmask(self) -> Self::BitMask; - - /// Converts a bitmask to a mask. - fn from_bitmask(bitmask: Self::BitMask) -> Self; -} - -/// Converts masks to and from byte array bitmasks. -/// -/// Each bit of the bitmask corresponds to a mask element, starting with the LSB of the first byte. -pub trait ToBitMaskArray: Sealed { - /// The bitmask array. - type BitMaskArray: Copy - + Unpin - + Send - + Sync - + AsRef<[u8]> - + AsMut<[u8]> - + Borrow<[u8]> - + BorrowMut<[u8]> - + 'static; - - /// Converts a mask to a bitmask. - fn to_bitmask_array(self) -> Self::BitMaskArray; - - /// Converts a bitmask to a mask. - fn from_bitmask_array(bitmask: Self::BitMaskArray) -> Self; -} - -macro_rules! impl_integer { - { $(impl ToBitMask for Mask<_, $lanes:literal>)* } => { - $( - impl ToBitMask for Mask { - type BitMask = $int; - - #[inline] - fn to_bitmask(self) -> $int { - self.0.to_bitmask_integer() - } - - #[inline] - fn from_bitmask(bitmask: $int) -> Self { - Self(mask_impl::Mask::from_bitmask_integer(bitmask)) - } - } - )* - } -} - -macro_rules! impl_array { - { $(impl ToBitMaskArray for Mask<_, $lanes:literal>)* } => { - $( - impl ToBitMaskArray for Mask { - type BitMaskArray = [u8; $int]; - - #[inline] - fn to_bitmask_array(self) -> Self::BitMaskArray { - self.0.to_bitmask_array() - } - - #[inline] - fn from_bitmask_array(bitmask: Self::BitMaskArray) -> Self { - Self(mask_impl::Mask::from_bitmask_array(bitmask)) - } - } - )* - } -} - -impl_integer! { - impl ToBitMask for Mask<_, 1> - impl ToBitMask for Mask<_, 2> - impl ToBitMask for Mask<_, 4> - impl ToBitMask for Mask<_, 8> - impl ToBitMask for Mask<_, 16> - impl ToBitMask for Mask<_, 32> - impl ToBitMask for Mask<_, 64> -} - -impl_array! { - impl ToBitMaskArray for Mask<_, 1> - impl ToBitMaskArray for Mask<_, 2> - impl ToBitMaskArray for Mask<_, 4> - impl ToBitMaskArray for Mask<_, 8> - impl ToBitMaskArray for Mask<_, 16> - impl ToBitMaskArray for Mask<_, 32> - impl ToBitMaskArray for Mask<_, 64> -} diff --git a/crates/core_simd/src/swizzle.rs b/crates/core_simd/src/swizzle.rs index 6af882c0a0e..ec8548d5574 100644 --- a/crates/core_simd/src/swizzle.rs +++ b/crates/core_simd/src/swizzle.rs @@ -349,4 +349,39 @@ where Odd::concat_swizzle(self, other), ) } + + /// Resize a vector. + /// + /// If `M` > `N`, extends the length of a vector, setting the new elements to `value`. + /// If `M` < `N`, truncates the vector to the first `M` elements. + /// + /// ``` + /// # #![feature(portable_simd)] + /// # #[cfg(feature = "as_crate")] use core_simd::simd; + /// # #[cfg(not(feature = "as_crate"))] use core::simd; + /// # use simd::u32x4; + /// let x = u32x4::from_array([0, 1, 2, 3]); + /// assert_eq!(x.resize::<8>(9).to_array(), [0, 1, 2, 3, 9, 9, 9, 9]); + /// assert_eq!(x.resize::<2>(9).to_array(), [0, 1]); + /// ``` + #[inline] + #[must_use = "method returns a new vector and does not mutate the original inputs"] + pub fn resize(self, value: T) -> Simd + where + LaneCount: SupportedLaneCount, + { + struct Resize; + impl Swizzle for Resize { + const INDEX: [usize; M] = const { + let mut index = [0; M]; + let mut i = 0; + while i < M { + index[i] = if i < N { i } else { N }; + i += 1; + } + index + }; + } + Resize::::concat_swizzle(self, Simd::splat(value)) + } } diff --git a/crates/core_simd/tests/masks.rs b/crates/core_simd/tests/masks.rs index 7c1d4c7dd3f..00fc2a24e27 100644 --- a/crates/core_simd/tests/masks.rs +++ b/crates/core_simd/tests/masks.rs @@ -72,7 +72,6 @@ macro_rules! test_mask_api { #[test] fn roundtrip_bitmask_conversion() { - use core_simd::simd::ToBitMask; let values = [ true, false, false, true, false, false, true, false, true, true, false, false, false, false, false, true, @@ -85,8 +84,6 @@ macro_rules! test_mask_api { #[test] fn roundtrip_bitmask_conversion_short() { - use core_simd::simd::ToBitMask; - let values = [ false, false, false, true, ]; @@ -126,16 +123,16 @@ macro_rules! test_mask_api { } #[test] - fn roundtrip_bitmask_array_conversion() { - use core_simd::simd::ToBitMaskArray; + fn roundtrip_bitmask_vector_conversion() { + use core_simd::simd::ToBytes; let values = [ true, false, false, true, false, false, true, false, true, true, false, false, false, false, false, true, ]; let mask = Mask::<$type, 16>::from_array(values); - let bitmask = mask.to_bitmask_array(); - assert_eq!(bitmask, [0b01001001, 0b10000011]); - assert_eq!(Mask::<$type, 16>::from_bitmask_array(bitmask), mask); + let bitmask = mask.to_bitmask_vector(); + assert_eq!(bitmask.resize::<2>(0).to_ne_bytes()[..2], [0b01001001, 0b10000011]); + assert_eq!(Mask::<$type, 16>::from_bitmask_vector(bitmask), mask); } } }