Skip to content

Commit

Permalink
Improve: metric_punned_t static methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 6, 2024
1 parent 4ff568b commit 220ef57
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 50 deletions.
27 changes: 16 additions & 11 deletions c/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ USEARCH_EXPORT usearch_index_t usearch_init(usearch_init_options_t* options, use
scalar_kind_t scalar_kind = scalar_kind_to_cpp(options->quantization);

metric_punned_t metric = //
!options->metric ? metric_punned_t(options->dimensions, metric_kind, scalar_kind)
: metric_punned_t(options->dimensions, //
!options->metric ? metric_punned_t::builtin(options->dimensions, metric_kind, scalar_kind)
: metric_punned_t::stateless(options->dimensions, //
reinterpret_cast<std::uintptr_t>(options->metric), //
metric_punned_signature_t::array_array_k, //
metric_kind, scalar_kind);
Expand Down Expand Up @@ -283,16 +283,21 @@ USEARCH_EXPORT void usearch_change_metric_kind(usearch_index_t index, usearch_me
USEARCH_ASSERT(index && error && "Missing arguments");
auto& index_dense = *reinterpret_cast<index_dense_t*>(index);
index_dense.change_metric(
metric_punned_t(index_dense.dimensions(), metric_kind_to_cpp(kind), index_dense.scalar_kind()));
metric_punned_t::builtin(index_dense.dimensions(), metric_kind_to_cpp(kind), index_dense.scalar_kind()));
}

USEARCH_EXPORT void usearch_change_metric(usearch_index_t index, usearch_metric_t metric, usearch_metric_kind_t kind,
usearch_error_t* error) {
USEARCH_EXPORT void usearch_change_metric(usearch_index_t index, usearch_metric_t metric, void* state,
usearch_metric_kind_t kind, usearch_error_t* error) {
USEARCH_ASSERT(index && error && "Missing arguments");
auto& index_dense = *reinterpret_cast<index_dense_t*>(index);
index_dense.change_metric(metric_punned_t(index_dense.dimensions(), reinterpret_cast<std::uintptr_t>(metric),
auto metric_punned =
state ? metric_punned_t::statefull(reinterpret_cast<std::uintptr_t>(metric),
reinterpret_cast<std::uintptr_t>(state), metric_kind_to_cpp(kind),
index_dense.scalar_kind())
: metric_punned_t::stateless(index_dense.dimensions(), reinterpret_cast<std::uintptr_t>(metric),
metric_punned_signature_t::array_array_k, metric_kind_to_cpp(kind),
index_dense.scalar_kind()));
index_dense.scalar_kind());
index_dense.change_metric(std::move(metric_punned));
}

USEARCH_EXPORT void usearch_reserve(usearch_index_t index, size_t capacity, usearch_error_t* error) {
Expand Down Expand Up @@ -321,13 +326,13 @@ USEARCH_EXPORT size_t usearch_count(usearch_index_t index, usearch_key_t key, us
return reinterpret_cast<index_dense_t*>(index)->count(key);
}

USEARCH_EXPORT size_t usearch_search( //
usearch_index_t index, void const* vector, usearch_scalar_kind_t kind, size_t results_limit, //
USEARCH_EXPORT size_t usearch_search( //
usearch_index_t index, void const* query, usearch_scalar_kind_t query_kind, size_t results_limit, //
usearch_key_t* found_keys, usearch_distance_t* found_distances, usearch_error_t* error) {

USEARCH_ASSERT(index && vector && error && "Missing arguments");
USEARCH_ASSERT(index && query && error && "Missing arguments");
search_result_t result =
search_(reinterpret_cast<index_dense_t*>(index), vector, scalar_kind_to_cpp(kind), results_limit);
search_(reinterpret_cast<index_dense_t*>(index), query, scalar_kind_to_cpp(query_kind), results_limit);
if (!result) {
*error = result.error.release();
return 0;
Expand Down
4 changes: 2 additions & 2 deletions include/usearch/index_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ class index_dense_gt {
return result.failed("Slot type doesn't match, consider rebuilding");

config_.multi = head.multi;
metric_ = metric_t(head.dimensions, head.kind_metric, head.kind_scalar);
metric_ = metric_t::builtin(head.dimensions, head.kind_metric, head.kind_scalar);
cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector());
casts_ = make_casts_(head.kind_scalar);
}
Expand Down Expand Up @@ -1016,7 +1016,7 @@ class index_dense_gt {
return result.failed("Slot type doesn't match, consider rebuilding");

config_.multi = head.multi;
metric_ = metric_t(head.dimensions, head.kind_metric, head.kind_scalar);
metric_ = metric_t::builtin(head.dimensions, head.kind_metric, head.kind_scalar);
cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector());
casts_ = make_casts_(head.kind_scalar);
offset += sizeof(buffer);
Expand Down
116 changes: 89 additions & 27 deletions include/usearch/index_plugins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,7 @@ using span_punned_t = span_gt<byte_t const>;
enum class metric_punned_signature_t {
array_array_k = 0,
array_array_size_k,
array_array_state_k,
};

/**
Expand All @@ -1339,12 +1340,14 @@ class metric_punned_t {
using metric_array_array_t = result_t (*)(uptr_t, uptr_t);
/// Distance function that takes two arrays and their length and returns a scalar.
using metric_array_array_size_t = result_t (*)(uptr_t, uptr_t, uptr_t);
/// Distance function that takes two arrays and some callback state and returns a scalar.
using metric_array_array_state_t = result_t (*)(uptr_t, uptr_t, uptr_t);
/// Distance function callback, like `metric_array_array_size_t`, but depends on member variables.
using metric_rounted_t = result_t (metric_punned_t::*)(uptr_t, uptr_t) const;

metric_rounted_t metric_routed_ = nullptr;
uptr_t metric_ptr_ = 0;
uptr_t metric_size_arg_ = 0;
uptr_t metric_third_arg_ = 0;

std::size_t dimensions_ = 0;
metric_kind_t metric_kind_ = metric_kind_t::unknown_k;
Expand All @@ -1367,37 +1370,95 @@ class metric_punned_t {
inline metric_punned_t() noexcept = default;
inline metric_punned_t(metric_punned_t const&) noexcept = default;
inline metric_punned_t& operator=(metric_punned_t const&) noexcept = default;
inline metric_punned_t( //
std::size_t dimensions, //
metric_kind_t metric_kind = metric_kind_t::l2sq_k, //
scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept
: metric_routed_(&metric_punned_t::invoke_array_array_size), metric_size_arg_(dimensions),
dimensions_(dimensions), metric_kind_(metric_kind), scalar_kind_(scalar_kind) {

inline metric_punned_t(std::size_t dimensions, metric_kind_t metric_kind = metric_kind_t::l2sq_k,
scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept
: metric_punned_t(builtin(dimensions, metric_kind, scalar_kind)) {}

inline metric_punned_t(std::size_t dimensions, std::uintptr_t metric_uintptr, metric_punned_signature_t signature,
metric_kind_t metric_kind, scalar_kind_t scalar_kind) noexcept
: metric_punned_t(stateless(dimensions, metric_uintptr, signature, metric_kind, scalar_kind)) {}

/**
* @brief Creates a metric of a natively supported kind, choosing the best
* available backend internally or from SimSIMD.
*
* @param dimensions The number of elements in the input arrays.
* @param metric_kind The kind of metric to use.
* @param scalar_kind The kind of scalar to use.
* @return A metric object that can be used to compute distances between vectors.
*/
inline static metric_punned_t builtin(std::size_t dimensions, metric_kind_t metric_kind = metric_kind_t::l2sq_k,
scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept {
metric_punned_t metric;
metric.metric_routed_ = &metric_punned_t::invoke_array_array_third;
metric.metric_ptr_ = 0;
metric.metric_third_arg_ =
scalar_kind == scalar_kind_t::b1x8_k ? divide_round_up<CHAR_BIT>(dimensions) : dimensions;
metric.dimensions_ = dimensions;
metric.metric_kind_ = metric_kind;
metric.scalar_kind_ = scalar_kind;

#if USEARCH_USE_SIMSIMD
if (!configure_with_simsimd())
configure_with_autovec();
if (!metric.configure_with_simsimd())
metric.configure_with_autovec();
#else
configure_with_autovec();
metric.configure_with_autovec();
#endif

if (scalar_kind == scalar_kind_t::b1x8_k)
metric_size_arg_ = divide_round_up<CHAR_BIT>(dimensions_);
return metric;
}

inline metric_punned_t( //
std::size_t dimensions, //
std::uintptr_t metric_uintptr, metric_punned_signature_t signature, //
metric_kind_t metric_kind, //
scalar_kind_t scalar_kind) noexcept
: metric_routed_(signature == metric_punned_signature_t::array_array_k
? &metric_punned_t::invoke_array_array
: &metric_punned_t::invoke_array_array_size),
metric_ptr_(metric_uintptr), metric_size_arg_(dimensions), dimensions_(dimensions), metric_kind_(metric_kind),
scalar_kind_(scalar_kind) {
/**
* @brief Creates a metric using the provided function pointer for a stateless metric.
* So the provided ::metric_uintptr is a pointer to a function that takes two arrays
* and returns a scalar. If the ::signature is metric_punned_signature_t::array_array_size_k,
* then the third argument is the number of scalar words in the input vectors.
*
* @param dimensions The number of elements in the input arrays.
* @param metric_uintptr The function pointer to the metric function.
* @param signature The signature of the metric function.
* @param metric_kind The kind of metric to use.
* @param scalar_kind The kind of scalar to use.
* @return A metric object that can be used to compute distances between vectors.
*/
inline static metric_punned_t stateless(std::size_t dimensions, std::uintptr_t metric_uintptr,
metric_punned_signature_t signature, metric_kind_t metric_kind,
scalar_kind_t scalar_kind) noexcept {
metric_punned_t metric;
metric.metric_routed_ = signature == metric_punned_signature_t::array_array_k
? &metric_punned_t::invoke_array_array
: &metric_punned_t::invoke_array_array_third;
metric.metric_ptr_ = metric_uintptr;
metric.metric_third_arg_ =
scalar_kind == scalar_kind_t::b1x8_k ? divide_round_up<CHAR_BIT>(dimensions) : dimensions;
metric.dimensions_ = dimensions;
metric.metric_kind_ = metric_kind;
metric.scalar_kind_ = scalar_kind;
return metric;
}

if (scalar_kind == scalar_kind_t::b1x8_k)
metric_size_arg_ = divide_round_up<CHAR_BIT>(dimensions_);
/**
* @brief Creates a metric using the provided function pointer for a statefull metric.
* The third argument is the state that will be passed to the metric function.
*
* @param metric_uintptr The function pointer to the metric function.
* @param metric_state The state to pass to the metric function.
* @param metric_kind The kind of metric to use.
* @param scalar_kind The kind of scalar to use.
* @return A metric object that can be used to compute distances between vectors.
*/
inline static metric_punned_t statefull(std::uintptr_t metric_uintptr, std::uintptr_t metric_state,
metric_kind_t metric_kind = metric_kind_t::unknown_k,
scalar_kind_t scalar_kind = scalar_kind_t::unknown_k) noexcept {
metric_punned_t metric;
metric.metric_routed_ = &metric_punned_t::invoke_array_array_third;
metric.metric_ptr_ = metric_uintptr;
metric.metric_third_arg_ = metric_state;
metric.dimensions_ = 0;
metric.metric_kind_ = metric_kind;
metric.scalar_kind_ = scalar_kind;
return metric;
}

inline std::size_t dimensions() const noexcept { return dimensions_; }
Expand Down Expand Up @@ -1485,16 +1546,17 @@ class metric_punned_t {
simsimd_distance_t result;
// Here `reinterpret_cast` raises warning... we know what we are doing!
auto function_pointer = (simsimd_metric_punned_t)(metric_ptr_);
function_pointer(reinterpret_cast<void const*>(a), reinterpret_cast<void const*>(b), metric_size_arg_, &result);
function_pointer(reinterpret_cast<void const*>(a), reinterpret_cast<void const*>(b), metric_third_arg_,
&result);
return (result_t)result;
}
result_t invoke_simsimd_reverse(uptr_t a, uptr_t b) const noexcept { return 1 - invoke_simsimd(a, b); }
#else
bool configure_with_simsimd() noexcept { return false; }
#endif
result_t invoke_array_array_size(uptr_t a, uptr_t b) const noexcept {
result_t invoke_array_array_third(uptr_t a, uptr_t b) const noexcept {
auto function_pointer = (metric_array_array_size_t)(metric_ptr_);
result_t result = function_pointer(a, b, metric_size_arg_);
result_t result = function_pointer(a, b, metric_third_arg_);
return result;
}
result_t invoke_array_array(uptr_t a, uptr_t b) const noexcept {
Expand Down
14 changes: 7 additions & 7 deletions python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ static dense_index_py_t make_index( //

metric_t metric = //
metric_uintptr //
? metric_t(dimensions, metric_uintptr, metric_signature, metric_kind, scalar_kind)
: metric_t(dimensions, metric_kind, scalar_kind);
? metric_t::stateless(dimensions, metric_uintptr, metric_signature, metric_kind, scalar_kind)
: metric_t::builtin(dimensions, metric_kind, scalar_kind);
if (metric.missing())
throw std::invalid_argument("Unsupported metric!");

Expand Down Expand Up @@ -486,8 +486,8 @@ static py::tuple search_many_brute_force( //
std::size_t dimensions = static_cast<std::size_t>(queries_dimensions);
metric_t metric = //
metric_uintptr //
? metric_t(dimensions, metric_uintptr, metric_signature, metric_kind, queries_kind)
: metric_t(dimensions, metric_kind, queries_kind);
? metric_t::stateless(dimensions, metric_uintptr, metric_signature, metric_kind, queries_kind)
: metric_t::builtin(dimensions, metric_kind, queries_kind);
if (!metric)
throw std::invalid_argument("Unsupported metric!");

Expand Down Expand Up @@ -940,7 +940,7 @@ PYBIND11_MODULE(compiled, m) {
m.def(
"hardware_acceleration",
[](scalar_kind_t scalar_kind, std::size_t dimensions, metric_kind_t metric_kind) -> py::str {
return metric_t(dimensions, metric_kind, scalar_kind).isa_name();
return metric_t::builtin(dimensions, metric_kind, scalar_kind).isa_name();
},
py::kw_only(), //
py::arg("dtype") = scalar_kind_t::f32_k, //
Expand Down Expand Up @@ -1101,8 +1101,8 @@ PYBIND11_MODULE(compiled, m) {
std::size_t dimensions = index.dimensions();
metric_t metric = //
metric_uintptr //
? metric_t(dimensions, metric_uintptr, metric_signature, metric_kind, scalar_kind)
: metric_t(dimensions, metric_kind, scalar_kind);
? metric_t::stateless(dimensions, metric_uintptr, metric_signature, metric_kind, scalar_kind)
: metric_t::builtin(dimensions, metric_kind, scalar_kind);
if (!metric)
throw std::invalid_argument("Unsupported metric kind!");
index.change_metric(std::move(metric));
Expand Down
6 changes: 3 additions & 3 deletions sqlite/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ static void sqlite_dense(sqlite3_context* context, int argc, sqlite3_value** arg
void const* vec2 = sqlite3_value_blob(argv[1]);

std::size_t dimensions = (size_t)(bytes1)*CHAR_BIT / bits_per_scalar(scalar_kind_ak);
metric_t metric = metric_t(dimensions, metric_kind_ak, scalar_kind_ak);
metric_t metric = metric_t::builtin(dimensions, metric_kind_ak, scalar_kind_ak);
distance_punned_t distance =
metric(reinterpret_cast<byte_t const*>(vec1), reinterpret_cast<byte_t const*>(vec2));
sqlite3_result_double(context, distance);
Expand Down Expand Up @@ -148,7 +148,7 @@ static void sqlite_dense(sqlite3_context* context, int argc, sqlite3_value** arg
}

// Compute the distance itself
metric_t metric = metric_t(dimensions, metric_kind_ak, parsed_scalar_kind_gt<scalar_kind_ak>::kind);
metric_t metric = metric_t::builtin(dimensions, metric_kind_ak, parsed_scalar_kind_gt<scalar_kind_ak>::kind);
distance_punned_t distance =
metric(reinterpret_cast<byte_t const*>(parsed1), reinterpret_cast<byte_t const*>(parsed2));
sqlite3_result_double(context, distance);
Expand Down Expand Up @@ -187,7 +187,7 @@ static void sqlite_dense(sqlite3_context* context, int argc, sqlite3_value** arg
}

// Compute the distance itself
metric_t metric = metric_t(dimensions, metric_kind_ak, parsed_scalar_kind_gt<scalar_kind_ak>::kind);
metric_t metric = metric_t::builtin(dimensions, metric_kind_ak, parsed_scalar_kind_gt<scalar_kind_ak>::kind);
distance_punned_t distance =
metric(reinterpret_cast<byte_t const*>(parsed1), reinterpret_cast<byte_t const*>(parsed2));
sqlite3_result_double(context, distance);
Expand Down

0 comments on commit 220ef57

Please sign in to comment.