|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +//! Implementations of a specialized in-place filter for mutable buffers using AVX512. |
| 5 | +
|
| 6 | +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 7 | +use std::arch::x86_64::*; |
| 8 | + |
| 9 | +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 10 | +use crate::filter::slice::SimdCompress; |
| 11 | +use crate::filter::slice::in_place::filter_in_place_scalar; |
| 12 | + |
| 13 | +/// Filter a mutable slice of elements in-place depending on the given mask. |
| 14 | +/// |
| 15 | +/// The mask is represented as a slice of bytes (LSB is the first element). |
| 16 | +/// |
| 17 | +/// Returns the true count of the mask. |
| 18 | +/// |
| 19 | +/// This function automatically dispatches to the most efficient implementation based on the |
| 20 | +/// available CPU features at compile time. |
| 21 | +/// |
| 22 | +/// # Panics |
| 23 | +/// |
| 24 | +/// Panics if `mask.len() != data.len().div_ceil(8)`. |
| 25 | +#[inline] |
| 26 | +pub fn filter_in_place<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize { |
| 27 | + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 28 | + { |
| 29 | + let use_simd = if T::WIDTH >= 32 { |
| 30 | + // 32-bit and 64-bit types only need AVX-512F. |
| 31 | + is_x86_feature_detected!("avx512f") |
| 32 | + } else { |
| 33 | + // 8-bit and 16-bit types need both AVX-512F and AVX-512VBMI2. |
| 34 | + is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vbmi2") |
| 35 | + }; |
| 36 | + |
| 37 | + if use_simd { |
| 38 | + return unsafe { filter_in_place_avx512(data, mask) }; |
| 39 | + } |
| 40 | + } |
| 41 | + |
| 42 | + // Fall back to scalar implementation for non-x86 or when SIMD not available. |
| 43 | + filter_in_place_scalar(data, mask) |
| 44 | +} |
| 45 | + |
| 46 | +/// Filter a mutable slice of elements in-place depending on the given mask. |
| 47 | +/// |
| 48 | +/// The mask is represented as a slice of bytes (LSB is the first element). |
| 49 | +/// |
| 50 | +/// Returns the true count of the mask. |
| 51 | +/// |
| 52 | +/// This function uses AVX-512 SIMD instructions for high-performance filtering. |
| 53 | +/// |
| 54 | +/// # Panics |
| 55 | +/// |
| 56 | +/// Panics if `mask.len() != data.len().div_ceil(8)`. |
| 57 | +/// |
| 58 | +/// # Safety |
| 59 | +/// |
| 60 | +/// This function requires the appropriate SIMD instruction set to be available. |
| 61 | +/// For AVX-512F types, the CPU must support AVX-512F. |
| 62 | +/// For AVX-512VBMI2 types, the CPU must support AVX-512VBMI2. |
| 63 | +#[inline] |
| 64 | +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] |
| 65 | +#[target_feature(enable = "avx512f,avx512vbmi2,popcnt")] |
| 66 | +pub unsafe fn filter_in_place_avx512<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize { |
| 67 | + assert_eq!( |
| 68 | + mask.len(), |
| 69 | + data.len().div_ceil(8), |
| 70 | + "Mask length must be data.len().div_ceil(8)" |
| 71 | + ); |
| 72 | + |
| 73 | + let data_len = data.len(); |
| 74 | + let mut write_pos = 0; |
| 75 | + |
| 76 | + // Pre-calculate loop bounds to eliminate branch misprediction in the hot loop. |
| 77 | + let full_chunks = data_len / T::ELEMENTS_PER_VECTOR; |
| 78 | + let remainder = data_len % T::ELEMENTS_PER_VECTOR; |
| 79 | + |
| 80 | + // Process full chunks with no branches in the loop. |
| 81 | + for chunk_idx in 0..full_chunks { |
| 82 | + let read_pos = chunk_idx * T::ELEMENTS_PER_VECTOR; |
| 83 | + let mask_byte_offset = chunk_idx * T::MASK_BYTES; |
| 84 | + |
| 85 | + // Read the mask for this chunk. |
| 86 | + // SAFETY: `mask_byte_offset + T::MASK_BYTES <= mask.len()` for all full chunks. |
| 87 | + let mask_value = unsafe { T::read_mask(mask.as_ptr(), mask_byte_offset) }; |
| 88 | + |
| 89 | + // Load elements into the SIMD register. |
| 90 | + // SAFETY: `read_pos + T::ELEMENTS_PER_VECTOR <= data.len()` for all full chunks. |
| 91 | + let vector = unsafe { _mm512_loadu_si512(data.as_ptr().add(read_pos) as *const __m512i) }; |
| 92 | + |
| 93 | + // Moves all elements that have their bit set to 1 in the mask value to the left. |
| 94 | + let filtered = unsafe { T::compress_vector(mask_value, vector) }; |
| 95 | + |
| 96 | + // Write the filtered result vector back to memory. |
| 97 | + // SAFETY: `write_pos + count_ones(mask_value) <= data.len()` since we're compacting. |
| 98 | + unsafe { _mm512_storeu_si512(data.as_mut_ptr().add(write_pos) as *mut __m512i, filtered) }; |
| 99 | + |
| 100 | + // Uses the hardware `popcnt` instruction if available. |
| 101 | + let count = T::count_ones(mask_value); |
| 102 | + write_pos += count; |
| 103 | + } |
| 104 | + |
| 105 | + // Handle the final partial chunk with simple scalar processing. |
| 106 | + let read_pos = full_chunks * T::ELEMENTS_PER_VECTOR; |
| 107 | + for i in 0..remainder { |
| 108 | + let read_idx = read_pos + i; |
| 109 | + let bit_idx = read_idx % 8; |
| 110 | + let byte_idx = read_idx / 8; |
| 111 | + |
| 112 | + if (mask[byte_idx] >> bit_idx) & 1 == 1 { |
| 113 | + data[write_pos] = data[read_idx]; |
| 114 | + write_pos += 1; |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + write_pos |
| 119 | +} |
0 commit comments