Skip to content

Commit

Permalink
Merge pull request #185 from unum-cloud/main-dev
Browse files Browse the repository at this point in the history
Rust and JavaScript: `remove` and `contains`
  • Loading branch information
ashvardanian authored Aug 5, 2023
2 parents 565b89a + be42532 commit c8de1d1
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 14 deletions.
62 changes: 62 additions & 0 deletions javascript/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class Index : public Napi::ObjectWrap<Index> {

void Add(Napi::CallbackInfo const& ctx);
Napi::Value Search(Napi::CallbackInfo const& ctx);
Napi::Value Remove(Napi::CallbackInfo const& ctx);
Napi::Value Contains(Napi::CallbackInfo const& ctx);

std::unique_ptr<index_dense_t> native_;
};
Expand All @@ -51,6 +53,8 @@ Napi::Object Index::Init(Napi::Env env, Napi::Object exports) {
InstanceMethod("connectivity", &Index::GetConnectivity),
InstanceMethod("add", &Index::Add),
InstanceMethod("search", &Index::Search),
InstanceMethod("remove", &Index::Remove),
InstanceMethod("contains", &Index::Contains),
InstanceMethod("save", &Index::Save),
InstanceMethod("load", &Index::Load),
InstanceMethod("view", &Index::View),
Expand Down Expand Up @@ -291,6 +295,64 @@ Napi::Value Index::Search(Napi::CallbackInfo const& ctx) {
}
}

Napi::Value Index::Remove(Napi::CallbackInfo const& ctx) {
Napi::Env env = ctx.Env();
if (ctx.Length() < 1 || !ctx[0].IsBigInt()) {
Napi::TypeError::New(env, "Expects an entry identifier").ThrowAsJavaScriptException();
return {};
}

Napi::BigInt key_js = ctx[0].As<Napi::BigInt>();
bool lossless = true;
std::uint64_t key = key_js.Uint64Value(&lossless);
if (!lossless) {
Napi::TypeError::New(env, "Identifier must be an unsigned integer").ThrowAsJavaScriptException();
return {};
}

try {
auto result = native_->remove(key);
if (!result) {
Napi::TypeError::New(env, "Removal has failed").ThrowAsJavaScriptException();
return {};
}
return Napi::Boolean::New(env, result.completed);
} catch (std::bad_alloc const&) {
Napi::TypeError::New(env, "Out of memory").ThrowAsJavaScriptException();
return {};
} catch (...) {
Napi::TypeError::New(env, "Search failed").ThrowAsJavaScriptException();
return {};
}
}

Napi::Value Index::Contains(Napi::CallbackInfo const& ctx) {
Napi::Env env = ctx.Env();
if (ctx.Length() < 1 || !ctx[0].IsBigInt()) {
Napi::TypeError::New(env, "Expects an entry identifier").ThrowAsJavaScriptException();
return {};
}

Napi::BigInt key_js = ctx[0].As<Napi::BigInt>();
bool lossless = true;
std::uint64_t key = key_js.Uint64Value(&lossless);
if (!lossless) {
Napi::TypeError::New(env, "Identifier must be an unsigned integer").ThrowAsJavaScriptException();
return {};
}

try {
bool result = native_->contains(key);
return Napi::Boolean::New(env, result);
} catch (std::bad_alloc const&) {
Napi::TypeError::New(env, "Out of memory").ThrowAsJavaScriptException();
return {};
} catch (...) {
Napi::TypeError::New(env, "Search failed").ThrowAsJavaScriptException();
return {};
}
}

Napi::Object InitAll(Napi::Env env, Napi::Object exports) { return Index::Init(env, exports); }

NODE_API_MODULE(usearch, InitAll)
1 change: 1 addition & 0 deletions javascript/test.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ assert.equal(index.size(), 0n, 'initial size should be 0');
index.add(15n, new Float32Array([10, 20]));
index.add(16n, new Float32Array([10, 25]));
assert.equal(index.size(), 2n, 'size after adding elements should be 2');
assert.equal(index.contains(15n), true, 'entry must be present after insertion');

var results = index.search(new Float32Array([13, 14]), 2n);
assert.deepEqual(results.keys, new BigUint64Array([15n, 16n]), 'keys should be 15 and 16');
Expand Down
19 changes: 17 additions & 2 deletions javascript/usearch.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export interface Matches {

/** K-Approximate Nearest Neighbors search index. */
export class Index {

/**
* Constructs a new index.
*
Expand Down Expand Up @@ -42,7 +42,7 @@ export class Index {
* @return {bigints} The capacity of index.
*/
capacity(): bigint;

/**
* Returns connectivity.
* @return {bigint} The connectivity of index.
Expand Down Expand Up @@ -84,4 +84,19 @@ export class Index {
* @return {Matches} Output of the search result.
*/
search(mat: Float32Array, k: bigint): Matches;

/**
* Check if an entry is contained in the index.
*
* @param {bigint} key Identifier to look up.
*/
contains(key: bigint): boolean;

/**
* Remove a vector from the index.
*
* @param {bigint} key Input identifier for every vector to be removed.
*/
remove(key: bigint): boolean;

}
25 changes: 17 additions & 8 deletions rust/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
using namespace unum::usearch;
using namespace unum;

using index_t = typename Index::index_t;
using index_t = index_dense_t;
using add_result_t = typename index_t::add_result_t;
using search_result_t = typename index_t::search_result_t;
using labeling_result_t = typename index_t::labeling_result_t;

Index::Index(std::unique_ptr<index_t> index) : index_(std::move(index)) {}

Expand All @@ -28,14 +29,22 @@ Matches Index::search_in_thread(rust::Slice<float const> vector, size_t count, s
config.expansion = index_->expansion_search();
search_result_t result = index_->search(vector.data(), count, config);
result.error.raise();
matches.count = result.dump_to(matches.keys.data(), matches.distances.data());
matches.keys.truncate(matches.count);
matches.distances.truncate(matches.count);
count = result.dump_to(matches.keys.data(), matches.distances.data());
matches.keys.truncate(count);
matches.distances.truncate(count);
return matches;
}

void Index::add(key_t key, rust::Slice<float const> vector) const { index_->add(key, vector.data()).error.raise(); }

bool Index::remove(key_t key) const {
labeling_result_t result = index_->remove(key);
result.error.raise();
return result.completed;
}

bool Index::contains(key_t key) const { return index_->contains(key); }

Matches Index::search(rust::Slice<float const> vector, size_t count) const {
Matches matches;
matches.keys.reserve(count);
Expand All @@ -44,9 +53,9 @@ Matches Index::search(rust::Slice<float const> vector, size_t count) const {
matches.keys.push_back(0), matches.distances.push_back(0);
search_result_t result = index_->search(vector.data(), count);
result.error.raise();
matches.count = result.dump_to(matches.keys.data(), matches.distances.data());
matches.keys.truncate(matches.count);
matches.distances.truncate(matches.count);
count = result.dump_to(matches.keys.data(), matches.distances.data());
matches.keys.truncate(count);
matches.distances.truncate(count);
return matches;
}

Expand Down Expand Up @@ -74,7 +83,7 @@ std::unique_ptr<Index> wrap(index_t&& index) {
metric_kind_t rust_to_cpp_metric(MetricKind value) {
switch (value) {
case MetricKind::IP: return metric_kind_t::ip_k;
case MetricKind::L2Sq: return metric_kind_t::l2sq_k;
case MetricKind::L2sq: return metric_kind_t::l2sq_k;
case MetricKind::Cos: return metric_kind_t::cos_k;
case MetricKind::Pearson: return metric_kind_t::pearson_k;
case MetricKind::Haversine: return metric_kind_t::haversine_k;
Expand Down
3 changes: 3 additions & 0 deletions rust/lib.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class Index {
Matches search(rust::Slice<float const> vector, size_t count) const;
Matches search_in_thread(rust::Slice<float const> vector, size_t count, size_t thread) const;

bool remove(key_t key) const;
bool contains(key_t key) const;

size_t dimensions() const;
size_t connectivity() const;
size_t size() const;
Expand Down
9 changes: 5 additions & 4 deletions rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ pub mod ffi {

// Shared structs with fields visible to both languages.
struct Matches {
count: usize,
keys: Vec<u64>,
distances: Vec<f32>,
}

enum MetricKind {
IP,
L2Sq,
L2sq,
Cos,
Pearson,
Haversine,
Expand Down Expand Up @@ -53,6 +52,8 @@ pub mod ffi {

pub fn add(self: &Index, key: u64, vector: &[f32]) -> Result<()>;
pub fn search(self: &Index, query: &[f32], count: usize) -> Result<Matches>;
pub fn remove(self: &Index, key: u64) -> Result<bool>;
pub fn contains(self: &Index, key: u64) -> bool;

pub fn save(self: &Index, path: &str) -> Result<()>;
pub fn load(self: &Index, path: &str) -> Result<()>;
Expand Down Expand Up @@ -97,7 +98,7 @@ mod tests {

// Read back the tags
let results = index.search(&first, 10).unwrap();
assert_eq!(results.count, 2);
assert_eq!(results.keys.len(), 2);

// Validate serialization
assert!(index.save("index.rust.usearch").is_ok());
Expand All @@ -106,7 +107,7 @@ mod tests {

// Make sure every function is called at least once
assert!(new_index(&options).is_ok());
options.metric = MetricKind::L2Sq;
options.metric = MetricKind::L2sq;
assert!(new_index(&options).is_ok());
options.metric = MetricKind::Cos;
assert!(new_index(&options).is_ok());
Expand Down

0 comments on commit c8de1d1

Please sign in to comment.