diff --git a/javascript/lib.cpp b/javascript/lib.cpp index 5034dd5f..afdb1bb6 100644 --- a/javascript/lib.cpp +++ b/javascript/lib.cpp @@ -37,6 +37,8 @@ class Index : public Napi::ObjectWrap { void Add(Napi::CallbackInfo const& ctx); Napi::Value Search(Napi::CallbackInfo const& ctx); + Napi::Value Remove(Napi::CallbackInfo const& ctx); + Napi::Value Contains(Napi::CallbackInfo const& ctx); std::unique_ptr native_; }; @@ -51,6 +53,8 @@ Napi::Object Index::Init(Napi::Env env, Napi::Object exports) { InstanceMethod("connectivity", &Index::GetConnectivity), InstanceMethod("add", &Index::Add), InstanceMethod("search", &Index::Search), + InstanceMethod("remove", &Index::Remove), + InstanceMethod("contains", &Index::Contains), InstanceMethod("save", &Index::Save), InstanceMethod("load", &Index::Load), InstanceMethod("view", &Index::View), @@ -291,6 +295,64 @@ Napi::Value Index::Search(Napi::CallbackInfo const& ctx) { } } +Napi::Value Index::Remove(Napi::CallbackInfo const& ctx) { + Napi::Env env = ctx.Env(); + if (ctx.Length() < 1 || !ctx[0].IsBigInt()) { + Napi::TypeError::New(env, "Expects an entry identifier").ThrowAsJavaScriptException(); + return {}; + } + + Napi::BigInt key_js = ctx[0].As(); + bool lossless = true; + std::uint64_t key = key_js.Uint64Value(&lossless); + if (!lossless) { + Napi::TypeError::New(env, "Identifier must be an unsigned integer").ThrowAsJavaScriptException(); + return {}; + } + + try { + auto result = native_->remove(key); + if (!result) { + Napi::TypeError::New(env, "Removal has failed").ThrowAsJavaScriptException(); + return {}; + } + return Napi::Boolean::New(env, result.completed); + } catch (std::bad_alloc const&) { + Napi::TypeError::New(env, "Out of memory").ThrowAsJavaScriptException(); + return {}; + } catch (...) { + Napi::TypeError::New(env, "Search failed").ThrowAsJavaScriptException(); + return {}; + } +} + +Napi::Value Index::Contains(Napi::CallbackInfo const& ctx) { + Napi::Env env = ctx.Env(); + if (ctx.Length() < 1 || !ctx[0].IsBigInt()) { + Napi::TypeError::New(env, "Expects an entry identifier").ThrowAsJavaScriptException(); + return {}; + } + + Napi::BigInt key_js = ctx[0].As(); + bool lossless = true; + std::uint64_t key = key_js.Uint64Value(&lossless); + if (!lossless) { + Napi::TypeError::New(env, "Identifier must be an unsigned integer").ThrowAsJavaScriptException(); + return {}; + } + + try { + bool result = native_->contains(key); + return Napi::Boolean::New(env, result); + } catch (std::bad_alloc const&) { + Napi::TypeError::New(env, "Out of memory").ThrowAsJavaScriptException(); + return {}; + } catch (...) { + Napi::TypeError::New(env, "Search failed").ThrowAsJavaScriptException(); + return {}; + } +} + Napi::Object InitAll(Napi::Env env, Napi::Object exports) { return Index::Init(env, exports); } NODE_API_MODULE(usearch, InitAll) diff --git a/javascript/test.js b/javascript/test.js index a6332ca8..f190af89 100644 --- a/javascript/test.js +++ b/javascript/test.js @@ -11,6 +11,7 @@ assert.equal(index.size(), 0n, 'initial size should be 0'); index.add(15n, new Float32Array([10, 20])); index.add(16n, new Float32Array([10, 25])); assert.equal(index.size(), 2n, 'size after adding elements should be 2'); +assert.equal(index.contains(15n), true, 'entry must be present after insertion'); var results = index.search(new Float32Array([13, 14]), 2n); assert.deepEqual(results.keys, new BigUint64Array([15n, 16n]), 'keys should be 15 and 16'); diff --git a/javascript/usearch.d.ts b/javascript/usearch.d.ts index 50cb411a..bd7a6bd1 100644 --- a/javascript/usearch.d.ts +++ b/javascript/usearch.d.ts @@ -11,7 +11,7 @@ export interface Matches { /** K-Approximate Nearest Neighbors search index. */ export class Index { - + /** * Constructs a new index. * @@ -42,7 +42,7 @@ export class Index { * @return {bigints} The capacity of index. */ capacity(): bigint; - + /** * Returns connectivity. * @return {bigint} The connectivity of index. @@ -84,4 +84,19 @@ export class Index { * @return {Matches} Output of the search result. */ search(mat: Float32Array, k: bigint): Matches; + + /** + * Check if an entry is contained in the index. + * + * @param {bigint} key Identifier to look up. + */ + contains(key: bigint): boolean; + + /** + * Remove a vector from the index. + * + * @param {bigint} key Input identifier for every vector to be removed. + */ + remove(key: bigint): boolean; + } \ No newline at end of file diff --git a/rust/lib.cpp b/rust/lib.cpp index 1afd67e8..603cc5b7 100644 --- a/rust/lib.cpp +++ b/rust/lib.cpp @@ -4,9 +4,10 @@ using namespace unum::usearch; using namespace unum; -using index_t = typename Index::index_t; +using index_t = index_dense_t; using add_result_t = typename index_t::add_result_t; using search_result_t = typename index_t::search_result_t; +using labeling_result_t = typename index_t::labeling_result_t; Index::Index(std::unique_ptr index) : index_(std::move(index)) {} @@ -28,14 +29,22 @@ Matches Index::search_in_thread(rust::Slice vector, size_t count, s config.expansion = index_->expansion_search(); search_result_t result = index_->search(vector.data(), count, config); result.error.raise(); - matches.count = result.dump_to(matches.keys.data(), matches.distances.data()); - matches.keys.truncate(matches.count); - matches.distances.truncate(matches.count); + count = result.dump_to(matches.keys.data(), matches.distances.data()); + matches.keys.truncate(count); + matches.distances.truncate(count); return matches; } void Index::add(key_t key, rust::Slice vector) const { index_->add(key, vector.data()).error.raise(); } +bool Index::remove(key_t key) const { + labeling_result_t result = index_->remove(key); + result.error.raise(); + return result.completed; +} + +bool Index::contains(key_t key) const { return index_->contains(key); } + Matches Index::search(rust::Slice vector, size_t count) const { Matches matches; matches.keys.reserve(count); @@ -44,9 +53,9 @@ Matches Index::search(rust::Slice vector, size_t count) const { matches.keys.push_back(0), matches.distances.push_back(0); search_result_t result = index_->search(vector.data(), count); result.error.raise(); - matches.count = result.dump_to(matches.keys.data(), matches.distances.data()); - matches.keys.truncate(matches.count); - matches.distances.truncate(matches.count); + count = result.dump_to(matches.keys.data(), matches.distances.data()); + matches.keys.truncate(count); + matches.distances.truncate(count); return matches; } @@ -74,7 +83,7 @@ std::unique_ptr wrap(index_t&& index) { metric_kind_t rust_to_cpp_metric(MetricKind value) { switch (value) { case MetricKind::IP: return metric_kind_t::ip_k; - case MetricKind::L2Sq: return metric_kind_t::l2sq_k; + case MetricKind::L2sq: return metric_kind_t::l2sq_k; case MetricKind::Cos: return metric_kind_t::cos_k; case MetricKind::Pearson: return metric_kind_t::pearson_k; case MetricKind::Haversine: return metric_kind_t::haversine_k; diff --git a/rust/lib.hpp b/rust/lib.hpp index 42192b7c..705f0d17 100644 --- a/rust/lib.hpp +++ b/rust/lib.hpp @@ -27,6 +27,9 @@ class Index { Matches search(rust::Slice vector, size_t count) const; Matches search_in_thread(rust::Slice vector, size_t count, size_t thread) const; + bool remove(key_t key) const; + bool contains(key_t key) const; + size_t dimensions() const; size_t connectivity() const; size_t size() const; diff --git a/rust/lib.rs b/rust/lib.rs index 46aeabb7..1c9eb9d2 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -3,14 +3,13 @@ pub mod ffi { // Shared structs with fields visible to both languages. struct Matches { - count: usize, keys: Vec, distances: Vec, } enum MetricKind { IP, - L2Sq, + L2sq, Cos, Pearson, Haversine, @@ -53,6 +52,8 @@ pub mod ffi { pub fn add(self: &Index, key: u64, vector: &[f32]) -> Result<()>; pub fn search(self: &Index, query: &[f32], count: usize) -> Result; + pub fn remove(self: &Index, key: u64) -> Result; + pub fn contains(self: &Index, key: u64) -> bool; pub fn save(self: &Index, path: &str) -> Result<()>; pub fn load(self: &Index, path: &str) -> Result<()>; @@ -97,7 +98,7 @@ mod tests { // Read back the tags let results = index.search(&first, 10).unwrap(); - assert_eq!(results.count, 2); + assert_eq!(results.keys.len(), 2); // Validate serialization assert!(index.save("index.rust.usearch").is_ok()); @@ -106,7 +107,7 @@ mod tests { // Make sure every function is called at least once assert!(new_index(&options).is_ok()); - options.metric = MetricKind::L2Sq; + options.metric = MetricKind::L2sq; assert!(new_index(&options).is_ok()); options.metric = MetricKind::Cos; assert!(new_index(&options).is_ok());