From eb5920e829b1dc6ae60f875cddbd294c38e9e3ca Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:09:33 -0700 Subject: [PATCH] Improve: Cloning `b1x8` & `f16` --- .vscode/settings.json | 1 + rust/lib.rs | 97 ++++++++++++++++++++++++++++++++----------- 2 files changed, 74 insertions(+), 24 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 506b075e..cc55c504 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -163,6 +163,7 @@ "Println", "pytest", "Quickstart", + "repr", "rtype", "SIMD", "simsimd", diff --git a/rust/lib.rs b/rust/lib.rs index f3e2bad1..28a1b567 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -26,14 +26,14 @@ pub type Key = u64; pub type Distance = f32; /// Callback signature for custom metric functions, defined in the Rust layer and used in the C++ layer. -pub type StatefullMetric = unsafe extern "C" fn( +pub type StatefulMetric = unsafe extern "C" fn( *const std::ffi::c_void, *const std::ffi::c_void, *mut std::ffi::c_void, ) -> Distance; /// Callback signature for custom predicate functions, defined in the Rust layer and used in the C++ layer. -pub type StatefullPredicate = unsafe extern "C" fn(Key, *mut std::ffi::c_void) -> bool; +pub type StatefulPredicate = unsafe extern "C" fn(Key, *mut std::ffi::c_void) -> bool; /// Represents errors that can occur when addressing bits. #[derive(Debug)] @@ -73,31 +73,76 @@ pub trait BitAddressable { fn get_bit(&self, index: usize) -> Result; } +/// A byte-wide bit vector type that provides low-level control over individual bits. +/// +/// This struct represents a single byte (8 bits) and enables manipulation and +/// interpretation of individual bits via various utility functions. #[repr(transparent)] #[allow(non_camel_case_types)] +#[derive(Clone, Copy, Eq, PartialEq)] pub struct b1x8(pub u8); + impl b1x8 { - /// Casts a slice of `u8` to a slice of `b1x8`. + /// Casts a slice of `u8` bytes to a slice of `b1x8`, allowing bit-level operations on byte slices. pub fn from_u8s(slice: &[u8]) -> &[Self] { unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) } } - /// Casts a mutable slice of `u8` to a mutable slice of `b1x8`. + /// Casts a mutable slice of `u8` bytes to a mutable slice of `b1x8`, enabling mutable + /// bit-level operations on byte slices. pub fn from_mut_u8s(slice: &mut [u8]) -> &mut [Self] { unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) } } - /// Casts a slice of `b1x8` back to a slice of `u8`. + /// Converts a slice of `b1x8` back to a slice of `u8`, useful for reading bit-level manipulations + /// in byte-oriented contexts. pub fn to_u8s(slice: &[Self]) -> &[u8] { unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, slice.len()) } } - /// Casts a mutable slice of `b1x8` back to a mutable slice of `u8`. + /// Converts a mutable slice of `b1x8` back to a mutable slice of `u8`, enabling further + /// modifications on the original byte data after bit-level manipulations. pub fn to_mut_u8s(slice: &mut [Self]) -> &mut [u8] { unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, slice.len()) } } } +/// A struct representing a half-precision floating-point number based on the IEEE 754 standard. +/// +/// This struct uses an `i16` to store the half-precision floating-point data, which includes +/// 1 sign bit, 5 exponent bits, and 10 mantissa bits. +#[repr(transparent)] +#[allow(non_camel_case_types)] +#[derive(Clone, Copy)] +pub struct f16(i16); + +impl f16 { + /// Casts a slice of `i16` integers to a slice of `f16`, allowing operations on half-precision + /// floating-point data stored in standard 16-bit integer arrays. + pub fn from_i16s(slice: &[i16]) -> &[Self] { + unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) } + } + + /// Casts a mutable slice of `i16` integers to a mutable slice of `f16`, enabling mutable operations + /// on half-precision floating-point data. + pub fn from_mut_i16s(slice: &mut [i16]) -> &mut [Self] { + unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) } + } + + /// Converts a slice of `f16` back to a slice of `i16`, useful for storage or manipulation in formats + /// that require standard integer types. + pub fn to_i16s(slice: &[Self]) -> &[i16] { + unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const i16, slice.len()) } + } + + /// Converts a mutable slice of `f16` back to a mutable slice of `i16`, enabling further + /// modifications on the original integer data after operations involving half-precision + /// floating-point numbers. + pub fn to_mut_i16s(slice: &mut [Self]) -> &mut [i16] { + unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut i16, slice.len()) } + } +} + impl BitAddressable for b1x8 { /// Sets a bit at a specific index within the byte. /// @@ -167,28 +212,32 @@ impl BitAddressable for [b1x8] { } } -#[repr(transparent)] -#[allow(non_camel_case_types)] -pub struct f16(i16); -impl f16 { - /// Casts a slice of `i16` to a slice of `f16`. - pub fn from_i16s(slice: &[i16]) -> &[Self] { - unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const Self, slice.len()) } - } +impl PartialEq for f16 { + fn eq(&self, other: &Self) -> bool { + // Check for NaN values first (exponent all ones and non-zero mantissa) + let nan_self = (self.0 & 0x7C00) == 0x7C00 && (self.0 & 0x03FF) != 0; + let nan_other = (other.0 & 0x7C00) == 0x7C00 && (other.0 & 0x03FF) != 0; + if nan_self || nan_other { + return false; + } - /// Casts a mutable slice of `i16` to a mutable slice of `f16`. - pub fn from_mut_i16s(slice: &mut [i16]) -> &mut [Self] { - unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut Self, slice.len()) } + self.0 == other.0 } +} - /// Casts a slice of `f16` back to a slice of `i16`. - pub fn to_i16s(slice: &[Self]) -> &[i16] { - unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const i16, slice.len()) } +impl std::fmt::Debug for b1x8 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:08b}", self.0) } +} - /// Casts a mutable slice of `f16` back to a mutable slice of `i16`. - pub fn to_mut_i16s(slice: &mut [Self]) -> &mut [i16] { - unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut i16, slice.len()) } +impl std::fmt::Debug for f16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let bits = self.0; + let sign = (bits >> 15) & 1; + let exponent = (bits >> 10) & 0x1F; + let mantissa = bits & 0x3FF; + write!(f, "{}|{:05b}|{:010b}", sign, exponent, mantissa) } } @@ -273,7 +322,7 @@ pub mod ffi { pub fn change_metric_kind(self: &NativeIndex, metric: MetricKind); /// Changes the metric function used to calculate the distance between vectors. - /// Avoids the `std::ffi::c_void` type and the `StatefullMetric` type, that the FFI + /// Avoids the `std::ffi::c_void` type and the `StatefulMetric` type, that the FFI /// does not support, replacing them with basic pointer-sized integer types. /// The first two arguments are the pointers to the vectors to compare, and the third /// argument is the `metric_state` propagated from the Rust layer.