Skip to content

Commit

Permalink
Add: filtered_search API in C & C++
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 7, 2024
1 parent c4f9e65 commit 5bb38aa
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 22 deletions.
45 changes: 35 additions & 10 deletions c/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,20 @@ std::size_t get_(index_dense_t* index, usearch_key_t key, size_t count, void* ve
}
}

search_result_t search_(index_dense_t* index, void const* vector, scalar_kind_t kind, size_t n) {
template <typename predicate_at = dummy_predicate_t>
search_result_t search_(index_dense_t* index, void const* vector, scalar_kind_t kind, size_t n,
predicate_at&& predicate = predicate_at{}) {
switch (kind) {
case scalar_kind_t::f32_k: return index->search((f32_t const*)vector, n);
case scalar_kind_t::f64_k: return index->search((f64_t const*)vector, n);
case scalar_kind_t::f16_k: return index->search((f16_t const*)vector, n);
case scalar_kind_t::i8_k: return index->search((i8_t const*)vector, n);
case scalar_kind_t::b1x8_k: return index->search((b1x8_t const*)vector, n);
case scalar_kind_t::f32_k:
return index->filtered_search((f32_t const*)vector, n, std::forward<predicate_at>(predicate));
case scalar_kind_t::f64_k:
return index->filtered_search((f64_t const*)vector, n, std::forward<predicate_at>(predicate));
case scalar_kind_t::f16_k:
return index->filtered_search((f16_t const*)vector, n, std::forward<predicate_at>(predicate));
case scalar_kind_t::i8_k:
return index->filtered_search((i8_t const*)vector, n, std::forward<predicate_at>(predicate));
case scalar_kind_t::b1x8_k:
return index->filtered_search((b1x8_t const*)vector, n, std::forward<predicate_at>(predicate));
default: return search_result_t().failed("Unknown scalar kind!");
}
}
Expand All @@ -123,9 +130,9 @@ USEARCH_EXPORT usearch_index_t usearch_init(usearch_init_options_t* options, use
metric_punned_t metric = //
!options->metric ? metric_punned_t::builtin(options->dimensions, metric_kind, scalar_kind)
: metric_punned_t::stateless(options->dimensions, //
reinterpret_cast<std::uintptr_t>(options->metric), //
metric_punned_signature_t::array_array_k, //
metric_kind, scalar_kind);
reinterpret_cast<std::uintptr_t>(options->metric), //
metric_punned_signature_t::array_array_k, //
metric_kind, scalar_kind);
if (metric.missing()) {
*error = "Unknown metric kind!";
return NULL;
Expand Down Expand Up @@ -295,7 +302,7 @@ USEARCH_EXPORT void usearch_change_metric(usearch_index_t index, usearch_metric_
reinterpret_cast<std::uintptr_t>(state), metric_kind_to_cpp(kind),
index_dense.scalar_kind())
: metric_punned_t::stateless(index_dense.dimensions(), reinterpret_cast<std::uintptr_t>(metric),
metric_punned_signature_t::array_array_k, metric_kind_to_cpp(kind),
metric_punned_signature_t::array_array_k, metric_kind_to_cpp(kind),
index_dense.scalar_kind());
index_dense.change_metric(std::move(metric_punned));
}
Expand Down Expand Up @@ -341,6 +348,24 @@ USEARCH_EXPORT size_t usearch_search(
return result.dump_to(found_keys, found_distances);
}

USEARCH_EXPORT size_t usearch_filtered_search( //
usearch_index_t index, //
void const* query, usearch_scalar_kind_t query_kind, //
int (*filter)(usearch_key_t key, void* filter_state), void* filter_state, //
size_t results_limit, usearch_key_t* found_keys, usearch_distance_t* found_distances, usearch_error_t* error) {

USEARCH_ASSERT(index && query && filter && error && "Missing arguments");
search_result_t result =
search_(reinterpret_cast<index_dense_t*>(index), query, scalar_kind_to_cpp(query_kind), results_limit,
[=](usearch_key_t key) noexcept { return filter(key, filter_state); });
if (!result) {
*error = result.error.release();
return 0;
}

return result.dump_to(found_keys, found_distances);
}

USEARCH_EXPORT size_t usearch_get( //
usearch_index_t index, usearch_key_t key, size_t count, //
void* vectors, usearch_scalar_kind_t kind, usearch_error_t*) {
Expand Down
24 changes: 22 additions & 2 deletions c/usearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,11 +272,12 @@ USEARCH_EXPORT void usearch_change_metric_kind(usearch_index_t index, usearch_me
* @brief Updates the custom metric function used for distance calculation between vectors.
* @param[in] index The handle to the USearch index to be queried.
* @param[in] metric The custom metric function used for distance calculation between vectors.
* @param[in] state The @b optional state pointer to be passed to the custom metric function.
* @param[in] kind The metric kind used for distance calculation between vectors. Needed for serialization.
* @param[out] error Pointer to a string where the error message will be stored, if an error occurs.
*/
USEARCH_EXPORT void usearch_change_metric(usearch_index_t index, usearch_metric_t metric, usearch_metric_kind_t kind,
usearch_error_t* error);
USEARCH_EXPORT void usearch_change_metric(usearch_index_t index, usearch_metric_t metric, void* state,
usearch_metric_kind_t kind, usearch_error_t* error);

/**
* @brief Adds a vector with a key to the index.
Expand Down Expand Up @@ -324,6 +325,25 @@ USEARCH_EXPORT size_t usearch_search( //
void const* query_vector, usearch_scalar_kind_t query_kind, //
size_t count, usearch_key_t* keys, usearch_distance_t* distances, usearch_error_t* error);

/**
* @brief Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to query,
* predicated on a custom function that returns `true` for vectors to be included.
*
* @param[in] index The handle to the USearch index to be queried.
* @param[in] query_vector Pointer to the query vector data.
* @param[in] query_kind The scalar type used in the query vector data.
* @param[in] count Upper bound on the number of neighbors to search, the "k" in "kANN".
* @param[out] keys Output buffer for up to `count` nearest neighbors keys.
* @param[out] distances Output buffer for up to `count` distances to nearest neighbors.
* @param[out] error Pointer to a string where the error message will be stored, if an error occurs.
* @return Number of found matches.
*/
USEARCH_EXPORT size_t usearch_filtered_search( //
usearch_index_t index, //
void const* query_vector, usearch_scalar_kind_t query_kind, //
int (*filter)(usearch_key_t key, void* filter_state), void* filter_state, //
size_t count, usearch_key_t* keys, usearch_distance_t* distances, usearch_error_t* error);

/**
* @brief Retrieves the vector associated with the given key from the index.
* @param[in] index The handle to the USearch index to be queried.
Expand Down
11 changes: 11 additions & 0 deletions cpp/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ void test_cosine(index_at& index, std::vector<std::vector<scalar_at>> const& vec

// Perform exact search
matched_count = index.search(vector_first, 5, args...).dump_to(matched_keys, matched_distances);
expect(matched_count != 0);

// Perform filtered exact search, keeping only odd values
if constexpr (punned_ak) {
auto is_odd = [](vector_key_t key) -> bool { return (key & 1) != 0; };
matched_count =
index.filtered_search(vector_first, is_odd, 5, args...).dump_to(matched_keys, matched_distances);
expect(matched_count != 0);
for (std::size_t i = 0; i < matched_count; i++)
expect(is_odd(matched_keys[i]));
}

// Validate scans
std::size_t count = 0;
Expand Down
33 changes: 23 additions & 10 deletions include/usearch/index_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,11 +604,17 @@ class index_dense_gt {
add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f32); }
add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f64); }

search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_b1x8); }
search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_i8); }
search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f16); }
search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f32); }
search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f64); }
search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_b1x8); }
search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_i8); }
search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f16); }
search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f32); }
search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f64); }

template <typename predicate_at> search_result_t filtered_search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_b1x8); }
template <typename predicate_at> search_result_t filtered_search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_i8); }
template <typename predicate_at> search_result_t filtered_search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_f16); }
template <typename predicate_at> search_result_t filtered_search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_f32); }
template <typename predicate_at> search_result_t filtered_search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_f64); }

std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_b1x8); }
std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_i8); }
Expand Down Expand Up @@ -1767,9 +1773,9 @@ class index_dense_gt {
: typed_->add(key, vector_data, metric, update_config, on_success);
}

template <typename scalar_at>
search_result_t search_( //
scalar_at const* vector, std::size_t wanted, //
template <typename scalar_at, typename predicate_at>
search_result_t search_(
scalar_at const* vector, std::size_t wanted, predicate_at&& predicate,
std::size_t thread, bool exact, cast_t const& cast) const {

// Cast the vector, if needed for compatibility with `metric_`
Expand All @@ -1787,8 +1793,15 @@ class index_dense_gt {
search_config.expansion = config_.expansion_search;
search_config.exact = exact;

auto allow = [=](member_cref_t const& member) noexcept { return member.key != free_key_; };
return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow);
if (std::is_same<typename std::decay<predicate_at>::type, dummy_predicate_t>::value) {
auto allow = [=](member_cref_t const& member) noexcept { return member.key != free_key_; };
return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow);
} else {
auto allow = [=, &predicate](member_cref_t const& member) noexcept {
return member.key != free_key_ && predicate(member.key);
};
return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow);
}
}

template <typename scalar_at>
Expand Down

0 comments on commit 5bb38aa

Please sign in to comment.