Skip to content

Commit

Permalink
Add: b1x8 for Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 8, 2024
1 parent e3cf772 commit 540fc75
Show file tree
Hide file tree
Showing 4 changed files with 586 additions and 50 deletions.
163 changes: 151 additions & 12 deletions rust/README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,53 @@
# USearch for Rust

Detailed symbol list and documentation for USearch Rust SDK can be found on [docs.rs](https://docs.rs/usearch/latest/usearch/struct.Index.html).

## Installation

```sh
cargo add usearch
```

You can find [interface documentation on docs.rs](https://docs.rs/usearch/2.3.0/usearch/ffi/struct.Index.html).
This will add a USearch dependency to your `Cargo.toml` file.

```toml
[dependencies]
usearch = "..."
```

By default, [SimSIMD](https://github.com/ashvardanian/simsimd) is used to provide dynamic dispatch for SIMD operations.
You can, however, override that by specifying custom features in your `Cargo.toml` file.
To disable all features, use the following configuration:

```toml
[dependencies]
usearch = { version = "...", default-features = false }
```

To enable specific features, use the following configuration:

```toml
[dependencies]
usearch = { version = "...", features = ["simsimd", "openmp", "fp16lib"] }
```

OpenMP will use the OpenMP runtime for parallelism, and `fp16lib` will use the C-layer `fp16` library to emulate half-precision floating point operations on older hardware.

## Quickstart

```rust
use usearch::{Index, IndexOptions, MetricKind, ScalarKind, new_index};

let options = IndexOptions {
dimensions: 3,
metric: MetricKind::IP,
quantization: ScalarKind::F16,
connectivity: 0,
expansion_add: 0,
expansion_search: 0
dimensions: 3, // necessary for most metric kinds
metric: MetricKind::IP, // or MetricKind::L2sq, MetricKind::Cos ...
quantization: ScalarKind::F16, // or ScalarKind::F32, ScalarKind::I8, ScalarKind::B1x8 ...
connectivity: 0, // zero for auto
expansion_add: 0, // zero for auto
expansion_search: 0, // zero for auto
};

let index = new_index(&options).unwrap();
let index: Index = new_index(&options).unwrap();

assert!(index.reserve(10).is_ok());
assert!(index.capacity() >= 10);
Expand All @@ -43,13 +69,17 @@ assert_eq!(results.keys.len(), 2);

## Serialization

To save and load the index from disk, use the following methods:

```rust
assert!(index.save("index.usearch").is_ok());
assert!(index.load("index.usearch").is_ok());
assert!(index.view("index.usearch").is_ok());
```

Similarly, serializing to/from buffers is supported.
Viewing the index does not load the data into memory, but allows you to inspect and traverse the index structure from external memory using memory-mapping.
Similarly, serializing to/from in-memory buffers is supported.
So you can memory-map the index file manually, and later call `view_from_buffer` or one of its siblings.

```rust
assert!(index.save_to_buffer(&mut serialization_buffer).is_ok());
Expand All @@ -59,8 +89,117 @@ assert!(index.view_from_buffer(&serialization_buffer).is_ok());

## Metrics

USearch comes pre-packaged with SimSIMD, bringing over 100 SIMD-accelerated distance kernels for x86 and ARM architectures.
That includes:

- `MetricKind::IP` - Inner Product metric, defined as `IP = 1 - sum(a[i] * b[i])`.
- `MetricKind::L2sq` - Squared Euclidean Distance metric, defined as `L2 = sum((a[i] - b[i])^2)`.
- `MetricKind::Cos` - Cosine Similarity metric, defined as `Cos = 1 - sum(a[i] * b[i]) / (sqrt(sum(a[i]^2) * sqrt(sum(b[i]^2)))`.
- `MetricKind::Pearson` - Pearson Correlation metric.
- `MetricKind::Haversine` - Haversine (Great Circle) Distance metric.
- `MetricKind::Divergence` - Jensen Shannon Divergence metric.
- `MetricKind::Hamming` - Bit-level Hamming Distance metric, defined as the number of differing bits.
- `MetricKind::Tanimoto` - Bit-level Tanimoto (Jaccard) metric, defined as the number of intersecting bits divided by the number of union bits.
- `MetricKind::Sorensen` - Bit-level Sorensen metric.

### User-Defined Metrics

Custom metrics allow for the implementation of specific algorithms to measure the distance or similarity between vectors in the index.
To use a custom metric with USearch, define a function that matches the expected signature and pass it to your index on creation, or later with `change_metric`.
Let's say you are implementing a weighted distance function to search through joint embeddings of images and textual descriptions of some products in a catalog, taking some [UForm](https://github.com/unum-cloud/uform) or CLIP-like models.

```rust
use simsimd::SpatialSimilarity;

let image_dimensions: usize = 768;
let text_dimensions: usize = 512;
let image_weights: f32 = 0.7;
let text_weights: f32 = 0.9;

let weighted_distance = Box::new(move |a: *const f32, b: *const f32| unsafe {
let a_slice = std::slice::from_raw_parts(a, image_dimensions + text_dimensions);
let b_slice = std::slice::from_raw_parts(b, image_dimensions + text_dimensions);

let image_similarity = f32::cosine(a_slice[0..image_dimensions], b_slice[0..image_dimensions]);
let text_similarity = f32::cosine(a_slice[image_dimensions..], b_slice[image_dimensions..]);

(1.0 - image_similarity) * image_weights + (1.0 - text_similarity) * text_weights
});
index.change_metric(weighted_distance);
```

## Filtering with Predicates

Sometimes you may want to cross-reference search-results against some external database or filter them based on some criteria.
In most engines, you'd have to manually perform paging requests, successively filtering the results.
In USearch you can simply pass a predicate function to the search method, which will be applied directly during graph traversal.

```rust
assert!(new_l2sq(3, &quant, 0, 0, 0).is_ok());
assert!(new_cos(3, &quant, 0, 0, 0).is_ok());
assert!(new_haversine(&quant, 0, 0, 0).is_ok());
let is_odd = |key: Key| key % 2 == 1;
let query = vec![0.2, 0.1, 0.2, 0.1, 0.3];
let results = index.filtered_search(&query, 10, is_odd).unwrap();
assert!(
results.keys.iter().all(|&key| key % 2 == 1),
"All keys must be odd"
);
```

## Quantization and Custom Scalar Types

USearch supports the Rust-native `f32` and `f64` scalar types, as well as the `i8` for quantized 8-bit scalars.
Goign beyond that, USearch supports non-native `f16` and `b1x8` for half-precision floating point and binary vectors, respectively.

### Half Precision Floating Point

Rust has no native support for half-precision floating-point numbers, but USearch provides a `f16` type.
It has no advanced functionality - it is a transparent wrapper around `i16` and can be used with `half` or any other half-precision library.
Assuming USearch uses the IEEE 754 no conversion is needed, you can `unsafe`-cast the outputs of other IEEE-compliant libraries to `usearch::f16`.

```rust
use usearch::f16 as USearchF16;
use half::f16 as HalfF16;

let vector_a: Vec<HalfF16> = ...
let vector_b: Vec<HalfF16> = ...

let buffer_a: &[USearchF16] = unsafe { std::slice::from_raw_parts(a_half.as_ptr() as *const SimF16, a_half.len()) };
let buffer_b: &[USearchF16] = unsafe { std::slice::from_raw_parts(b_half.as_ptr() as *const SimF16, b_half.len()) };

index.add(42, buffer_a);
index.add(43, buffer_b);
```

### Binary Vectors

USearch also implement binary distance functions and natively supports bit-vectors.
If you initialize the index with `quantization: ScalarKind::B1`, you can add floating-point vectors and they will be quantized mapping positive values to `1` and negative and zero values to `0`.
Alternatively, you can use the `b1x8` type to represent packed binary vectors directly.

```rs
let index = Index::new(&IndexOptions {
dimensions: 8,
metric: MetricKind::Hamming,
quantization: ScalarKind::B1,
..Default::default()
})
.unwrap();

// Binary vectors represented as `b1x8` slices
let vector42: Vec<b1x8> = vec![b1x8(0b00001111)];
let vector43: Vec<b1x8> = vec![b1x8(0b11110000)];
let query: Vec<b1x8> = vec![b1x8(0b01111000)];

// Adding binary vectors to the index
index.reserve(10).unwrap();
index.add(42, &vector42).unwrap();
index.add(43, &vector43).unwrap();

let results = index.search(&query, 5).unwrap();

// Validate the search results based on Hamming distance
assert_eq!(results.keys.len(), 2);
assert_eq!(results.keys[0], 43);
assert_eq!(results.distances[0], 2.0); // 2 bits differ between query and vector43
assert_eq!(results.keys[1], 42);
assert_eq!(results.distances[1], 6.0); // 6 bits differ between query and vector42
```
12 changes: 8 additions & 4 deletions rust/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,23 +62,27 @@ auto make_predicate(uptr_t metric, uptr_t metric_state) {
}

// clang-format off
void NativeIndex::add_b1x8(vector_key_t key, rust::Slice<uint8_t const> vec) const { index_->add(key, (b1x8_t const*)vec.data()).error.raise(); }
void NativeIndex::add_i8(vector_key_t key, rust::Slice<int8_t const> vec) const { index_->add(key, vec.data()).error.raise(); }
void NativeIndex::add_f16(vector_key_t key, rust::Slice<uint16_t const> vec) const { index_->add(key, (f16_t const*)vec.data()).error.raise(); }
void NativeIndex::add_f16(vector_key_t key, rust::Slice<int16_t const> vec) const { index_->add(key, (f16_t const*)vec.data()).error.raise(); }
void NativeIndex::add_f32(vector_key_t key, rust::Slice<float const> vec) const { index_->add(key, vec.data()).error.raise(); }
void NativeIndex::add_f64(vector_key_t key, rust::Slice<double const> vec) const { index_->add(key, vec.data()).error.raise(); }

Matches NativeIndex::search_b1x8(rust::Slice<uint8_t const> vec, size_t count) const { return search_(*index_, (b1x8_t const*)vec.data(), count); }
Matches NativeIndex::search_i8(rust::Slice<int8_t const> vec, size_t count) const { return search_(*index_, vec.data(), count); }
Matches NativeIndex::search_f16(rust::Slice<uint16_t const> vec, size_t count) const { return search_(*index_, (f16_t const*)vec.data(), count); }
Matches NativeIndex::search_f16(rust::Slice<int16_t const> vec, size_t count) const { return search_(*index_, (f16_t const*)vec.data(), count); }
Matches NativeIndex::search_f32(rust::Slice<float const> vec, size_t count) const { return search_(*index_, vec.data(), count); }
Matches NativeIndex::search_f64(rust::Slice<double const> vec, size_t count) const { return search_(*index_, vec.data(), count); }

Matches NativeIndex::filtered_search_b1x8(rust::Slice<uint8_t const> vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, (b1x8_t const*)vec.data(), count, make_predicate(metric, metric_state)); }
Matches NativeIndex::filtered_search_i8(rust::Slice<int8_t const> vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, vec.data(), count, make_predicate(metric, metric_state)); }
Matches NativeIndex::filtered_search_f16(rust::Slice<uint16_t const> vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, (f16_t const*)vec.data(), count, make_predicate(metric, metric_state)); }
Matches NativeIndex::filtered_search_f16(rust::Slice<int16_t const> vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, (f16_t const*)vec.data(), count, make_predicate(metric, metric_state)); }
Matches NativeIndex::filtered_search_f32(rust::Slice<float const> vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, vec.data(), count, make_predicate(metric, metric_state)); }
Matches NativeIndex::filtered_search_f64(rust::Slice<double const> vec, size_t count, uptr_t metric, uptr_t metric_state) const { return search_(*index_, vec.data(), count, make_predicate(metric, metric_state)); }

size_t NativeIndex::get_b1x8(vector_key_t key, rust::Slice<uint8_t> vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, (b1x8_t*)vec.data(), vec.size() / dimensions()); }
size_t NativeIndex::get_i8(vector_key_t key, rust::Slice<int8_t> vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, vec.data(), vec.size() / dimensions()); }
size_t NativeIndex::get_f16(vector_key_t key, rust::Slice<uint16_t> vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, (f16_t*)vec.data(), vec.size() / dimensions()); }
size_t NativeIndex::get_f16(vector_key_t key, rust::Slice<int16_t> vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, (f16_t*)vec.data(), vec.size() / dimensions()); }
size_t NativeIndex::get_f32(vector_key_t key, rust::Slice<float> vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, vec.data(), vec.size() / dimensions()); }
size_t NativeIndex::get_f64(vector_key_t key, rust::Slice<double> vec) const { if (vec.size() % dimensions()) throw std::invalid_argument("Vector length must be a multiple of index dimensionality"); return index_->get(key, vec.data(), vec.size() / dimensions()); }
// clang-format on
Expand Down
24 changes: 13 additions & 11 deletions rust/lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,29 @@ class NativeIndex {

void reserve(size_t) const;

void add_b1x8(vector_key_t key, rust::Slice<uint8_t const> vector) const;
void add_i8(vector_key_t key, rust::Slice<int8_t const> vector) const;
void add_f16(vector_key_t key, rust::Slice<uint16_t const> vector) const;
void add_f16(vector_key_t key, rust::Slice<int16_t const> vector) const;
void add_f32(vector_key_t key, rust::Slice<float const> vector) const;
void add_f64(vector_key_t key, rust::Slice<double const> vector) const;

Matches search_b1x8(rust::Slice<uint8_t const> query, size_t count) const;
Matches search_i8(rust::Slice<int8_t const> query, size_t count) const;
Matches search_f16(rust::Slice<uint16_t const> query, size_t count) const;
Matches search_f16(rust::Slice<int16_t const> query, size_t count) const;
Matches search_f32(rust::Slice<float const> query, size_t count) const;
Matches search_f64(rust::Slice<double const> query, size_t count) const;

Matches filtered_search_i8(rust::Slice<int8_t const> query, size_t count, //
uptr_t filter_function, uptr_t filter_state) const;
Matches filtered_search_f16(rust::Slice<uint16_t const> query, size_t count, //
uptr_t filter_function, uptr_t filter_state) const;
Matches filtered_search_f32(rust::Slice<float const> query, size_t count, //
uptr_t filter_function, uptr_t filter_state) const;
Matches filtered_search_f64(rust::Slice<double const> query, size_t count, //
uptr_t filter_function, uptr_t filter_state) const;
// clang-format off
Matches filtered_search_b1x8(rust::Slice<uint8_t const> query, size_t count, uptr_t filter_function, uptr_t filter_state) const;
Matches filtered_search_i8(rust::Slice<int8_t const> query, size_t count, uptr_t filter_function, uptr_t filter_state) const;
Matches filtered_search_f16(rust::Slice<int16_t const> query, size_t count, uptr_t filter_function, uptr_t filter_state) const;
Matches filtered_search_f32(rust::Slice<float const> query, size_t count, uptr_t filter_function, uptr_t filter_state) const;
Matches filtered_search_f64(rust::Slice<double const> query, size_t count, uptr_t filter_function, uptr_t filter_state) const;
// clang-format on

size_t get_b1x8(vector_key_t key, rust::Slice<uint8_t> vector) const;
size_t get_i8(vector_key_t key, rust::Slice<int8_t> vector) const;
size_t get_f16(vector_key_t key, rust::Slice<uint16_t> vector) const;
size_t get_f16(vector_key_t key, rust::Slice<int16_t> vector) const;
size_t get_f32(vector_key_t key, rust::Slice<float> vector) const;
size_t get_f64(vector_key_t key, rust::Slice<double> vector) const;

Expand Down
Loading

0 comments on commit 540fc75

Please sign in to comment.