Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 97 additions & 52 deletions diskann-wide/src/arch/x86_64/algorithms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,52 @@
// x86 intrinsics
use std::arch::x86_64::*;

use super::{V3, v3::i32x4};
use crate::SIMDVector;
use super::V3;

/// Efficiently load the first `8 < bytes < 16` bytes from `ptr` without accessing memory
/// outside of `[ptr, ptr + bytes)`.
///
/// # Safety
///
/// * `bytes` must be in the range `(8, 16)`.
/// * The memory in `[ptr, ptr + bytes)` must be readable and valid.
#[inline(always)]
unsafe fn __load_8_to_16_bytes(_: V3, ptr: *const u8, bytes: usize) -> __m128i {
debug_assert!(bytes > 8 && bytes < 16);

// The trick here is to use 2 8-byte loads. One (call it X) beginning at `ptr` loading
// `[ptr, ptr + 8)` and the other (call it Y) loading `[ptr + bytes - 8, ptr + bytes)`.
//
// Then, we need a way to glue Y after the first `bytes - 8` bytes of X (formulating the
// problem this way is done intentionally as we'll see below).
//
// We do this using the powerful `_mm_shuffle_epi8` instruction.
//
// This is set up by using an identity shuffle adjusted by subtracting the shift amount.
// Lanes that underflow become negative (high bit set), which `_mm_shuffle_epi8` zeroes.
// Lanes beyond the loaded 8 bytes read from the zero-extended upper half of
// `_mm_loadl_epi64`, producing zeros that are harmless under OR.
//
// For example, if `bytes` is 13 (8 + 5), the adjusted shuffle mask is
// ```
// [-X, -X, -X, -X, -X, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
// |-----------------|
// output lanes here
// will be zeroed
// ```
// This will effectively move the 8 bytes of Y each over by 5 lanes. When OR'ed with X,
// this becomes the 13 bytes we want.
//
// SAFETY: Both reads are within `[ptr, ptr + bytes)`. The intrinsics require SSSE3/SSE2,
// available on V3.
unsafe {
let base = _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
let lo = _mm_loadl_epi64(ptr as *const __m128i);
let hi = _mm_loadl_epi64(ptr.add(bytes - 8) as *const __m128i);
let mask = _mm_sub_epi8(base, _mm_set1_epi8((bytes - 8) as i8));
_mm_or_si128(lo, _mm_shuffle_epi8(hi, mask))
}
}

/// Perform a load of the first `first` bytes beginning at `ptr` into
/// an unsigned 128-bit integer.
Expand All @@ -25,9 +69,8 @@ use crate::SIMDVector;
///
/// Guarantee: Memory addresses in `[ptr + first, ptr + 16)` will not be accessed.
#[inline(always)]
pub(crate) unsafe fn __load_first_of_16_bytes(_: V3, mut ptr: *const u8, first: usize) -> u128 {
let mut remaining = first;
if remaining >= 16 {
pub(crate) unsafe fn __load_first_of_16_bytes(arch: V3, ptr: *const u8, first: usize) -> u128 {
if first >= 16 {
// SAFETY:
// * Pointer Cast: The instruction `_mm_loadu_si128` does not have any alignment
// restrictions, so if `[ptr, ptr + first)` is valid, the cast will be valid.
Expand All @@ -40,41 +83,36 @@ pub(crate) unsafe fn __load_first_of_16_bytes(_: V3, mut ptr: *const u8, first:
};
}

// Move the pointer to one-past the end of the memory we are going to load.
//
// SAFETY: The caller asserts that the memory in `[ptr, ptr + first)` is valid.
ptr = unsafe { ptr.add(first) };

let mut buffer: u128 = 0;
// For `first > 8`, use the optimized two-load method.
if first > 8 {
// SAFETY: `first` is in `(8, 16)` and `[ptr, ptr + first)` is valid.
return unsafe {
std::mem::transmute::<__m128i, u128>(__load_8_to_16_bytes(arch, ptr, first))
};
}

// SAFETY: We emit in-bounds unaligned reads that are in the range specified by the
// caller to be safe.
// For `first <= 8`, everything fits in general purpose registers.
//
// Use two overlapping reads whose results are combined with a single shift + OR.
//
// SAFETY: All reads are within `[ptr, ptr + first)`, which the caller asserts is valid.
unsafe {
if remaining >= 8 {
ptr = ptr.sub(8);
let v: u64 = std::ptr::read_unaligned(ptr as *const u64);
buffer |= v as u128;
remaining -= 8;
}
if remaining >= 4 {
ptr = ptr.sub(4);
let v: u32 = std::ptr::read_unaligned(ptr as *const u32);
buffer = (buffer << (8 * std::mem::size_of::<u32>())) | (v as u128);
remaining -= 4;
}
if remaining >= 2 {
ptr = ptr.sub(2);
let v: u16 = std::ptr::read_unaligned(ptr as *const u16);
buffer = (buffer << (8 * std::mem::size_of::<u16>())) | (v as u128);
remaining -= 2;
}
if remaining >= 1 {
ptr = ptr.sub(1);
let v: u8 = std::ptr::read(ptr);
buffer = (buffer << 8) | (v as u128);
if first == 8 {
std::ptr::read_unaligned(ptr as *const u64) as u128
} else if first >= 4 {
let lo = std::ptr::read_unaligned(ptr as *const u32) as u64;
let hi = std::ptr::read_unaligned(ptr.add(first - 4) as *const u32) as u64;
(lo | (hi << ((first - 4) * 8))) as u128
} else if first >= 2 {
let lo = std::ptr::read_unaligned(ptr as *const u16) as u64;
let hi = std::ptr::read_unaligned(ptr.add(first - 2) as *const u16) as u64;
(lo | (hi << ((first - 2) * 8))) as u128
} else if first == 1 {
std::ptr::read(ptr) as u128
} else {
0
}
}
buffer
}

/// Load the first `first` 16-bit words from `ptr` and return the result as a `__m128i`.
Expand All @@ -96,24 +134,31 @@ pub(crate) unsafe fn __load_first_u16_of_16_bytes(
return unsafe { _mm_loadu_si128(ptr as *const __m128i) };
}

// Strategy: Use a masked load to load elements at 4-byte granularities.
// Then, we have at most one 2-byte load left.
//
// We use `_mm_insert_epi16` to insert the last element.
let byte_ptr = ptr as *const u8;
let bytes = first * 2;

// For `bytes > 8` (i.e., `first > 4`), use the optimized two-load method.
if bytes > 8 {
// SAFETY: `bytes` is in `(8, 16)` and `[byte_ptr, byte_ptr + bytes)` is valid.
return unsafe { __load_8_to_16_bytes(arch, byte_ptr, bytes) };
}

// For `bytes <= 8`, everything fits in general purpose registers.
//
// SAFETY: The reads emitted are in the range `[ptr, ptr + first)` asserted by the caller
// to be safe. The use of the intrinsic is safe by the presence of `arch`.
// SAFETY: All reads are within `[ptr, ptr + first)`, which the caller
// asserts is valid.
unsafe {
let mut reg = i32x4::load_simd_first(arch, ptr as *const i32, first / 2).to_underlying();
if first == 1 {
reg = _mm_insert_epi16::<0>(reg, std::ptr::read_unaligned(ptr.add(first - 1)).into());
} else if first == 3 {
reg = _mm_insert_epi16::<2>(reg, std::ptr::read_unaligned(ptr.add(first - 1)).into());
} else if first == 5 {
reg = _mm_insert_epi16::<4>(reg, std::ptr::read_unaligned(ptr.add(first - 1)).into());
} else if first == 7 {
reg = _mm_insert_epi16::<6>(reg, std::ptr::read_unaligned(ptr.add(first - 1)).into());
if bytes == 8 {
let v = std::ptr::read_unaligned(byte_ptr as *const u64);
_mm_cvtsi64_si128(v as i64)
} else if bytes >= 4 {
let lo = std::ptr::read_unaligned(byte_ptr as *const u32) as u64;
let hi = std::ptr::read_unaligned(byte_ptr.add(bytes - 4) as *const u32) as u64;
_mm_cvtsi64_si128((lo | (hi << ((bytes - 4) * 8))) as i64)
} else if bytes >= 2 {
_mm_cvtsi32_si128(std::ptr::read_unaligned(byte_ptr as *const u16) as i32)
} else {
_mm_setzero_si128()
}
reg
}
}