Skip to content

Commit 02c00d3

Browse files
committed
add AVX512 support for filtering in place
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent b4c0f8d commit 02c00d3

File tree

11 files changed

+783
-69
lines changed

11 files changed

+783
-69
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-compute/Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ arrow = ["dep:arrow-array", "dep:arrow-buffer", "dep:arrow-schema"]
3737

3838
[dev-dependencies]
3939
divan = { workspace = true }
40+
itertools = { workspace = true }
41+
rand = { workspace = true }
4042

4143
[[bench]]
4244
name = "filter_buffer_mut"
@@ -45,3 +47,7 @@ harness = false
4547
[[bench]]
4648
name = "expand_buffer"
4749
harness = false
50+
51+
[[bench]]
52+
name = "avx512"
53+
harness = false

vortex-compute/benches/avx512.rs

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#![expect(clippy::cast_possible_truncation)]
5+
6+
use itertools::Itertools;
7+
use rand::Rng;
8+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
9+
use vortex_compute::filter::slice::in_place::avx512::filter_in_place_avx512;
10+
use vortex_compute::filter::slice::in_place::filter_in_place_scalar;
11+
12+
fn main() {
13+
divan::main();
14+
}
15+
16+
// Create a random mask where each bit has `probability` chance of being set.
17+
fn create_random_mask(size: usize, probability: f64) -> Vec<u8> {
18+
let mut rng = rand::rng();
19+
let num_bytes = size.div_ceil(8);
20+
let mut mask = Vec::with_capacity(num_bytes);
21+
22+
for _ in 0..num_bytes {
23+
let mut byte = 0u8;
24+
for bit in 0..8 {
25+
if rng.random::<f64>() < probability {
26+
byte |= 1 << bit;
27+
}
28+
}
29+
mask.push(byte);
30+
}
31+
32+
mask
33+
}
34+
35+
// Benchmark different data sizes.
36+
const SIZES: &[usize] = &[1 << 10, 1 << 14, 1 << 17];
37+
38+
// Different probability values to benchmark.
39+
const PROBABILITIES: &[f64] = &[0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0];
40+
41+
#[divan::bench(sample_size = 64, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
42+
fn random_probability_scalar(bencher: divan::Bencher, (size, probability): (usize, f64)) {
43+
let mask = create_random_mask(size, probability);
44+
bencher
45+
.with_inputs(|| (0..size as i32).collect::<Vec<_>>())
46+
.bench_values(|mut data| filter_in_place_scalar(&mut data, &mask))
47+
}
48+
49+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
50+
#[divan::bench(sample_size = 64, args = SIZES.iter().copied().cartesian_product(PROBABILITIES.iter().copied()))]
51+
fn random_probability_avx512(bencher: divan::Bencher, (size, probability): (usize, f64)) {
52+
let mask = create_random_mask(size, probability);
53+
bencher
54+
.with_inputs(|| (0..size as i32).collect::<Vec<_>>())
55+
.bench_values(|mut data| unsafe { filter_in_place_avx512(&mut data, &mask) })
56+
}
57+
58+
const LARGE_SIZE: usize = 1024 * 1024; // 4 MB
59+
60+
#[divan::bench(sample_size = 16, args = PROBABILITIES)]
61+
fn scalar_throughput(bencher: divan::Bencher, probability: f64) {
62+
let mask = create_random_mask(LARGE_SIZE, probability);
63+
bencher
64+
.counter(divan::counter::BytesCount::new(LARGE_SIZE * 4))
65+
.with_inputs(|| (0..LARGE_SIZE as i32).collect::<Vec<_>>())
66+
.bench_values(|mut data| filter_in_place_scalar(&mut data, &mask))
67+
}
68+
69+
#[divan::bench(sample_size = 16, args = PROBABILITIES)]
70+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
71+
fn avx512_throughput(bencher: divan::Bencher, probability: f64) {
72+
let mask = create_random_mask(LARGE_SIZE, probability);
73+
bencher
74+
.counter(divan::counter::BytesCount::new(LARGE_SIZE * 4))
75+
.with_inputs(|| (0..LARGE_SIZE as i32).collect::<Vec<_>>())
76+
.bench_values(|mut data| unsafe { filter_in_place_avx512(&mut data, &mask) })
77+
}

vortex-compute/src/filter/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
mod bitbuffer;
77
mod buffer;
88
mod mask;
9-
mod slice_mut;
9+
pub mod slice;
1010
mod vector;
1111

1212
/// Function for filtering based on a selection mask.
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! Implementations of a specialized in-place filter for mutable buffers using AVX512.
5+
6+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
7+
use std::arch::x86_64::*;
8+
9+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
10+
use crate::filter::slice::SimdCompress;
11+
use crate::filter::slice::in_place::filter_in_place_scalar;
12+
13+
/// Filter a mutable slice of elements in-place depending on the given mask.
14+
///
15+
/// The mask is represented as a slice of bytes (LSB is the first element).
16+
///
17+
/// Returns the true count of the mask.
18+
///
19+
/// This function automatically dispatches to the most efficient implementation based on the
20+
/// available CPU features at compile time.
21+
///
22+
/// # Panics
23+
///
24+
/// Panics if `mask.len() != data.len().div_ceil(8)`.
25+
#[inline]
26+
pub fn filter_in_place<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize {
27+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
28+
{
29+
let use_simd = if T::WIDTH >= 32 {
30+
// 32-bit and 64-bit types only need AVX-512F.
31+
is_x86_feature_detected!("avx512f")
32+
} else {
33+
// 8-bit and 16-bit types need both AVX-512F and AVX-512VBMI2.
34+
is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vbmi2")
35+
};
36+
37+
if use_simd {
38+
return unsafe { filter_in_place_avx512(data, mask) };
39+
}
40+
}
41+
42+
// Fall back to scalar implementation for non-x86 or when SIMD not available.
43+
filter_in_place_scalar(data, mask)
44+
}
45+
46+
/// Filter a mutable slice of elements in-place depending on the given mask.
47+
///
48+
/// The mask is represented as a slice of bytes (LSB is the first element).
49+
///
50+
/// Returns the true count of the mask.
51+
///
52+
/// This function uses AVX-512 SIMD instructions for high-performance filtering.
53+
///
54+
/// # Panics
55+
///
56+
/// Panics if `mask.len() != data.len().div_ceil(8)`.
57+
///
58+
/// # Safety
59+
///
60+
/// This function requires the appropriate SIMD instruction set to be available.
61+
/// For AVX-512F types, the CPU must support AVX-512F.
62+
/// For AVX-512VBMI2 types, the CPU must support AVX-512VBMI2.
63+
#[inline]
64+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
65+
#[target_feature(enable = "avx512f,avx512vbmi2,popcnt")]
66+
pub unsafe fn filter_in_place_avx512<T: SimdCompress>(data: &mut [T], mask: &[u8]) -> usize {
67+
assert_eq!(
68+
mask.len(),
69+
data.len().div_ceil(8),
70+
"Mask length must be data.len().div_ceil(8)"
71+
);
72+
73+
let data_len = data.len();
74+
let mut write_pos = 0;
75+
76+
// Pre-calculate loop bounds to eliminate branch misprediction in the hot loop.
77+
let full_chunks = data_len / T::ELEMENTS_PER_VECTOR;
78+
let remainder = data_len % T::ELEMENTS_PER_VECTOR;
79+
80+
// Process full chunks with no branches in the loop.
81+
for chunk_idx in 0..full_chunks {
82+
let read_pos = chunk_idx * T::ELEMENTS_PER_VECTOR;
83+
let mask_byte_offset = chunk_idx * T::MASK_BYTES;
84+
85+
// Read the mask for this chunk.
86+
// SAFETY: `mask_byte_offset + T::MASK_BYTES <= mask.len()` for all full chunks.
87+
let mask_value = unsafe { T::read_mask(mask.as_ptr(), mask_byte_offset) };
88+
89+
// Load elements into the SIMD register.
90+
// SAFETY: `read_pos + T::ELEMENTS_PER_VECTOR <= data.len()` for all full chunks.
91+
let vector = unsafe { _mm512_loadu_si512(data.as_ptr().add(read_pos) as *const __m512i) };
92+
93+
// Moves all elements that have their bit set to 1 in the mask value to the left.
94+
let filtered = unsafe { T::compress_vector(mask_value, vector) };
95+
96+
// Write the filtered result vector back to memory.
97+
// SAFETY: `write_pos + count_ones(mask_value) <= data.len()` since we're compacting.
98+
unsafe { _mm512_storeu_si512(data.as_mut_ptr().add(write_pos) as *mut __m512i, filtered) };
99+
100+
// Uses the hardware `popcnt` instruction if available.
101+
let count = T::count_ones(mask_value);
102+
write_pos += count;
103+
}
104+
105+
// Handle the final partial chunk with simple scalar processing.
106+
let read_pos = full_chunks * T::ELEMENTS_PER_VECTOR;
107+
for i in 0..remainder {
108+
let read_idx = read_pos + i;
109+
let bit_idx = read_idx % 8;
110+
let byte_idx = read_idx / 8;
111+
112+
if (mask[byte_idx] >> bit_idx) & 1 == 1 {
113+
data[write_pos] = data[read_idx];
114+
write_pos += 1;
115+
}
116+
}
117+
118+
write_pos
119+
}

0 commit comments

Comments
 (0)