Skip to content

Commit

Permalink
Fix: #[repr(transparent)] for f16 and b1x8
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 11, 2024
1 parent aadb717 commit e182d77
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions rust/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,21 @@
use std::boxed::Box;
//! # USearch Crate for Rust
//!
//! `usearch` is a high-performance library for Approximate Nearest Neighbor (ANN) search in high-dimensional spaces.
//! It offers efficient and scalable solutions for indexing and querying dense vector spaces with support for multiple distance metrics and vector types.
//!
//! This crate wraps the native functionalities of USearch, providing Rust-friendly interfaces and integration capabilities.
//! It is designed to facilitate rapid development and deployment of applications requiring fast and accurate vector search functionalities, such as recommendation systems, image retrieval systems, and natural language processing tasks.
//!
//! ## Features
//!
//! - SIMD-accelerated distance calculations for various metrics.
//! - Support for `f32`, `f64`, `i8`, custom `f16`, and binary (`b1x8`) vector types.
//! - Extensible with custom distance metrics and filtering predicates.
//! - Efficient serialization and deserialization for persistence and network transfers.
//!
//! ## Quick Start
//!
//! Refer to the `Index` struct for detailed usage examples.

/// The key type used to identify vectors in the index.
/// It is a 64-bit unsigned integer.
Expand Down Expand Up @@ -56,6 +73,7 @@ pub trait BitAddressable {
fn get_bit(&self, index: usize) -> Result<bool, BitAddressableError>;
}

#[repr(transparent)]
#[allow(non_camel_case_types)]
pub struct b1x8(pub u8);
impl b1x8 {
Expand Down Expand Up @@ -149,6 +167,7 @@ impl BitAddressable for [b1x8] {
}
}

#[repr(transparent)]
#[allow(non_camel_case_types)]
pub struct f16(i16);
impl f16 {
Expand Down Expand Up @@ -347,8 +366,8 @@ pub use ffi::{IndexOptions, MetricKind, ScalarKind};
///
/// This enum allows the encapsulation of custom distance calculation logic for vectors of different
/// data types, facilitating the use of custom metrics in vector space operations. Each variant of this
/// enum holds a boxed function pointer (`Box<dyn Fn(...) -> Distance + Send + Sync>`) that defines the
/// distance calculation between two vectors of a specific type. The function returns a `Distance`, which
/// enum holds a boxed function pointer (`std::boxed::Box<dyn Fn(...) -> Distance + Send + Sync>`) that defines
/// the distance calculation between two vectors of a specific type. The function returns a `Distance`, which
/// is typically a floating-point value representing the calculated distance between the two vectors.
///
/// # Variants
Expand Down Expand Up @@ -394,11 +413,11 @@ pub use ffi::{IndexOptions, MetricKind, ScalarKind};
///
/// In this example, `dimensions` should be defined and valid for the vectors `a` and `b`.
pub enum MetricFunction {
B1X8Metric(Box<dyn Fn(*const b1x8, *const b1x8) -> Distance + Send + Sync>),
I8Metric(Box<dyn Fn(*const i8, *const i8) -> Distance + Send + Sync>),
F16Metric(Box<dyn Fn(*const f16, *const f16) -> Distance + Send + Sync>),
F32Metric(Box<dyn Fn(*const f32, *const f32) -> Distance + Send + Sync>),
F64Metric(Box<dyn Fn(*const f64, *const f64) -> Distance + Send + Sync>),
B1X8Metric(std::boxed::Box<dyn Fn(*const b1x8, *const b1x8) -> Distance + Send + Sync>),
I8Metric(std::boxed::Box<dyn Fn(*const i8, *const i8) -> Distance + Send + Sync>),
F16Metric(std::boxed::Box<dyn Fn(*const f16, *const f16) -> Distance + Send + Sync>),
F32Metric(std::boxed::Box<dyn Fn(*const f32, *const f32) -> Distance + Send + Sync>),
F64Metric(std::boxed::Box<dyn Fn(*const f64, *const f64) -> Distance + Send + Sync>),
}

/// Approximate Nearest Neighbors search index for dense vectors.
Expand All @@ -414,7 +433,6 @@ pub enum MetricFunction {
/// ```rust
/// use usearch::{Index, IndexOptions, MetricKind, ScalarKind};
///
/// // Create an index with specific options
/// let mut options = IndexOptions::default();
/// options.dimensions = 4; // Set the number of dimensions for vectors
/// options.metric = MetricKind::Cos; // Use cosine similarity for distance measurement
Expand All @@ -436,7 +454,8 @@ pub enum MetricFunction {
/// println!("Key: {}, Distance: {}", key, distance);
/// }
/// ```
///
/// For more examples, including how to add vectors to the index and perform searches,
/// refer to the individual method documentation.
pub struct Index {
inner: cxx::UniquePtr<ffi::NativeIndex>,
metric_fn: Option<MetricFunction>,
Expand Down Expand Up @@ -553,7 +572,7 @@ pub trait VectorType {
/// - `Err(cxx::Exception)` if an error occurred during the operation.
fn change_metric(
index: &mut Index,
metric: Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception>
where
Self: Sized;
Expand Down Expand Up @@ -597,7 +616,7 @@ impl VectorType for f32 {

fn change_metric(
index: &mut Index,
metric: Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
// Store the metric function in the Index.
type MetricFn = fn(*const f32, *const f32) -> Distance;
Expand Down Expand Up @@ -664,7 +683,7 @@ impl VectorType for i8 {
}
fn change_metric(
index: &mut Index,
metric: Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
// Store the metric function in the Index.
type MetricFn = fn(*const i8, *const i8) -> Distance;
Expand Down Expand Up @@ -731,7 +750,7 @@ impl VectorType for f64 {
}
fn change_metric(
index: &mut Index,
metric: Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
// Store the metric function in the Index.
type MetricFn = fn(*const f64, *const f64) -> Distance;
Expand Down Expand Up @@ -802,7 +821,7 @@ impl VectorType for f16 {

fn change_metric(
index: &mut Index,
metric: Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
// Store the metric function in the Index.
type MetricFn = fn(*const f16, *const f16) -> Distance;
Expand Down Expand Up @@ -873,7 +892,7 @@ impl VectorType for b1x8 {

fn change_metric(
index: &mut Index,
metric: Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
metric: std::boxed::Box<dyn Fn(*const Self, *const Self) -> Distance + Send + Sync>,
) -> Result<(), cxx::Exception> {
// Store the metric function in the Index.
type MetricFn = fn(*const b1x8, *const b1x8) -> Distance;
Expand Down Expand Up @@ -942,7 +961,7 @@ impl Index {
/// Overrides the metric function used to calculate the distance between vectors.
pub fn change_metric<T: VectorType>(
self: &mut Index,
metric: Box<dyn Fn(*const T, *const T) -> Distance + Send + Sync>,
metric: std::boxed::Box<dyn Fn(*const T, *const T) -> Distance + Send + Sync>,
) {
T::change_metric(self, metric).unwrap();
}
Expand Down

0 comments on commit e182d77

Please sign in to comment.