From 7ba576df60d0a39068c5c949ba9300430127b4aa Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Mon, 8 Apr 2024 20:34:55 -0700 Subject: [PATCH] [#21830] DocDB: Import the usearch library Summary: Import the usearch header-only library and its dependency, the fp16 library. Both of these libraries are header-only. Introducing an inline-thirdparty directory in the src directory where we can easily import header-only libraries by copying the relevant header sources. We can switch to git subtrees, or create a tool automate pulling in upstream changes into our own git repository later. Jira: DB-10732 Test Plan: Jenkins: test regex: usearch New test: usearch_vector_index-test Reviewers: tnayak Reviewed By: tnayak Subscribers: jason, ybase Differential Revision: https://phorge.dev.yugabyte.com/D33682 --- CMakeLists.txt | 9 +- src/inline-thirdparty/README.md | 20 + src/inline-thirdparty/fp16/fp16.h | 11 + src/inline-thirdparty/fp16/fp16/bitcasts.h | 92 + src/inline-thirdparty/fp16/fp16/fp16.h | 451 ++ src/inline-thirdparty/fp16/fp16/psimd.h | 131 + .../usearch/usearch/index.hpp | 3852 +++++++++++++++++ .../usearch/usearch/index_dense.hpp | 2022 +++++++++ .../usearch/usearch/index_plugins.hpp | 2317 ++++++++++ src/yb/docdb/CMakeLists.txt | 1 + src/yb/docdb/usearch_vector_index-test.cc | 152 + src/yb/util/tsan_util.h | 14 + 12 files changed, 9070 insertions(+), 2 deletions(-) create mode 100644 src/inline-thirdparty/README.md create mode 100644 src/inline-thirdparty/fp16/fp16.h create mode 100644 src/inline-thirdparty/fp16/fp16/bitcasts.h create mode 100644 src/inline-thirdparty/fp16/fp16/fp16.h create mode 100644 src/inline-thirdparty/fp16/fp16/psimd.h create mode 100644 src/inline-thirdparty/usearch/usearch/index.hpp create mode 100644 src/inline-thirdparty/usearch/usearch/index_dense.hpp create mode 100644 src/inline-thirdparty/usearch/usearch/index_plugins.hpp create mode 100644 src/yb/docdb/usearch_vector_index-test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 90bfe3f70a92..3dbf3fdad79b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -409,7 +409,6 @@ if(IS_GCC) # https://gist.githubusercontent.com/mbautin/de18543ea85d46db49dfa4b4b7df082a/raw ADD_CXX_FLAGS("-Wno-use-after-free") endif() - endif() if(USING_LINUXBREW) @@ -600,10 +599,16 @@ file(MAKE_DIRECTORY "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}") set(EXECUTABLE_OUTPUT_PATH "${YB_BUILD_ROOT}/bin") file(MAKE_DIRECTORY "${EXECUTABLE_OUTPUT_PATH}") -# Generated sources always have higher priority. +# Generated sources always have higher priority than identically named sources in the source +# directory. include_directories(${CMAKE_CURRENT_BINARY_DIR}/src) + include_directories(src) +include_directories("src/inline-thirdparty/usearch") +include_directories("src/inline-thirdparty/fp16") + + enable_testing() if (USING_LINUXBREW) diff --git a/src/inline-thirdparty/README.md b/src/inline-thirdparty/README.md new file mode 100644 index 000000000000..2e34648c8c2e --- /dev/null +++ b/src/inline-thirdparty/README.md @@ -0,0 +1,20 @@ +# inline-thirdparty + +This is a directory where we copy some of the third-party header-only libraries, rather than adding +them to the yugabyte-db-thirdparty repo. We also only copy the relevant subdirectory of the upstream +repositories. Each library is copied in its own appropriately named directory, and each library's +directory is added separately to the list of include directories in CMakeLists.txt. + +* usearch + * Repo: https://github.com/yugabyte/usearch + * Description: Similarity search for vector and text + * Subdirectory: include + * Tag: v2.11.0-yb-1 + * License: Apache 2.0 + +* fp16 + * Repo: https://github.com/Maratyszcza/FP16/ + * Description: Header-only library for conversion to/from half-precision floating point formats + * Subdirectory: include + * Commit: 0a92994d729ff76a58f692d3028ca1b64b145d91 + * License: MIT diff --git a/src/inline-thirdparty/fp16/fp16.h b/src/inline-thirdparty/fp16/fp16.h new file mode 100644 index 000000000000..9d7366e997da --- /dev/null +++ b/src/inline-thirdparty/fp16/fp16.h @@ -0,0 +1,11 @@ +#pragma once +#ifndef FP16_H +#define FP16_H + +#include + +#if defined(PSIMD_H) +#include +#endif + +#endif /* FP16_H */ diff --git a/src/inline-thirdparty/fp16/fp16/bitcasts.h b/src/inline-thirdparty/fp16/fp16/bitcasts.h new file mode 100644 index 000000000000..86a4e22c48b2 --- /dev/null +++ b/src/inline-thirdparty/fp16/fp16/bitcasts.h @@ -0,0 +1,92 @@ +#pragma once +#ifndef FP16_BITCASTS_H +#define FP16_BITCASTS_H + +#if defined(__cplusplus) && (__cplusplus >= 201103L) + #include +#elif !defined(__OPENCL_VERSION__) + #include +#endif + +#if defined(__INTEL_COMPILER) + #include +#endif + +#if defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) + #include +#endif + + +static inline float fp32_from_bits(uint32_t w) { +#if defined(__OPENCL_VERSION__) + return as_float(w); +#elif defined(__CUDA_ARCH__) + return __uint_as_float((unsigned int) w); +#elif defined(__INTEL_COMPILER) + return _castu32_f32(w); +#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) + return _CopyFloatFromInt32((__int32) w); +#else + union { + uint32_t as_bits; + float as_value; + } fp32 = { w }; + return fp32.as_value; +#endif +} + +static inline uint32_t fp32_to_bits(float f) { +#if defined(__OPENCL_VERSION__) + return as_uint(f); +#elif defined(__CUDA_ARCH__) + return (uint32_t) __float_as_uint(f); +#elif defined(__INTEL_COMPILER) + return _castf32_u32(f); +#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) + return (uint32_t) _CopyInt32FromFloat(f); +#else + union { + float as_value; + uint32_t as_bits; + } fp32 = { f }; + return fp32.as_bits; +#endif +} + +static inline double fp64_from_bits(uint64_t w) { +#if defined(__OPENCL_VERSION__) + return as_double(w); +#elif defined(__CUDA_ARCH__) + return __longlong_as_double((long long) w); +#elif defined(__INTEL_COMPILER) + return _castu64_f64(w); +#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) + return _CopyDoubleFromInt64((__int64) w); +#else + union { + uint64_t as_bits; + double as_value; + } fp64 = { w }; + return fp64.as_value; +#endif +} + +static inline uint64_t fp64_to_bits(double f) { +#if defined(__OPENCL_VERSION__) + return as_ulong(f); +#elif defined(__CUDA_ARCH__) + return (uint64_t) __double_as_longlong(f); +#elif defined(__INTEL_COMPILER) + return _castf64_u64(f); +#elif defined(_MSC_VER) && (defined(_M_ARM) || defined(_M_ARM64)) + return (uint64_t) _CopyInt64FromDouble(f); +#else + union { + double as_value; + uint64_t as_bits; + } fp64 = { f }; + return fp64.as_bits; +#endif +} + +#endif /* FP16_BITCASTS_H */ diff --git a/src/inline-thirdparty/fp16/fp16/fp16.h b/src/inline-thirdparty/fp16/fp16/fp16.h new file mode 100644 index 000000000000..2b61fff5c1b9 --- /dev/null +++ b/src/inline-thirdparty/fp16/fp16/fp16.h @@ -0,0 +1,451 @@ +#pragma once +#ifndef FP16_FP16_H +#define FP16_FP16_H + +#if defined(__cplusplus) && (__cplusplus >= 201103L) + #include + #include +#elif !defined(__OPENCL_VERSION__) + #include + #include +#endif + +#ifdef _MSC_VER + #include +#endif + +#include + + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +static inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized. + * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one. + * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift + * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the + * biased exponent into 1, and making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long) nonsign); + uint32_t renorm_shift = (uint32_t) nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows it into bit 31, + * and the subsequent shift turns the high 9 bits into 1. Thus + * inf_nan_mask == + * 0x7F800000 if the half-precision number had exponent of 15 (i.e. was NaN or infinity) + * 0x00000000 otherwise + */ + const int32_t inf_nan_mask = ((int32_t) (nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 into 1. Otherwise, bit 31 remains 0. + * The signed shift right by 31 broadcasts bit 31 into all bits of the zero_mask. Thus + * zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t) (nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) becomes an 8-bit field and 10-bit mantissa + * shifts into the 10 high bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the different in exponent bias + * (0x7F for single-precision number less 0xF for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to account for renormalization. As renorm_shift + * is less than 0x70, this can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | inf_nan_mask) & ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline float fp16_ieee_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number. + * Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent + * of the single-precision output must be 0xFF (max possible value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested + * by the difference in the exponent bias (see above). + * - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of + * exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least + * partially IEEE754-compliant implementations. + * + * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * IEEE half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline uint16_t fp16_ieee_from_fp32_value(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +/* + * Convert a 16-bit floating-point number in ARM alternative half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ +static inline uint32_t fp16_alt_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized. + * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one. + * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift + * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the + * biased exponent into 1, and making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long) nonsign); + uint32_t renorm_shift = (uint32_t) nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 into 1. Otherwise, bit 31 remains 0. + * The signed shift right by 31 broadcasts bit 31 into all bits of the zero_mask. Thus + * zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t) (nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) becomes an 8-bit field and 10-bit mantissa + * shifts into the 10 high bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the different in exponent bias + * (0x7F for single-precision number less 0xF for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to account for renormalization. As renorm_shift + * is less than 0x70, this can be combined with step 3. + * 5. Binary ANDNOT with zero_mask to turn the mantissa and exponent into zero if the input was zero. + * 6. Combine with the sign of the input number. + */ + return sign | (((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) & ~zero_mask); +} + +/* + * Convert a 16-bit floating-point number in ARM alternative half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline float fp16_alt_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, the exponent is adjusted for the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70). This operation never overflows or generates non-finite values, as the largest + * half-precision exponent is 0x1F and after the adjustment is can not exceed 0x8F < 0xFE (largest single-precision + * exponent for non-finite values). + * + * Note that this operation does not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0x70) << 23; + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset); + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +/* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * ARM alternative half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ +static inline uint16_t fp16_alt_from_fp32_value(float f) { + const uint32_t w = fp32_to_bits(f); + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t shl1_w = w + w; + + const uint32_t shl1_max_fp16_fp32 = UINT32_C(0x8FFFC000); + const uint32_t shl1_base = shl1_w > shl1_max_fp16_fp32 ? shl1_max_fp16_fp32 : shl1_w; + uint32_t shl1_bias = shl1_base & UINT32_C(0xFF000000); + const uint32_t exp_difference = 23 - 10; + const uint32_t shl1_bias_min = (127 - 1 - exp_difference) << 24; + if (shl1_bias < shl1_bias_min) { + shl1_bias = shl1_bias_min; + } + + const float bias = fp32_from_bits((shl1_bias >> 1) + ((exp_difference + 2) << 23)); + const float base = fp32_from_bits((shl1_base >> 1) + (2 << 23)) + bias; + + const uint32_t exp_f = fp32_to_bits(base) >> 13; + return (sign >> 16) | ((exp_f & UINT32_C(0x00007C00)) + (fp32_to_bits(base) & UINT32_C(0x00000FFF))); +} + +#endif /* FP16_FP16_H */ diff --git a/src/inline-thirdparty/fp16/fp16/psimd.h b/src/inline-thirdparty/fp16/fp16/psimd.h new file mode 100644 index 000000000000..428ab0651de9 --- /dev/null +++ b/src/inline-thirdparty/fp16/fp16/psimd.h @@ -0,0 +1,131 @@ +#pragma once +#ifndef FP16_PSIMD_H +#define FP16_PSIMD_H + +#if defined(__cplusplus) && (__cplusplus >= 201103L) + #include +#elif !defined(__OPENCL_VERSION__) + #include +#endif + +#include + + +PSIMD_INTRINSIC psimd_f32 fp16_ieee_to_fp32_psimd(psimd_u16 half) { + const psimd_u32 word = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half); + + const psimd_u32 sign = word & psimd_splat_u32(UINT32_C(0x80000000)); + const psimd_u32 shr3_nonsign = (word + word) >> psimd_splat_u32(4); + + const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x70000000)); +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const psimd_f32 exp_scale = psimd_splat_f32(0x1.0p-112f); +#else + const psimd_f32 exp_scale = psimd_splat_f32(fp32_from_bits(UINT32_C(0x7800000))); +#endif + const psimd_f32 norm_nonsign = psimd_mul_f32((psimd_f32) (shr3_nonsign + exp_offset), exp_scale); + + const psimd_u16 magic_mask = psimd_splat_u16(UINT16_C(0x3E80)); + const psimd_f32 magic_bias = psimd_splat_f32(0.25f); + const psimd_f32 denorm_nonsign = psimd_sub_f32((psimd_f32) psimd_interleave_lo_u16(half + half, magic_mask), magic_bias); + + const psimd_s32 denorm_cutoff = psimd_splat_s32(INT32_C(0x00800000)); + const psimd_s32 denorm_mask = (psimd_s32) shr3_nonsign < denorm_cutoff; + return (psimd_f32) (sign | (psimd_s32) psimd_blend_f32(denorm_mask, denorm_nonsign, norm_nonsign)); +} + +PSIMD_INTRINSIC psimd_f32x2 fp16_ieee_to_fp32x2_psimd(psimd_u16 half) { + const psimd_u32 word_lo = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half); + const psimd_u32 word_hi = (psimd_u32) psimd_interleave_hi_u16(psimd_zero_u16(), half); + + const psimd_u32 sign_mask = psimd_splat_u32(UINT32_C(0x80000000)); + const psimd_u32 sign_lo = word_lo & sign_mask; + const psimd_u32 sign_hi = word_hi & sign_mask; + const psimd_u32 shr3_nonsign_lo = (word_lo + word_lo) >> psimd_splat_u32(4); + const psimd_u32 shr3_nonsign_hi = (word_hi + word_hi) >> psimd_splat_u32(4); + + const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x70000000)); +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const psimd_f32 exp_scale = psimd_splat_f32(0x1.0p-112f); +#else + const psimd_f32 exp_scale = psimd_splat_f32(fp32_from_bits(UINT32_C(0x7800000))); +#endif + const psimd_f32 norm_nonsign_lo = psimd_mul_f32((psimd_f32) (shr3_nonsign_lo + exp_offset), exp_scale); + const psimd_f32 norm_nonsign_hi = psimd_mul_f32((psimd_f32) (shr3_nonsign_hi + exp_offset), exp_scale); + + const psimd_u16 magic_mask = psimd_splat_u16(UINT16_C(0x3E80)); + const psimd_u16 shl1_half = half + half; + const psimd_f32 magic_bias = psimd_splat_f32(0.25f); + const psimd_f32 denorm_nonsign_lo = psimd_sub_f32((psimd_f32) psimd_interleave_lo_u16(shl1_half, magic_mask), magic_bias); + const psimd_f32 denorm_nonsign_hi = psimd_sub_f32((psimd_f32) psimd_interleave_hi_u16(shl1_half, magic_mask), magic_bias); + + const psimd_s32 denorm_cutoff = psimd_splat_s32(INT32_C(0x00800000)); + const psimd_s32 denorm_mask_lo = (psimd_s32) shr3_nonsign_lo < denorm_cutoff; + const psimd_s32 denorm_mask_hi = (psimd_s32) shr3_nonsign_hi < denorm_cutoff; + + psimd_f32x2 result; + result.lo = (psimd_f32) (sign_lo | (psimd_s32) psimd_blend_f32(denorm_mask_lo, denorm_nonsign_lo, norm_nonsign_lo)); + result.hi = (psimd_f32) (sign_hi | (psimd_s32) psimd_blend_f32(denorm_mask_hi, denorm_nonsign_hi, norm_nonsign_hi)); + return result; +} + +PSIMD_INTRINSIC psimd_f32 fp16_alt_to_fp32_psimd(psimd_u16 half) { + const psimd_u32 word = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half); + + const psimd_u32 sign = word & psimd_splat_u32(INT32_C(0x80000000)); + const psimd_u32 shr3_nonsign = (word + word) >> psimd_splat_u32(4); + +#if 0 + const psimd_s32 exp112_offset = psimd_splat_s32(INT32_C(0x38000000)); + const psimd_s32 nonsign_bits = (psimd_s32) shr3_nonsign + exp112_offset; + const psimd_s32 exp1_offset = psimd_splat_s32(INT32_C(0x00800000)); + const psimd_f32 two_nonsign = (psimd_f32) (nonsign_bits + exp1_offset); + const psimd_s32 exp113_offset = exp112_offset | exp1_offset; + return (psimd_f32) (sign | (psimd_s32) psimd_sub_f32(two_nonsign, (psimd_f32) psimd_max_s32(nonsign_bits, exp113_offset))); +#else + const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x38000000)); + const psimd_f32 nonsign = (psimd_f32) (shr3_nonsign + exp_offset); +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const psimd_f32 denorm_bias = psimd_splat_f32(0x1.0p-14f); +#else + const psimd_f32 denorm_bias = psimd_splat_f32(fp32_from_bits(UINT32_C(0x38800000))); +#endif + return (psimd_f32) (sign | (psimd_s32) psimd_sub_f32(psimd_add_f32(nonsign, nonsign), psimd_max_f32(nonsign, denorm_bias))); +#endif +} + +PSIMD_INTRINSIC psimd_f32x2 fp16_alt_to_fp32x2_psimd(psimd_u16 half) { + const psimd_u32 word_lo = (psimd_u32) psimd_interleave_lo_u16(psimd_zero_u16(), half); + const psimd_u32 word_hi = (psimd_u32) psimd_interleave_hi_u16(psimd_zero_u16(), half); + + const psimd_u32 sign_mask = psimd_splat_u32(UINT32_C(0x80000000)); + const psimd_u32 sign_lo = word_lo & sign_mask; + const psimd_u32 sign_hi = word_hi & sign_mask; + const psimd_u32 shr3_nonsign_lo = (word_lo + word_lo) >> psimd_splat_u32(4); + const psimd_u32 shr3_nonsign_hi = (word_hi + word_hi) >> psimd_splat_u32(4); + +#if 1 + const psimd_s32 exp112_offset = psimd_splat_s32(INT32_C(0x38000000)); + const psimd_s32 nonsign_bits_lo = (psimd_s32) shr3_nonsign_lo + exp112_offset; + const psimd_s32 nonsign_bits_hi = (psimd_s32) shr3_nonsign_hi + exp112_offset; + const psimd_s32 exp1_offset = psimd_splat_s32(INT32_C(0x00800000)); + const psimd_f32 two_nonsign_lo = (psimd_f32) (nonsign_bits_lo + exp1_offset); + const psimd_f32 two_nonsign_hi = (psimd_f32) (nonsign_bits_hi + exp1_offset); + const psimd_s32 exp113_offset = exp1_offset | exp112_offset; + psimd_f32x2 result; + result.lo = (psimd_f32) (sign_lo | (psimd_s32) psimd_sub_f32(two_nonsign_lo, (psimd_f32) psimd_max_s32(nonsign_bits_lo, exp113_offset))); + result.hi = (psimd_f32) (sign_hi | (psimd_s32) psimd_sub_f32(two_nonsign_hi, (psimd_f32) psimd_max_s32(nonsign_bits_hi, exp113_offset))); + return result; +#else + const psimd_u32 exp_offset = psimd_splat_u32(UINT32_C(0x38000000)); + const psimd_f32 nonsign_lo = (psimd_f32) (shr3_nonsign_lo + exp_offset); + const psimd_f32 nonsign_hi = (psimd_f32) (shr3_nonsign_hi + exp_offset); + const psimd_f32 denorm_bias = psimd_splat_f32(0x1.0p-14f); + psimd_f32x2 result; + result.lo = (psimd_f32) (sign_lo | (psimd_s32) psimd_sub_f32(psimd_add_f32(nonsign_lo, nonsign_lo), psimd_max_f32(nonsign_lo, denorm_bias))); + result.hi = (psimd_f32) (sign_hi | (psimd_s32) psimd_sub_f32(psimd_add_f32(nonsign_hi, nonsign_hi), psimd_max_f32(nonsign_hi, denorm_bias))); + return result; +#endif +} + +#endif /* FP16_PSIMD_H */ diff --git a/src/inline-thirdparty/usearch/usearch/index.hpp b/src/inline-thirdparty/usearch/usearch/index.hpp new file mode 100644 index 000000000000..d31c1554791d --- /dev/null +++ b/src/inline-thirdparty/usearch/usearch/index.hpp @@ -0,0 +1,3852 @@ +/** + * @file index.hpp + * @author Ash Vardanian + * @brief Single-header Vector Search. + * @date 2023-04-26 + * + * @copyright Copyright (c) 2023 + */ +#ifndef UNUM_USEARCH_HPP +#define UNUM_USEARCH_HPP + +#define USEARCH_VERSION_MAJOR 2 +#define USEARCH_VERSION_MINOR 11 +#define USEARCH_VERSION_PATCH 0 + +// Inferring C++ version +// https://stackoverflow.com/a/61552074 +#if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) +#define USEARCH_DEFINED_CPP17 +#endif + +// Inferring target OS: Windows, MacOS, or Linux +#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) +#define USEARCH_DEFINED_WINDOWS +#elif defined(__APPLE__) && defined(__MACH__) +#define USEARCH_DEFINED_APPLE +#elif defined(__linux__) +#define USEARCH_DEFINED_LINUX +#endif + +// Inferring the compiler: Clang vs GCC +#if defined(__clang__) +#define USEARCH_DEFINED_CLANG +#elif defined(__GNUC__) +#define USEARCH_DEFINED_GCC +#endif + +#if defined(__clang__) || defined(_MSC_VER) +#define USEARCH_USE_PRAGMA_REGION +#endif + +// Inferring hardware architecture: x86 vs Arm +#if defined(__x86_64__) +#define USEARCH_DEFINED_X86 +#elif defined(__aarch64__) +#define USEARCH_DEFINED_ARM +#endif + +// Inferring hardware bitness: 32 vs 64 +// https://stackoverflow.com/a/5273354 +#if INTPTR_MAX == INT64_MAX +#define USEARCH_64BIT_ENV +#elif INTPTR_MAX == INT32_MAX +#define USEARCH_32BIT_ENV +#else +#error Unknown pointer size or missing size macros! +#endif + +#if !defined(USEARCH_USE_OPENMP) +#define USEARCH_USE_OPENMP 0 +#endif + +// OS-specific includes +#if defined(USEARCH_DEFINED_WINDOWS) +#define _USE_MATH_DEFINES +#define NOMINMAX +#include +#include // `fstat` for file size +#undef NOMINMAX +#undef _USE_MATH_DEFINES +#else +#include // `fallocate` +#include // `posix_memalign` +#include // `mmap` +#include // `fstat` for file size +#include // `open`, `close` +#endif + +// STL includes +#include // `std::sort_heap` +#include // `std::atomic` +#include // `std::bitset` +#include // `CHAR_BIT` +#include // `std::sqrt` +#include // `std::memset` +#include // `std::reverse_iterator` +#include // `std::unique_lock` - replacement candidate +#include // `std::default_random_engine` - replacement candidate +#include // `std::runtime_exception` +#include // `std::thread` +#include // `std::pair` + +// Prefetching +#if defined(USEARCH_DEFINED_GCC) +// https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html +// Zero means we are only going to read from that memory. +// Three means high temporal locality and suggests to keep +// the data in all layers of cache. +#define prefetch_m(ptr) __builtin_prefetch((void*)(ptr), 0, 3) +#elif defined(USEARCH_DEFINED_X86) +#define prefetch_m(ptr) _mm_prefetch((void*)(ptr), _MM_HINT_T0) +#else +#define prefetch_m(ptr) +#endif + +// Alignment +#if defined(USEARCH_DEFINED_WINDOWS) +#define usearch_pack_m +#define usearch_align_m __declspec(align(64)) +#else +#define usearch_pack_m __attribute__((packed)) +#define usearch_align_m __attribute__((aligned(64))) +#endif + +// Debugging +#if defined(NDEBUG) +#define usearch_assert_m(must_be_true, message) +#define usearch_noexcept_m noexcept +#else +#define usearch_assert_m(must_be_true, message) \ + if (!(must_be_true)) { \ + throw std::runtime_error(message); \ + } +#define usearch_noexcept_m +#endif + +namespace unum { +namespace usearch { + +using byte_t = char; + +template std::size_t divide_round_up(std::size_t num) noexcept { + return (num + multiple_ak - 1) / multiple_ak; +} + +inline std::size_t divide_round_up(std::size_t num, std::size_t denominator) noexcept { + return (num + denominator - 1) / denominator; +} + +inline std::size_t ceil2(std::size_t v) noexcept { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; +#ifdef USEARCH_64BIT_ENV + v |= v >> 32; +#endif + v++; + return v; +} + +/// @brief Simply dereferencing misaligned pointers can be dangerous. +template void misaligned_store(void* ptr, at v) noexcept { + static_assert(!std::is_reference::value, "Can't store a reference"); + std::memcpy(ptr, &v, sizeof(at)); +} + +/// @brief Simply dereferencing misaligned pointers can be dangerous. +template at misaligned_load(void* ptr) noexcept { + static_assert(!std::is_reference::value, "Can't load a reference"); + at v; + std::memcpy(&v, ptr, sizeof(at)); + return v; +} + +/// @brief The `std::exchange` alternative for C++11. +template at exchange(at& obj, other_at&& new_value) { + at old_value = std::move(obj); + obj = std::forward(new_value); + return old_value; +} + +/// @brief The `std::destroy_at` alternative for C++11. +template +typename std::enable_if::value>::type destroy_at(at*) {} +template +typename std::enable_if::value>::type destroy_at(at* obj) { + obj->~sfinae_at(); +} + +/// @brief The `std::construct_at` alternative for C++11. +template +typename std::enable_if::value>::type construct_at(at*) {} +template +typename std::enable_if::value>::type construct_at(at* obj) { + new (obj) at(); +} + +/** + * @brief A reference to a misaligned memory location with a specific type. + * It is needed to avoid Undefined Behavior when dereferencing addresses + * indivisible by `sizeof(at)`. + */ +template class misaligned_ref_gt { + using element_t = at; + using mutable_t = typename std::remove_const::type; + byte_t* ptr_; + + public: + misaligned_ref_gt(byte_t* ptr) noexcept : ptr_(ptr) {} + operator mutable_t() const noexcept { return misaligned_load(ptr_); } + misaligned_ref_gt& operator=(mutable_t const& v) noexcept { + misaligned_store(ptr_, v); + return *this; + } + + void reset(byte_t* ptr) noexcept { ptr_ = ptr; } + byte_t* ptr() const noexcept { return ptr_; } +}; + +/** + * @brief A pointer to a misaligned memory location with a specific type. + * It is needed to avoid Undefined Behavior when dereferencing addresses + * indivisible by `sizeof(at)`. + */ +template class misaligned_ptr_gt { + using element_t = at; + using mutable_t = typename std::remove_const::type; + byte_t* ptr_; + + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = misaligned_ptr_gt; + using reference = misaligned_ref_gt; + + reference operator*() const noexcept { return {ptr_}; } + reference operator[](std::size_t i) noexcept { return reference(ptr_ + i * sizeof(element_t)); } + value_type operator[](std::size_t i) const noexcept { + return misaligned_load(ptr_ + i * sizeof(element_t)); + } + + misaligned_ptr_gt(byte_t* ptr) noexcept : ptr_(ptr) {} + misaligned_ptr_gt operator++(int) noexcept { return misaligned_ptr_gt(ptr_ + sizeof(element_t)); } + misaligned_ptr_gt operator--(int) noexcept { return misaligned_ptr_gt(ptr_ - sizeof(element_t)); } + misaligned_ptr_gt operator+(difference_type d) noexcept { return misaligned_ptr_gt(ptr_ + d * sizeof(element_t)); } + misaligned_ptr_gt operator-(difference_type d) noexcept { return misaligned_ptr_gt(ptr_ - d * sizeof(element_t)); } + + // clang-format off + misaligned_ptr_gt& operator++() noexcept { ptr_ += sizeof(element_t); return *this; } + misaligned_ptr_gt& operator--() noexcept { ptr_ -= sizeof(element_t); return *this; } + misaligned_ptr_gt& operator+=(difference_type d) noexcept { ptr_ += d * sizeof(element_t); return *this; } + misaligned_ptr_gt& operator-=(difference_type d) noexcept { ptr_ -= d * sizeof(element_t); return *this; } + // clang-format on + + bool operator==(misaligned_ptr_gt const& other) noexcept { return ptr_ == other.ptr_; } + bool operator!=(misaligned_ptr_gt const& other) noexcept { return ptr_ != other.ptr_; } +}; + +/** + * @brief Non-owning memory range view, similar to `std::span`, but for C++11. + */ +template class span_gt { + scalar_at* data_; + std::size_t size_; + + public: + span_gt() noexcept : data_(nullptr), size_(0u) {} + span_gt(scalar_at* begin, scalar_at* end) noexcept : data_(begin), size_(end - begin) {} + span_gt(scalar_at* begin, std::size_t size) noexcept : data_(begin), size_(size) {} + scalar_at* data() const noexcept { return data_; } + std::size_t size() const noexcept { return size_; } + scalar_at* begin() const noexcept { return data_; } + scalar_at* end() const noexcept { return data_ + size_; } + operator scalar_at*() const noexcept { return data(); } +}; + +/** + * @brief Similar to `std::vector`, but doesn't support dynamic resizing. + * On the bright side, this can't throw exceptions. + */ +template > class buffer_gt { + scalar_at* data_; + std::size_t size_; + + public: + buffer_gt() noexcept : data_(nullptr), size_(0u) {} + buffer_gt(std::size_t size) noexcept : data_(allocator_at{}.allocate(size)), size_(data_ ? size : 0u) { + if (!std::is_trivially_default_constructible::value) + for (std::size_t i = 0; i != size_; ++i) + construct_at(data_ + i); + } + ~buffer_gt() noexcept { + if (!std::is_trivially_destructible::value) + for (std::size_t i = 0; i != size_; ++i) + destroy_at(data_ + i); + allocator_at{}.deallocate(data_, size_); + data_ = nullptr; + size_ = 0; + } + scalar_at* data() const noexcept { return data_; } + std::size_t size() const noexcept { return size_; } + scalar_at* begin() const noexcept { return data_; } + scalar_at* end() const noexcept { return data_ + size_; } + operator scalar_at*() const noexcept { return data(); } + scalar_at& operator[](std::size_t i) noexcept { return data_[i]; } + scalar_at const& operator[](std::size_t i) const noexcept { return data_[i]; } + explicit operator bool() const noexcept { return data_; } + scalar_at* release() noexcept { + size_ = 0; + return exchange(data_, nullptr); + } + + buffer_gt(buffer_gt const&) = delete; + buffer_gt& operator=(buffer_gt const&) = delete; + + buffer_gt(buffer_gt&& other) noexcept : data_(exchange(other.data_, nullptr)), size_(exchange(other.size_, 0)) {} + buffer_gt& operator=(buffer_gt&& other) noexcept { + std::swap(data_, other.data_); + std::swap(size_, other.size_); + return *this; + } +}; + +/** + * @brief A lightweight error class for handling error messages, + * which are expected to be allocated in static memory. + */ +class error_t { + char const* message_{}; + + public: + error_t(char const* message = nullptr) noexcept : message_(message) {} + error_t& operator=(char const* message) noexcept { + message_ = message; + return *this; + } + + error_t(error_t const&) = delete; + error_t& operator=(error_t const&) = delete; + error_t(error_t&& other) noexcept : message_(exchange(other.message_, nullptr)) {} + error_t& operator=(error_t&& other) noexcept { + std::swap(message_, other.message_); + return *this; + } + explicit operator bool() const noexcept { return message_ != nullptr; } + char const* what() const noexcept { return message_; } + char const* release() noexcept { return exchange(message_, nullptr); } + +#if defined(__cpp_exceptions) || defined(__EXCEPTIONS) + ~error_t() noexcept(false) { +#if defined(USEARCH_DEFINED_CPP17) + if (message_ && std::uncaught_exceptions() == 0) +#else + if (message_ && std::uncaught_exception() == 0) +#endif + raise(); + } + void raise() noexcept(false) { + if (message_) + throw std::runtime_error(exchange(message_, nullptr)); + } +#else + ~error_t() noexcept { raise(); } + void raise() noexcept { + if (message_) + std::terminate(); + } +#endif +}; + +/** + * @brief Similar to `std::expected` in C++23, wraps a statement evaluation result, + * or an error. It's used to avoid raising exception, and gracefully propagate + * the error. + * + * @tparam result_at The type of the expected result. + */ +template struct expected_gt { + result_at result; + error_t error; + + operator result_at&() & { + error.raise(); + return result; + } + operator result_at&&() && { + error.raise(); + return std::move(result); + } + result_at const& operator*() const noexcept { return result; } + explicit operator bool() const noexcept { return !error; } + expected_gt failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } +}; + +/** + * @brief Light-weight bitset implementation to sync nodes updates during graph mutations. + * Extends basic functionality with @b atomic operations. + */ +template > class bitset_gt { + using allocator_t = allocator_at; + using byte_t = typename allocator_t::value_type; + static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); + + using compressed_slot_t = unsigned long; + + static constexpr std::size_t bits_per_slot() { return sizeof(compressed_slot_t) * CHAR_BIT; } + static constexpr compressed_slot_t bits_mask() { return sizeof(compressed_slot_t) * CHAR_BIT - 1; } + static constexpr std::size_t slots(std::size_t bits) { return divide_round_up(bits); } + + compressed_slot_t* slots_{}; + /// @brief Number of slots. + std::size_t count_{}; + + public: + bitset_gt() noexcept {} + ~bitset_gt() noexcept { reset(); } + + explicit operator bool() const noexcept { return slots_; } + void clear() noexcept { + if (slots_) + std::memset(slots_, 0, count_ * sizeof(compressed_slot_t)); + } + + void reset() noexcept { + if (slots_) + allocator_t{}.deallocate((byte_t*)slots_, count_ * sizeof(compressed_slot_t)); + slots_ = nullptr; + count_ = 0; + } + + bitset_gt(std::size_t capacity) noexcept + : slots_((compressed_slot_t*)allocator_t{}.allocate(slots(capacity) * sizeof(compressed_slot_t))), + count_(slots_ ? slots(capacity) : 0u) { + clear(); + } + + bitset_gt(bitset_gt&& other) noexcept { + slots_ = exchange(other.slots_, nullptr); + count_ = exchange(other.count_, 0); + } + + bitset_gt& operator=(bitset_gt&& other) noexcept { + std::swap(slots_, other.slots_); + std::swap(count_, other.count_); + return *this; + } + + bitset_gt(bitset_gt const&) = delete; + bitset_gt& operator=(bitset_gt const&) = delete; + + inline bool test(std::size_t i) const noexcept { return slots_[i / bits_per_slot()] & (1ul << (i & bits_mask())); } + inline bool set(std::size_t i) noexcept { + compressed_slot_t& slot = slots_[i / bits_per_slot()]; + compressed_slot_t mask{1ul << (i & bits_mask())}; + bool value = slot & mask; + slot |= mask; + return value; + } + +#if defined(USEARCH_DEFINED_WINDOWS) + + inline bool atomic_set(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + return InterlockedOr((long volatile*)&slots_[i / bits_per_slot()], mask) & mask; + } + + inline void atomic_reset(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + InterlockedAnd((long volatile*)&slots_[i / bits_per_slot()], ~mask); + } + +#else + + inline bool atomic_set(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + return __atomic_fetch_or(&slots_[i / bits_per_slot()], mask, __ATOMIC_ACQUIRE) & mask; + } + + inline void atomic_reset(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + __atomic_fetch_and(&slots_[i / bits_per_slot()], ~mask, __ATOMIC_RELEASE); + } + +#endif + + class lock_t { + bitset_gt& bitset_; + std::size_t bit_offset_; + + public: + inline ~lock_t() noexcept { bitset_.atomic_reset(bit_offset_); } + inline lock_t(bitset_gt& bitset, std::size_t bit_offset) noexcept : bitset_(bitset), bit_offset_(bit_offset) { + while (bitset_.atomic_set(bit_offset_)) + ; + } + }; + + inline lock_t lock(std::size_t i) noexcept { return {*this, i}; } +}; + +using bitset_t = bitset_gt<>; + +/** + * @brief Similar to `std::priority_queue`, but allows raw access to underlying + * memory, in case you want to shuffle it or sort. Good for collections + * from 100s to 10'000s elements. + */ +template , // is needed before C++14. + typename allocator_at = std::allocator> // +class max_heap_gt { + public: + using element_t = element_at; + using comparator_t = comparator_at; + using allocator_t = allocator_at; + + using value_type = element_t; + + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + + private: + element_t* elements_; + std::size_t size_; + std::size_t capacity_; + + public: + max_heap_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) {} + + max_heap_gt(max_heap_gt&& other) noexcept + : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), + capacity_(exchange(other.capacity_, 0)) {} + + max_heap_gt& operator=(max_heap_gt&& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + return *this; + } + + max_heap_gt(max_heap_gt const&) = delete; + max_heap_gt& operator=(max_heap_gt const&) = delete; + + ~max_heap_gt() noexcept { reset(); } + + void reset() noexcept { + if (elements_) + allocator_t{}.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + size_ = 0; + } + + inline bool empty() const noexcept { return !size_; } + inline std::size_t size() const noexcept { return size_; } + inline std::size_t capacity() const noexcept { return capacity_; } + + /// @brief Selects the largest element in the heap. + /// @return Reference to the stored element. + inline element_t const& top() const noexcept { return elements_[0]; } + inline void clear() noexcept { size_ = 0; } + + bool reserve(std::size_t new_capacity) noexcept { + if (new_capacity < capacity_) + return true; + + new_capacity = ceil2(new_capacity); + new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); + auto allocator = allocator_t{}; + auto new_elements = allocator.allocate(new_capacity); + if (!new_elements) + return false; + + if (elements_) { + std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); + allocator.deallocate(elements_, capacity_); + } + elements_ = new_elements; + capacity_ = new_capacity; + return new_elements; + } + + bool insert(element_t&& element) noexcept { + if (!reserve(size_ + 1)) + return false; + + insert_reserved(std::move(element)); + return true; + } + + inline void insert_reserved(element_t&& element) noexcept { + new (&elements_[size_]) element_t(element); + size_++; + shift_up(size_ - 1); + } + + inline element_t pop() noexcept { + element_t result = top(); + std::swap(elements_[0], elements_[size_ - 1]); + size_--; + elements_[size_].~element_t(); + shift_down(0); + return result; + } + + /** @brief Invalidates the "max-heap" property, transforming into ascending range. */ + inline void sort_ascending() noexcept { std::sort_heap(elements_, elements_ + size_, &less); } + inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } + + inline element_t* data() noexcept { return elements_; } + inline element_t const* data() const noexcept { return elements_; } + + private: + inline std::size_t parent_idx(std::size_t i) const noexcept { return (i - 1u) / 2u; } + inline std::size_t left_child_idx(std::size_t i) const noexcept { return (i * 2u) + 1u; } + inline std::size_t right_child_idx(std::size_t i) const noexcept { return (i * 2u) + 2u; } + static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } + + void shift_up(std::size_t i) noexcept { + for (; i && less(elements_[parent_idx(i)], elements_[i]); i = parent_idx(i)) + std::swap(elements_[parent_idx(i)], elements_[i]); + } + + void shift_down(std::size_t i) noexcept { + std::size_t max_idx = i; + + std::size_t left = left_child_idx(i); + if (left < size_ && less(elements_[max_idx], elements_[left])) + max_idx = left; + + std::size_t right = right_child_idx(i); + if (right < size_ && less(elements_[max_idx], elements_[right])) + max_idx = right; + + if (i != max_idx) { + std::swap(elements_[i], elements_[max_idx]); + shift_down(max_idx); + } + } +}; + +/** + * @brief Similar to `std::priority_queue`, but allows raw access to underlying + * memory and always keeps the data sorted. Ideal for small collections + * under 128 elements. + */ +template , // is needed before C++14. + typename allocator_at = std::allocator> // +class sorted_buffer_gt { + public: + using element_t = element_at; + using comparator_t = comparator_at; + using allocator_t = allocator_at; + + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + + using value_type = element_t; + + private: + element_t* elements_; + std::size_t size_; + std::size_t capacity_; + + public: + sorted_buffer_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) {} + + sorted_buffer_gt(sorted_buffer_gt&& other) noexcept + : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), + capacity_(exchange(other.capacity_, 0)) {} + + sorted_buffer_gt& operator=(sorted_buffer_gt&& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + return *this; + } + + sorted_buffer_gt(sorted_buffer_gt const&) = delete; + sorted_buffer_gt& operator=(sorted_buffer_gt const&) = delete; + + ~sorted_buffer_gt() noexcept { reset(); } + + void reset() noexcept { + if (elements_) + allocator_t{}.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + size_ = 0; + } + + inline bool empty() const noexcept { return !size_; } + inline std::size_t size() const noexcept { return size_; } + inline std::size_t capacity() const noexcept { return capacity_; } + inline element_t const& top() const noexcept { return elements_[size_ - 1]; } + inline void clear() noexcept { size_ = 0; } + + bool reserve(std::size_t new_capacity) noexcept { + if (new_capacity < capacity_) + return true; + + new_capacity = ceil2(new_capacity); + new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); + auto allocator = allocator_t{}; + auto new_elements = allocator.allocate(new_capacity); + if (!new_elements) + return false; + + if (size_) + std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); + if (elements_) + allocator.deallocate(elements_, capacity_); + + elements_ = new_elements; + capacity_ = new_capacity; + return true; + } + + inline void insert_reserved(element_t&& element) noexcept { + std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; + std::size_t to_move = size_ - slot; + element_t* source = elements_ + size_ - 1; + for (; to_move; --to_move, --source) + source[1] = source[0]; + elements_[slot] = element; + size_++; + } + + /** + * @return `true` if the entry was added, `false` if it wasn't relevant enough. + */ + inline bool insert(element_t&& element, std::size_t limit) noexcept { + std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; + if (slot == limit) + return false; + std::size_t to_move = size_ - slot - (size_ == limit); + element_t* source = elements_ + size_ - 1 - (size_ == limit); + for (; to_move; --to_move, --source) + source[1] = source[0]; + elements_[slot] = element; + size_ += size_ != limit; + return true; + } + + inline element_t pop() noexcept { + size_--; + element_t result = elements_[size_]; + elements_[size_].~element_t(); + return result; + } + + void sort_ascending() noexcept {} + inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } + + inline element_t* data() noexcept { return elements_; } + inline element_t const* data() const noexcept { return elements_; } + + private: + static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } +}; + +#if defined(USEARCH_DEFINED_WINDOWS) +#pragma pack(push, 1) // Pack struct elements on 1-byte alignment +#endif + +/** + * @brief Five-byte integer type to address node clouds with over 4B entries. + * + * @note Avoid usage in 32bit environment + */ +class usearch_pack_m uint40_t { + unsigned char octets[5]; + + inline uint40_t& broadcast(unsigned char c) { + std::memset(octets, c, 5); + return *this; + } + + public: + inline uint40_t() noexcept { broadcast(0); } + inline uint40_t(std::uint32_t n) noexcept { std::memcpy(&octets[1], &n, 4); } + +#ifdef USEARCH_64BIT_ENV + inline uint40_t(std::uint64_t n) noexcept { std::memcpy(octets, &n, 5); } +#endif + + uint40_t(uint40_t&&) = default; + uint40_t(uint40_t const&) = default; + uint40_t& operator=(uint40_t&&) = default; + uint40_t& operator=(uint40_t const&) = default; + +#if defined(USEARCH_DEFINED_CLANG) && defined(USEARCH_DEFINED_APPLE) + inline uint40_t(std::size_t n) noexcept { +#ifdef USEARCH_64BIT_ENV + std::memcpy(octets, &n, 5); +#else + std::memcpy(octets, &n, 4); +#endif + } +#endif + + inline operator std::size_t() const noexcept { + std::size_t result = 0; +#ifdef USEARCH_64BIT_ENV + std::memcpy(&result, octets, 5); +#else + std::memcpy(&result, octets + 1, 4); +#endif + return result; + } + + inline static uint40_t max() noexcept { return uint40_t{}.broadcast(0xFF); } + inline static uint40_t min() noexcept { return uint40_t{}.broadcast(0); } +}; + +#if defined(USEARCH_DEFINED_WINDOWS) +#pragma pack(pop) // Reset alignment to default +#endif + +static_assert(sizeof(uint40_t) == 5, "uint40_t must be exactly 5 bytes"); + +// clang-format off +template ::value>::type* = nullptr> key_at default_free_value() { return std::numeric_limits::max(); } +template ::value>::type* = nullptr> uint40_t default_free_value() { return uint40_t::max(); } +template ::value && !std::is_same::value>::type* = nullptr> key_at default_free_value() { return key_at(); } +// clang-format on + +template struct hash_gt { + std::size_t operator()(element_at const& element) const noexcept { return std::hash{}(element); } +}; + +template <> struct hash_gt { + std::size_t operator()(uint40_t const& element) const noexcept { return std::hash{}(element); } +}; + +/** + * @brief Minimalistic hash-set implementation to track visited nodes during graph traversal. + * + * It doesn't support deletion of separate objects, but supports `clear`-ing all at once. + * It expects `reserve` to be called ahead of all insertions, so no resizes are needed. + * It also assumes `0xFF...FF` slots to be unused, to simplify the design. + * It uses linear probing, the number of slots is always a power of two, and it uses linear-probing + * in case of bucket collisions. + */ +template , typename allocator_at = std::allocator> +class growing_hash_set_gt { + + using element_t = element_at; + using hasher_t = hasher_at; + + using allocator_t = allocator_at; + using byte_t = typename allocator_t::value_type; + static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); + + element_t* slots_{}; + /// @brief Number of slots. + std::size_t capacity_{}; + /// @brief Number of populated. + std::size_t count_{}; + hasher_t hasher_{}; + + public: + growing_hash_set_gt() noexcept {} + ~growing_hash_set_gt() noexcept { reset(); } + + explicit operator bool() const noexcept { return slots_; } + std::size_t size() const noexcept { return count_; } + + void clear() noexcept { + if (slots_) + std::memset((void*)slots_, 0xFF, capacity_ * sizeof(element_t)); + count_ = 0; + } + + void reset() noexcept { + if (slots_) + allocator_t{}.deallocate((byte_t*)slots_, capacity_ * sizeof(element_t)); + slots_ = nullptr; + capacity_ = 0; + count_ = 0; + } + + growing_hash_set_gt(std::size_t capacity) noexcept + : slots_((element_t*)allocator_t{}.allocate(ceil2(capacity) * sizeof(element_t))), + capacity_(slots_ ? ceil2(capacity) : 0u), count_(0u) { + clear(); + } + + growing_hash_set_gt(growing_hash_set_gt&& other) noexcept { + slots_ = exchange(other.slots_, nullptr); + capacity_ = exchange(other.capacity_, 0); + count_ = exchange(other.count_, 0); + } + + growing_hash_set_gt& operator=(growing_hash_set_gt&& other) noexcept { + std::swap(slots_, other.slots_); + std::swap(capacity_, other.capacity_); + std::swap(count_, other.count_); + return *this; + } + + growing_hash_set_gt(growing_hash_set_gt const&) = delete; + growing_hash_set_gt& operator=(growing_hash_set_gt const&) = delete; + + inline bool test(element_t const& elem) const noexcept { + std::size_t index = hasher_(elem) & (capacity_ - 1); + while (slots_[index] != default_free_value()) { + if (slots_[index] == elem) + return true; + + index = (index + 1) & (capacity_ - 1); + } + return false; + } + + /** + * + * @return Similar to `bitset_gt`, returns the previous value. + */ + inline bool set(element_t const& elem) noexcept { + std::size_t index = hasher_(elem) & (capacity_ - 1); + while (slots_[index] != default_free_value()) { + // Already exists + if (slots_[index] == elem) + return true; + + index = (index + 1) & (capacity_ - 1); + } + slots_[index] = elem; + ++count_; + return false; + } + + bool reserve(std::size_t new_capacity) noexcept { + new_capacity = (new_capacity * 5u) / 3u; + if (new_capacity <= capacity_) + return true; + + new_capacity = ceil2(new_capacity); + element_t* new_slots = (element_t*)allocator_t{}.allocate(new_capacity * sizeof(element_t)); + if (!new_slots) + return false; + + std::memset((void*)new_slots, 0xFF, new_capacity * sizeof(element_t)); + std::size_t new_count = count_; + if (count_) { + for (std::size_t old_index = 0; old_index != capacity_; ++old_index) { + if (slots_[old_index] == default_free_value()) + continue; + + std::size_t new_index = hasher_(slots_[old_index]) & (new_capacity - 1); + while (new_slots[new_index] != default_free_value()) + new_index = (new_index + 1) & (new_capacity - 1); + new_slots[new_index] = slots_[old_index]; + } + } + + reset(); + slots_ = new_slots; + capacity_ = new_capacity; + count_ = new_count; + return true; + } +}; + +/** + * @brief Basic single-threaded @b ring class, used for all kinds of task queues. + */ +template > // +class ring_gt { + public: + using element_t = element_at; + using allocator_t = allocator_at; + + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + + using value_type = element_t; + + private: + element_t* elements_{}; + std::size_t capacity_{}; + std::size_t head_{}; + std::size_t tail_{}; + bool empty_{true}; + allocator_t allocator_{}; + + public: + explicit ring_gt(allocator_t const& alloc = allocator_t()) noexcept : allocator_(alloc) {} + + ring_gt(ring_gt const&) = delete; + ring_gt& operator=(ring_gt const&) = delete; + + ring_gt(ring_gt&& other) noexcept { swap(other); } + ring_gt& operator=(ring_gt&& other) noexcept { + swap(other); + return *this; + } + + void swap(ring_gt& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(capacity_, other.capacity_); + std::swap(head_, other.head_); + std::swap(tail_, other.tail_); + std::swap(empty_, other.empty_); + std::swap(allocator_, other.allocator_); + } + + ~ring_gt() noexcept { reset(); } + + bool empty() const noexcept { return empty_; } + size_t capacity() const noexcept { return capacity_; } + size_t size() const noexcept { + if (empty_) + return 0; + else if (head_ >= tail_) + return head_ - tail_; + else + return capacity_ - (tail_ - head_); + } + + void clear() noexcept { + head_ = 0; + tail_ = 0; + empty_ = true; + } + + void reset() noexcept { + if (elements_) + allocator_.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + head_ = 0; + tail_ = 0; + empty_ = true; + } + + bool reserve(std::size_t n) noexcept { + if (n < size()) + return false; // prevent data loss + if (n <= capacity()) + return true; + n = (std::max)(ceil2(n), 64u); + element_t* elements = allocator_.allocate(n); + if (!elements) + return false; + + std::size_t i = 0; + while (try_pop(elements[i])) + i++; + + reset(); + elements_ = elements; + capacity_ = n; + head_ = i; + tail_ = 0; + empty_ = (i == 0); + return true; + } + + void push(element_t const& value) noexcept { + elements_[head_] = value; + head_ = (head_ + 1) % capacity_; + empty_ = false; + } + + bool try_push(element_t const& value) noexcept { + if (head_ == tail_ && !empty_) + return false; // elements_ is full + + return push(value); + return true; + } + + bool try_pop(element_t& value) noexcept { + if (empty_) + return false; + + value = std::move(elements_[tail_]); + tail_ = (tail_ + 1) % capacity_; + empty_ = head_ == tail_; + return true; + } + + element_t const& operator[](std::size_t i) const noexcept { return elements_[(tail_ + i) % capacity_]; } +}; + +/// @brief Number of neighbors per graph node. +/// Defaults to 32 in FAISS and 16 in hnswlib. +/// > It is called `M` in the paper. +constexpr std::size_t default_connectivity() { return 16; } + +/// @brief Hyper-parameter controlling the quality of indexing. +/// Defaults to 40 in FAISS and 200 in hnswlib. +/// > It is called `efConstruction` in the paper. +constexpr std::size_t default_expansion_add() { return 128; } + +/// @brief Hyper-parameter controlling the quality of search. +/// Defaults to 16 in FAISS and 10 in hnswlib. +/// > It is called `ef` in the paper. +constexpr std::size_t default_expansion_search() { return 64; } + +constexpr std::size_t default_allocator_entry_bytes() { return 64; } + +/** + * @brief Configuration settings for the index construction. + * Includes the main `::connectivity` parameter (`M` in the paper) + * and two expansion factors - for construction and search. + */ +struct index_config_t { + /// @brief Number of neighbors per graph node. + /// Defaults to 32 in FAISS and 16 in hnswlib. + /// > It is called `M` in the paper. + std::size_t connectivity = default_connectivity(); + + /// @brief Number of neighbors per graph node in base level graph. + /// Defaults to double of the other levels, so 64 in FAISS and 32 in hnswlib. + /// > It is called `M0` in the paper. + std::size_t connectivity_base = default_connectivity() * 2; + + inline index_config_t() = default; + inline index_config_t(std::size_t c) noexcept + : connectivity(c ? c : default_connectivity()), connectivity_base(c ? c * 2 : default_connectivity() * 2) {} + inline index_config_t(std::size_t c, std::size_t cb) noexcept + : connectivity(c), connectivity_base((std::max)(c, cb)) {} +}; + +struct index_limits_t { + std::size_t members = 0; + std::size_t threads_add = std::thread::hardware_concurrency(); + std::size_t threads_search = std::thread::hardware_concurrency(); + + inline index_limits_t(std::size_t n, std::size_t t) noexcept : members(n), threads_add(t), threads_search(t) {} + inline index_limits_t(std::size_t n = 0) noexcept : index_limits_t(n, std::thread::hardware_concurrency()) {} + inline std::size_t threads() const noexcept { return (std::max)(threads_add, threads_search); } + inline std::size_t concurrency() const noexcept { return (std::min)(threads_add, threads_search); } +}; + +struct index_update_config_t { + /// @brief Hyper-parameter controlling the quality of indexing. + /// Defaults to 40 in FAISS and 200 in hnswlib. + /// > It is called `efConstruction` in the paper. + std::size_t expansion = default_expansion_add(); + + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; +}; + +struct index_search_config_t { + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); + + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; + + /// @brief Brute-forces exhaustive search over all entries in the index. + bool exact = false; +}; + +struct index_cluster_config_t { + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); + + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; +}; + +struct index_copy_config_t {}; + +struct index_join_config_t { + /// @brief Controls maximum number of proposals per man during stable marriage. + std::size_t max_proposals = 0; + + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); + + /// @brief Brute-forces exhaustive search over all entries in the index. + bool exact = false; +}; + +/// @brief C++17 and newer version deprecate the `std::result_of` +template +using return_type_gt = +#if defined(USEARCH_DEFINED_CPP17) + typename std::invoke_result::type; +#else + typename std::result_of::type; +#endif + +/** + * @brief An example of what a USearch-compatible ad-hoc filter would look like. + * + * A similar function object can be passed to search queries to further filter entries + * on their auxiliary properties, such as some categorical keys stored in an external DBMS. + */ +struct dummy_predicate_t { + template constexpr bool operator()(member_at&&) const noexcept { return true; } +}; + +/** + * @brief An example of what a USearch-compatible ad-hoc operation on in-flight entries. + * + * This kind of callbacks is used when the engine is being updated and you want to patch + * the entries, while their are still under locks - limiting concurrent access and providing + * consistency. + */ +struct dummy_callback_t { + template void operator()(member_at&&) const noexcept {} +}; + +/** + * @brief An example of what a USearch-compatible progress-bar should look like. + * + * This is particularly helpful when handling long-running tasks, like serialization, + * saving, and loading from disk, or index-level joins. + * The reporter checks return value to continue or stop the process, `false` means need to stop. + */ +struct dummy_progress_t { + inline bool operator()(std::size_t /*processed*/, std::size_t /*total*/) const noexcept { return true; } +}; + +/** + * @brief An example of what a USearch-compatible values prefetching mechanism should look like. + * + * USearch is designed to handle very large datasets, that may not fir into RAM. Fetching from + * external memory is very expensive, so we've added a pre-fetching mechanism, that accepts + * multiple objects at once, to cache in RAM ahead of the computation. + * The received iterators support both `get_slot` and `get_key` operations. + * An example usage may look like this: + * + * template + * inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept { + * for (; begin != end; ++begin) + * io_uring_prefetch(offset_in_file(get_key(begin))); + * } + */ +struct dummy_prefetch_t { + template + inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept {} +}; + +/** + * @brief An example of what a USearch-compatible executor (thread-pool) should look like. + * + * It's expected to have `parallel(callback)` API to schedule one task per thread; + * an identical `fixed(count, callback)` and `dynamic(count, callback)` overloads that also accepts + * the number of tasks, and somehow schedules them between threads; as well as `size()` to + * determine the number of available threads. + */ +struct dummy_executor_t { + dummy_executor_t() noexcept {} + std::size_t size() const noexcept { return 1; } + + template + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept { + for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) + thread_aware_function(0, task_idx); + } + + template + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept { + for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) + if (!thread_aware_function(0, task_idx)) + break; + } + + template + void parallel(thread_aware_function_at&& thread_aware_function) noexcept { + thread_aware_function(0); + } +}; + +/** + * @brief An example of what a USearch-compatible key-to-key mapping should look like. + * + * This is particularly helpful for "Semantic Joins", where we map entries of one collection + * to entries of another. In asymmetric setups, where A -> B is needed, but B -> A is not, + * this can be passed to minimize memory usage. + */ +struct dummy_key_to_key_mapping_t { + struct member_ref_t { + template member_ref_t& operator=(key_at&&) noexcept { return *this; } + }; + template member_ref_t operator[](key_at&&) const noexcept { return {}; } +}; + +/** + * @brief Checks if the provided object has a dummy type, emulating an interface, + * but performing no real computation. + */ +template static constexpr bool is_dummy() { + using object_t = typename std::remove_all_extents::type; + return std::is_same::type, dummy_predicate_t>::value || // + std::is_same::type, dummy_callback_t>::value || // + std::is_same::type, dummy_progress_t>::value || // + std::is_same::type, dummy_prefetch_t>::value || // + std::is_same::type, dummy_executor_t>::value || // + std::is_same::type, dummy_key_to_key_mapping_t>::value; +} + +template struct has_reset_gt { + static_assert(std::integral_constant::value, "Second template parameter needs to be of function type."); +}; + +template +struct has_reset_gt { + private: + template + static constexpr auto check(at*) -> + typename std::is_same().reset(std::declval()...)), return_at>::type; + template static constexpr std::false_type check(...); + + typedef decltype(check(0)) type; + + public: + static constexpr bool value = type::value; +}; + +/** + * @brief Checks if a certain class has a member function called `reset`. + */ +template constexpr bool has_reset() { return has_reset_gt::value; } + +struct serialization_result_t { + error_t error; + + explicit operator bool() const noexcept { return !error; } + serialization_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } +}; + +/** + * @brief Smart-pointer wrapping the LibC @b `FILE` for binary file @b outputs. + * + * This class raises no exceptions and corresponds errors through `serialization_result_t`. + * The class automatically closes the file when the object is destroyed. + */ +class output_file_t { + char const* path_ = nullptr; + std::FILE* file_ = nullptr; + + public: + output_file_t(char const* path) noexcept : path_(path) {} + ~output_file_t() noexcept { close(); } + output_file_t(output_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) {} + output_file_t& operator=(output_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(file_, other.file_); + return *this; + } + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!file_) + file_ = std::fopen(path_, "wb"); + if (!file_) + return result.failed(std::strerror(errno)); + return result; + } + serialization_result_t write(void const* begin, std::size_t length) noexcept { + serialization_result_t result; + std::size_t written = std::fwrite(begin, length, 1, file_); + if (length && !written) + return result.failed(std::strerror(errno)); + return result; + } + void close() noexcept { + if (file_) + std::fclose(exchange(file_, nullptr)); + } +}; + +/** + * @brief Smart-pointer wrapping the LibC @b `FILE` for binary files @b inputs. + * + * This class raises no exceptions and corresponds errors through `serialization_result_t`. + * The class automatically closes the file when the object is destroyed. + */ +class input_file_t { + char const* path_ = nullptr; + std::FILE* file_ = nullptr; + + public: + input_file_t(char const* path) noexcept : path_(path) {} + ~input_file_t() noexcept { close(); } + input_file_t(input_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) {} + input_file_t& operator=(input_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(file_, other.file_); + return *this; + } + + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!file_) + file_ = std::fopen(path_, "rb"); + if (!file_) + return result.failed(std::strerror(errno)); + return result; + } + serialization_result_t read(void* begin, std::size_t length) noexcept { + serialization_result_t result; + std::size_t read = std::fread(begin, length, 1, file_); + if (length && !read) + return result.failed(std::feof(file_) ? "End of file reached!" : std::strerror(errno)); + return result; + } + void close() noexcept { + if (file_) + std::fclose(exchange(file_, nullptr)); + } + + explicit operator bool() const noexcept { return file_; } + bool seek_to(std::size_t progress) noexcept { + return std::fseek(file_, static_cast(progress), SEEK_SET) == 0; + } + bool seek_to_end() noexcept { return std::fseek(file_, 0L, SEEK_END) == 0; } + bool infer_progress(std::size_t& progress) noexcept { + long int result = std::ftell(file_); + if (result == -1L) + return false; + progress = static_cast(result); + return true; + } +}; + +/** + * @brief Represents a memory-mapped file or a pre-allocated anonymous memory region. + * + * This class provides a convenient way to memory-map a file and access its contents as a block of + * memory. The class handles platform-specific memory-mapping operations on Windows, Linux, and MacOS. + * The class automatically closes the file when the object is destroyed. + */ +class memory_mapped_file_t { + char const* path_{}; /**< The path to the file to be memory-mapped. */ + void* ptr_{}; /**< A pointer to the memory-mapping. */ + size_t length_{}; /**< The length of the memory-mapped file in bytes. */ + +#if defined(USEARCH_DEFINED_WINDOWS) + HANDLE file_handle_{}; /**< The file handle on Windows. */ + HANDLE mapping_handle_{}; /**< The mapping handle on Windows. */ +#else + int file_descriptor_{}; /**< The file descriptor on Linux and MacOS. */ +#endif + + public: + explicit operator bool() const noexcept { return ptr_ != nullptr; } + byte_t* data() noexcept { return reinterpret_cast(ptr_); } + byte_t const* data() const noexcept { return reinterpret_cast(ptr_); } + std::size_t size() const noexcept { return static_cast(length_); } + + memory_mapped_file_t() noexcept {} + memory_mapped_file_t(char const* path) noexcept : path_(path) {} + ~memory_mapped_file_t() noexcept { close(); } + memory_mapped_file_t(memory_mapped_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), ptr_(exchange(other.ptr_, nullptr)), + length_(exchange(other.length_, 0)), +#if defined(USEARCH_DEFINED_WINDOWS) + file_handle_(exchange(other.file_handle_, nullptr)), mapping_handle_(exchange(other.mapping_handle_, nullptr)) +#else + file_descriptor_(exchange(other.file_descriptor_, 0)) +#endif + { + } + + memory_mapped_file_t(byte_t* data, std::size_t length) noexcept : ptr_(data), length_(length) {} + + memory_mapped_file_t& operator=(memory_mapped_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(ptr_, other.ptr_); + std::swap(length_, other.length_); +#if defined(USEARCH_DEFINED_WINDOWS) + std::swap(file_handle_, other.file_handle_); + std::swap(mapping_handle_, other.mapping_handle_); +#else + std::swap(file_descriptor_, other.file_descriptor_); +#endif + return *this; + } + + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!path_ || ptr_) + return result; + +#if defined(USEARCH_DEFINED_WINDOWS) + + HANDLE file_handle = + CreateFile(path_, GENERIC_READ, FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); + if (file_handle == INVALID_HANDLE_VALUE) + return result.failed("Opening file failed!"); + + std::size_t file_length = GetFileSize(file_handle, 0); + HANDLE mapping_handle = CreateFileMapping(file_handle, 0, PAGE_READONLY, 0, 0, 0); + if (mapping_handle == 0) { + CloseHandle(file_handle); + return result.failed("Mapping file failed!"); + } + + byte_t* file = (byte_t*)MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_length); + if (file == 0) { + CloseHandle(mapping_handle); + CloseHandle(file_handle); + return result.failed("View the map failed!"); + } + file_handle_ = file_handle; + mapping_handle_ = mapping_handle; + ptr_ = file; + length_ = file_length; +#else + +#if defined(USEARCH_DEFINED_LINUX) + int descriptor = open(path_, O_RDONLY | O_NOATIME); +#else + int descriptor = open(path_, O_RDONLY); +#endif + if (descriptor < 0) + return result.failed(std::strerror(errno)); + + // Estimate the file size + struct stat file_stat; + int fstat_status = fstat(descriptor, &file_stat); + if (fstat_status < 0) { + ::close(descriptor); + return result.failed(std::strerror(errno)); + } + + // Map the entire file + byte_t* file = (byte_t*)mmap(NULL, file_stat.st_size, PROT_READ, MAP_SHARED, descriptor, 0); + if (file == MAP_FAILED) { + ::close(descriptor); + return result.failed(std::strerror(errno)); + } + file_descriptor_ = descriptor; + ptr_ = file; + length_ = file_stat.st_size; +#endif // Platform specific code + return result; + } + + void close() noexcept { + if (!path_) { + ptr_ = nullptr; + length_ = 0; + return; + } +#if defined(USEARCH_DEFINED_WINDOWS) + UnmapViewOfFile(ptr_); + CloseHandle(mapping_handle_); + CloseHandle(file_handle_); + mapping_handle_ = nullptr; + file_handle_ = nullptr; +#else + munmap(ptr_, length_); + ::close(file_descriptor_); + file_descriptor_ = 0; +#endif + ptr_ = nullptr; + length_ = 0; + } +}; + +struct index_serialized_header_t { + std::uint64_t size = 0; + std::uint64_t connectivity = 0; + std::uint64_t connectivity_base = 0; + std::uint64_t max_level = 0; + std::uint64_t entry_slot = 0; +}; + +using default_key_t = std::uint64_t; +using default_slot_t = std::uint32_t; +using default_distance_t = float; + +template struct member_gt { + key_at key; + std::size_t slot; +}; + +template inline std::size_t get_slot(member_gt const& m) noexcept { return m.slot; } +template inline key_at get_key(member_gt const& m) noexcept { return m.key; } + +template struct member_cref_gt { + misaligned_ref_gt key; + std::size_t slot; +}; + +template inline std::size_t get_slot(member_cref_gt const& m) noexcept { return m.slot; } +template inline key_at get_key(member_cref_gt const& m) noexcept { return m.key; } + +template struct member_ref_gt { + misaligned_ref_gt key; + std::size_t slot; + + inline operator member_cref_gt() const noexcept { return {key.ptr(), slot}; } +}; + +template inline std::size_t get_slot(member_ref_gt const& m) noexcept { return m.slot; } +template inline key_at get_key(member_ref_gt const& m) noexcept { return m.key; } + +/** + * @brief Approximate Nearest Neighbors Search @b index-structure using the + * Hierarchical Navigable Small World @b (HNSW) graphs algorithm. + * If classical containers store @b Key->Value mappings, this one can + * be seen as a network of keys, accelerating approximate @b Value~>Key visited_members. + * + * Unlike most implementations, this one is generic anc can be used for any search, + * not just within equi-dimensional vectors. Examples range from texts to similar Chess + * positions. + * + * @tparam key_at + * The type of primary objects stored in the index. + * The values, to which those map, are not managed by the same index structure. + * + * @tparam compressed_slot_at + * The smallest unsigned integer type to address indexed elements. + * It is used internally to maximize space-efficiency and is generally + * up-casted to @b `std::size_t` in public interfaces. + * Can be a built-in @b `uint32_t`, `uint64_t`, or our custom @b `uint40_t`. + * Which makes the most sense for 4B+ entry indexes. + * + * @tparam dynamic_allocator_at + * Dynamic memory allocator for temporary buffers, visits indicators, and + * priority queues, needed during construction and traversals of graphs. + * The allocated buffers may be uninitialized. + * + * @tparam tape_allocator_at + * Potentially different memory allocator for primary allocations of nodes and vectors. + * It would never `deallocate` separate entries, and would only free all the space at once. + * The allocated buffers may be uninitialized. + * + * @section Features + * + * - Thread-safe for concurrent construction, search, and updates. + * - Doesn't allocate new threads, and reuses the ones its called from. + * - Allows storing value externally, managing just the similarity index. + * - Joins. + + * @section Usage + * + * @subsection Exceptions + * + * None of the methods throw exceptions in the "Release" compilation mode. + * It may only `throw` if your memory ::dynamic_allocator_at or ::metric_at isn't + * safe to copy. + * + * @subsection Serialization + * + * When serialized, doesn't include any additional metadata. + * It is just the multi-level proximity-graph. You may want to store metadata about + * the used metric and key types somewhere else. + * + * @section Implementation Details + * + * Like every HNSW implementation, USearch builds levels of "Proximity Graphs". + * Every added vector forms a node in one or more levels of the graph. + * Every node is present in the base level. Every following level contains a smaller + * fraction of nodes. During search, the operation starts with the smaller levels + * and zooms-in on every following iteration of larger graph traversals. + * + * Just one memory allocation is performed regardless of the number of levels. + * The adjacency lists across all levels are concatenated into that single buffer. + * That buffer starts with a "head", that stores the metadata, such as the + * tallest "level" of the graph that it belongs to, the external "key", and the + * number of "dimensions" in the vector. + * + * @section Metrics, Predicates and Callbacks + * + * + * @section Smart References and Iterators + * + * - `member_citerator_t` and `member_iterator_t` have only slots, no indirections. + * + * - `member_cref_t` and `member_ref_t` contains the `slot` and a reference + * to the key. So it passes through 1 level of visited_members in `nodes_`. + * Retrieving the key via `get_key` will cause fetching yet another cache line. + * + * - `member_gt` contains an already prefetched copy of the key. + * + */ +template , // + typename tape_allocator_at = dynamic_allocator_at> // +class index_gt { + public: + using distance_t = distance_at; + using vector_key_t = key_at; + using key_t = vector_key_t; + using compressed_slot_t = compressed_slot_at; + using dynamic_allocator_t = dynamic_allocator_at; + using tape_allocator_t = tape_allocator_at; + static_assert(sizeof(vector_key_t) >= sizeof(compressed_slot_t), "Having tiny keys doesn't make sense."); + + using member_cref_t = member_cref_gt; + using member_ref_t = member_ref_gt; + + template class member_iterator_gt { + using ref_t = ref_at; + using index_t = index_at; + + friend class index_gt; + member_iterator_gt() noexcept {} + member_iterator_gt(index_t* index, std::size_t slot) noexcept : index_(index), slot_(slot) {} + + index_t* index_{}; + std::size_t slot_{}; + + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = ref_t; + using difference_type = std::ptrdiff_t; + using pointer = void; + using reference = ref_t; + + reference operator*() const noexcept { return {index_->node_at_(slot_).key(), slot_}; } + vector_key_t key() const noexcept { return index_->node_at_(slot_).key(); } + + friend inline std::size_t get_slot(member_iterator_gt const& it) noexcept { return it.slot_; } + friend inline vector_key_t get_key(member_iterator_gt const& it) noexcept { return it.key(); } + + member_iterator_gt operator++(int) noexcept { return member_iterator_gt(index_, slot_ + 1); } + member_iterator_gt operator--(int) noexcept { return member_iterator_gt(index_, slot_ - 1); } + member_iterator_gt operator+(difference_type d) noexcept { return member_iterator_gt(index_, slot_ + d); } + member_iterator_gt operator-(difference_type d) noexcept { return member_iterator_gt(index_, slot_ - d); } + + // clang-format off + member_iterator_gt& operator++() noexcept { slot_ += 1; return *this; } + member_iterator_gt& operator--() noexcept { slot_ -= 1; return *this; } + member_iterator_gt& operator+=(difference_type d) noexcept { slot_ += d; return *this; } + member_iterator_gt& operator-=(difference_type d) noexcept { slot_ -= d; return *this; } + bool operator==(member_iterator_gt const& other) const noexcept { return index_ == other.index_ && slot_ == other.slot_; } + bool operator!=(member_iterator_gt const& other) const noexcept { return index_ != other.index_ || slot_ != other.slot_; } + // clang-format on + }; + + using member_iterator_t = member_iterator_gt; + using member_citerator_t = member_iterator_gt; + + // STL compatibility: + using value_type = vector_key_t; + using allocator_type = dynamic_allocator_t; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using reference = member_ref_t; + using const_reference = member_cref_t; + using pointer = void; + using const_pointer = void; + using iterator = member_iterator_t; + using const_iterator = member_citerator_t; + using reverse_iterator = std::reverse_iterator; + using reverse_const_iterator = std::reverse_iterator; + + using dynamic_allocator_traits_t = std::allocator_traits; + using byte_t = typename dynamic_allocator_t::value_type; + static_assert( // + sizeof(byte_t) == 1, // + "Primary allocator must allocate separate addressable bytes"); + + using tape_allocator_traits_t = std::allocator_traits; + static_assert( // + sizeof(typename tape_allocator_traits_t::value_type) == 1, // + "Tape allocator must allocate separate addressable bytes"); + + private: + /** + * @brief Integer for the number of node neighbors at a specific level of the + * multi-level graph. It's selected to be `std::uint32_t` to improve the + * alignment in most common cases. + */ + using neighbors_count_t = std::uint32_t; + using level_t = std::int16_t; + + /** + * @brief How many bytes of memory are needed to form the "head" of the node. + */ + static constexpr std::size_t node_head_bytes_() { return sizeof(vector_key_t) + sizeof(level_t); } + + using nodes_mutexes_t = bitset_gt; + + using visits_hash_set_t = growing_hash_set_gt, dynamic_allocator_t>; + + struct precomputed_constants_t { + double inverse_log_connectivity{}; + std::size_t neighbors_bytes{}; + std::size_t neighbors_base_bytes{}; + }; + /// @brief A space-efficient internal data-structure used in graph traversal queues. + struct candidate_t { + distance_t distance; + compressed_slot_t slot; + inline bool operator<(candidate_t other) const noexcept { return distance < other.distance; } + }; + + using candidates_view_t = span_gt; + using candidates_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + using top_candidates_t = sorted_buffer_gt, candidates_allocator_t>; + using next_candidates_t = max_heap_gt, candidates_allocator_t>; + + /** + * @brief A loosely-structured handle for every node. One such node is created for every member. + * To minimize memory usage and maximize the number of entries per cache-line, it only + * stores to pointers. The internal tape starts with a `vector_key_t` @b key, then + * a `level_t` for the number of graph @b levels in which this member appears, + * then the { `neighbors_count_t`, `compressed_slot_t`, `compressed_slot_t` ... } sequences + * for @b each-level. + */ + class node_t { + byte_t* tape_{}; + + public: + explicit node_t(byte_t* tape) noexcept : tape_(tape) {} + byte_t* tape() const noexcept { return tape_; } + byte_t* neighbors_tape() const noexcept { return tape_ + node_head_bytes_(); } + explicit operator bool() const noexcept { return tape_; } + + node_t() = default; + node_t(node_t const&) = default; + node_t& operator=(node_t const&) = default; + + misaligned_ref_gt ckey() const noexcept { return {tape_}; } + misaligned_ref_gt key() const noexcept { return {tape_}; } + misaligned_ref_gt level() const noexcept { return {tape_ + sizeof(vector_key_t)}; } + + void key(vector_key_t v) noexcept { return misaligned_store(tape_, v); } + void level(level_t v) noexcept { return misaligned_store(tape_ + sizeof(vector_key_t), v); } + }; + + static_assert(std::is_trivially_copy_constructible::value, "Nodes must be light!"); + static_assert(std::is_trivially_destructible::value, "Nodes must be light!"); + + /** + * @brief A slice of the node's tape, containing a the list of neighbors + * for a node in a single graph level. It's pre-allocated to fit + * as many neighbors "slots", as may be needed at the target level, + * and starts with a single integer `neighbors_count_t` counter. + */ + class neighbors_ref_t { + byte_t* tape_; + + static constexpr std::size_t shift(std::size_t i = 0) { + return sizeof(neighbors_count_t) + sizeof(compressed_slot_t) * i; + } + + public: + neighbors_ref_t(byte_t* tape) noexcept : tape_(tape) {} + misaligned_ptr_gt begin() noexcept { return tape_ + shift(); } + misaligned_ptr_gt end() noexcept { return begin() + size(); } + misaligned_ptr_gt begin() const noexcept { return tape_ + shift(); } + misaligned_ptr_gt end() const noexcept { return begin() + size(); } + compressed_slot_t operator[](std::size_t i) const noexcept { + return misaligned_load(tape_ + shift(i)); + } + std::size_t size() const noexcept { return misaligned_load(tape_); } + void clear() noexcept { + neighbors_count_t n = misaligned_load(tape_); + std::memset(tape_, 0, shift(n)); + // misaligned_store(tape_, 0); + } + void push_back(compressed_slot_t slot) noexcept { + neighbors_count_t n = misaligned_load(tape_); + misaligned_store(tape_ + shift(n), slot); + misaligned_store(tape_, n + 1); + } + }; + + /** + * @brief A package of all kinds of temporary data-structures, that the threads + * would reuse to process requests. Similar to having all of those as + * separate `thread_local` global variables. + */ + struct usearch_align_m context_t { + top_candidates_t top_candidates{}; + next_candidates_t next_candidates{}; + visits_hash_set_t visits{}; + std::default_random_engine level_generator{}; + std::size_t iteration_cycles{}; + std::size_t computed_distances_count{}; + + template // + inline distance_t measure(value_at const& first, entry_at const& second, metric_at&& metric) noexcept { + static_assert( // + std::is_same::value || std::is_same::value, + "Unexpected type"); + + computed_distances_count++; + return metric(first, second); + } + + template // + inline distance_t measure(entry_at const& first, entry_at const& second, metric_at&& metric) noexcept { + static_assert( // + std::is_same::value || std::is_same::value, + "Unexpected type"); + + computed_distances_count++; + return metric(first, second); + } + }; + + index_config_t config_{}; + index_limits_t limits_{}; + + mutable dynamic_allocator_t dynamic_allocator_{}; + tape_allocator_t tape_allocator_{}; + + precomputed_constants_t pre_{}; + memory_mapped_file_t viewed_file_{}; + + /// @brief Number of "slots" available for `node_t` objects. Equals to @b `limits_.members`. + usearch_align_m mutable std::atomic nodes_capacity_{}; + + /// @brief Number of "slots" already storing non-null nodes. + usearch_align_m mutable std::atomic nodes_count_{}; + + /// @brief Controls access to `max_level_` and `entry_slot_`. + /// If any thread is updating those values, no other threads can `add()` or `search()`. + std::mutex global_mutex_{}; + + /// @brief The level of the top-most graph in the index. Grows as the logarithm of size, starts from zero. + level_t max_level_{}; + + /// @brief The slot in which the only node of the top-level graph is stored. + std::size_t entry_slot_{}; + + using nodes_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + /// @brief C-style array of `node_t` smart-pointers. + buffer_gt nodes_{}; + + /// @brief Mutex, that limits concurrent access to `nodes_`. + mutable nodes_mutexes_t nodes_mutexes_{}; + + using contexts_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + /// @brief Array of thread-specific buffers for temporary data. + mutable buffer_gt contexts_{}; + + public: + std::size_t connectivity() const noexcept { return config_.connectivity; } + std::size_t capacity() const noexcept { return nodes_capacity_; } + std::size_t size() const noexcept { return nodes_count_; } + std::size_t max_level() const noexcept { return nodes_count_ ? static_cast(max_level_) : 0; } + index_config_t const& config() const noexcept { return config_; } + index_limits_t const& limits() const noexcept { return limits_; } + bool is_immutable() const noexcept { return bool(viewed_file_); } + + /** + * @section Exceptions + * Doesn't throw, unless the ::metric's and ::allocators's throw on copy-construction. + */ + explicit index_gt( // + index_config_t config = {}, dynamic_allocator_t dynamic_allocator = {}, + tape_allocator_t tape_allocator = {}) noexcept + : config_(config), limits_(0, 0), dynamic_allocator_(std::move(dynamic_allocator)), + tape_allocator_(std::move(tape_allocator)), pre_(precompute_(config)), nodes_count_(0u), max_level_(-1), + entry_slot_(0u), nodes_(), nodes_mutexes_(), contexts_() {} + + /** + * @brief Clones the structure with the same hyper-parameters, but without contents. + */ + index_gt fork() noexcept { return index_gt{config_, dynamic_allocator_, tape_allocator_}; } + + ~index_gt() noexcept { reset(); } + + index_gt(index_gt&& other) noexcept { swap(other); } + + index_gt& operator=(index_gt&& other) noexcept { + swap(other); + return *this; + } + + struct copy_result_t { + error_t error; + index_gt index; + + explicit operator bool() const noexcept { return !error; } + copy_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + copy_result_t copy(index_copy_config_t config = {}) const noexcept { + copy_result_t result; + index_gt& other = result.index; + other = index_gt(config_, dynamic_allocator_, tape_allocator_); + if (!other.reserve(limits_)) + return result.failed("Failed to reserve the contexts"); + + // Now all is left - is to allocate new `node_t` instances and populate + // the `other.nodes_` array into it. + for (std::size_t i = 0; i != nodes_count_; ++i) + other.nodes_[i] = other.node_make_copy_(node_bytes_(nodes_[i])); + + other.nodes_count_ = nodes_count_.load(); + other.max_level_ = max_level_; + other.entry_slot_ = entry_slot_; + + // This controls nothing for now :) + (void)config; + return result; + } + + member_citerator_t cbegin() const noexcept { return {this, 0}; } + member_citerator_t cend() const noexcept { return {this, size()}; } + member_citerator_t begin() const noexcept { return {this, 0}; } + member_citerator_t end() const noexcept { return {this, size()}; } + member_iterator_t begin() noexcept { return {this, 0}; } + member_iterator_t end() noexcept { return {this, size()}; } + + member_ref_t at(std::size_t slot) noexcept { return {nodes_[slot].key(), slot}; } + member_cref_t at(std::size_t slot) const noexcept { return {nodes_[slot].ckey(), slot}; } + member_iterator_t iterator_at(std::size_t slot) noexcept { return {this, slot}; } + member_citerator_t citerator_at(std::size_t slot) const noexcept { return {this, slot}; } + + dynamic_allocator_t const& dynamic_allocator() const noexcept { return dynamic_allocator_; } + tape_allocator_t const& tape_allocator() const noexcept { return tape_allocator_; } + +#if defined(USEARCH_USE_PRAGMA_REGION) +#pragma region Adjusting Configuration +#endif + + /** + * @brief Erases all the vectors from the index. + * + * Will change `size()` to zero, but will keep the same `capacity()`. + * Will keep the number of available threads/contexts the same as it was. + */ + void clear() noexcept { + if (!has_reset()) { + std::size_t n = nodes_count_; + for (std::size_t i = 0; i != n; ++i) + node_free_(i); + } else + tape_allocator_.deallocate(nullptr, 0); + nodes_count_ = 0; + max_level_ = -1; + entry_slot_ = 0u; + } + + /** + * @brief Erases all members from index, closing files, and returning RAM to OS. + * + * Will change both `size()` and `capacity()` to zero. + * Will deallocate all threads/contexts. + * If the index is memory-mapped - releases the mapping and the descriptor. + */ + void reset() noexcept { + clear(); + + nodes_ = {}; + contexts_ = {}; + nodes_mutexes_ = {}; + limits_ = index_limits_t{0, 0}; + nodes_capacity_ = 0; + viewed_file_ = memory_mapped_file_t{}; + tape_allocator_ = {}; + } + + /** + * @brief Swaps the underlying memory buffers and thread contexts. + */ + void swap(index_gt& other) noexcept { + std::swap(config_, other.config_); + std::swap(limits_, other.limits_); + std::swap(dynamic_allocator_, other.dynamic_allocator_); + std::swap(tape_allocator_, other.tape_allocator_); + std::swap(pre_, other.pre_); + std::swap(viewed_file_, other.viewed_file_); + std::swap(max_level_, other.max_level_); + std::swap(entry_slot_, other.entry_slot_); + std::swap(nodes_, other.nodes_); + std::swap(nodes_mutexes_, other.nodes_mutexes_); + std::swap(contexts_, other.contexts_); + + // Non-atomic parts. + std::size_t capacity_copy = nodes_capacity_; + std::size_t count_copy = nodes_count_; + nodes_capacity_ = other.nodes_capacity_.load(); + nodes_count_ = other.nodes_count_.load(); + other.nodes_capacity_ = capacity_copy; + other.nodes_count_ = count_copy; + } + + /** + * @brief Increases the `capacity()` of the index to allow adding more vectors. + * @return `true` on success, `false` on memory allocation errors. + */ + bool reserve(index_limits_t limits) usearch_noexcept_m { + + if (limits.threads_add <= limits_.threads_add // + && limits.threads_search <= limits_.threads_search // + && limits.members <= limits_.members) + return true; + + nodes_mutexes_t new_mutexes(limits.members); + buffer_gt new_nodes(limits.members); + buffer_gt new_contexts(limits.threads()); + if (!new_nodes || !new_contexts || !new_mutexes) + return false; + + // Move the nodes info, and deallocate previous buffers. + if (nodes_) + std::memcpy(new_nodes.data(), nodes_.data(), sizeof(node_t) * size()); + + limits_ = limits; + nodes_capacity_ = limits.members; + nodes_ = std::move(new_nodes); + contexts_ = std::move(new_contexts); + nodes_mutexes_ = std::move(new_mutexes); + return true; + } + +#if defined(USEARCH_USE_PRAGMA_REGION) +#pragma endregion + +#pragma region Construction and Search +#endif + + struct add_result_t { + error_t error{}; + std::size_t new_size{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + std::size_t slot{}; + + explicit operator bool() const noexcept { return !error; } + add_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /// @brief Describes a matched search result, augmenting `member_cref_t` + /// contents with `distance` to the query object. + struct match_t { + member_cref_t member; + distance_t distance; + + inline match_t() noexcept : member({nullptr, 0}), distance(std::numeric_limits::max()) {} + + inline match_t(member_cref_t member, distance_t distance) noexcept : member(member), distance(distance) {} + + inline match_t(match_t&& other) noexcept + : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) {} + + inline match_t(match_t const& other) noexcept + : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) {} + + inline match_t& operator=(match_t const& other) noexcept { + member.key.reset(other.member.key.ptr()); + member.slot = other.member.slot; + distance = other.distance; + return *this; + } + + inline match_t& operator=(match_t&& other) noexcept { + member.key.reset(other.member.key.ptr()); + member.slot = other.member.slot; + distance = other.distance; + return *this; + } + }; + + class search_result_t { + node_t const* nodes_{}; + top_candidates_t const* top_{}; + + friend class index_gt; + inline search_result_t(index_gt const& index, top_candidates_t& top) noexcept + : nodes_(index.nodes_), top_(&top) {} + + public: + /** @brief Number of search results found. */ + std::size_t count{}; + /** @brief Number of graph nodes traversed. */ + std::size_t visited_members{}; + /** @brief Number of times the distances were computed. */ + std::size_t computed_distances{}; + error_t error{}; + + inline search_result_t() noexcept {} + inline search_result_t(search_result_t&&) = default; + inline search_result_t& operator=(search_result_t&&) = default; + + explicit operator bool() const noexcept { return !error; } + search_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + + inline operator std::size_t() const noexcept { return count; } + inline std::size_t size() const noexcept { return count; } + inline bool empty() const noexcept { return !count; } + inline match_t operator[](std::size_t i) const noexcept { return at(i); } + inline match_t front() const noexcept { return at(0); } + inline match_t back() const noexcept { return at(count - 1); } + inline bool contains(vector_key_t key) const noexcept { + for (std::size_t i = 0; i != count; ++i) + if (at(i).member.key == key) + return true; + return false; + } + inline match_t at(std::size_t i) const noexcept { + candidate_t const* top_ordered = top_->data(); + candidate_t candidate = top_ordered[i]; + node_t node = nodes_[candidate.slot]; + return {member_cref_t{node.ckey(), candidate.slot}, candidate.distance}; + } + inline std::size_t merge_into( // + vector_key_t* keys, distance_t* distances, // + std::size_t old_count, std::size_t max_count) const noexcept { + + std::size_t merged_count = old_count; + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + distance_t* merged_end = distances + merged_count; + std::size_t offset = std::lower_bound(distances, merged_end, result.distance) - distances; + if (offset == max_count) + continue; + + std::size_t count_worse = merged_count - offset - (max_count == merged_count); + std::memmove(keys + offset + 1, keys + offset, count_worse * sizeof(vector_key_t)); + std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t)); + keys[offset] = result.member.key; + distances[offset] = result.distance; + merged_count += merged_count != max_count; + } + return merged_count; + } + inline std::size_t dump_to(vector_key_t* keys, distance_t* distances) const noexcept { + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + keys[i] = result.member.key; + distances[i] = result.distance; + } + return count; + } + inline std::size_t dump_to(vector_key_t* keys) const noexcept { + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + keys[i] = result.member.key; + } + return count; + } + }; + + struct cluster_result_t { + error_t error{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + match_t cluster{}; + + explicit operator bool() const noexcept { return !error; } + cluster_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Inserts a new entry into the index. Thread-safe. Supports @b heterogeneous lookups. + * Expects needed capacity to be reserved ahead of time: `size() < capacity()`. + * + * @tparam metric_at + * A function responsible for computing the distance @b (dis-similarity) between two objects. + * It should be callable into distinctly different scenarios: + * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. + * - `distance_t operator() (entry_at, entry_at)` - between existing entries. + * Where any possible `entry_at` has both two interfaces: `std::size_t slot()`, `vector_key_t key()`. + * + * @param[in] key External identifier/name/descriptor for the new entry. + * @param[in] value Content that will be compared against other entries to index. + * @param[in] metric Callable object measuring distance between ::value and present objects. + * @param[in] config Configuration options for this specific operation. + * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. + */ + template < // + typename value_at, // + typename metric_at, // + typename callback_at = dummy_callback_t, // + typename prefetch_at = dummy_prefetch_t // + > + add_result_t add( // + vector_key_t key, value_at&& value, metric_at&& metric, // + index_update_config_t config = {}, // + callback_at&& callback = callback_at{}, // + prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { + + add_result_t result; + if (is_immutable()) + return result.failed("Can't add to an immutable index"); + + // Make sure we have enough local memory to perform this request + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + next_candidates_t& next = context.next_candidates; + top.clear(); + next.clear(); + + // The top list needs one more slot than the connectivity of the base level + // for the heuristic, that tries to squeeze one more element into saturated list. + std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); + std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); + if (!top.reserve(top_limit)) + return result.failed("Out of memory!"); + if (!next.reserve(config.expansion)) + return result.failed("Out of memory!"); + + // Determining how much memory to allocate for the node depends on the target level + std::unique_lock new_level_lock(global_mutex_); + level_t max_level_copy = max_level_; // Copy under lock + std::size_t entry_idx_copy = entry_slot_; // Copy under lock + level_t target_level = choose_random_level_(context.level_generator); + + // Make sure we are not overflowing + std::size_t capacity = nodes_capacity_.load(); + std::size_t new_slot = nodes_count_.fetch_add(1); + if (new_slot >= capacity) { + nodes_count_.fetch_sub(1); + return result.failed("Reserve capacity ahead of insertions!"); + } + + // Allocate the neighbors + node_t node = node_make_(key, target_level); + if (!node) { + nodes_count_.fetch_sub(1); + return result.failed("Out of memory!"); + } + if (target_level <= max_level_copy) + new_level_lock.unlock(); + + nodes_[new_slot] = node; + result.new_size = new_slot + 1; + result.slot = new_slot; + callback(at(new_slot)); + node_lock_t new_lock = node_lock_(new_slot); + + // Do nothing for the first element + if (!new_slot) { + entry_slot_ = new_slot; + max_level_ = target_level; + return result; + } + + // Pull stats + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + connect_node_across_levels_( // + value, metric, prefetch, // + new_slot, entry_idx_copy, max_level_copy, target_level, // + config, context); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + + // Updating the entry point if needed + if (target_level > max_level_copy) { + entry_slot_ = new_slot; + max_level_ = target_level; + } + return result; + } + + /** + * @brief Update an existing entry. Thread-safe. Supports @b heterogeneous lookups. + * + * @tparam metric_at + * A function responsible for computing the distance @b (dis-similarity) between two objects. + * It should be callable into distinctly different scenarios: + * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. + * - `distance_t operator() (entry_at, entry_at)` - between existing entries. + * For any possible `entry_at` following interfaces will work: + * - `std::size_t get_slot(entry_at const &)` + * - `vector_key_t get_key(entry_at const &)` + * + * @param[in] iterator Iterator pointing to an existing entry to be replaced. + * @param[in] key External identifier/name/descriptor for the entry. + * @param[in] value Content that will be compared against other entries in the index. + * @param[in] metric Callable object measuring distance between ::value and present objects. + * @param[in] config Configuration options for this specific operation. + * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. + */ + template < // + typename value_at, // + typename metric_at, // + typename callback_at = dummy_callback_t, // + typename prefetch_at = dummy_prefetch_t // + > + add_result_t update( // + member_iterator_t iterator, // + vector_key_t key, // + value_at&& value, // + metric_at&& metric, // + index_update_config_t config = {}, // + callback_at&& callback = callback_at{}, // + prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { + + usearch_assert_m(!is_immutable(), "Can't add to an immutable index"); + add_result_t result; + std::size_t old_slot = iterator.slot_; + + // Make sure we have enough local memory to perform this request + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + next_candidates_t& next = context.next_candidates; + top.clear(); + next.clear(); + + // The top list needs one more slot than the connectivity of the base level + // for the heuristic, that tries to squeeze one more element into saturated list. + std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); + std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); + if (!top.reserve(top_limit)) + return result.failed("Out of memory!"); + if (!next.reserve(config.expansion)) + return result.failed("Out of memory!"); + + node_lock_t new_lock = node_lock_(old_slot); + node_t node = node_at_(old_slot); + + level_t node_level = node.level(); + span_bytes_t node_bytes = node_bytes_(node); + std::memset(node_bytes.data(), 0, node_bytes.size()); + node.level(node_level); + + // Pull stats + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + connect_node_across_levels_( // + value, metric, prefetch, // + old_slot, entry_slot_, max_level_, node_level, // + config, context); + node.key(key); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + result.slot = old_slot; + + callback(at(old_slot)); + return result; + } + + /** + * @brief Searches for the closest elements to the given ::query. Thread-safe. + * + * @param[in] query Content that will be compared against other entries in the index. + * @param[in] wanted The upper bound for the number of results to return. + * @param[in] config Configuration options for this specific operation. + * @param[in] predicate Optional filtering predicate for `member_cref_t`. + * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. + */ + template < // + typename value_at, // + typename metric_at, // + typename predicate_at = dummy_predicate_t, // + typename prefetch_at = dummy_prefetch_t // + > + search_result_t search( // + value_at&& query, // + std::size_t wanted, // + metric_at&& metric, // + index_search_config_t config = {}, // + predicate_at&& predicate = predicate_at{}, // + prefetch_at&& prefetch = prefetch_at{}) const noexcept { + + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + search_result_t result{*this, top}; + if (!nodes_count_) + return result; + + // Go down the level, tracking only the closest match + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + if (config.exact) { + if (!top.reserve(wanted)) + return result.failed("Out of memory!"); + search_exact_(query, metric, predicate, wanted, context); + } else { + next_candidates_t& next = context.next_candidates; + std::size_t expansion = (std::max)(config.expansion, wanted); + if (!next.reserve(expansion)) + return result.failed("Out of memory!"); + if (!top.reserve(expansion)) + return result.failed("Out of memory!"); + + std::size_t closest_slot = search_for_one_(query, metric, prefetch, entry_slot_, max_level_, 0, context); + + // For bottom layer we need a more optimized procedure + if (!search_to_find_in_base_(query, metric, predicate, prefetch, closest_slot, expansion, context)) + return result.failed("Out of memory!"); + } + + top.sort_ascending(); + top.shrink(wanted); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + result.count = top.size(); + return result; + } + + /** + * @brief Identifies the closest cluster to the given ::query. Thread-safe. + * + * @param[in] query Content that will be compared against other entries in the index. + * @param[in] level The index level to target. Higher means lower resolution. + * @param[in] config Configuration options for this specific operation. + * @param[in] predicate Optional filtering predicate for `member_cref_t`. + * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. + */ + template < // + typename value_at, // + typename metric_at, // + typename predicate_at = dummy_predicate_t, // + typename prefetch_at = dummy_prefetch_t // + > + cluster_result_t cluster( // + value_at&& query, // + std::size_t level, // + metric_at&& metric, // + index_cluster_config_t config = {}, // + predicate_at&& predicate = predicate_at{}, // + prefetch_at&& prefetch = prefetch_at{}) const noexcept { + + context_t& context = contexts_[config.thread]; + cluster_result_t result; + if (!nodes_count_) + return result.failed("No clusters to identify"); + + // Go down the level, tracking only the closest match + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + next_candidates_t& next = context.next_candidates; + std::size_t expansion = config.expansion; + if (!next.reserve(expansion)) + return result.failed("Out of memory!"); + + result.cluster.member = at(search_for_one_(query, metric, prefetch, entry_slot_, max_level_, + static_cast(level - 1), context)); + result.cluster.distance = context.measure(query, result.cluster.member, metric); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + + (void)predicate; + return result; + } + +#if defined(USEARCH_USE_PRAGMA_REGION) +#pragma endregion + +#pragma region Metadata +#endif + + struct stats_t { + std::size_t nodes{}; + std::size_t edges{}; + std::size_t max_edges{}; + std::size_t allocated_bytes{}; + }; + + stats_t stats() const noexcept { + stats_t result{}; + + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + std::size_t max_edges = node.level() * config_.connectivity + config_.connectivity_base; + std::size_t edges = 0; + for (level_t level = 0; level <= node.level(); ++level) + edges += neighbors_(node, level).size(); + + ++result.nodes; + result.allocated_bytes += node_bytes_(node).size(); + result.edges += edges; + result.max_edges += max_edges; + } + return result; + } + + stats_t stats(std::size_t level) const noexcept { + stats_t result{}; + + std::size_t neighbors_bytes = !level ? pre_.neighbors_base_bytes : pre_.neighbors_bytes; + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + if (static_cast(node.level()) < level) + continue; + + ++result.nodes; + result.edges += neighbors_(node, level).size(); + result.allocated_bytes += node_head_bytes_() + neighbors_bytes; + } + + std::size_t max_edges_per_node = level ? config_.connectivity_base : config_.connectivity; + result.max_edges = result.nodes * max_edges_per_node; + return result; + } + + stats_t stats(stats_t* stats_per_level, std::size_t max_level) const noexcept { + + std::size_t head_bytes = node_head_bytes_(); + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + + stats_per_level[0].nodes++; + stats_per_level[0].edges += neighbors_(node, 0).size(); + stats_per_level[0].allocated_bytes += pre_.neighbors_base_bytes + head_bytes; + + level_t node_level = static_cast(node.level()); + for (level_t l = 1; l <= (std::min)(node_level, static_cast(max_level)); ++l) { + stats_per_level[l].nodes++; + stats_per_level[l].edges += neighbors_(node, l).size(); + stats_per_level[l].allocated_bytes += pre_.neighbors_bytes; + } + } + + // The `max_edges` parameter can be inferred from `nodes` + stats_per_level[0].max_edges = stats_per_level[0].nodes * config_.connectivity_base; + for (std::size_t l = 1; l <= max_level; ++l) + stats_per_level[l].max_edges = stats_per_level[l].nodes * config_.connectivity; + + // Aggregate stats across levels + stats_t result{}; + for (std::size_t l = 0; l <= max_level; ++l) + result.nodes += stats_per_level[l].nodes, // + result.edges += stats_per_level[l].edges, // + result.allocated_bytes += stats_per_level[l].allocated_bytes, // + result.max_edges += stats_per_level[l].max_edges; // + + return result; + } + + /** + * @brief A relatively accurate lower bound on the amount of memory consumed by the system. + * In practice it's error will be below 10%. + * + * @see `serialized_length` for the length of the binary serialized representation. + */ + std::size_t memory_usage(std::size_t allocator_entry_bytes = default_allocator_entry_bytes()) const noexcept { + std::size_t total = 0; + if (!viewed_file_) { + stats_t s = stats(); + total += s.allocated_bytes; + total += s.nodes * allocator_entry_bytes; + } + + // Temporary data-structures, proportional to the number of nodes: + total += limits_.members * sizeof(node_t) + allocator_entry_bytes; + + // Temporary data-structures, proportional to the number of threads: + total += limits_.threads() * sizeof(context_t) + allocator_entry_bytes * 3; + return total; + } + + std::size_t memory_usage_per_node(level_t level) const noexcept { return node_bytes_(level); } + +#if defined(USEARCH_USE_PRAGMA_REGION) +#pragma endregion + +#pragma region Serialization +#endif + + /** + * @brief Estimate the binary length (in bytes) of the serialized index. + */ + std::size_t serialized_length() const noexcept { + std::size_t neighbors_length = 0; + for (std::size_t i = 0; i != size(); ++i) + neighbors_length += node_bytes_(node_at_(i).level()) + sizeof(level_t); + return sizeof(index_serialized_header_t) + neighbors_length; + } + + /** + * @brief Saves serialized binary index representation to a stream. + */ + template + serialization_result_t save_to_stream(output_callback_at&& output, progress_at&& progress = {}) const noexcept { + + serialization_result_t result; + + // Export some basic metadata + index_serialized_header_t header; + header.size = nodes_count_; + header.connectivity = config_.connectivity; + header.connectivity_base = config_.connectivity_base; + header.max_level = max_level_; + header.entry_slot = entry_slot_; + if (!output(&header, sizeof(header))) + return result.failed("Failed to serialize the header into stream"); + + // Progress status + std::size_t processed = 0; + std::size_t const total = 2 * header.size; + + // Export the number of levels per node + // That is both enough to estimate the overall memory consumption, + // and to be able to estimate the offsets of every entry in the file. + for (std::size_t i = 0; i != header.size; ++i) { + node_t node = node_at_(i); + level_t level = node.level(); + if (!output(&level, sizeof(level))) + return result.failed("Failed to serialize into stream"); + if (!progress(++processed, total)) + return result.failed("Terminated by user"); + } + + // After that dump the nodes themselves + for (std::size_t i = 0; i != header.size; ++i) { + span_bytes_t node_bytes = node_bytes_(node_at_(i)); + if (!output(node_bytes.data(), node_bytes.size())) + return result.failed("Failed to serialize into stream"); + if (!progress(++processed, total)) + return result.failed("Terminated by user"); + } + + return {}; + } + + /** + * @brief Symmetric to `save_from_stream`, pulls data from a stream. + */ + template + serialization_result_t load_from_stream(input_callback_at&& input, progress_at&& progress = {}) noexcept { + + serialization_result_t result; + + // Remove previously stored objects + reset(); + + // Pull basic metadata + index_serialized_header_t header; + if (!input(&header, sizeof(header))) + return result.failed("Failed to pull the header from the stream"); + + // We are loading an empty index, no more work to do + if (!header.size) { + reset(); + return result; + } + + // Allocate some dynamic memory to read all the levels + using levels_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt levels(header.size); + if (!levels) + return result.failed("Out of memory"); + if (!input(levels, header.size * sizeof(level_t))) + return result.failed("Failed to pull nodes levels from the stream"); + + // Submit metadata + config_.connectivity = header.connectivity; + config_.connectivity_base = header.connectivity_base; + pre_ = precompute_(config_); + index_limits_t limits; + limits.members = header.size; + if (!reserve(limits)) { + reset(); + return result.failed("Out of memory"); + } + nodes_count_ = header.size; + max_level_ = static_cast(header.max_level); + entry_slot_ = static_cast(header.entry_slot); + + // Load the nodes + for (std::size_t i = 0; i != header.size; ++i) { + span_bytes_t node_bytes = node_malloc_(levels[i]); + if (!input(node_bytes.data(), node_bytes.size())) { + reset(); + return result.failed("Failed to pull nodes from the stream"); + } + nodes_[i] = node_t{node_bytes.data()}; + if (!progress(i + 1, header.size)) + return result.failed("Terminated by user"); + } + return {}; + } + + template + serialization_result_t save(char const* file_path, progress_at&& progress = {}) const noexcept { + return save(output_file_t(file_path), std::forward(progress)); + } + + template + serialization_result_t load(char const* file_path, progress_at&& progress = {}) noexcept { + return load(input_file_t(file_path), std::forward(progress)); + } + + /** + * @brief Saves serialized binary index representation to a file, generally on disk. + */ + template + serialization_result_t save(output_file_t file, progress_at&& progress = {}) const noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void* buffer, std::size_t length) { + io_result = file.write(buffer, length); + return !!io_result; + }, + std::forward(progress)); + + if (!stream_result) + return stream_result; + return io_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t save(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) const noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(file.data() + offset, buffer, length); + offset += length; + return true; + }, + std::forward(progress)); + + return stream_result; + } + + /** + * @brief Loads the serialized binary index representation from disk to RAM. + * Adjusts the configuration properties of the constructed index to + * match the settings in the file. + */ + template + serialization_result_t load(input_file_t file, progress_at&& progress = {}) noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + io_result = file.read(buffer, length); + return !!io_result; + }, + std::forward(progress)); + + if (!stream_result) + return stream_result; + return io_result; + } + + /** + * @brief Loads the serialized binary index representation from disk to RAM. + * Adjusts the configuration properties of the constructed index to + * match the settings in the file. + */ + template + serialization_result_t load(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(buffer, file.data() + offset, length); + offset += length; + return true; + }, + std::forward(progress)); + + return stream_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t view(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) noexcept { + + // Remove previously stored objects + reset(); + + serialization_result_t result = file.open_if_not(); + if (!result) + return result; + + // Pull basic metadata + index_serialized_header_t header; + if (file.size() - offset < sizeof(header)) + return result.failed("File is corrupted and lacks a header"); + std::memcpy(&header, file.data() + offset, sizeof(header)); + + if (!header.size) { + reset(); + return result; + } + + // Precompute offsets of every node, but before that we need to update the configs + // This could have been done with `std::exclusive_scan`, but it's only available from C++17. + using offsets_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt offsets(header.size); + if (!offsets) + return result.failed("Out of memory"); + + config_.connectivity = header.connectivity; + config_.connectivity_base = header.connectivity_base; + pre_ = precompute_(config_); + misaligned_ptr_gt levels{(byte_t*)file.data() + offset + sizeof(header)}; + offsets[0u] = offset + sizeof(header) + sizeof(level_t) * header.size; + for (std::size_t i = 1; i < header.size; ++i) + offsets[i] = offsets[i - 1] + node_bytes_(levels[i - 1]); + + std::size_t total_bytes = offsets[header.size - 1] + node_bytes_(levels[header.size - 1]); + if (file.size() < total_bytes) { + reset(); + return result.failed("File is corrupted and can't fit all the nodes"); + } + + // Submit metadata and reserve memory + index_limits_t limits; + limits.members = header.size; + if (!reserve(limits)) { + reset(); + return result.failed("Out of memory"); + } + nodes_count_ = header.size; + max_level_ = static_cast(header.max_level); + entry_slot_ = static_cast(header.entry_slot); + + // Rapidly address all the nodes + for (std::size_t i = 0; i != header.size; ++i) { + nodes_[i] = node_t{(byte_t*)file.data() + offsets[i]}; + if (!progress(i + 1, header.size)) + return result.failed("Terminated by user"); + } + viewed_file_ = std::move(file); + return {}; + } + +#if defined(USEARCH_USE_PRAGMA_REGION) +#pragma endregion +#endif + + /** + * @brief Performs compaction on the whole HNSW index, purging some entries + * and links to them, while also generating a more efficient mapping, + * putting the more frequently used entries closer together. + * + * + * Scans the whole collection, removing the links leading towards + * banned entries. This essentially isolates some nodes from the rest + * of the graph, while keeping their outgoing links, in case the node + * is structurally relevant and has a crucial role in the index. + * It won't reclaim the memory. + * + * @param[in] allow_member Predicate to mark nodes for isolation. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ + template + void compact( // + values_at&& values, // + metric_at&& metric, // + slot_transition_at&& slot_transition, // + + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}, // + prefetch_at&& prefetch = prefetch_at{}) noexcept { + + // Export all the keys, slots, and levels. + // Partition them with the predicate. + // Sort the allowed entries in descending order of their level. + // Create a new array mapping old slots to the new ones (INT_MAX for deleted items). + struct slot_level_t { + compressed_slot_t old_slot; + compressed_slot_t cluster; + level_t level; + }; + using slot_level_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt slots_and_levels(size()); + + // Progress status + std::atomic do_tasks{true}; + std::atomic processed{0}; + std::size_t const total = 3 * slots_and_levels.size(); + + // For every bottom level node, determine its parent cluster + executor.dynamic(slots_and_levels.size(), [&](std::size_t thread_idx, std::size_t old_slot) { + context_t& context = contexts_[thread_idx]; + std::size_t cluster = search_for_one_( // + values[citerator_at(old_slot)], // + metric, prefetch, // + entry_slot_, max_level_, 0, context); + slots_and_levels[old_slot] = { // + static_cast(old_slot), // + static_cast(cluster), // + node_at_(old_slot).level()}; + ++processed; + if (thread_idx == 0) + do_tasks = progress(processed.load(), total); + return do_tasks.load(); + }); + if (!do_tasks.load()) + return; + + // Where the actual permutation happens: + std::sort(slots_and_levels.begin(), slots_and_levels.end(), [](slot_level_t const& a, slot_level_t const& b) { + return a.level == b.level ? a.cluster < b.cluster : a.level > b.level; + }); + + using size_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt old_slot_to_new(slots_and_levels.size()); + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) + old_slot_to_new[slots_and_levels[new_slot].old_slot] = new_slot; + + // Erase all the incoming links + buffer_gt reordered_nodes(slots_and_levels.size()); + tape_allocator_t reordered_tape; + + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { + std::size_t old_slot = slots_and_levels[new_slot].old_slot; + node_t old_node = node_at_(old_slot); + + std::size_t node_bytes = node_bytes_(old_node.level()); + byte_t* new_data = (byte_t*)reordered_tape.allocate(node_bytes); + node_t new_node{new_data}; + std::memcpy(new_data, old_node.tape(), node_bytes); + + for (level_t level = 0; level <= old_node.level(); ++level) + for (misaligned_ref_gt neighbor : neighbors_(new_node, level)) + neighbor = static_cast(old_slot_to_new[compressed_slot_t(neighbor)]); + + reordered_nodes[new_slot] = new_node; + if (!progress(++processed, total)) + return; + } + + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { + std::size_t old_slot = slots_and_levels[new_slot].old_slot; + slot_transition(node_at_(old_slot).ckey(), // + static_cast(old_slot), // + static_cast(new_slot)); + if (!progress(++processed, total)) + return; + } + + nodes_ = std::move(reordered_nodes); + tape_allocator_ = std::move(reordered_tape); + entry_slot_ = old_slot_to_new[entry_slot_]; + } + + /** + * @brief Scans the whole collection, removing the links leading towards + * banned entries. This essentially isolates some nodes from the rest + * of the graph, while keeping their outgoing links, in case the node + * is structurally relevant and has a crucial role in the index. + * It won't reclaim the memory. + * + * @param[in] allow_member Predicate to mark nodes for isolation. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ + template < // + typename allow_member_at = dummy_predicate_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + void isolate( // + allow_member_at&& allow_member, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) noexcept { + + // Progress status + std::atomic do_tasks{true}; + std::atomic processed{0}; + + // Erase all the incoming links + std::size_t nodes_count = size(); + executor.dynamic(nodes_count, [&](std::size_t thread_idx, std::size_t node_idx) { + node_t node = node_at_(node_idx); + for (level_t level = 0; level <= node.level(); ++level) { + neighbors_ref_t neighbors = neighbors_(node, level); + std::size_t old_size = neighbors.size(); + neighbors.clear(); + for (std::size_t i = 0; i != old_size; ++i) { + compressed_slot_t neighbor_slot = neighbors[i]; + node_t neighbor = node_at_(neighbor_slot); + if (allow_member(member_cref_t{neighbor.ckey(), neighbor_slot})) + neighbors.push_back(neighbor_slot); + } + } + ++processed; + if (thread_idx == 0) + do_tasks = progress(processed.load(), nodes_count); + return do_tasks.load(); + }); + + // At the end report the latest numbers, because the reporter thread may be finished earlier + progress(processed.load(), nodes_count); + } + + private: + inline static precomputed_constants_t precompute_(index_config_t const& config) noexcept { + precomputed_constants_t pre; + pre.inverse_log_connectivity = 1.0 / std::log(static_cast(config.connectivity)); + pre.neighbors_bytes = config.connectivity * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); + pre.neighbors_base_bytes = config.connectivity_base * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); + return pre; + } + + using span_bytes_t = span_gt; + + inline span_bytes_t node_bytes_(node_t node) const noexcept { return {node.tape(), node_bytes_(node.level())}; } + inline std::size_t node_bytes_(level_t level) const noexcept { + return node_head_bytes_() + node_neighbors_bytes_(level); + } + inline std::size_t node_neighbors_bytes_(node_t node) const noexcept { return node_neighbors_bytes_(node.level()); } + inline std::size_t node_neighbors_bytes_(level_t level) const noexcept { + return pre_.neighbors_base_bytes + pre_.neighbors_bytes * level; + } + + span_bytes_t node_malloc_(level_t level) noexcept { + std::size_t node_bytes = node_bytes_(level); + byte_t* data = (byte_t*)tape_allocator_.allocate(node_bytes); + return data ? span_bytes_t{data, node_bytes} : span_bytes_t{}; + } + + node_t node_make_(vector_key_t key, level_t level) noexcept { + span_bytes_t node_bytes = node_malloc_(level); + if (!node_bytes) + return {}; + + std::memset(node_bytes.data(), 0, node_bytes.size()); + node_t node{(byte_t*)node_bytes.data()}; + node.key(key); + node.level(level); + return node; + } + + node_t node_make_copy_(span_bytes_t old_bytes) noexcept { + byte_t* data = (byte_t*)tape_allocator_.allocate(old_bytes.size()); + if (!data) + return {}; + std::memcpy(data, old_bytes.data(), old_bytes.size()); + return node_t{data}; + } + + void node_free_(std::size_t idx) noexcept { + if (viewed_file_) + return; + + node_t& node = nodes_[idx]; + tape_allocator_.deallocate(node.tape(), node_bytes_(node).size()); + node = node_t{}; + } + + inline node_t node_at_(std::size_t idx) const noexcept { return nodes_[idx]; } + inline neighbors_ref_t neighbors_base_(node_t node) const noexcept { return {node.neighbors_tape()}; } + + inline neighbors_ref_t neighbors_non_base_(node_t node, level_t level) const noexcept { + return {node.neighbors_tape() + pre_.neighbors_base_bytes + (level - 1) * pre_.neighbors_bytes}; + } + + inline neighbors_ref_t neighbors_(node_t node, level_t level) const noexcept { + return level ? neighbors_non_base_(node, level) : neighbors_base_(node); + } + + struct node_lock_t { + nodes_mutexes_t& mutexes; + std::size_t slot; + inline ~node_lock_t() noexcept { mutexes.atomic_reset(slot); } + }; + + inline node_lock_t node_lock_(std::size_t slot) const noexcept { + while (nodes_mutexes_.atomic_set(slot)) + ; + return {nodes_mutexes_, slot}; + } + + template + void connect_node_across_levels_( // + value_at&& value, metric_at&& metric, prefetch_at&& prefetch, // + std::size_t node_slot, std::size_t entry_slot, level_t max_level, level_t target_level, // + index_update_config_t const& config, context_t& context) usearch_noexcept_m { + + // Go down the level, tracking only the closest match + std::size_t closest_slot = search_for_one_( // + value, metric, prefetch, // + entry_slot, max_level, target_level, context); + + // From `target_level` down perform proper extensive search + for (level_t level = (std::min)(target_level, max_level); level >= 0; --level) { + // TODO: Handle out of memory conditions + search_to_insert_(value, metric, prefetch, closest_slot, node_slot, level, config.expansion, context); + closest_slot = connect_new_node_(metric, node_slot, level, context); + reconnect_neighbor_nodes_(metric, node_slot, value, level, context); + } + } + + template + std::size_t connect_new_node_( // + metric_at&& metric, std::size_t new_slot, level_t level, context_t& context) usearch_noexcept_m { + + node_t new_node = node_at_(new_slot); + top_candidates_t& top = context.top_candidates; + + // Outgoing links from `new_slot`: + neighbors_ref_t new_neighbors = neighbors_(new_node, level); + { + usearch_assert_m(!new_neighbors.size(), "The newly inserted element should have blank link list"); + candidates_view_t top_view = refine_(metric, config_.connectivity, top, context); + + for (std::size_t idx = 0; idx != top_view.size(); idx++) { + usearch_assert_m(!new_neighbors[idx], "Possible memory corruption"); + usearch_assert_m(level <= node_at_(top_view[idx].slot).level(), "Linking to missing level"); + new_neighbors.push_back(top_view[idx].slot); + } + } + + return new_neighbors[0]; + } + + template + void reconnect_neighbor_nodes_( // + metric_at&& metric, std::size_t new_slot, value_at&& value, level_t level, + context_t& context) usearch_noexcept_m { + + node_t new_node = node_at_(new_slot); + top_candidates_t& top = context.top_candidates; + neighbors_ref_t new_neighbors = neighbors_(new_node, level); + + // Reverse links from the neighbors: + std::size_t const connectivity_max = level ? config_.connectivity : config_.connectivity_base; + for (compressed_slot_t close_slot : new_neighbors) { + if (close_slot == new_slot) + continue; + node_lock_t close_lock = node_lock_(close_slot); + node_t close_node = node_at_(close_slot); + + neighbors_ref_t close_header = neighbors_(close_node, level); + usearch_assert_m(close_header.size() <= connectivity_max, "Possible corruption"); + usearch_assert_m(close_slot != new_slot, "Self-loops are impossible"); + usearch_assert_m(level <= close_node.level(), "Linking to missing level"); + + // If `new_slot` is already present in the neighboring connections of `close_slot` + // then no need to modify any connections or run the heuristics. + if (close_header.size() < connectivity_max) { + close_header.push_back(static_cast(new_slot)); + continue; + } + + // To fit a new connection we need to drop an existing one. + top.clear(); + usearch_assert_m((top.reserve(close_header.size() + 1)), "The memory must have been reserved in `add`"); + top.insert_reserved( + {context.measure(value, citerator_at(close_slot), metric), static_cast(new_slot)}); + for (compressed_slot_t successor_slot : close_header) + top.insert_reserved( + {context.measure(citerator_at(close_slot), citerator_at(successor_slot), metric), successor_slot}); + + // Export the results: + close_header.clear(); + candidates_view_t top_view = refine_(metric, connectivity_max, top, context); + for (std::size_t idx = 0; idx != top_view.size(); idx++) + close_header.push_back(top_view[idx].slot); + } + } + + level_t choose_random_level_(std::default_random_engine& level_generator) const noexcept { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -std::log(distribution(level_generator)) * pre_.inverse_log_connectivity; + return (level_t)r; + } + + struct candidates_range_t; + class candidates_iterator_t { + friend struct candidates_range_t; + + index_gt const& index_; + neighbors_ref_t neighbors_; + visits_hash_set_t& visits_; + std::size_t current_; + + candidates_iterator_t& skip_missing() noexcept { + if (!visits_.size()) + return *this; + while (current_ != neighbors_.size()) { + compressed_slot_t neighbor_slot = neighbors_[current_]; + if (visits_.test(neighbor_slot)) + current_++; + else + break; + } + return *this; + } + + public: + using element_t = compressed_slot_t; + using iterator_category = std::forward_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = misaligned_ptr_gt; + using reference = misaligned_ref_gt; + + reference operator*() const noexcept { return slot(); } + candidates_iterator_t(index_gt const& index, neighbors_ref_t neighbors, visits_hash_set_t& visits, + std::size_t progress) noexcept + : index_(index), neighbors_(neighbors), visits_(visits), current_(progress) {} + candidates_iterator_t operator++(int) noexcept { + return candidates_iterator_t(index_, visits_, neighbors_, current_ + 1).skip_missing(); + } + candidates_iterator_t& operator++() noexcept { + ++current_; + skip_missing(); + return *this; + } + bool operator==(candidates_iterator_t const& other) noexcept { return current_ == other.current_; } + bool operator!=(candidates_iterator_t const& other) noexcept { return current_ != other.current_; } + + vector_key_t key() const noexcept { return index_->node_at_(slot()).key(); } + compressed_slot_t slot() const noexcept { return neighbors_[current_]; } + friend inline std::size_t get_slot(candidates_iterator_t const& it) noexcept { return it.slot(); } + friend inline vector_key_t get_key(candidates_iterator_t const& it) noexcept { return it.key(); } + }; + + struct candidates_range_t { + index_gt const& index; + neighbors_ref_t neighbors; + visits_hash_set_t& visits; + + candidates_iterator_t begin() const noexcept { + return candidates_iterator_t{index, neighbors, visits, 0}.skip_missing(); + } + candidates_iterator_t end() const noexcept { return {index, neighbors, visits, neighbors.size()}; } + }; + + template + std::size_t search_for_one_( // + value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // + std::size_t closest_slot, level_t begin_level, level_t end_level, context_t& context) const noexcept { + + visits_hash_set_t& visits = context.visits; + visits.clear(); + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(closest_slot), citerator_at(closest_slot + 1)); + + distance_t closest_dist = context.measure(query, citerator_at(closest_slot), metric); + for (level_t level = begin_level; level > end_level; --level) { + bool changed; + do { + changed = false; + node_lock_t closest_lock = node_lock_(closest_slot); + neighbors_ref_t closest_neighbors = neighbors_non_base_(node_at_(closest_slot), level); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, closest_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Actual traversal + for (compressed_slot_t candidate_slot : closest_neighbors) { + distance_t candidate_dist = context.measure(query, citerator_at(candidate_slot), metric); + if (candidate_dist < closest_dist) { + closest_dist = candidate_dist; + closest_slot = candidate_slot; + changed = true; + } + } + context.iteration_cycles++; + } while (changed); + } + return closest_slot; + } + + /** + * @brief Traverses a layer of a graph, to find the best place to insert a new node. + * Locks the nodes in the process, assuming other threads are updating neighbors lists. + * @return `true` if procedure succeeded, `false` if run out of memory. + */ + template + bool search_to_insert_( // + value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // + std::size_t start_slot, std::size_t new_slot, level_t level, std::size_t top_limit, + context_t& context) noexcept { + + visits_hash_set_t& visits = context.visits; + next_candidates_t& next = context.next_candidates; // pop min, push + top_candidates_t& top = context.top_candidates; // pop max, push + + visits.clear(); + next.clear(); + top.clear(); + if (!visits.reserve(config_.connectivity_base + 1u)) + return false; + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); + + distance_t radius = context.measure(query, citerator_at(start_slot), metric); + next.insert_reserved({-radius, static_cast(start_slot)}); + top.insert_reserved({radius, static_cast(start_slot)}); + visits.set(static_cast(start_slot)); + + while (!next.empty()) { + + candidate_t candidacy = next.top(); + if ((-candidacy.distance) > radius && top.size() == top_limit) + break; + + next.pop(); + context.iteration_cycles++; + + compressed_slot_t candidate_slot = candidacy.slot; + if (new_slot == candidate_slot) + continue; + node_t candidate_ref = node_at_(candidate_slot); + node_lock_t candidate_lock = node_lock_(candidate_slot); + neighbors_ref_t candidate_neighbors = neighbors_(candidate_ref, level); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Assume the worst-case when reserving memory + if (!visits.reserve(visits.size() + candidate_neighbors.size())) + return false; + + for (compressed_slot_t successor_slot : candidate_neighbors) { + if (visits.set(successor_slot)) + continue; + + // node_lock_t successor_lock = node_lock_(successor_slot); + distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); + if (top.size() < top_limit || successor_dist < radius) { + // This can substantially grow our priority queue: + next.insert({-successor_dist, successor_slot}); + // This will automatically evict poor matches: + top.insert({successor_dist, successor_slot}, top_limit); + radius = top.top().distance; + } + } + } + return true; + } + + /** + * @brief Traverses the @b base layer of a graph, to find a close match. + * Doesn't lock any nodes, assuming read-only simultaneous access. + * @return `true` if procedure succeeded, `false` if run out of memory. + */ + template + bool search_to_find_in_base_( // + value_at&& query, metric_at&& metric, predicate_at&& predicate, prefetch_at&& prefetch, // + std::size_t start_slot, std::size_t expansion, context_t& context) const noexcept { + + visits_hash_set_t& visits = context.visits; + next_candidates_t& next = context.next_candidates; // pop min, push + top_candidates_t& top = context.top_candidates; // pop max, push + std::size_t const top_limit = expansion; + + visits.clear(); + next.clear(); + top.clear(); + if (!visits.reserve(config_.connectivity_base + 1u)) + return false; + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); + + distance_t radius = context.measure(query, citerator_at(start_slot), metric); + next.insert_reserved({-radius, static_cast(start_slot)}); + visits.set(static_cast(start_slot)); + + // Don't populate the top list if the predicate is not satisfied + if (is_dummy() || predicate(member_cref_t{node_at_(start_slot).ckey(), start_slot})) + top.insert_reserved({radius, static_cast(start_slot)}); + + while (!next.empty()) { + + candidate_t candidate = next.top(); + if ((-candidate.distance) > radius) + break; + + next.pop(); + context.iteration_cycles++; + + neighbors_ref_t candidate_neighbors = neighbors_base_(node_at_(candidate.slot)); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Assume the worst-case when reserving memory + if (!visits.reserve(visits.size() + candidate_neighbors.size())) + return false; + + for (compressed_slot_t successor_slot : candidate_neighbors) { + if (visits.set(successor_slot)) + continue; + + distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); + if (top.size() < top_limit || successor_dist < radius) { + // This can substantially grow our priority queue: + next.insert({-successor_dist, successor_slot}); + if (is_dummy() || + predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) + top.insert({successor_dist, successor_slot}, top_limit); + radius = top.top().distance; + } + } + } + + return true; + } + + /** + * @brief Iterates through all members, without actually touching the index. + */ + template + void search_exact_( // + value_at&& query, metric_at&& metric, predicate_at&& predicate, // + std::size_t count, context_t& context) const noexcept { + + top_candidates_t& top = context.top_candidates; + top.clear(); + top.reserve(count); + for (std::size_t i = 0; i != size(); ++i) { + if (!is_dummy()) + if (!predicate(at(i))) + continue; + + distance_t distance = context.measure(query, citerator_at(i), metric); + top.insert(candidate_t{distance, static_cast(i)}, count); + } + } + + /** + * @brief This algorithm from the original paper implements a heuristic, + * that massively reduces the number of connections a point has, + * to keep only the neighbors, that are from each other. + */ + template + candidates_view_t refine_( // + metric_at&& metric, // + std::size_t needed, top_candidates_t& top, context_t& context) const noexcept { + + top.sort_ascending(); + candidate_t* top_data = top.data(); + std::size_t const top_count = top.size(); + if (top_count < needed) + return {top_data, top_count}; + + std::size_t submitted_count = 1; + std::size_t consumed_count = 1; /// Always equal or greater than `submitted_count`. + while (submitted_count < needed && consumed_count < top_count) { + candidate_t candidate = top_data[consumed_count]; + bool good = true; + for (std::size_t idx = 0; idx < submitted_count; idx++) { + candidate_t submitted = top_data[idx]; + distance_t inter_result_dist = context.measure( // + citerator_at(candidate.slot), // + citerator_at(submitted.slot), // + metric); + if (inter_result_dist < candidate.distance) { + good = false; + break; + } + } + + if (good) { + top_data[submitted_count] = top_data[consumed_count]; + submitted_count++; + } + consumed_count++; + } + + top.shrink(submitted_count); + return {top_data, submitted_count}; + } +}; + +struct join_result_t { + error_t error{}; + std::size_t intersection_size{}; + std::size_t engagements{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + + explicit operator bool() const noexcept { return !error; } + join_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } +}; + +/** + * @brief Adapts the Male-Optimal Stable Marriage algorithm for unequal sets + * to perform fast one-to-one matching between two large collections + * of vectors, using approximate nearest neighbors search. + * + * @param[inout] man_to_woman Container to map ::men keys to ::women. + * @param[inout] woman_to_man Container to map ::women keys to ::men. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ +template < // + + typename men_at, // + typename women_at, // + typename men_values_at, // + typename women_values_at, // + typename men_metric_at, // + typename women_metric_at, // + + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > +static join_result_t join( // + men_at const& men, // + women_at const& women, // + men_values_at const& men_values, // + women_values_at const& women_values, // + men_metric_at&& men_metric, // + women_metric_at&& women_metric, // + + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) noexcept { + + if (women.size() < men.size()) + return unum::usearch::join( // + women, men, // + women_values, men_values, // + std::forward(women_metric), std::forward(men_metric), // + + config, // + std::forward(woman_to_man), // + std::forward(man_to_woman), // + std::forward(executor), // + std::forward(progress)); + + join_result_t result; + + // Sanity checks and argument validation: + if (&men == &women) + return result.failed("Can't join with itself, consider copying"); + + if (config.max_proposals == 0) + config.max_proposals = std::log(men.size()) + executor.size(); + + using proposals_count_t = std::uint16_t; + config.max_proposals = (std::min)(men.size(), config.max_proposals); + + using distance_t = typename men_at::distance_t; + using dynamic_allocator_traits_t = typename men_at::dynamic_allocator_traits_t; + using man_key_t = typename men_at::vector_key_t; + using woman_key_t = typename women_at::vector_key_t; + + // Use the `compressed_slot_t` type of the larger collection + using compressed_slot_t = typename women_at::compressed_slot_t; + using compressed_slot_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + using proposals_count_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + // Create an atomic queue, as a ring structure, from/to which + // free men will be added/pulled. + std::mutex free_men_mutex{}; + ring_gt free_men; + free_men.reserve(men.size()); + for (std::size_t i = 0; i != men.size(); ++i) + free_men.push(static_cast(i)); + + // We are gonna need some temporary memory. + buffer_gt proposal_counts(men.size()); + buffer_gt man_to_woman_slots(men.size()); + buffer_gt woman_to_man_slots(women.size()); + if (!proposal_counts || !man_to_woman_slots || !woman_to_man_slots) + return result.failed("Can't temporary mappings"); + + compressed_slot_t missing_slot; + std::memset((void*)&missing_slot, 0xFF, sizeof(compressed_slot_t)); + std::memset((void*)man_to_woman_slots.data(), 0xFF, sizeof(compressed_slot_t) * men.size()); + std::memset((void*)woman_to_man_slots.data(), 0xFF, sizeof(compressed_slot_t) * women.size()); + std::memset(proposal_counts.data(), 0, sizeof(proposals_count_t) * men.size()); + + // Define locks, to limit concurrent accesses to `man_to_woman_slots` and `woman_to_man_slots`. + bitset_t men_locks(men.size()), women_locks(women.size()); + if (!men_locks || !women_locks) + return result.failed("Can't allocate locks"); + + std::atomic rounds{0}; + std::atomic engagements{0}; + std::atomic computed_distances{0}; + std::atomic visited_members{0}; + std::atomic atomic_error{nullptr}; + + // Concurrently process all the men + executor.parallel([&](std::size_t thread_idx) { + index_search_config_t search_config; + search_config.expansion = config.expansion; + search_config.exact = config.exact; + search_config.thread = thread_idx; + compressed_slot_t free_man_slot; + + // While there exist a free man who still has a woman to propose to. + while (!atomic_error.load(std::memory_order_relaxed)) { + std::size_t passed_rounds = 0; + std::size_t total_rounds = 0; + { + std::unique_lock pop_lock(free_men_mutex); + if (!free_men.try_pop(free_man_slot)) + // Primary exit path, we have exhausted the list of candidates + break; + passed_rounds = ++rounds; + total_rounds = passed_rounds + free_men.size(); + } + if (thread_idx == 0 && !progress(passed_rounds, total_rounds)) { + atomic_error.store("Terminated by user"); + break; + } + while (men_locks.atomic_set(free_man_slot)) + ; + + proposals_count_t& free_man_proposals = proposal_counts[free_man_slot]; + if (free_man_proposals >= config.max_proposals) + continue; + + // Find the closest woman, to whom this man hasn't proposed yet. + ++free_man_proposals; + auto candidates = women.search(men_values[free_man_slot], free_man_proposals, women_metric, search_config); + visited_members += candidates.visited_members; + computed_distances += candidates.computed_distances; + if (!candidates) { + atomic_error = candidates.error.release(); + break; + } + + auto match = candidates.back(); + auto woman = match.member; + while (women_locks.atomic_set(woman.slot)) + ; + + compressed_slot_t husband_slot = woman_to_man_slots[woman.slot]; + bool woman_is_free = husband_slot == missing_slot; + if (woman_is_free) { + // Engagement + man_to_woman_slots[free_man_slot] = woman.slot; + woman_to_man_slots[woman.slot] = free_man_slot; + engagements++; + } else { + distance_t distance_from_husband = women_metric(women_values[woman.slot], men_values[husband_slot]); + distance_t distance_from_candidate = match.distance; + if (distance_from_husband > distance_from_candidate) { + // Break-up + while (men_locks.atomic_set(husband_slot)) + ; + man_to_woman_slots[husband_slot] = missing_slot; + men_locks.atomic_reset(husband_slot); + + // New Engagement + man_to_woman_slots[free_man_slot] = woman.slot; + woman_to_man_slots[woman.slot] = free_man_slot; + engagements++; + + std::unique_lock push_lock(free_men_mutex); + free_men.push(husband_slot); + } else { + std::unique_lock push_lock(free_men_mutex); + free_men.push(free_man_slot); + } + } + + men_locks.atomic_reset(free_man_slot); + women_locks.atomic_reset(woman.slot); + } + }); + + if (atomic_error) + return result.failed(atomic_error.load()); + + // Export the "slots" into keys: + std::size_t intersection_size = 0; + for (std::size_t man_slot = 0; man_slot != men.size(); ++man_slot) { + compressed_slot_t woman_slot = man_to_woman_slots[man_slot]; + if (woman_slot != missing_slot) { + man_key_t man = men.at(man_slot).key; + woman_key_t woman = women.at(woman_slot).key; + man_to_woman[man] = woman; + woman_to_man[woman] = man; + intersection_size++; + } + } + + // Export stats + result.engagements = engagements; + result.intersection_size = intersection_size; + result.computed_distances = computed_distances; + result.visited_members = visited_members; + return result; +} + +} // namespace usearch +} // namespace unum + +#endif diff --git a/src/inline-thirdparty/usearch/usearch/index_dense.hpp b/src/inline-thirdparty/usearch/usearch/index_dense.hpp new file mode 100644 index 000000000000..58851829d691 --- /dev/null +++ b/src/inline-thirdparty/usearch/usearch/index_dense.hpp @@ -0,0 +1,2022 @@ +#pragma once +#include // `aligned_alloc` + +#include // `std::function` +#include // `std::iota` +#include // `std::thread` +#include // `std::vector` + +#include +#include + +#if defined(USEARCH_DEFINED_CPP17) +#include // `std::shared_mutex` +#endif + +namespace unum { +namespace usearch { + +template class index_dense_gt; + +/** + * @brief The "magic" sequence helps infer the type of the file. + * USearch indexes start with the "usearch" string. + */ +constexpr char const* default_magic() { return "usearch"; } + +using index_dense_head_buffer_t = byte_t[64]; + +static_assert(sizeof(index_dense_head_buffer_t) == 64, "File header should be exactly 64 bytes"); + +/** + * @brief Serialized binary representations of the USearch index start with metadata. + * Metadata is parsed into a `index_dense_head_t`, containing the USearch package version, + * and the properties of the index. + * + * It uses: 13 bytes for file versioning, 22 bytes for structural information = 35 bytes. + * The following 24 bytes contain binary size of the graph, of the vectors, and the checksum, + * leaving 5 bytes at the end vacant. + */ +struct index_dense_head_t { + + // Versioning: + using magic_t = char[7]; + using version_t = std::uint16_t; + + // Versioning: 7 + 2 * 3 = 13 bytes + char const* magic; + misaligned_ref_gt version_major; + misaligned_ref_gt version_minor; + misaligned_ref_gt version_patch; + + // Structural: 4 * 3 = 12 bytes + misaligned_ref_gt kind_metric; + misaligned_ref_gt kind_scalar; + misaligned_ref_gt kind_key; + misaligned_ref_gt kind_compressed_slot; + + // Population: 8 * 3 = 24 bytes + misaligned_ref_gt count_present; + misaligned_ref_gt count_deleted; + misaligned_ref_gt dimensions; + misaligned_ref_gt multi; + + index_dense_head_t(byte_t* ptr) noexcept + : magic((char const*)exchange(ptr, ptr + sizeof(magic_t))), // + version_major(exchange(ptr, ptr + sizeof(version_t))), // + version_minor(exchange(ptr, ptr + sizeof(version_t))), // + version_patch(exchange(ptr, ptr + sizeof(version_t))), // + kind_metric(exchange(ptr, ptr + sizeof(metric_kind_t))), // + kind_scalar(exchange(ptr, ptr + sizeof(scalar_kind_t))), // + kind_key(exchange(ptr, ptr + sizeof(scalar_kind_t))), // + kind_compressed_slot(exchange(ptr, ptr + sizeof(scalar_kind_t))), // + count_present(exchange(ptr, ptr + sizeof(std::uint64_t))), // + count_deleted(exchange(ptr, ptr + sizeof(std::uint64_t))), // + dimensions(exchange(ptr, ptr + sizeof(std::uint64_t))), // + multi(exchange(ptr, ptr + sizeof(bool))) {} +}; + +struct index_dense_head_result_t { + + index_dense_head_buffer_t buffer; + index_dense_head_t head; + error_t error; + + explicit operator bool() const noexcept { return !error; } + index_dense_head_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } +}; + +struct index_dense_config_t : public index_config_t { + std::size_t expansion_add = default_expansion_add(); + std::size_t expansion_search = default_expansion_search(); + bool exclude_vectors = false; + bool multi = false; + + /** + * @brief Allows you to reduce RAM consumption by avoiding + * reverse-indexing keys-to-vectors, and only keeping + * the vectors-to-keys mappings. + * + * ! This configuration parameter doesn't affect the serialized file, + * ! and is not preserved between runs. Makes sense for small vector + * ! representations that fit ina single cache line. + */ + bool enable_key_lookups = true; + + index_dense_config_t(index_config_t base) noexcept : index_config_t(base) {} + + index_dense_config_t(std::size_t c = default_connectivity(), std::size_t ea = default_expansion_add(), + std::size_t es = default_expansion_search()) noexcept + : index_config_t(c), expansion_add(ea ? ea : default_expansion_add()), + expansion_search(es ? es : default_expansion_search()) {} +}; + +struct index_dense_clustering_config_t { + std::size_t min_clusters = 0; + std::size_t max_clusters = 0; + enum mode_t { + merge_smallest_k, + merge_closest_k, + } mode = merge_smallest_k; +}; + +struct index_dense_serialization_config_t { + bool exclude_vectors = false; + bool use_64_bit_dimensions = false; +}; + +struct index_dense_copy_config_t : public index_copy_config_t { + bool force_vector_copy = true; + + index_dense_copy_config_t() = default; + index_dense_copy_config_t(index_copy_config_t base) noexcept : index_copy_config_t(base) {} +}; + +struct index_dense_metadata_result_t { + index_dense_serialization_config_t config; + index_dense_head_buffer_t head_buffer; + index_dense_head_t head; + error_t error; + + explicit operator bool() const noexcept { return !error; } + index_dense_metadata_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + + index_dense_metadata_result_t() noexcept : config(), head_buffer(), head(head_buffer), error() {} + + index_dense_metadata_result_t(index_dense_metadata_result_t&& other) noexcept + : config(), head_buffer(), head(head_buffer), error(std::move(other.error)) { + std::memcpy(&config, &other.config, sizeof(other.config)); + std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); + } + + index_dense_metadata_result_t& operator=(index_dense_metadata_result_t&& other) noexcept { + std::memcpy(&config, &other.config, sizeof(other.config)); + std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); + error = std::move(other.error); + return *this; + } +}; + +/** + * @brief Extracts metadata from a pre-constructed index on disk, + * without loading it or mapping the whole binary file. + */ +inline index_dense_metadata_result_t index_dense_metadata_from_path(char const* file_path) noexcept { + index_dense_metadata_result_t result; + std::unique_ptr file(std::fopen(file_path, "rb"), &std::fclose); + if (!file) + return result.failed(std::strerror(errno)); + + // Read the header + std::size_t read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + // Check if the file immediately starts with the index, instead of vectors + result.config.exclude_vectors = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + + if (std::fseek(file.get(), 0L, SEEK_END) != 0) + return result.failed("Can't infer file size"); + + // Check if it starts with 32-bit + std::size_t const file_size = std::ftell(file.get()); + + std::uint32_t dimensions_u32[2]{0}; + std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); + std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); + + std::uint64_t dimensions_u64[2]{0}; + std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); + std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); + + // Check if it starts with 32-bit + if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) { + if (std::fseek(file.get(), static_cast(offset_if_u32), SEEK_SET) != 0) + return result.failed(std::strerror(errno)); + read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = false; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + // Check if it starts with 64-bit + if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) { + if (std::fseek(file.get(), static_cast(offset_if_u64), SEEK_SET) != 0) + return result.failed(std::strerror(errno)); + read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + // Check if it starts with 64-bit + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + return result.failed("Not a dense USearch index!"); +} + +/** + * @brief Extracts metadata from a pre-constructed index serialized into an in-memory buffer. + */ +inline index_dense_metadata_result_t index_dense_metadata_from_buffer(memory_mapped_file_t file, + std::size_t offset = 0) noexcept { + index_dense_metadata_result_t result; + + // Read the header + if (offset + sizeof(index_dense_head_buffer_t) >= file.size()) + return result.failed("End of file reached!"); + + byte_t* const file_data = file.data() + offset; + std::size_t const file_size = file.size() - offset; + std::memcpy(&result.head_buffer, file_data, sizeof(index_dense_head_buffer_t)); + + // Check if the file immediately starts with the index, instead of vectors + result.config.exclude_vectors = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + + // Check if it starts with 32-bit + std::uint32_t dimensions_u32[2]{0}; + std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); + std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); + + std::uint64_t dimensions_u64[2]{0}; + std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); + std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); + + // Check if it starts with 32-bit + if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) { + std::memcpy(&result.head_buffer, file_data + offset_if_u32, sizeof(index_dense_head_buffer_t)); + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = false; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + // Check if it starts with 64-bit + if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) { + std::memcpy(&result.head_buffer, file_data + offset_if_u64, sizeof(index_dense_head_buffer_t)); + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + return result.failed("Not a dense USearch index!"); +} + +/** + * @brief Oversimplified type-punned index for equidimensional vectors + * with automatic @b down-casting, hardware-specific @b SIMD metrics, + * and ability to @b remove existing vectors, common in Semantic Caching + * applications. + * + * @section Serialization + * + * The serialized binary form of `index_dense_gt` is made up of three parts: + * 1. Binary matrix, aka the `.bbin` part, + * 2. Metadata about used metrics, number of used vs free slots, + * 3. The HNSW index in a binary form. + * The first (1.) generally starts with 2 integers - number of rows (vectors) and @b single-byte columns. + * The second (2.) starts with @b "usearch"-magic-string, used to infer the file type on open. + * The third (3.) is implemented by the underlying `index_gt` class. + */ +template // +class index_dense_gt { + public: + using vector_key_t = key_at; + using key_t = vector_key_t; + using compressed_slot_t = compressed_slot_at; + using distance_t = distance_punned_t; + using metric_t = metric_punned_t; + + using member_ref_t = member_ref_gt; + using member_cref_t = member_cref_gt; + + using head_t = index_dense_head_t; + using head_buffer_t = index_dense_head_buffer_t; + using head_result_t = index_dense_head_result_t; + + using serialization_config_t = index_dense_serialization_config_t; + + using dynamic_allocator_t = aligned_allocator_gt; + using tape_allocator_t = memory_mapping_allocator_gt<64>; + + private: + /// @brief Schema: input buffer, bytes in input buffer, output buffer. + using cast_t = std::function; + /// @brief Punned index. + using index_t = index_gt< // + distance_t, vector_key_t, compressed_slot_t, // + dynamic_allocator_t, tape_allocator_t>; + using index_allocator_t = aligned_allocator_gt; + + using member_iterator_t = typename index_t::member_iterator_t; + using member_citerator_t = typename index_t::member_citerator_t; + + /// @brief Punned metric object. + class metric_proxy_t { + index_dense_gt const* index_ = nullptr; + + public: + metric_proxy_t(index_dense_gt const& index) noexcept : index_(&index) {} + + inline distance_t operator()(byte_t const* a, member_cref_t b) const noexcept { return f(a, v(b)); } + inline distance_t operator()(member_cref_t a, member_cref_t b) const noexcept { return f(v(a), v(b)); } + + inline distance_t operator()(byte_t const* a, member_citerator_t b) const noexcept { return f(a, v(b)); } + inline distance_t operator()(member_citerator_t a, member_citerator_t b) const noexcept { + return f(v(a), v(b)); + } + + inline distance_t operator()(byte_t const* a, byte_t const* b) const noexcept { return f(a, b); } + + inline byte_t const* v(member_cref_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } + inline byte_t const* v(member_citerator_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } + inline distance_t f(byte_t const* a, byte_t const* b) const noexcept { return index_->metric_(a, b); } + }; + + index_dense_config_t config_; + index_t* typed_ = nullptr; + + mutable std::vector cast_buffer_; + struct casts_t { + cast_t from_b1x8; + cast_t from_i8; + cast_t from_f16; + cast_t from_f32; + cast_t from_f64; + + cast_t to_b1x8; + cast_t to_i8; + cast_t to_f16; + cast_t to_f32; + cast_t to_f64; + } casts_; + + /// @brief An instance of a potentially stateful `metric_t` used to initialize copies and forks. + metric_t metric_; + + using vectors_tape_allocator_t = memory_mapping_allocator_gt<8>; + /// @brief Allocator for the copied vectors, aligned to widest double-precision scalars. + vectors_tape_allocator_t vectors_tape_allocator_; + + /// @brief For every managed `compressed_slot_t` stores a pointer to the allocated vector copy. + mutable std::vector vectors_lookup_; + + /// @brief Originally forms and array of integers [0, threads], marking all + mutable std::vector available_threads_; + + /// @brief Mutex, controlling concurrent access to `available_threads_`. + mutable std::mutex available_threads_mutex_; + +#if defined(USEARCH_DEFINED_CPP17) + using shared_mutex_t = std::shared_mutex; +#else + using shared_mutex_t = unfair_shared_mutex_t; +#endif + using shared_lock_t = shared_lock_gt; + using unique_lock_t = std::unique_lock; + + struct key_and_slot_t { + vector_key_t key; + compressed_slot_t slot; + + bool any_slot() const { return slot == default_free_value(); } + static key_and_slot_t any_slot(vector_key_t key) { return {key, default_free_value()}; } + }; + + struct lookup_key_hash_t { + using is_transparent = void; + std::size_t operator()(key_and_slot_t const& k) const noexcept { return std::hash{}(k.key); } + std::size_t operator()(vector_key_t const& k) const noexcept { return std::hash{}(k); } + }; + + struct lookup_key_same_t { + using is_transparent = void; + bool operator()(key_and_slot_t const& a, vector_key_t const& b) const noexcept { return a.key == b; } + bool operator()(vector_key_t const& a, key_and_slot_t const& b) const noexcept { return a == b.key; } + bool operator()(key_and_slot_t const& a, key_and_slot_t const& b) const noexcept { return a.key == b.key; } + }; + + /// @brief Multi-Map from keys to IDs, and allocated vectors. + flat_hash_multi_set_gt slot_lookup_; + + /// @brief Mutex, controlling concurrent access to `slot_lookup_`. + mutable shared_mutex_t slot_lookup_mutex_; + + /// @brief Ring-shaped queue of deleted entries, to be reused on future insertions. + ring_gt free_keys_; + + /// @brief Mutex, controlling concurrent access to `free_keys_`. + mutable std::mutex free_keys_mutex_; + + /// @brief A constant for the reserved key value, used to mark deleted entries. + vector_key_t free_key_ = default_free_value(); + + public: + using search_result_t = typename index_t::search_result_t; + using cluster_result_t = typename index_t::cluster_result_t; + using add_result_t = typename index_t::add_result_t; + using stats_t = typename index_t::stats_t; + using match_t = typename index_t::match_t; + + index_dense_gt() = default; + index_dense_gt(index_dense_gt&& other) + : config_(std::move(other.config_)), + + typed_(exchange(other.typed_, nullptr)), // + cast_buffer_(std::move(other.cast_buffer_)), // + casts_(std::move(other.casts_)), // + metric_(std::move(other.metric_)), // + + vectors_tape_allocator_(std::move(other.vectors_tape_allocator_)), // + vectors_lookup_(std::move(other.vectors_lookup_)), // + + available_threads_(std::move(other.available_threads_)), // + slot_lookup_(std::move(other.slot_lookup_)), // + free_keys_(std::move(other.free_keys_)), // + free_key_(std::move(other.free_key_)) {} // + + index_dense_gt& operator=(index_dense_gt&& other) { + swap(other); + return *this; + } + + /** + * @brief Swaps the contents of this index with another index. + * @param other The other index to swap with. + */ + void swap(index_dense_gt& other) { + std::swap(config_, other.config_); + + std::swap(typed_, other.typed_); + std::swap(cast_buffer_, other.cast_buffer_); + std::swap(casts_, other.casts_); + std::swap(metric_, other.metric_); + + std::swap(vectors_tape_allocator_, other.vectors_tape_allocator_); + std::swap(vectors_lookup_, other.vectors_lookup_); + + std::swap(available_threads_, other.available_threads_); + std::swap(slot_lookup_, other.slot_lookup_); + std::swap(free_keys_, other.free_keys_); + std::swap(free_key_, other.free_key_); + } + + ~index_dense_gt() { + if (typed_) + typed_->~index_t(); + index_allocator_t{}.deallocate(typed_, 1); + typed_ = nullptr; + } + + /** + * @brief Constructs an instance of ::index_dense_gt. + * @param[in] metric One of the provided or an @b ad-hoc metric, type-punned. + * @param[in] config The index configuration (optional). + * @param[in] free_key The key used for freed vectors (optional). + * @return An instance of ::index_dense_gt. + */ + static index_dense_gt make( // + metric_t metric, // + index_dense_config_t config = {}, // + vector_key_t free_key = default_free_value()) { + + scalar_kind_t scalar_kind = metric.scalar_kind(); + std::size_t hardware_threads = std::thread::hardware_concurrency(); + + index_dense_gt result; + result.config_ = config; + result.cast_buffer_.resize(hardware_threads * metric.bytes_per_vector()); + result.casts_ = make_casts_(scalar_kind); + result.metric_ = metric; + result.free_key_ = free_key; + + // Fill the thread IDs. + result.available_threads_.resize(hardware_threads); + std::iota(result.available_threads_.begin(), result.available_threads_.end(), 0ul); + + // Available since C11, but only C++17, so we use the C version. + index_t* raw = index_allocator_t{}.allocate(1); + new (raw) index_t(config); + result.typed_ = raw; + return result; + } + + static index_dense_gt make(char const* path, bool view = false) { + index_dense_metadata_result_t meta = index_dense_metadata_from_path(path); + if (!meta) + return {}; + metric_punned_t metric(meta.head.dimensions, meta.head.kind_metric, meta.head.kind_scalar); + index_dense_gt result = make(metric); + if (!result) + return result; + if (view) + result.view(path); + else + result.load(path); + return result; + } + + explicit operator bool() const { return typed_; } + std::size_t connectivity() const { return typed_->connectivity(); } + std::size_t size() const { return typed_->size() - free_keys_.size(); } + std::size_t capacity() const { return typed_->capacity(); } + std::size_t max_level() const noexcept { return typed_->max_level(); } + index_dense_config_t const& config() const { return config_; } + index_limits_t const& limits() const { return typed_->limits(); } + bool multi() const { return config_.multi; } + + // The metric and its properties + metric_t const& metric() const { return metric_; } + void change_metric(metric_t metric) { metric_ = std::move(metric); } + + scalar_kind_t scalar_kind() const noexcept { return metric_.scalar_kind(); } + std::size_t bytes_per_vector() const noexcept { return metric_.bytes_per_vector(); } + std::size_t scalar_words() const noexcept { return metric_.scalar_words(); } + std::size_t dimensions() const noexcept { return metric_.dimensions(); } + + // Fetching and changing search criteria + std::size_t expansion_add() const { return config_.expansion_add; } + std::size_t expansion_search() const { return config_.expansion_search; } + void change_expansion_add(std::size_t n) { config_.expansion_add = n; } + void change_expansion_search(std::size_t n) { config_.expansion_search = n; } + + member_citerator_t cbegin() const { return typed_->cbegin(); } + member_citerator_t cend() const { return typed_->cend(); } + member_citerator_t begin() const { return typed_->begin(); } + member_citerator_t end() const { return typed_->end(); } + member_iterator_t begin() { return typed_->begin(); } + member_iterator_t end() { return typed_->end(); } + + stats_t stats() const { return typed_->stats(); } + stats_t stats(std::size_t level) const { return typed_->stats(level); } + stats_t stats(stats_t* stats_per_level, std::size_t max_level) const { + return typed_->stats(stats_per_level, max_level); + } + + dynamic_allocator_t const& allocator() const { return typed_->dynamic_allocator(); } + vector_key_t const& free_key() const { return free_key_; } + + /** + * @brief A relatively accurate lower bound on the amount of memory consumed by the system. + * In practice it's error will be below 10%. + * + * @see `serialized_length` for the length of the binary serialized representation. + */ + std::size_t memory_usage() const { + return // + typed_->memory_usage(0) + // + typed_->tape_allocator().total_wasted() + // + typed_->tape_allocator().total_reserved() + // + vectors_tape_allocator_.total_allocated(); + } + + static constexpr std::size_t any_thread() { return std::numeric_limits::max(); } + static constexpr distance_t infinite_distance() { return std::numeric_limits::max(); } + + struct aggregated_distances_t { + std::size_t count = 0; + distance_t mean = infinite_distance(); + distance_t min = infinite_distance(); + distance_t max = infinite_distance(); + }; + + // clang-format off + add_result_t add(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_b1x8); } + add_result_t add(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_i8); } + add_result_t add(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f16); } + add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f32); } + add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f64); } + + search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_b1x8); } + search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_i8); } + search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f16); } + search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f32); } + search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f64); } + + template search_result_t filtered_search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_b1x8); } + template search_result_t filtered_search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_i8); } + template search_result_t filtered_search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_f16); } + template search_result_t filtered_search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_f32); } + template search_result_t filtered_search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_f64); } + + std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_b1x8); } + std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_i8); } + std::size_t get(vector_key_t key, f16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f16); } + std::size_t get(vector_key_t key, f32_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f32); } + std::size_t get(vector_key_t key, f64_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f64); } + + cluster_result_t cluster(b1x8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_b1x8); } + cluster_result_t cluster(i8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_i8); } + cluster_result_t cluster(f16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f16); } + cluster_result_t cluster(f32_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f32); } + cluster_result_t cluster(f64_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f64); } + + aggregated_distances_t distance_between(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_b1x8); } + aggregated_distances_t distance_between(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_i8); } + aggregated_distances_t distance_between(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f16); } + aggregated_distances_t distance_between(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f32); } + aggregated_distances_t distance_between(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f64); } + // clang-format on + + /** + * @brief Computes the distance between two managed entities. + * If either key maps into more than one vector, will aggregate results + * exporting the mean, maximum, and minimum values. + */ + aggregated_distances_t distance_between(vector_key_t a, vector_key_t b, std::size_t = any_thread()) const { + shared_lock_t lock(slot_lookup_mutex_); + aggregated_distances_t result; + if (!multi()) { + auto a_it = slot_lookup_.find(key_and_slot_t::any_slot(a)); + auto b_it = slot_lookup_.find(key_and_slot_t::any_slot(b)); + bool a_missing = a_it == slot_lookup_.end(); + bool b_missing = b_it == slot_lookup_.end(); + if (a_missing || b_missing) + return result; + + key_and_slot_t a_key_and_slot = *a_it; + byte_t const* a_vector = vectors_lookup_[a_key_and_slot.slot]; + key_and_slot_t b_key_and_slot = *b_it; + byte_t const* b_vector = vectors_lookup_[b_key_and_slot.slot]; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean = result.min = result.max = a_b_distance; + result.count = 1; + return result; + } + + auto a_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(a)); + auto b_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(b)); + bool a_missing = a_range.first == a_range.second; + bool b_missing = b_range.first == b_range.second; + if (a_missing || b_missing) + return result; + + result.min = std::numeric_limits::max(); + result.max = std::numeric_limits::min(); + result.mean = 0; + result.count = 0; + + while (a_range.first != a_range.second) { + key_and_slot_t a_key_and_slot = *a_range.first; + byte_t const* a_vector = vectors_lookup_[a_key_and_slot.slot]; + while (b_range.first != b_range.second) { + key_and_slot_t b_key_and_slot = *b_range.first; + byte_t const* b_vector = vectors_lookup_[b_key_and_slot.slot]; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean += a_b_distance; + result.min = (std::min)(result.min, a_b_distance); + result.max = (std::max)(result.max, a_b_distance); + result.count++; + + // + ++b_range.first; + } + ++a_range.first; + } + + result.mean /= result.count; + return result; + } + + /** + * @brief Identifies a node in a given `level`, that is the closest to the `key`. + */ + cluster_result_t cluster(vector_key_t key, std::size_t level, std::size_t thread = any_thread()) const { + + // Check if such `key` is even present. + shared_lock_t slots_lock(slot_lookup_mutex_); + auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + cluster_result_t result; + if (key_range.first == key_range.second) + return result.failed("Key missing!"); + + index_cluster_config_t cluster_config; + thread_lock_t lock = thread_lock_(thread); + cluster_config.thread = lock.thread_id; + cluster_config.expansion = config_.expansion_search; + metric_proxy_t metric{*this}; + auto allow = [free_key_ = this->free_key_](member_cref_t const& member) noexcept { + return member.key != free_key_; + }; + + // Find the closest cluster for any vector under that key. + while (key_range.first != key_range.second) { + key_and_slot_t key_and_slot = *key_range.first; + byte_t const* vector_data = vectors_lookup_[key_and_slot.slot]; + cluster_result_t new_result = typed_->cluster(vector_data, level, metric, cluster_config, allow); + if (!new_result) + return new_result; + if (new_result.cluster.distance < result.cluster.distance) + result = std::move(new_result); + + ++key_range.first; + } + return result; + } + + /** + * @brief Reserves memory for the index and the keyed lookup. + * @return `true` if the memory reservation was successful, `false` otherwise. + */ + bool reserve(index_limits_t limits) { + { + unique_lock_t lock(slot_lookup_mutex_); + slot_lookup_.reserve(limits.members); + vectors_lookup_.resize(limits.members); + } + return typed_->reserve(limits); + } + + /** + * @brief Erases all the vectors from the index. + * + * Will change `size()` to zero, but will keep the same `capacity()`. + * Will keep the number of available threads/contexts the same as it was. + */ + void clear() { + unique_lock_t lookup_lock(slot_lookup_mutex_); + + std::unique_lock free_lock(free_keys_mutex_); + typed_->clear(); + slot_lookup_.clear(); + vectors_lookup_.clear(); + free_keys_.clear(); + vectors_tape_allocator_.reset(); + } + + /** + * @brief Erases all members from index, closing files, and returning RAM to OS. + * + * Will change both `size()` and `capacity()` to zero. + * Will deallocate all threads/contexts. + * If the index is memory-mapped - releases the mapping and the descriptor. + */ + void reset() { + unique_lock_t lookup_lock(slot_lookup_mutex_); + + std::unique_lock free_lock(free_keys_mutex_); + std::unique_lock available_threads_lock(available_threads_mutex_); + typed_->reset(); + slot_lookup_.clear(); + vectors_lookup_.clear(); + free_keys_.clear(); + vectors_tape_allocator_.reset(); + + // Reset the thread IDs. + available_threads_.resize(std::thread::hardware_concurrency()); + std::iota(available_threads_.begin(), available_threads_.end(), 0ul); + } + + /** + * @brief Saves serialized binary index representation to a stream. + */ + template + serialization_result_t save_to_stream(output_callback_at&& output, // + serialization_config_t config = {}, // + progress_at&& progress = {}) const { + + serialization_result_t result; + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + + // We may not want to put the vectors into the same file + if (!config.exclude_vectors) { + // Save the matrix size + if (!config.use_64_bit_dimensions) { + std::uint32_t dimensions[2]; + dimensions[0] = static_cast(typed_->size()); + dimensions[1] = static_cast(metric_.bytes_per_vector()); + if (!output(&dimensions, sizeof(dimensions))) + return result.failed("Failed to serialize into stream"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } else { + std::uint64_t dimensions[2]; + dimensions[0] = static_cast(typed_->size()); + dimensions[1] = static_cast(metric_.bytes_per_vector()); + if (!output(&dimensions, sizeof(dimensions))) + return result.failed("Failed to serialize into stream"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } + + // Dump the vectors one after another + for (std::uint64_t i = 0; i != matrix_rows; ++i) { + byte_t* vector = vectors_lookup_[i]; + if (!output(vector, matrix_cols)) + return result.failed("Failed to serialize into stream"); + } + } + + // Augment metadata + { + index_dense_head_buffer_t buffer; + std::memset(buffer, 0, sizeof(buffer)); + index_dense_head_t head{buffer}; + std::memcpy(buffer, default_magic(), std::strlen(default_magic())); + + // Describe software version + using version_t = index_dense_head_t::version_t; + head.version_major = static_cast(USEARCH_VERSION_MAJOR); + head.version_minor = static_cast(USEARCH_VERSION_MINOR); + head.version_patch = static_cast(USEARCH_VERSION_PATCH); + + // Describes types used + head.kind_metric = metric_.metric_kind(); + head.kind_scalar = metric_.scalar_kind(); + head.kind_key = unum::usearch::scalar_kind(); + head.kind_compressed_slot = unum::usearch::scalar_kind(); + + head.count_present = size(); + head.count_deleted = typed_->size() - size(); + head.dimensions = dimensions(); + head.multi = multi(); + + if (!output(&buffer, sizeof(buffer))) + return result.failed("Failed to serialize into stream"); + } + + // Save the actual proximity graph + return typed_->save_to_stream(std::forward(output), std::forward(progress)); + } + + /** + * @brief Estimate the binary length (in bytes) of the serialized index. + */ + std::size_t serialized_length(serialization_config_t config = {}) const noexcept { + std::size_t dimensions_length = 0; + std::size_t matrix_length = 0; + if (!config.exclude_vectors) { + dimensions_length = config.use_64_bit_dimensions ? sizeof(std::uint64_t) * 2 : sizeof(std::uint32_t) * 2; + matrix_length = typed_->size() * metric_.bytes_per_vector(); + } + return dimensions_length + matrix_length + sizeof(index_dense_head_buffer_t) + typed_->serialized_length(); + } + + /** + * @brief Parses the index from file to RAM. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for imports. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t load_from_stream(input_callback_at&& input, // + serialization_config_t config = {}, // + progress_at&& progress = {}) { + + // Discard all previous memory allocations of `vectors_tape_allocator_` + reset(); + + // Infer the new index size + serialization_result_t result; + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + + // We may not want to load the vectors from the same file, or allow attaching them afterwards + if (!config.exclude_vectors) { + // Save the matrix size + if (!config.use_64_bit_dimensions) { + std::uint32_t dimensions[2]; + if (!input(&dimensions, sizeof(dimensions))) + return result.failed("Failed to read 32-bit dimensions of the matrix"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } else { + std::uint64_t dimensions[2]; + if (!input(&dimensions, sizeof(dimensions))) + return result.failed("Failed to read 64-bit dimensions of the matrix"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } + // Load the vectors one after another + vectors_lookup_.resize(matrix_rows); + for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) { + byte_t* vector = vectors_tape_allocator_.allocate(matrix_cols); + if (!input(vector, matrix_cols)) + return result.failed("Failed to read vectors"); + vectors_lookup_[slot] = vector; + } + } + + // Load metadata and choose the right metric + { + index_dense_head_buffer_t buffer; + if (!input(buffer, sizeof(buffer))) + return result.failed("Failed to read the index "); + + index_dense_head_t head{buffer}; + if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) + return result.failed("Magic header mismatch - the file isn't an index"); + + // Validate the software version + if (head.version_major != USEARCH_VERSION_MAJOR) + return result.failed("File format may be different, please rebuild"); + + // Check the types used + if (head.kind_key != unum::usearch::scalar_kind()) + return result.failed("Key type doesn't match, consider rebuilding"); + if (head.kind_compressed_slot != unum::usearch::scalar_kind()) + return result.failed("Slot type doesn't match, consider rebuilding"); + + config_.multi = head.multi; + metric_ = metric_t::builtin(head.dimensions, head.kind_metric, head.kind_scalar); + cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); + casts_ = make_casts_(head.kind_scalar); + } + + // Pull the actual proximity graph + result = typed_->load_from_stream(std::forward(input), std::forward(progress)); + if (!result) + return result; + if (typed_->size() != static_cast(matrix_rows)) + return result.failed("Index size and the number of vectors doesn't match"); + + reindex_keys_(); + return result; + } + + /** + * @brief Parses the index from file, without loading it into RAM. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for imports. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t view(memory_mapped_file_t file, // + std::size_t offset = 0, serialization_config_t config = {}, // + progress_at&& progress = {}) { + + // Discard all previous memory allocations of `vectors_tape_allocator_` + reset(); + + serialization_result_t result = file.open_if_not(); + if (!result) + return result; + + // Infer the new index size + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + span_punned_t vectors_buffer; + + // We may not want to fetch the vectors from the same file, or allow attaching them afterwards + if (!config.exclude_vectors) { + // Save the matrix size + if (!config.use_64_bit_dimensions) { + std::uint32_t dimensions[2]; + if (file.size() - offset < sizeof(dimensions)) + return result.failed("File is corrupted and lacks matrix dimensions"); + std::memcpy(&dimensions, file.data() + offset, sizeof(dimensions)); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + offset += sizeof(dimensions); + } else { + std::uint64_t dimensions[2]; + if (file.size() - offset < sizeof(dimensions)) + return result.failed("File is corrupted and lacks matrix dimensions"); + std::memcpy(&dimensions, file.data() + offset, sizeof(dimensions)); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + offset += sizeof(dimensions); + } + vectors_buffer = {file.data() + offset, static_cast(matrix_rows * matrix_cols)}; + offset += vectors_buffer.size(); + } + + // Load metadata and choose the right metric + { + index_dense_head_buffer_t buffer; + if (file.size() - offset < sizeof(buffer)) + return result.failed("File is corrupted and lacks a header"); + + std::memcpy(buffer, file.data() + offset, sizeof(buffer)); + + index_dense_head_t head{buffer}; + if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) + return result.failed("Magic header mismatch - the file isn't an index"); + + // Validate the software version + if (head.version_major != USEARCH_VERSION_MAJOR) + return result.failed("File format may be different, please rebuild"); + + // Check the types used + if (head.kind_key != unum::usearch::scalar_kind()) + return result.failed("Key type doesn't match, consider rebuilding"); + if (head.kind_compressed_slot != unum::usearch::scalar_kind()) + return result.failed("Slot type doesn't match, consider rebuilding"); + + config_.multi = head.multi; + metric_ = metric_t::builtin(head.dimensions, head.kind_metric, head.kind_scalar); + cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); + casts_ = make_casts_(head.kind_scalar); + offset += sizeof(buffer); + } + + // Pull the actual proximity graph + result = typed_->view(std::move(file), offset, std::forward(progress)); + if (!result) + return result; + if (typed_->size() != static_cast(matrix_rows)) + return result.failed("Index size and the number of vectors doesn't match"); + + // Address the vectors + vectors_lookup_.resize(matrix_rows); + if (!config.exclude_vectors) + for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) + vectors_lookup_[slot] = (byte_t*)vectors_buffer.data() + matrix_cols * slot; + + reindex_keys_(); + return result; + } + + /** + * @brief Saves the index to a file. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for exports. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t save(output_file_t file, serialization_config_t config = {}, + progress_at&& progress = {}) const { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void const* buffer, std::size_t length) { + io_result = file.write(buffer, length); + return !!io_result; + }, + config, std::forward(progress)); + + if (!stream_result) { + io_result.error.release(); + return stream_result; + } + return io_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t save(memory_mapped_file_t file, // + std::size_t offset = 0, // + serialization_config_t config = {}, // + progress_at&& progress = {}) const { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void const* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(file.data() + offset, buffer, length); + offset += length; + return true; + }, + config, std::forward(progress)); + + return stream_result; + } + + /** + * @brief Parses the index from file to RAM. + * @param[in] path The path to the file. + * @param[in] config Configuration parameters for imports. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t load(input_file_t file, serialization_config_t config = {}, progress_at&& progress = {}) { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + io_result = file.read(buffer, length); + return !!io_result; + }, + config, std::forward(progress)); + + if (!stream_result) { + io_result.error.release(); + return stream_result; + } + return io_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t load(memory_mapped_file_t file, // + std::size_t offset = 0, // + serialization_config_t config = {}, // + progress_at&& progress = {}) { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(buffer, file.data() + offset, length); + offset += length; + return true; + }, + config, std::forward(progress)); + + return stream_result; + } + + template + serialization_result_t save(char const* file_path, // + serialization_config_t config = {}, // + progress_at&& progress = {}) const { + return save(output_file_t(file_path), config, std::forward(progress)); + } + + template + serialization_result_t load(char const* file_path, // + serialization_config_t config = {}, // + progress_at&& progress = {}) { + return load(input_file_t(file_path), config, std::forward(progress)); + } + + /** + * @brief Checks if a vector with specified key is present. + * @return `true` if the key is present in the index, `false` otherwise. + */ + bool contains(vector_key_t key) const { + shared_lock_t lock(slot_lookup_mutex_); + return slot_lookup_.contains(key_and_slot_t::any_slot(key)); + } + + /** + * @brief Count the number of vectors with specified key present. + * @return Zero if nothing is found, a positive integer otherwise. + */ + std::size_t count(vector_key_t key) const { + shared_lock_t lock(slot_lookup_mutex_); + return slot_lookup_.count(key_and_slot_t::any_slot(key)); + } + + struct labeling_result_t { + error_t error{}; + std::size_t completed{}; + + explicit operator bool() const noexcept { return !error; } + labeling_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Removes an entry with the specified key from the index. + * @param[in] key The key of the entry to remove. + * @return The ::labeling_result_t indicating the result of the removal operation. + * If the removal was successful, `result.completed` will be `true`. + * If the key was not found in the index, `result.completed` will be `false`. + * If an error occurred during the removal operation, `result.error` will contain an error message. + */ + labeling_result_t remove(vector_key_t key) { + labeling_result_t result; + + unique_lock_t lookup_lock(slot_lookup_mutex_); + auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + if (matching_slots.first == matching_slots.second) + return result; + + // Grow the removed entries ring, if needed + std::size_t matching_count = std::distance(matching_slots.first, matching_slots.second); + std::unique_lock free_lock(free_keys_mutex_); + if (!free_keys_.reserve(free_keys_.size() + matching_count)) + return result.failed("Can't allocate memory for a free-list"); + + // A removed entry would be: + // - present in `free_keys_` + // - missing in the `slot_lookup_` + // - marked in the `typed_` index with a `free_key_` + for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) { + compressed_slot_t slot = (*slots_it).slot; + free_keys_.push(slot); + typed_->at(slot).key = free_key_; + } + slot_lookup_.erase(key); + result.completed = matching_count; + + return result; + } + + /** + * @brief Removes multiple entries with the specified keys from the index. + * @param[in] keys_begin The beginning of the keys range. + * @param[in] keys_end The ending of the keys range. + * @return The ::labeling_result_t indicating the result of the removal operation. + * `result.completed` will contain the number of keys that were successfully removed. + * `result.error` will contain an error message if an error occurred during the removal operation. + */ + template + labeling_result_t remove(keys_iterator_at keys_begin, keys_iterator_at keys_end) { + + labeling_result_t result; + unique_lock_t lookup_lock(slot_lookup_mutex_); + std::unique_lock free_lock(free_keys_mutex_); + // Grow the removed entries ring, if needed + std::size_t matching_count = 0; + for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) + matching_count += slot_lookup_.count(key_and_slot_t::any_slot(*keys_it)); + + if (!free_keys_.reserve(free_keys_.size() + matching_count)) + return result.failed("Can't allocate memory for a free-list"); + + // Remove them one-by-one + for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) { + vector_key_t key = *keys_it; + auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + // A removed entry would be: + // - present in `free_keys_` + // - missing in the `slot_lookup_` + // - marked in the `typed_` index with a `free_key_` + matching_count = 0; + for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) { + compressed_slot_t slot = (*slots_it).slot; + free_keys_.push(slot); + typed_->at(slot).key = free_key_; + ++matching_count; + } + + slot_lookup_.erase(key); + result.completed += matching_count; + } + + return result; + } + + /** + * @brief Renames an entry with the specified key to a new key. + * @param[in] from The current key of the entry to rename. + * @param[in] to The new key to assign to the entry. + * @return The ::labeling_result_t indicating the result of the rename operation. + * If the rename was successful, `result.completed` will be `true`. + * If the entry with the current key was not found, `result.completed` will be `false`. + */ + labeling_result_t rename(vector_key_t from, vector_key_t to) { + labeling_result_t result; + unique_lock_t lookup_lock(slot_lookup_mutex_); + + if (!multi() && slot_lookup_.contains(key_and_slot_t::any_slot(to))) + return result.failed("Renaming impossible, the key is already in use"); + + // The `from` may map to multiple entries + while (true) { + key_and_slot_t key_and_slot_removed; + if (!slot_lookup_.pop_first(key_and_slot_t::any_slot(from), key_and_slot_removed)) + break; + + key_and_slot_t key_and_slot_replacing{to, key_and_slot_removed.slot}; + slot_lookup_.try_emplace(key_and_slot_replacing); // This can't fail + typed_->at(key_and_slot_removed.slot).key = to; + ++result.completed; + } + + return result; + } + + /** + * @brief Exports a range of keys for the vectors present in the index. + * @param[out] keys Pointer to the array where the keys will be exported. + * @param[in] offset The number of keys to skip. Useful for pagination. + * @param[in] limit The maximum number of keys to export, that can fit in ::keys. + */ + void export_keys(vector_key_t* keys, std::size_t offset, std::size_t limit) const { + shared_lock_t lock(slot_lookup_mutex_); + offset = (std::min)(offset, slot_lookup_.size()); + slot_lookup_.for_each([&](key_and_slot_t const& key_and_slot) { + if (offset) + // Skip the first `offset` entries + --offset; + else if (limit) { + *keys = key_and_slot.key; + ++keys; + --limit; + } + }); + } + + struct copy_result_t { + index_dense_gt index; + error_t error; + + explicit operator bool() const noexcept { return !error; } + copy_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Copies the ::index_dense_gt @b with all the data in it. + * @param config The copy configuration (optional). + * @return A copy of the ::index_dense_gt instance. + */ + copy_result_t copy(index_dense_copy_config_t config = {}) const { + copy_result_t result = fork(); + if (!result) + return result; + + auto typed_result = typed_->copy(config); + if (!typed_result) + return result.failed(std::move(typed_result.error)); + + // Export the free (removed) slot numbers + index_dense_gt& copy = result.index; + if (!copy.free_keys_.reserve(free_keys_.size())) + return result.failed(std::move(typed_result.error)); + for (std::size_t i = 0; i != free_keys_.size(); ++i) + copy.free_keys_.push(free_keys_[i]); + + // Allocate buffers and move the vectors themselves + if (!config.force_vector_copy && copy.config_.exclude_vectors) + copy.vectors_lookup_ = vectors_lookup_; + else { + copy.vectors_lookup_.resize(vectors_lookup_.size()); + for (std::size_t slot = 0; slot != vectors_lookup_.size(); ++slot) + copy.vectors_lookup_[slot] = copy.vectors_tape_allocator_.allocate(copy.metric_.bytes_per_vector()); + if (std::count(copy.vectors_lookup_.begin(), copy.vectors_lookup_.end(), nullptr)) + return result.failed("Out of memory!"); + for (std::size_t slot = 0; slot != vectors_lookup_.size(); ++slot) + std::memcpy(copy.vectors_lookup_[slot], vectors_lookup_[slot], metric_.bytes_per_vector()); + } + + copy.slot_lookup_ = slot_lookup_; + *copy.typed_ = std::move(typed_result.index); + return result; + } + + /** + * @brief Copies the ::index_dense_gt model @b without any data. + * @return A similarly configured ::index_dense_gt instance. + */ + copy_result_t fork() const { + copy_result_t result; + index_dense_gt& other = result.index; + + other.config_ = config_; + other.cast_buffer_ = cast_buffer_; + other.casts_ = casts_; + + other.metric_ = metric_; + other.available_threads_ = available_threads_; + other.free_key_ = free_key_; + + index_t* raw = index_allocator_t{}.allocate(1); + if (!raw) + return result.failed("Can't allocate the index"); + + new (raw) index_t(config()); + other.typed_ = raw; + return result; + } + + struct compaction_result_t { + error_t error{}; + std::size_t pruned_edges{}; + + explicit operator bool() const noexcept { return !error; } + compaction_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Performs compaction on the index, pruning links to removed entries. + * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. + * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. + * @return The ::compaction_result_t indicating the result of the compaction operation. + * `result.pruned_edges` will contain the number of edges that were removed. + * `result.error` will contain an error message if an error occurred during the compaction operation. + */ + template + compaction_result_t isolate(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + compaction_result_t result; + std::atomic pruned_edges; + auto disallow = [&](member_cref_t const& member) noexcept { + bool freed = member.key == free_key_; + pruned_edges += freed; + return freed; + }; + typed_->isolate(disallow, std::forward(executor), std::forward(progress)); + result.pruned_edges = pruned_edges; + return result; + } + + class values_proxy_t { + index_dense_gt const* index_; + + public: + values_proxy_t(index_dense_gt const& index) noexcept : index_(&index) {} + byte_t const* operator[](compressed_slot_t slot) const noexcept { return index_->vectors_lookup_[slot]; } + byte_t const* operator[](member_citerator_t it) const noexcept { return index_->vectors_lookup_[get_slot(it)]; } + }; + + /** + * @brief Performs compaction on the index, pruning links to removed entries. + * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. + * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. + * @return The ::compaction_result_t indicating the result of the compaction operation. + * `result.pruned_edges` will contain the number of edges that were removed. + * `result.error` will contain an error message if an error occurred during the compaction operation. + */ + template + compaction_result_t compact(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + compaction_result_t result; + + std::vector new_vectors_lookup(vectors_lookup_.size()); + vectors_tape_allocator_t new_vectors_allocator; + + auto track_slot_change = [&](vector_key_t, compressed_slot_t old_slot, compressed_slot_t new_slot) { + byte_t* new_vector = new_vectors_allocator.allocate(metric_.bytes_per_vector()); + byte_t* old_vector = vectors_lookup_[old_slot]; + std::memcpy(new_vector, old_vector, metric_.bytes_per_vector()); + new_vectors_lookup[new_slot] = new_vector; + }; + typed_->compact(values_proxy_t{*this}, metric_proxy_t{*this}, track_slot_change, + std::forward(executor), std::forward(progress)); + vectors_lookup_ = std::move(new_vectors_lookup); + vectors_tape_allocator_ = std::move(new_vectors_allocator); + return result; + } + + template < // + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + join_result_t join( // + index_dense_gt const& women, // + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) const { + + index_dense_gt const& men = *this; + return unum::usearch::join( // + *men.typed_, *women.typed_, // + values_proxy_t{men}, values_proxy_t{women}, // + metric_proxy_t{men}, metric_proxy_t{women}, // + config, // + std::forward(man_to_woman), // + std::forward(woman_to_man), // + std::forward(executor), // + std::forward(progress)); + } + + struct clustering_result_t { + error_t error{}; + std::size_t clusters{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + + explicit operator bool() const noexcept { return !error; } + clustering_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Implements clustering, classifying the given objects (vectors of member keys) + * into a given number of clusters. + * + * @param[in] queries_begin Iterator pointing to the first query. + * @param[in] queries_end Iterator pointing to the last query. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + * @param[in] config Configuration parameters for clustering. + * + * @param[out] cluster_keys Pointer to the array where the cluster keys will be exported. + * @param[out] cluster_distances Pointer to the array where the distances to those centroids will be exported. + */ + template < // + typename queries_iterator_at, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + clustering_result_t cluster( // + queries_iterator_at queries_begin, // + queries_iterator_at queries_end, // + index_dense_clustering_config_t config, // + vector_key_t* cluster_keys, // + distance_t* cluster_distances, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) { + + std::size_t const queries_count = queries_end - queries_begin; + + // Find the first level (top -> down) that has enough nodes to exceed `config.min_clusters`. + std::size_t level = max_level(); + if (config.min_clusters) { + for (; level > 1; --level) { + if (stats(level).nodes > config.min_clusters) + break; + } + } else + level = 1, config.max_clusters = stats(1).nodes, config.min_clusters = 2; + + clustering_result_t result; + if (max_level() < 2) + return result.failed("Index too small to cluster!"); + + // A structure used to track the popularity of a specific cluster + struct cluster_t { + vector_key_t centroid; + vector_key_t merged_into; + std::size_t popularity; + byte_t* vector; + }; + + auto centroid_id = [](cluster_t const& a, cluster_t const& b) { return a.centroid < b.centroid; }; + auto higher_popularity = [](cluster_t const& a, cluster_t const& b) { return a.popularity > b.popularity; }; + + std::atomic visited_members(0); + std::atomic computed_distances(0); + std::atomic atomic_error{nullptr}; + + using dynamic_allocator_traits_t = std::allocator_traits; + using clusters_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt clusters(queries_count); + if (!clusters) + return result.failed("Out of memory!"); + + map_to_clusters: + // Concurrently perform search until a certain depth + executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { + auto result = cluster(queries_begin[query_idx], level, thread_idx); + if (!result) { + atomic_error = result.error.release(); + return false; + } + + cluster_keys[query_idx] = result.cluster.member.key; + cluster_distances[query_idx] = result.cluster.distance; + + // Export in case we need to refine afterwards + clusters[query_idx].centroid = result.cluster.member.key; + clusters[query_idx].vector = vectors_lookup_[result.cluster.member.slot]; + clusters[query_idx].merged_into = free_key(); + clusters[query_idx].popularity = 1; + + visited_members += result.visited_members; + computed_distances += result.computed_distances; + return true; + }); + + if (atomic_error) + return result.failed(atomic_error.load()); + + // Now once we have identified the closest clusters, + // we can try reducing their quantity, refining + std::sort(clusters.begin(), clusters.end(), centroid_id); + + // Transform into run-length encoding, computing the number of unique clusters + std::size_t unique_clusters = 0; + { + std::size_t last_idx = 0; + for (std::size_t current_idx = 1; current_idx != clusters.size(); ++current_idx) { + if (clusters[last_idx].centroid == clusters[current_idx].centroid) { + clusters[last_idx].popularity++; + } else { + last_idx++; + clusters[last_idx] = clusters[current_idx]; + } + } + unique_clusters = last_idx + 1; + } + + // In some cases the queries may be co-located, all mapping into the same cluster on that + // level. In that case we refine the granularity and dive deeper into clusters: + if (unique_clusters < config.min_clusters && level > 1) { + level--; + goto map_to_clusters; + } + + std::sort(clusters.data(), clusters.data() + unique_clusters, higher_popularity); + + // If clusters are too numerous, merge the ones that are too close to each other. + std::size_t merge_cycles = 0; + merge_nearby_clusters: + if (unique_clusters > config.max_clusters) { + + cluster_t& merge_source = clusters[unique_clusters - 1]; + std::size_t merge_target_idx = 0; + distance_t merge_distance = std::numeric_limits::max(); + + for (std::size_t candidate_idx = 0; candidate_idx + 1 < unique_clusters; ++candidate_idx) { + distance_t distance = metric_(merge_source.vector, clusters[candidate_idx].vector); + if (distance < merge_distance) { + merge_distance = distance; + merge_target_idx = candidate_idx; + } + } + + merge_source.merged_into = clusters[merge_target_idx].centroid; + clusters[merge_target_idx].popularity += exchange(merge_source.popularity, 0); + + // The target object may have to be swapped a few times to get to optimal position. + while (merge_target_idx && + clusters[merge_target_idx - 1].popularity < clusters[merge_target_idx].popularity) + std::swap(clusters[merge_target_idx - 1], clusters[merge_target_idx]), --merge_target_idx; + + unique_clusters--; + merge_cycles++; + goto merge_nearby_clusters; + } + + // Replace evicted clusters + if (merge_cycles) { + // Sort dropped clusters by name to accelerate future lookups + auto clusters_end = clusters.data() + config.max_clusters + merge_cycles; + std::sort(clusters.data(), clusters_end, centroid_id); + + executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { + vector_key_t& cluster_key = cluster_keys[query_idx]; + distance_t& cluster_distance = cluster_distances[query_idx]; + + // Recursively trace replacements of that cluster + while (true) { + // To avoid implementing heterogeneous comparisons, lets wrap the `cluster_key` + cluster_t updated_cluster; + updated_cluster.centroid = cluster_key; + updated_cluster = *std::lower_bound(clusters.data(), clusters_end, updated_cluster, centroid_id); + if (updated_cluster.merged_into == free_key()) + break; + cluster_key = updated_cluster.merged_into; + } + + cluster_distance = distance_between(cluster_key, queries_begin[query_idx], thread_idx).mean; + return true; + }); + } + + result.computed_distances = computed_distances; + result.visited_members = visited_members; + + (void)progress; + return result; + } + + private: + struct thread_lock_t { + index_dense_gt const& parent; + std::size_t thread_id; + bool engaged; + + ~thread_lock_t() { + if (engaged) + parent.thread_unlock_(thread_id); + } + }; + + thread_lock_t thread_lock_(std::size_t thread_id) const { + if (thread_id != any_thread()) + return {*this, thread_id, false}; + + available_threads_mutex_.lock(); + thread_id = available_threads_.back(); + available_threads_.pop_back(); + available_threads_mutex_.unlock(); + return {*this, thread_id, true}; + } + + void thread_unlock_(std::size_t thread_id) const { + available_threads_mutex_.lock(); + available_threads_.push_back(thread_id); + available_threads_mutex_.unlock(); + } + + template + add_result_t add_( // + vector_key_t key, scalar_at const* vector, // + std::size_t thread, bool force_vector_copy, cast_t const& cast) { + + if (!multi() && contains(key)) + return add_result_t{}.failed("Duplicate keys not allowed in high-level wrappers"); + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + bool copy_vector = !config_.exclude_vectors || force_vector_copy; + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data, copy_vector = true; + } + + // Check if there are some removed entries, whose nodes we can reuse + compressed_slot_t free_slot = default_free_value(); + { + std::unique_lock lock(free_keys_mutex_); + free_keys_.try_pop(free_slot); + } + + // Perform the insertion or the update + bool reuse_node = free_slot != default_free_value(); + auto on_success = [&](member_ref_t member) { + unique_lock_t slot_lock(slot_lookup_mutex_); + slot_lookup_.try_emplace(key_and_slot_t{key, static_cast(member.slot)}); + if (copy_vector) { + if (!reuse_node) + vectors_lookup_[member.slot] = vectors_tape_allocator_.allocate(metric_.bytes_per_vector()); + std::memcpy(vectors_lookup_[member.slot], vector_data, metric_.bytes_per_vector()); + } else + vectors_lookup_[member.slot] = (byte_t*)vector_data; + }; + + index_update_config_t update_config; + update_config.thread = lock.thread_id; + update_config.expansion = config_.expansion_add; + + metric_proxy_t metric{*this}; + return reuse_node // + ? typed_->update(typed_->iterator_at(free_slot), key, vector_data, metric, update_config, on_success) + : typed_->add(key, vector_data, metric, update_config, on_success); + } + + template + search_result_t search_(scalar_at const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread, + bool exact, cast_t const& cast) const { + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + index_search_config_t search_config; + search_config.thread = lock.thread_id; + search_config.expansion = config_.expansion_search; + search_config.exact = exact; + + if (std::is_same::type, dummy_predicate_t>::value) { + auto allow = [free_key_ = this->free_key_](member_cref_t const& member) noexcept { + return member.key != free_key_; + }; + return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + } else { + auto allow = [free_key_ = this->free_key_, &predicate](member_cref_t const& member) noexcept { + return member.key != free_key_ && predicate(member.key); + }; + return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + } + } + + template + cluster_result_t cluster_( // + scalar_at const* vector, std::size_t level, // + std::size_t thread, cast_t const& cast) const { + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + index_cluster_config_t cluster_config; + cluster_config.thread = lock.thread_id; + cluster_config.expansion = config_.expansion_search; + + auto allow = [free_key_ = this->free_key_](member_cref_t const& member) noexcept { + return member.key != free_key_; + }; + return typed_->cluster(vector_data, level, metric_proxy_t{*this}, cluster_config, allow); + } + + template + aggregated_distances_t distance_between_( // + vector_key_t key, scalar_at const* vector, // + std::size_t thread, cast_t const& cast) const { + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + // Check if such `key` is even present. + shared_lock_t slots_lock(slot_lookup_mutex_); + auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + aggregated_distances_t result; + if (key_range.first == key_range.second) + return result; + + result.min = std::numeric_limits::max(); + result.max = std::numeric_limits::min(); + result.mean = 0; + result.count = 0; + + while (key_range.first != key_range.second) { + key_and_slot_t key_and_slot = *key_range.first; + byte_t const* a_vector = vectors_lookup_[key_and_slot.slot]; + byte_t const* b_vector = vector_data; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean += a_b_distance; + result.min = (std::min)(result.min, a_b_distance); + result.max = (std::max)(result.max, a_b_distance); + result.count++; + + // + ++key_range.first; + } + + result.mean /= result.count; + return result; + } + + void reindex_keys_() { + + // Estimate number of entries first + std::size_t count_total = typed_->size(); + std::size_t count_removed = 0; + for (std::size_t i = 0; i != count_total; ++i) { + member_cref_t member = typed_->at(i); + count_removed += member.key == free_key_; + } + + if (!count_removed && !config_.enable_key_lookups) + return; + + // Pull entries from the underlying `typed_` into either + // into `slot_lookup_`, or `free_keys_` if they are unused. + unique_lock_t lock(slot_lookup_mutex_); + slot_lookup_.clear(); + if (config_.enable_key_lookups) + slot_lookup_.reserve(count_total - count_removed); + free_keys_.clear(); + free_keys_.reserve(count_removed); + for (std::size_t i = 0; i != typed_->size(); ++i) { + member_cref_t member = typed_->at(i); + if (member.key == free_key_) + free_keys_.push(static_cast(i)); + else if (config_.enable_key_lookups) + slot_lookup_.try_emplace(key_and_slot_t{vector_key_t(member.key), static_cast(i)}); + } + } + + template + std::size_t get_(vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit, cast_t const& cast) const { + + if (!multi()) { + compressed_slot_t slot; + // Find the matching ID + { + shared_lock_t lock(slot_lookup_mutex_); + auto it = slot_lookup_.find(key_and_slot_t::any_slot(key)); + if (it == slot_lookup_.end()) + return false; + slot = (*it).slot; + } + // Export the entry + byte_t const* punned_vector = reinterpret_cast(vectors_lookup_[slot]); + bool casted = cast(punned_vector, dimensions(), (byte_t*)reconstructed); + if (!casted) + std::memcpy(reconstructed, punned_vector, metric_.bytes_per_vector()); + return true; + } else { + shared_lock_t lock(slot_lookup_mutex_); + auto equal_range_pair = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + std::size_t count_exported = 0; + for (auto begin = equal_range_pair.first; + begin != equal_range_pair.second && count_exported != vectors_limit; ++begin, ++count_exported) { + // + compressed_slot_t slot = (*begin).slot; + byte_t const* punned_vector = reinterpret_cast(vectors_lookup_[slot]); + byte_t* reconstructed_vector = (byte_t*)reconstructed + metric_.bytes_per_vector() * count_exported; + bool casted = cast(punned_vector, dimensions(), reconstructed_vector); + if (!casted) + std::memcpy(reconstructed_vector, punned_vector, metric_.bytes_per_vector()); + } + return count_exported; + } + } + + template static casts_t make_casts_() { + casts_t result; + + result.from_b1x8 = cast_gt{}; + result.from_i8 = cast_gt{}; + result.from_f16 = cast_gt{}; + result.from_f32 = cast_gt{}; + result.from_f64 = cast_gt{}; + + result.to_b1x8 = cast_gt{}; + result.to_i8 = cast_gt{}; + result.to_f16 = cast_gt{}; + result.to_f32 = cast_gt{}; + result.to_f64 = cast_gt{}; + + return result; + } + + static casts_t make_casts_(scalar_kind_t scalar_kind) { + switch (scalar_kind) { + case scalar_kind_t::f64_k: return make_casts_(); + case scalar_kind_t::f32_k: return make_casts_(); + case scalar_kind_t::f16_k: return make_casts_(); + case scalar_kind_t::i8_k: return make_casts_(); + case scalar_kind_t::b1x8_k: return make_casts_(); + default: return {}; + } + } +}; + +using index_dense_t = index_dense_gt<>; +using index_dense_big_t = index_dense_gt; + +/** + * @brief Adapts the Male-Optimal Stable Marriage algorithm for unequal sets + * to perform fast one-to-one matching between two large collections + * of vectors, using approximate nearest neighbors search. + * + * @param[inout] man_to_woman Container to map ::first keys to ::second. + * @param[inout] woman_to_man Container to map ::second keys to ::first. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ +template < // + + typename men_key_at, // + typename women_key_at, // + typename men_slot_at, // + typename women_slot_at, // + + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > +static join_result_t join( // + index_dense_gt const& men, // + index_dense_gt const& women, // + + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) noexcept { + + return men.join( // + women, config, // + std::forward(woman_to_man), // + std::forward(man_to_woman), // + std::forward(executor), // + std::forward(progress)); +} + +} // namespace usearch +} // namespace unum diff --git a/src/inline-thirdparty/usearch/usearch/index_plugins.hpp b/src/inline-thirdparty/usearch/usearch/index_plugins.hpp new file mode 100644 index 000000000000..a5539a3a6b04 --- /dev/null +++ b/src/inline-thirdparty/usearch/usearch/index_plugins.hpp @@ -0,0 +1,2317 @@ +#pragma once +#define __STDC_WANT_IEC_60559_TYPES_EXT__ +#include // `_Float16` +#include // `aligned_alloc` + +#include // `std::strncmp` +#include // `std::iota` +#include // `std::thread` +#include // `std::vector` + +#include // `std::atomic` +#include // `std::thread` + +#include // `expected_gt` and macros + +#if !defined(USEARCH_USE_OPENMP) +#define USEARCH_USE_OPENMP 0 +#endif + +#if USEARCH_USE_OPENMP +#include // `omp_get_num_threads()` +#endif + +#if defined(USEARCH_DEFINED_LINUX) +#include // `getauxval()` +#endif + +#if !defined(USEARCH_USE_FP16LIB) +#if defined(__AVX512F__) +#define USEARCH_USE_FP16LIB 0 +#elif defined(USEARCH_DEFINED_ARM) +#include // `__fp16` +#define USEARCH_USE_FP16LIB 0 +#else +#define USEARCH_USE_FP16LIB 1 +#endif +#endif + +#if USEARCH_USE_FP16LIB +#include +#endif + +#if !defined(USEARCH_USE_SIMSIMD) +#define USEARCH_USE_SIMSIMD 0 +#endif + +#if USEARCH_USE_SIMSIMD +// Propagate the `f16` settings +#define SIMSIMD_NATIVE_F16 !USEARCH_USE_FP16LIB +#define SIMSIMD_DYNAMIC_DISPATCH 0 +// No problem, if some of the functions are unused or undefined +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma warning(push) +#pragma warning(disable : 4101) +#include +#pragma warning(pop) +#pragma GCC diagnostic pop +#endif + +namespace unum { +namespace usearch { + +using u40_t = uint40_t; +enum b1x8_t : unsigned char {}; + +struct uuid_t { + std::uint8_t octets[16]; +}; + +class f16_bits_t; +class i8_converted_t; + +#if !USEARCH_USE_FP16LIB +#if defined(USEARCH_DEFINED_ARM) +using f16_native_t = __fp16; +#else +using f16_native_t = _Float16; +#endif +using f16_t = f16_native_t; +#else +using f16_native_t = void; +using f16_t = f16_bits_t; +#endif + +using f64_t = double; +using f32_t = float; + +using u64_t = std::uint64_t; +using u32_t = std::uint32_t; +using u16_t = std::uint16_t; +using u8_t = std::uint8_t; + +using i64_t = std::int64_t; +using i32_t = std::int32_t; +using i16_t = std::int16_t; +using i8_t = std::int8_t; + +enum class metric_kind_t : std::uint8_t { + unknown_k = 0, + // Classics: + ip_k = 'i', + cos_k = 'c', + l2sq_k = 'e', + + // Custom: + pearson_k = 'p', + haversine_k = 'h', + divergence_k = 'd', + + // Sets: + jaccard_k = 'j', + hamming_k = 'b', + tanimoto_k = 't', + sorensen_k = 's', +}; + +enum class scalar_kind_t : std::uint8_t { + unknown_k = 0, + // Custom: + b1x8_k = 1, + u40_k = 2, + uuid_k = 3, + // Common: + f64_k = 10, + f32_k = 11, + f16_k = 12, + f8_k = 13, + // Common Integral: + u64_k = 14, + u32_k = 15, + u16_k = 16, + u8_k = 17, + i64_k = 20, + i32_k = 21, + i16_k = 22, + i8_k = 23, +}; + +enum class prefetching_kind_t { + none_k, + cpu_k, + io_uring_k, +}; + +template scalar_kind_t scalar_kind() noexcept { + if (std::is_same()) + return scalar_kind_t::b1x8_k; + if (std::is_same()) + return scalar_kind_t::u40_k; + if (std::is_same()) + return scalar_kind_t::uuid_k; + if (std::is_same()) + return scalar_kind_t::f64_k; + if (std::is_same()) + return scalar_kind_t::f32_k; + if (std::is_same()) + return scalar_kind_t::f16_k; + if (std::is_same()) + return scalar_kind_t::i8_k; + if (std::is_same()) + return scalar_kind_t::u64_k; + if (std::is_same()) + return scalar_kind_t::u32_k; + if (std::is_same()) + return scalar_kind_t::u16_k; + if (std::is_same()) + return scalar_kind_t::u8_k; + if (std::is_same()) + return scalar_kind_t::i64_k; + if (std::is_same()) + return scalar_kind_t::i32_k; + if (std::is_same()) + return scalar_kind_t::i16_k; + if (std::is_same()) + return scalar_kind_t::i8_k; + return scalar_kind_t::unknown_k; +} + +template at angle_to_radians(at angle) noexcept { return angle * at(3.14159265358979323846) / at(180); } + +template at square(at value) noexcept { return value * value; } + +template inline at clamp(at v, at lo, at hi, compare_at comp) noexcept { + return comp(v, lo) ? lo : comp(hi, v) ? hi : v; +} +template inline at clamp(at v, at lo, at hi) noexcept { + return usearch::clamp(v, lo, hi, std::less{}); +} + +inline bool str_equals(char const* begin, std::size_t len, char const* other_begin) noexcept { + std::size_t other_len = std::strlen(other_begin); + return len == other_len && std::strncmp(begin, other_begin, len) == 0; +} + +inline std::size_t bits_per_scalar(scalar_kind_t scalar_kind) noexcept { + switch (scalar_kind) { + case scalar_kind_t::f64_k: return 64; + case scalar_kind_t::f32_k: return 32; + case scalar_kind_t::f16_k: return 16; + case scalar_kind_t::i8_k: return 8; + case scalar_kind_t::b1x8_k: return 1; + default: return 0; + } +} + +inline std::size_t bits_per_scalar_word(scalar_kind_t scalar_kind) noexcept { + switch (scalar_kind) { + case scalar_kind_t::f64_k: return 64; + case scalar_kind_t::f32_k: return 32; + case scalar_kind_t::f16_k: return 16; + case scalar_kind_t::i8_k: return 8; + case scalar_kind_t::b1x8_k: return 8; + default: return 0; + } +} + +inline char const* scalar_kind_name(scalar_kind_t scalar_kind) noexcept { + switch (scalar_kind) { + case scalar_kind_t::f32_k: return "f32"; + case scalar_kind_t::f16_k: return "f16"; + case scalar_kind_t::f64_k: return "f64"; + case scalar_kind_t::i8_k: return "i8"; + case scalar_kind_t::b1x8_k: return "b1x8"; + default: return ""; + } +} + +inline char const* metric_kind_name(metric_kind_t metric) noexcept { + switch (metric) { + case metric_kind_t::unknown_k: return "unknown"; + case metric_kind_t::ip_k: return "ip"; + case metric_kind_t::cos_k: return "cos"; + case metric_kind_t::l2sq_k: return "l2sq"; + case metric_kind_t::pearson_k: return "pearson"; + case metric_kind_t::haversine_k: return "haversine"; + case metric_kind_t::divergence_k: return "divergence"; + case metric_kind_t::jaccard_k: return "jaccard"; + case metric_kind_t::hamming_k: return "hamming"; + case metric_kind_t::tanimoto_k: return "tanimoto"; + case metric_kind_t::sorensen_k: return "sorensen"; + } + return ""; +} +inline expected_gt scalar_kind_from_name(char const* name, std::size_t len) { + expected_gt parsed; + if (str_equals(name, len, "f32")) + parsed.result = scalar_kind_t::f32_k; + else if (str_equals(name, len, "f64")) + parsed.result = scalar_kind_t::f64_k; + else if (str_equals(name, len, "f16")) + parsed.result = scalar_kind_t::f16_k; + else if (str_equals(name, len, "i8")) + parsed.result = scalar_kind_t::i8_k; + else + parsed.failed("Unknown type, choose: f32, f16, f64, i8"); + return parsed; +} + +inline expected_gt scalar_kind_from_name(char const* name) { + return scalar_kind_from_name(name, std::strlen(name)); +} + +inline expected_gt metric_from_name(char const* name, std::size_t len) { + expected_gt parsed; + if (str_equals(name, len, "l2sq") || str_equals(name, len, "euclidean_sq")) { + parsed.result = metric_kind_t::l2sq_k; + } else if (str_equals(name, len, "ip") || str_equals(name, len, "inner") || str_equals(name, len, "dot")) { + parsed.result = metric_kind_t::ip_k; + } else if (str_equals(name, len, "cos") || str_equals(name, len, "angular")) { + parsed.result = metric_kind_t::cos_k; + } else if (str_equals(name, len, "haversine")) { + parsed.result = metric_kind_t::haversine_k; + } else if (str_equals(name, len, "divergence")) { + parsed.result = metric_kind_t::divergence_k; + } else if (str_equals(name, len, "pearson")) { + parsed.result = metric_kind_t::pearson_k; + } else if (str_equals(name, len, "hamming")) { + parsed.result = metric_kind_t::hamming_k; + } else if (str_equals(name, len, "tanimoto")) { + parsed.result = metric_kind_t::tanimoto_k; + } else if (str_equals(name, len, "sorensen")) { + parsed.result = metric_kind_t::sorensen_k; + } else + parsed.failed("Unknown distance, choose: l2sq, ip, cos, haversine, divergence, jaccard, pearson, hamming, " + "tanimoto, sorensen"); + return parsed; +} + +inline expected_gt metric_from_name(char const* name) { + return metric_from_name(name, std::strlen(name)); +} + +inline float f16_to_f32(std::uint16_t u16) noexcept { +#if !USEARCH_USE_FP16LIB + f16_native_t f16; + std::memcpy(&f16, &u16, sizeof(std::uint16_t)); + return float(f16); +#else + return fp16_ieee_to_fp32_value(u16); +#endif +} + +inline std::uint16_t f32_to_f16(float f32) noexcept { +#if !USEARCH_USE_FP16LIB + f16_native_t f16 = f16_native_t(f32); + std::uint16_t u16; + std::memcpy(&u16, &f16, sizeof(std::uint16_t)); + return u16; +#else + return fp16_ieee_from_fp32_value(f32); +#endif +} + +/** + * @brief Numeric type for the IEEE 754 half-precision floating point. + * If hardware support isn't available, falls back to a hardware + * agnostic in-software implementation. + */ +class f16_bits_t { + std::uint16_t uint16_{}; + + public: + inline f16_bits_t() noexcept : uint16_(0) {} + inline f16_bits_t(f16_bits_t&&) = default; + inline f16_bits_t& operator=(f16_bits_t&&) = default; + inline f16_bits_t(f16_bits_t const&) = default; + inline f16_bits_t& operator=(f16_bits_t const&) = default; + + inline operator float() const noexcept { return f16_to_f32(uint16_); } + inline explicit operator bool() const noexcept { return f16_to_f32(uint16_) > 0.5f; } + + inline f16_bits_t(i8_converted_t) noexcept; + inline f16_bits_t(bool v) noexcept : uint16_(f32_to_f16(v)) {} + inline f16_bits_t(float v) noexcept : uint16_(f32_to_f16(v)) {} + inline f16_bits_t(double v) noexcept : uint16_(f32_to_f16(static_cast(v))) {} + + inline f16_bits_t operator+(f16_bits_t other) const noexcept { return {float(*this) + float(other)}; } + inline f16_bits_t operator-(f16_bits_t other) const noexcept { return {float(*this) - float(other)}; } + inline f16_bits_t operator*(f16_bits_t other) const noexcept { return {float(*this) * float(other)}; } + inline f16_bits_t operator/(f16_bits_t other) const noexcept { return {float(*this) / float(other)}; } + inline f16_bits_t operator+(float other) const noexcept { return {float(*this) + other}; } + inline f16_bits_t operator-(float other) const noexcept { return {float(*this) - other}; } + inline f16_bits_t operator*(float other) const noexcept { return {float(*this) * other}; } + inline f16_bits_t operator/(float other) const noexcept { return {float(*this) / other}; } + inline f16_bits_t operator+(double other) const noexcept { return {float(*this) + other}; } + inline f16_bits_t operator-(double other) const noexcept { return {float(*this) - other}; } + inline f16_bits_t operator*(double other) const noexcept { return {float(*this) * other}; } + inline f16_bits_t operator/(double other) const noexcept { return {float(*this) / other}; } + + inline f16_bits_t& operator+=(float v) noexcept { + uint16_ = f32_to_f16(v + f16_to_f32(uint16_)); + return *this; + } + + inline f16_bits_t& operator-=(float v) noexcept { + uint16_ = f32_to_f16(v - f16_to_f32(uint16_)); + return *this; + } + + inline f16_bits_t& operator*=(float v) noexcept { + uint16_ = f32_to_f16(v * f16_to_f32(uint16_)); + return *this; + } + + inline f16_bits_t& operator/=(float v) noexcept { + uint16_ = f32_to_f16(v / f16_to_f32(uint16_)); + return *this; + } +}; + +/** + * @brief An STL-based executor or a "thread-pool" for parallel execution. + * Isn't efficient for small batches, as it recreates the threads on every call. + */ +class executor_stl_t { + std::size_t threads_count_{}; + + struct jthread_t { + std::thread native_; + + jthread_t() = default; + jthread_t(jthread_t&&) = default; + jthread_t(jthread_t const&) = delete; + template jthread_t(callable_at&& func) : native_([=]() { func(); }) {} + + ~jthread_t() { + if (native_.joinable()) + native_.join(); + } + }; + + public: + /** + * @param threads_count The number of threads to be used for parallel execution. + */ + executor_stl_t(std::size_t threads_count = 0) noexcept + : threads_count_(threads_count ? threads_count : std::thread::hardware_concurrency()) {} + + /** + * @return Maximum number of threads available to the executor. + */ + std::size_t size() const noexcept { return threads_count_; } + + /** + * @brief Executes a fixed number of tasks using the specified thread-aware function. + * @param tasks The total number of tasks to be executed. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { + std::vector threads_pool; + std::size_t tasks_per_thread = tasks; + std::size_t threads_count = (std::min)(threads_count_, tasks); + if (threads_count > 1) { + tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); + for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { + threads_pool.emplace_back([=]() { + for (std::size_t task_idx = thread_idx * tasks_per_thread; + task_idx < (std::min)(tasks, thread_idx * tasks_per_thread + tasks_per_thread); ++task_idx) + thread_aware_function(thread_idx, task_idx); + }); + } + } + for (std::size_t task_idx = 0; task_idx < (std::min)(tasks, tasks_per_thread); ++task_idx) + thread_aware_function(0, task_idx); + } + + /** + * @brief Executes limited number of tasks using the specified thread-aware function. + * @param tasks The upper bound on the number of tasks. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { + std::vector threads_pool; + std::size_t tasks_per_thread = tasks; + std::size_t threads_count = (std::min)(threads_count_, tasks); + std::atomic_bool stop{false}; + if (threads_count > 1) { + tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); + for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { + threads_pool.emplace_back([=, &stop]() { + for (std::size_t task_idx = thread_idx * tasks_per_thread; + task_idx < (std::min)(tasks, thread_idx * tasks_per_thread + tasks_per_thread) && + !stop.load(std::memory_order_relaxed); + ++task_idx) + if (!thread_aware_function(thread_idx, task_idx)) + stop.store(true, std::memory_order_relaxed); + }); + } + } + for (std::size_t task_idx = 0; + task_idx < (std::min)(tasks, tasks_per_thread) && !stop.load(std::memory_order_relaxed); ++task_idx) + if (!thread_aware_function(0, task_idx)) + stop.store(true, std::memory_order_relaxed); + } + + /** + * @brief Saturates every available thread with the given workload, until they finish. + * @param thread_aware_function The thread-aware function to be called for each thread index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { + if (threads_count_ == 1) + return thread_aware_function(0); + std::vector threads_pool; + for (std::size_t thread_idx = 1; thread_idx < threads_count_; ++thread_idx) + threads_pool.emplace_back([=]() { thread_aware_function(thread_idx); }); + thread_aware_function(0); + } +}; + +#if USEARCH_USE_OPENMP + +/** + * @brief An OpenMP-based executor or a "thread-pool" for parallel execution. + * Is the preferred implementation, when available, and maximum performance is needed. + */ +class executor_openmp_t { + public: + /** + * @param threads_count The number of threads to be used for parallel execution. + */ + executor_openmp_t(std::size_t threads_count = 0) noexcept { + omp_set_num_threads(static_cast(threads_count ? threads_count : std::thread::hardware_concurrency())); + } + + /** + * @return Maximum number of threads available to the executor. + */ + std::size_t size() const noexcept { return omp_get_max_threads(); } + + /** + * @brief Executes tasks in bulk using the specified thread-aware function. + * @param tasks The total number of tasks to be executed. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { +#pragma omp parallel for schedule(dynamic, 1) + for (std::size_t i = 0; i != tasks; ++i) { + thread_aware_function(omp_get_thread_num(), i); + } + } + + /** + * @brief Executes tasks in bulk using the specified thread-aware function. + * @param tasks The total number of tasks to be executed. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { + // OpenMP cancellation points are not yet available on most platforms, and require + // the `OMP_CANCELLATION` environment variable to be set. + // http://jakascorner.com/blog/2016/08/omp-cancel.html + // if (omp_get_cancellation()) { + // #pragma omp parallel for schedule(dynamic, 1) + // for (std::size_t i = 0; i != tasks; ++i) { + // #pragma omp cancellation point for + // if (!thread_aware_function(omp_get_thread_num(), i)) { + // #pragma omp cancel for + // } + // } + // } + std::atomic_bool stop{false}; +#pragma omp parallel for schedule(dynamic, 1) shared(stop) + for (std::size_t i = 0; i != tasks; ++i) { + if (!stop.load(std::memory_order_relaxed) && !thread_aware_function(omp_get_thread_num(), i)) + stop.store(true, std::memory_order_relaxed); + } + } + + /** + * @brief Saturates every available thread with the given workload, until they finish. + * @param thread_aware_function The thread-aware function to be called for each thread index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { +#pragma omp parallel + { thread_aware_function(omp_get_thread_num()); } + } +}; + +using executor_default_t = executor_openmp_t; + +#else + +using executor_default_t = executor_stl_t; + +#endif + +/** + * @brief Uses OS-specific APIs for aligned memory allocations. + */ +template // +class aligned_allocator_gt { + public: + using value_type = element_at; + using size_type = std::size_t; + using pointer = element_at*; + using const_pointer = element_at const*; + template struct rebind { + using other = aligned_allocator_gt; + }; + + constexpr std::size_t alignment() const { return alignment_ak; } + + pointer allocate(size_type length) const { + std::size_t length_bytes = alignment_ak * divide_round_up(length * sizeof(value_type)); + std::size_t alignment = alignment_ak; + // void* result = nullptr; + // int status = posix_memalign(&result, alignment, length_bytes); + // return status == 0 ? (pointer)result : nullptr; +#if defined(USEARCH_DEFINED_WINDOWS) + return (pointer)_aligned_malloc(length_bytes, alignment); +#else + return (pointer)aligned_alloc(alignment, length_bytes); +#endif + } + + void deallocate(pointer begin, size_type) const { +#if defined(USEARCH_DEFINED_WINDOWS) + _aligned_free(begin); +#else + free(begin); +#endif + } +}; + +using aligned_allocator_t = aligned_allocator_gt<>; + +class page_allocator_t { + public: + static constexpr std::size_t page_size() { return 4096; } + + /** + * @brief Allocates an @b uninitialized block of memory of the specified size. + * @param count_bytes The number of bytes to allocate. + * @return A pointer to the allocated memory block, or `nullptr` if allocation fails. + */ + byte_t* allocate(std::size_t count_bytes) const noexcept { + count_bytes = divide_round_up(count_bytes, page_size()) * page_size(); +#if defined(USEARCH_DEFINED_WINDOWS) + return (byte_t*)(::VirtualAlloc(NULL, count_bytes, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE)); +#else + return (byte_t*)mmap(NULL, count_bytes, PROT_WRITE | PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0); +#endif + } + + void deallocate(byte_t* page_pointer, std::size_t count_bytes) const noexcept { +#if defined(USEARCH_DEFINED_WINDOWS) + ::VirtualFree(page_pointer, 0, MEM_RELEASE); +#else + count_bytes = divide_round_up(count_bytes, page_size()) * page_size(); + munmap(page_pointer, count_bytes); +#endif + } +}; + +/** + * @brief Memory-mapping allocator designed for "alloc many, free at once" usage patterns. + * @b Thread-safe, @b except constructors and destructors. + * + * Using this memory allocator won't affect your overall speed much, as that is not the bottleneck. + * However, it can drastically improve memory usage especially for huge indexes of small vectors. + */ +template class memory_mapping_allocator_gt { + + static constexpr std::size_t min_capacity() { return 1024 * 1024 * 4; } + static constexpr std::size_t capacity_multiplier() { return 2; } + static constexpr std::size_t head_size() { + /// Pointer to the the previous arena and the size of the current one. + return divide_round_up(sizeof(byte_t*) + sizeof(std::size_t)) * alignment_ak; + } + + std::mutex mutex_; + byte_t* last_arena_ = nullptr; + std::size_t last_usage_ = head_size(); + std::size_t last_capacity_ = min_capacity(); + std::size_t wasted_space_ = 0; + + public: + using value_type = byte_t; + using size_type = std::size_t; + using pointer = byte_t*; + using const_pointer = byte_t const*; + + memory_mapping_allocator_gt() = default; + memory_mapping_allocator_gt(memory_mapping_allocator_gt&& other) noexcept + : last_arena_(exchange(other.last_arena_, nullptr)), last_usage_(exchange(other.last_usage_, 0)), + last_capacity_(exchange(other.last_capacity_, 0)), wasted_space_(exchange(other.wasted_space_, 0)) {} + + memory_mapping_allocator_gt& operator=(memory_mapping_allocator_gt&& other) noexcept { + std::swap(last_arena_, other.last_arena_); + std::swap(last_usage_, other.last_usage_); + std::swap(last_capacity_, other.last_capacity_); + std::swap(wasted_space_, other.wasted_space_); + return *this; + } + + ~memory_mapping_allocator_gt() noexcept { reset(); } + + /** + * @brief Discards all previously allocated memory buffers. + */ + void reset() noexcept { + byte_t* last_arena = last_arena_; + while (last_arena) { + byte_t* previous_arena = nullptr; + std::memcpy(&previous_arena, last_arena, sizeof(byte_t*)); + std::size_t last_cap = 0; + std::memcpy(&last_cap, last_arena + sizeof(byte_t*), sizeof(std::size_t)); + page_allocator_t{}.deallocate(last_arena, last_cap); + last_arena = previous_arena; + } + + // Clear the references: + last_arena_ = nullptr; + last_usage_ = head_size(); + last_capacity_ = min_capacity(); + wasted_space_ = 0; + } + + /** + * @brief Copy constructor. + * @note This is a no-op copy constructor since the allocator is not copyable. + */ + memory_mapping_allocator_gt(memory_mapping_allocator_gt const&) noexcept {} + + /** + * @brief Copy assignment operator. + * @note This is a no-op copy assignment operator since the allocator is not copyable. + * @return Reference to the allocator after the assignment. + */ + memory_mapping_allocator_gt& operator=(memory_mapping_allocator_gt const&) noexcept { + reset(); + return *this; + } + + /** + * @brief Allocates an @b uninitialized block of memory of the specified size. + * @param count_bytes The number of bytes to allocate. + * @return A pointer to the allocated memory block, or `nullptr` if allocation fails. + */ + inline byte_t* allocate(std::size_t count_bytes) noexcept { + std::size_t extended_bytes = divide_round_up(count_bytes) * alignment_ak; + std::unique_lock lock(mutex_); + if (!last_arena_ || (last_usage_ + extended_bytes >= last_capacity_)) { + std::size_t new_cap = (std::max)(last_capacity_, ceil2(extended_bytes)) * capacity_multiplier(); + byte_t* new_arena = page_allocator_t{}.allocate(new_cap); + if (!new_arena) + return nullptr; + std::memcpy(new_arena, &last_arena_, sizeof(byte_t*)); + std::memcpy(new_arena + sizeof(byte_t*), &new_cap, sizeof(std::size_t)); + + wasted_space_ += total_reserved(); + last_arena_ = new_arena; + last_capacity_ = new_cap; + last_usage_ = head_size(); + } + + wasted_space_ += extended_bytes - count_bytes; + return last_arena_ + exchange(last_usage_, last_usage_ + extended_bytes); + } + + /** + * @brief Returns the amount of memory used by the allocator across all arenas. + * @return The amount of space in bytes. + */ + std::size_t total_allocated() const noexcept { + if (!last_arena_) + return 0; + std::size_t total_used = 0; + std::size_t last_capacity = last_capacity_; + do { + total_used += last_capacity; + last_capacity /= capacity_multiplier(); + } while (last_capacity >= min_capacity()); + return total_used; + } + + /** + * @brief Returns the amount of wasted space due to alignment. + * @return The amount of wasted space in bytes. + */ + std::size_t total_wasted() const noexcept { return wasted_space_; } + + /** + * @brief Returns the amount of remaining memory already reserved but not yet used. + * @return The amount of reserved memory in bytes. + */ + std::size_t total_reserved() const noexcept { return last_arena_ ? last_capacity_ - last_usage_ : 0; } + + /** + * @warning The very first memory de-allocation discards all the arenas! + */ + void deallocate(byte_t* = nullptr, std::size_t = 0) noexcept { reset(); } +}; + +using memory_mapping_allocator_t = memory_mapping_allocator_gt<>; + +/** + * @brief C++11 userspace implementation of an oversimplified `std::shared_mutex`, + * that assumes rare interleaving of shared and unique locks. It's not fair, + * but requires only a single 32-bit atomic integer to work. + */ +class unfair_shared_mutex_t { + /** Any positive integer describes the number of concurrent readers */ + enum state_t : std::int32_t { + idle_k = 0, + writing_k = -1, + }; + std::atomic state_{idle_k}; + + public: + inline void lock() noexcept { + std::int32_t raw; + relock: + raw = idle_k; + if (!state_.compare_exchange_weak(raw, writing_k, std::memory_order_acquire, std::memory_order_relaxed)) { + std::this_thread::yield(); + goto relock; + } + } + + inline void unlock() noexcept { state_.store(idle_k, std::memory_order_release); } + + inline void lock_shared() noexcept { + std::int32_t raw; + relock_shared: + raw = state_.load(std::memory_order_acquire); + // Spin while it's uniquely locked + if (raw == writing_k) { + std::this_thread::yield(); + goto relock_shared; + } + // Try incrementing the counter + if (!state_.compare_exchange_weak(raw, raw + 1, std::memory_order_acquire, std::memory_order_relaxed)) { + std::this_thread::yield(); + goto relock_shared; + } + } + + inline void unlock_shared() noexcept { state_.fetch_sub(1, std::memory_order_release); } + + /** + * @brief Try upgrades the current `lock_shared()` to a unique `lock()` state. + */ + inline bool try_escalate() noexcept { + std::int32_t one_read = 1; + return state_.compare_exchange_weak(one_read, writing_k, std::memory_order_acquire, std::memory_order_relaxed); + } + + /** + * @brief Escalates current lock potentially loosing control in the middle. + * It's a shortcut for `try_escalate`-`unlock_shared`-`lock` trio. + */ + inline void unsafe_escalate() noexcept { + if (!try_escalate()) { + unlock_shared(); + lock(); + } + } + + /** + * @brief Upgrades the current `lock_shared()` to a unique `lock()` state. + */ + inline void escalate() noexcept { + while (!try_escalate()) + std::this_thread::yield(); + } + + /** + * @brief De-escalation of a previously escalated state. + */ + inline void de_escalate() noexcept { + std::int32_t one_read = 1; + state_.store(one_read, std::memory_order_release); + } +}; + +template class shared_lock_gt { + mutex_at& mutex_; + + public: + inline explicit shared_lock_gt(mutex_at& m) noexcept : mutex_(m) { mutex_.lock_shared(); } + inline ~shared_lock_gt() noexcept { mutex_.unlock_shared(); } +}; + +/** + * @brief Utility class used to cast arrays of one scalar type to another, + * avoiding unnecessary conversions. + */ +template struct cast_gt { + inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { + from_scalar_at const* typed_input = reinterpret_cast(input); + to_scalar_at* typed_output = reinterpret_cast(output); + auto converter = [](from_scalar_at from) { return to_scalar_at(from); }; + std::transform(typed_input, typed_input + dim, typed_output, converter); + return true; + } +}; + +template <> struct cast_gt { + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } +}; + +template <> struct cast_gt { + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } +}; + +template <> struct cast_gt { + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } +}; + +template <> struct cast_gt { + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } +}; + +template <> struct cast_gt { + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } +}; + +template struct cast_gt { + inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { + from_scalar_at const* typed_input = reinterpret_cast(input); + unsigned char* typed_output = reinterpret_cast(output); + for (std::size_t i = 0; i != dim; ++i) + // Converting from scalar types to boolean isn't trivial and depends on the type. + // The most common case is to consider all positive values as `true` and all others as `false`. + // - `bool(0.00001f)` converts to 1 + // - `bool(-0.00001f)` converts to 1 + // - `bool(0)` converts to 0 + // - `bool(-0)` converts to 0 + // - `bool(std::numeric_limits::infinity())` converts to 1 + // - `bool(std::numeric_limits::epsilon())` converts to 1 + // - `bool(std::numeric_limits::signaling_NaN())` converts to 1 + // - `bool(std::numeric_limits::denorm_min())` converts to 1 + typed_output[i / CHAR_BIT] |= bool(typed_input[i] > 0) ? (128 >> (i & (CHAR_BIT - 1))) : 0; + return true; + } +}; + +template struct cast_gt { + inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { + unsigned char const* typed_input = reinterpret_cast(input); + to_scalar_at* typed_output = reinterpret_cast(output); + for (std::size_t i = 0; i != dim; ++i) + // We can't entirely reconstruct the original scalar type from a boolean. + // The simplest variant would be to map set bits to ones, and unset bits to zeros. + typed_output[i] = bool(typed_input[i / CHAR_BIT] & (128 >> (i & (CHAR_BIT - 1)))); + return true; + } +}; + +/** + * @brief Numeric type for uniformly-distributed floating point + * values within [-1,1] range, quantized to integers [-100,100]. + */ +class i8_converted_t { + std::int8_t int8_{}; + + public: + constexpr static f32_t divisor_k = 100.f; + constexpr static std::int8_t min_k = -100; + constexpr static std::int8_t max_k = 100; + + inline i8_converted_t() noexcept : int8_(0) {} + inline i8_converted_t(bool v) noexcept : int8_(v ? max_k : 0) {} + + inline i8_converted_t(i8_converted_t&&) = default; + inline i8_converted_t& operator=(i8_converted_t&&) = default; + inline i8_converted_t(i8_converted_t const&) = default; + inline i8_converted_t& operator=(i8_converted_t const&) = default; + + inline operator f16_t() const noexcept { return static_cast(f32_t(int8_) / divisor_k); } + inline operator f32_t() const noexcept { return f32_t(int8_) / divisor_k; } + inline operator f64_t() const noexcept { return f64_t(int8_) / divisor_k; } + inline explicit operator bool() const noexcept { return int8_ > (max_k / 2); } + inline explicit operator std::int8_t() const noexcept { return int8_; } + inline explicit operator std::int16_t() const noexcept { return int8_; } + inline explicit operator std::int32_t() const noexcept { return int8_; } + inline explicit operator std::int64_t() const noexcept { return int8_; } + + inline i8_converted_t(f16_t v) + : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} + inline i8_converted_t(f32_t v) + : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} + inline i8_converted_t(f64_t v) + : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} +}; + +f16_bits_t::f16_bits_t(i8_converted_t v) noexcept : uint16_(f32_to_f16(v)) {} + +template <> struct cast_gt : public cast_gt {}; +template <> struct cast_gt : public cast_gt {}; +template <> struct cast_gt : public cast_gt {}; + +template <> struct cast_gt : public cast_gt {}; +template <> struct cast_gt : public cast_gt {}; +template <> struct cast_gt : public cast_gt {}; + +/** + * @brief Inner (Dot) Product distance. + */ +template struct metric_ip_gt { + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : ab) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; ++i) + ab += result_t(a[i]) * result_t(b[i]); + return 1 - ab; + } +}; + +/** + * @brief Cosine (Angular) distance. + * Identical to the Inner Product of normalized vectors. + * Unless you are running on an tiny embedded platform, this metric + * is recommended over `::metric_ip_gt` for low-precision scalars. + */ +template struct metric_cos_gt { + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab{}, a2{}, b2{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : ab, a2, b2) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + ab += ai * bi, a2 += square(ai), b2 += square(bi); + } + + result_t result_if_zero[2][2]; + result_if_zero[0][0] = 1 - ab / (std::sqrt(a2) * std::sqrt(b2)); + result_if_zero[0][1] = result_if_zero[1][0] = 1; + result_if_zero[1][1] = 0; + return result_if_zero[a2 == 0][b2 == 0]; + } +}; + +/** + * @brief Squared Euclidean (L2) distance. + * Square root is avoided at the end, as it won't affect the ordering. + */ +template struct metric_l2sq_gt { + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab_deltas_sq{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : ab_deltas_sq) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + ab_deltas_sq += square(ai - bi); + } + return ab_deltas_sq; + } +}; + +/** + * @brief Hamming distance computes the number of differing bits in + * two arrays of integers. An example would be a textual document, + * tokenized and hashed into a fixed-capacity bitset. + */ +template struct metric_hamming_gt { + using scalar_t = scalar_at; + using result_t = result_at; + static_assert( // + std::is_unsigned::value || + (std::is_enum::value && std::is_unsigned::type>::value), + "Hamming distance requires unsigned integral words"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { + constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; + result_t matches{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : matches) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != words; ++i) + matches += std::bitset(a[i] ^ b[i]).count(); + return matches; + } +}; + +/** + * @brief Tanimoto distance is the intersection over bitwise union. + * Often used in chemistry and biology to compare molecular fingerprints. + */ +template struct metric_tanimoto_gt { + using scalar_t = scalar_at; + using result_t = result_at; + static_assert( // + std::is_unsigned::value || + (std::is_enum::value && std::is_unsigned::type>::value), + "Tanimoto distance requires unsigned integral words"); + static_assert(std::is_floating_point::value, "Tanimoto distance will be a fraction"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { + constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; + result_t and_count{}; + result_t or_count{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : and_count, or_count) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != words; ++i) { + and_count += std::bitset(a[i] & b[i]).count(); + or_count += std::bitset(a[i] | b[i]).count(); + } + return 1 - result_t(and_count) / or_count; + } +}; + +/** + * @brief Sorensen-Dice or F1 distance is the intersection over bitwise union. + * Often used in chemistry and biology to compare molecular fingerprints. + */ +template struct metric_sorensen_gt { + using scalar_t = scalar_at; + using result_t = result_at; + static_assert( // + std::is_unsigned::value || + (std::is_enum::value && std::is_unsigned::type>::value), + "Sorensen-Dice distance requires unsigned integral words"); + static_assert(std::is_floating_point::value, "Sorensen-Dice distance will be a fraction"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { + constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; + result_t and_count{}; + result_t any_count{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : and_count, any_count) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != words; ++i) { + and_count += std::bitset(a[i] & b[i]).count(); + any_count += std::bitset(a[i]).count() + std::bitset(b[i]).count(); + } + return 1 - 2 * result_t(and_count) / any_count; + } +}; + +/** + * @brief Counts the number of matching elements in two unique sorted sets. + * Can be used to compute the similarity between two textual documents + * using the IDs of tokens present in them. + * Similar to `metric_tanimoto_gt` for dense representations. + */ +template struct metric_jaccard_gt { + using scalar_t = scalar_at; + using result_t = result_at; + static_assert(!std::is_floating_point::value, "Jaccard distance requires integral scalars"); + + inline result_t operator()( // + scalar_t const* a, scalar_t const* b, std::size_t a_length, std::size_t b_length) const noexcept { + result_t intersection{}; + std::size_t i{}; + std::size_t j{}; + while (i != a_length && j != b_length) { + intersection += a[i] == b[j]; + i += a[i] < b[j]; + j += a[i] >= b[j]; + } + return 1 - intersection / (a_length + b_length - intersection); + } +}; + +/** + * @brief Measures Pearson Correlation between two sequences in a single pass. + */ +template struct metric_pearson_gt { + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + // The correlation coefficient can't be defined for one or zero-dimensional data. + if (dim <= 1) + return 0; + // Conventional Pearson Correlation Coefficient definiton subtracts the mean value of each + // sequence from each element, before dividing them. WikiPedia article suggests a convenient + // single-pass algorithm for calculating sample correlations, though depending on the numbers + // involved, it can sometimes be numerically unstable. + result_t a_sum{}, b_sum{}, ab_sum{}; + result_t a_sq_sum{}, b_sq_sum{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : a_sum, b_sum, ab_sum, a_sq_sum, b_sq_sum) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + a_sum += ai; + b_sum += bi; + ab_sum += ai * bi; + a_sq_sum += ai * ai; + b_sq_sum += bi * bi; + } + result_t denom = (dim * a_sq_sum - a_sum * a_sum) * (dim * b_sq_sum - b_sum * b_sum); + if (denom == 0) + return 0; + result_t corr = dim * ab_sum - a_sum * b_sum; + denom = std::sqrt(denom); + // The normal Pearson correlation value is between -1 and 1, but we are looking for a distance. + // So instead of returning `corr / denom`, we return `1 - corr / denom`. + return 1 - corr / denom; + } +}; + +/** + * @brief Measures Jensen-Shannon Divergence between two probability distributions. + */ +template struct metric_divergence_gt { + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* p, scalar_t const* q, std::size_t dim) const noexcept { + result_t kld_pm{}, kld_qm{}; + result_t epsilon = std::numeric_limits::epsilon(); +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : kld_pm, kld_qm) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; ++i) { + result_t pi = static_cast(p[i]); + result_t qi = static_cast(q[i]); + result_t mi = (pi + qi) / 2 + epsilon; + kld_pm += pi * std::log((pi + epsilon) / mi); + kld_qm += qi * std::log((qi + epsilon) / mi); + } + return (kld_pm + kld_qm) / 2; + } +}; + +struct cos_i8_t { + using scalar_t = i8_t; + using result_t = f32_t; + + inline result_t operator()(i8_t const* a, i8_t const* b, std::size_t dim) const noexcept { + std::int32_t ab{}, a2{}, b2{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : ab, a2, b2) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; i++) { + std::int16_t ai{a[i]}; + std::int16_t bi{b[i]}; + ab += ai * bi; + a2 += square(ai); + b2 += square(bi); + } + result_t a2f = std::sqrt(static_cast(a2)); + result_t b2f = std::sqrt(static_cast(b2)); + return (ab != 0) ? (1.f - ab / (a2f * b2f)) : 0; + } +}; + +struct l2sq_i8_t { + using scalar_t = i8_t; + using result_t = f32_t; + + inline result_t operator()(i8_t const* a, i8_t const* b, std::size_t dim) const noexcept { + std::int32_t ab_deltas_sq{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : ab_deltas_sq) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; i++) + ab_deltas_sq += square(std::int16_t(a[i]) - std::int16_t(b[i])); + return static_cast(ab_deltas_sq); + } +}; + +/** + * @brief Haversine distance for the shortest distance between two nodes on + * the surface of a 3D sphere, defined with latitude and longitude. + */ +template struct metric_haversine_gt { + using scalar_t = scalar_at; + using result_t = result_at; + static_assert(!std::is_integral::value, "Latitude and longitude must be floating-node"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t = 2) const noexcept { + result_t lat_a = a[0], lon_a = a[1]; + result_t lat_b = b[0], lon_b = b[1]; + + result_t lat_delta = angle_to_radians(lat_b - lat_a) / 2; + result_t lon_delta = angle_to_radians(lon_b - lon_a) / 2; + + result_t converted_lat_a = angle_to_radians(lat_a); + result_t converted_lat_b = angle_to_radians(lat_b); + + result_t x = square(std::sin(lat_delta)) + // + std::cos(converted_lat_a) * std::cos(converted_lat_b) * square(std::sin(lon_delta)); + + return 2 * std::asin(std::sqrt(x)); + } +}; + +using distance_punned_t = float; +using span_punned_t = span_gt; + +/** + * @brief The signature of the user-defined function. + * Can be just two array pointers, precompiled for a specific array length, + * or include one or two array sizes as 64-bit unsigned integers. + */ +enum class metric_punned_signature_t { + array_array_k = 0, + array_array_size_k, + array_array_state_k, +}; + +/** + * @brief Type-punned metric class, which unlike STL's `std::function` avoids any memory allocations. + * It also provides additional APIs to check, if SIMD hardware-acceleration is available. + * Wraps the `simsimd_metric_punned_t` when available. The auto-vectorized backend otherwise. + */ +class metric_punned_t { + public: + using scalar_t = byte_t; + using result_t = distance_punned_t; + + private: + /// In the generalized function API all the are arguments are pointer-sized. + using uptr_t = std::size_t; + /// Distance function that takes two arrays and returns a scalar. + using metric_array_array_t = result_t (*)(uptr_t, uptr_t); + /// Distance function that takes two arrays and their length and returns a scalar. + using metric_array_array_size_t = result_t (*)(uptr_t, uptr_t, uptr_t); + /// Distance function that takes two arrays and some callback state and returns a scalar. + using metric_array_array_state_t = result_t (*)(uptr_t, uptr_t, uptr_t); + /// Distance function callback, like `metric_array_array_size_t`, but depends on member variables. + using metric_rounted_t = result_t (metric_punned_t::*)(uptr_t, uptr_t) const; + + metric_rounted_t metric_routed_ = nullptr; + uptr_t metric_ptr_ = 0; + uptr_t metric_third_arg_ = 0; + + std::size_t dimensions_ = 0; + metric_kind_t metric_kind_ = metric_kind_t::unknown_k; + scalar_kind_t scalar_kind_ = scalar_kind_t::unknown_k; + +#if USEARCH_USE_SIMSIMD + simsimd_capability_t isa_kind_ = simsimd_cap_serial_k; +#endif + + public: + /** + * @brief Computes the distance between two vectors of fixed length. + * + * ! This is the only relevant function in the object. Everything else is just dynamic dispatch logic. + */ + inline result_t operator()(byte_t const* a, byte_t const* b) const noexcept { + return (this->*metric_routed_)(reinterpret_cast(a), reinterpret_cast(b)); + } + + inline metric_punned_t() noexcept = default; + inline metric_punned_t(metric_punned_t const&) noexcept = default; + inline metric_punned_t& operator=(metric_punned_t const&) noexcept = default; + + inline metric_punned_t(std::size_t dimensions, metric_kind_t metric_kind = metric_kind_t::l2sq_k, + scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept + : metric_punned_t(builtin(dimensions, metric_kind, scalar_kind)) {} + + inline metric_punned_t(std::size_t dimensions, std::uintptr_t metric_uintptr, metric_punned_signature_t signature, + metric_kind_t metric_kind, scalar_kind_t scalar_kind) noexcept + : metric_punned_t(stateless(dimensions, metric_uintptr, signature, metric_kind, scalar_kind)) {} + + /** + * @brief Creates a metric of a natively supported kind, choosing the best + * available backend internally or from SimSIMD. + * + * @param dimensions The number of elements in the input arrays. + * @param metric_kind The kind of metric to use. + * @param scalar_kind The kind of scalar to use. + * @return A metric object that can be used to compute distances between vectors. + */ + inline static metric_punned_t builtin(std::size_t dimensions, metric_kind_t metric_kind = metric_kind_t::l2sq_k, + scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept { + metric_punned_t metric; + metric.metric_routed_ = &metric_punned_t::invoke_array_array_third; + metric.metric_ptr_ = 0; + metric.metric_third_arg_ = + scalar_kind == scalar_kind_t::b1x8_k ? divide_round_up(dimensions) : dimensions; + metric.dimensions_ = dimensions; + metric.metric_kind_ = metric_kind; + metric.scalar_kind_ = scalar_kind; + +#if USEARCH_USE_SIMSIMD + if (!metric.configure_with_simsimd()) + metric.configure_with_autovec(); +#else + metric.configure_with_autovec(); +#endif + + return metric; + } + + /** + * @brief Creates a metric using the provided function pointer for a stateless metric. + * So the provided ::metric_uintptr is a pointer to a function that takes two arrays + * and returns a scalar. If the ::signature is metric_punned_signature_t::array_array_size_k, + * then the third argument is the number of scalar words in the input vectors. + * + * @param dimensions The number of elements in the input arrays. + * @param metric_uintptr The function pointer to the metric function. + * @param signature The signature of the metric function. + * @param metric_kind The kind of metric to use. + * @param scalar_kind The kind of scalar to use. + * @return A metric object that can be used to compute distances between vectors. + */ + inline static metric_punned_t stateless(std::size_t dimensions, std::uintptr_t metric_uintptr, + metric_punned_signature_t signature, metric_kind_t metric_kind, + scalar_kind_t scalar_kind) noexcept { + metric_punned_t metric; + metric.metric_routed_ = signature == metric_punned_signature_t::array_array_k + ? &metric_punned_t::invoke_array_array + : &metric_punned_t::invoke_array_array_third; + metric.metric_ptr_ = metric_uintptr; + metric.metric_third_arg_ = + scalar_kind == scalar_kind_t::b1x8_k ? divide_round_up(dimensions) : dimensions; + metric.dimensions_ = dimensions; + metric.metric_kind_ = metric_kind; + metric.scalar_kind_ = scalar_kind; + return metric; + } + + /** + * @brief Creates a metric using the provided function pointer for a statefull metric. + * The third argument is the state that will be passed to the metric function. + * + * @param metric_uintptr The function pointer to the metric function. + * @param metric_state The state to pass to the metric function. + * @param metric_kind The kind of metric to use. + * @param scalar_kind The kind of scalar to use. + * @return A metric object that can be used to compute distances between vectors. + */ + inline static metric_punned_t statefull(std::uintptr_t metric_uintptr, std::uintptr_t metric_state, + metric_kind_t metric_kind = metric_kind_t::unknown_k, + scalar_kind_t scalar_kind = scalar_kind_t::unknown_k) noexcept { + metric_punned_t metric; + metric.metric_routed_ = &metric_punned_t::invoke_array_array_third; + metric.metric_ptr_ = metric_uintptr; + metric.metric_third_arg_ = metric_state; + metric.dimensions_ = 0; + metric.metric_kind_ = metric_kind; + metric.scalar_kind_ = scalar_kind; + return metric; + } + + inline std::size_t dimensions() const noexcept { return dimensions_; } + inline metric_kind_t metric_kind() const noexcept { return metric_kind_; } + inline scalar_kind_t scalar_kind() const noexcept { return scalar_kind_; } + inline explicit operator bool() const noexcept { return metric_routed_ && metric_ptr_; } + + /** + * @brief Checks fi we've failed to initialized the metric with provided arguments. + * + * It's different from `operator bool()` when it comes to explicitly uninitialized metrics. + * It's a common case, where a NULL state is created only to be overwritten later, when + * we recover an old index state from a file or a network. + */ + inline bool missing() const noexcept { return !bool(*this) && metric_kind_ != metric_kind_t::unknown_k; } + + inline char const* isa_name() const noexcept { + if (!*this) + return "uninitialized"; + +#if USEARCH_USE_SIMSIMD + switch (isa_kind_) { + case simsimd_cap_serial_k: return "serial"; + case simsimd_cap_neon_k: return "neon"; + case simsimd_cap_sve_k: return "sve"; + case simsimd_cap_haswell_k: return "haswell"; + case simsimd_cap_skylake_k: return "skylake"; + case simsimd_cap_ice_k: return "ice"; + case simsimd_cap_sapphire_k: return "sapphire"; + default: return "unknown"; + } +#endif + return "serial"; + } + + inline std::size_t bytes_per_vector() const noexcept { + return divide_round_up(dimensions_ * bits_per_scalar(scalar_kind_)); + } + + inline std::size_t scalar_words() const noexcept { + return divide_round_up(dimensions_ * bits_per_scalar(scalar_kind_), bits_per_scalar_word(scalar_kind_)); + } + + private: +#if USEARCH_USE_SIMSIMD + bool configure_with_simsimd(simsimd_capability_t simd_caps) noexcept { + simsimd_metric_kind_t kind = simsimd_metric_unknown_k; + simsimd_datatype_t datatype = simsimd_datatype_unknown_k; + simsimd_capability_t allowed = simsimd_cap_any_k; + switch (metric_kind_) { + case metric_kind_t::ip_k: kind = simsimd_metric_dot_k; break; + case metric_kind_t::cos_k: kind = simsimd_metric_cos_k; break; + case metric_kind_t::l2sq_k: kind = simsimd_metric_l2sq_k; break; + case metric_kind_t::hamming_k: kind = simsimd_metric_hamming_k; break; + case metric_kind_t::tanimoto_k: kind = simsimd_metric_jaccard_k; break; + case metric_kind_t::jaccard_k: kind = simsimd_metric_jaccard_k; break; + default: break; + } + switch (scalar_kind_) { + case scalar_kind_t::f32_k: datatype = simsimd_datatype_f32_k; break; + case scalar_kind_t::f64_k: datatype = simsimd_datatype_f64_k; break; + case scalar_kind_t::f16_k: datatype = simsimd_datatype_f16_k; break; + case scalar_kind_t::i8_k: datatype = simsimd_datatype_i8_k; break; + case scalar_kind_t::b1x8_k: datatype = simsimd_datatype_b8_k; break; + default: break; + } + simsimd_metric_punned_t simd_metric = NULL; + simsimd_capability_t simd_kind = simsimd_cap_any_k; + simsimd_find_metric_punned(kind, datatype, simd_caps, allowed, &simd_metric, &simd_kind); + if (simd_metric == nullptr) + return false; + + std::memcpy(&metric_ptr_, &simd_metric, sizeof(simd_metric)); + metric_routed_ = metric_kind_ == metric_kind_t::ip_k + ? reinterpret_cast(&metric_punned_t::invoke_simsimd_reverse) + : reinterpret_cast(&metric_punned_t::invoke_simsimd); + isa_kind_ = simd_kind; + return true; + } + bool configure_with_simsimd() noexcept { + static simsimd_capability_t static_capabilities = simsimd_capabilities(); + return configure_with_simsimd(static_capabilities); + } + result_t invoke_simsimd(uptr_t a, uptr_t b) const noexcept { + simsimd_distance_t result; + // Here `reinterpret_cast` raises warning... we know what we are doing! + auto function_pointer = (simsimd_metric_punned_t)(metric_ptr_); + function_pointer(reinterpret_cast(a), reinterpret_cast(b), metric_third_arg_, + &result); + return (result_t)result; + } + result_t invoke_simsimd_reverse(uptr_t a, uptr_t b) const noexcept { return 1 - invoke_simsimd(a, b); } +#else + bool configure_with_simsimd() noexcept { return false; } +#endif + result_t invoke_array_array_third(uptr_t a, uptr_t b) const noexcept { + auto function_pointer = (metric_array_array_size_t)(metric_ptr_); + result_t result = function_pointer(a, b, metric_third_arg_); + return result; + } + result_t invoke_array_array(uptr_t a, uptr_t b) const noexcept { + auto function_pointer = (metric_array_array_t)(metric_ptr_); + result_t result = function_pointer(a, b); + return result; + } + void configure_with_autovec() noexcept { + switch (metric_kind_) { + case metric_kind_t::ip_k: { + switch (scalar_kind_) { + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::cos_k: { + switch (scalar_kind_) { + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::l2sq_k: { + switch (scalar_kind_) { + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::pearson_k: { + switch (scalar_kind_) { + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::haversine_k: { + switch (scalar_kind_) { + case scalar_kind_t::f16_k: + metric_ptr_ = (uptr_t)&equidimensional_>; + break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::divergence_k: { + switch (scalar_kind_) { + case scalar_kind_t::f16_k: + metric_ptr_ = (uptr_t)&equidimensional_>; + break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::jaccard_k: // Equivalent to Tanimoto + case metric_kind_t::tanimoto_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case metric_kind_t::hamming_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case metric_kind_t::sorensen_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: return; + } + } + + template + inline static result_t equidimensional_(uptr_t a, uptr_t b, uptr_t a_dimensions) noexcept { + using scalar_t = typename typed_at::scalar_t; + return static_cast(typed_at{}((scalar_t const*)a, (scalar_t const*)b, a_dimensions)); + } +}; + +/** + * @brief View over a potentially-strided memory buffer, containing a row-major matrix. + */ +template // +class vectors_view_gt { + using scalar_t = scalar_at; + + scalar_t const* begin_{}; + std::size_t dimensions_{}; + std::size_t count_{}; + std::size_t stride_bytes_{}; + + public: + vectors_view_gt() noexcept = default; + vectors_view_gt(vectors_view_gt const&) noexcept = default; + vectors_view_gt& operator=(vectors_view_gt const&) noexcept = default; + + vectors_view_gt(scalar_t const* begin, std::size_t dimensions, std::size_t count = 1) noexcept + : vectors_view_gt(begin, dimensions, count, dimensions * sizeof(scalar_at)) {} + + vectors_view_gt(scalar_t const* begin, std::size_t dimensions, std::size_t count, std::size_t stride_bytes) noexcept + : begin_(begin), dimensions_(dimensions), count_(count), stride_bytes_(stride_bytes) {} + + explicit operator bool() const noexcept { return begin_; } + std::size_t size() const noexcept { return count_; } + std::size_t dimensions() const noexcept { return dimensions_; } + std::size_t stride() const noexcept { return stride_bytes_; } + scalar_t const* data() const noexcept { return begin_; } + scalar_t const* at(std::size_t i) const noexcept { + return reinterpret_cast(reinterpret_cast(begin_) + i * stride_bytes_); + } +}; + +struct exact_offset_and_distance_t { + u32_t offset; + f32_t distance; +}; + +using exact_search_results_t = vectors_view_gt; + +/** + * @brief Helper-structure for exact search operations. + * Perfect if you have @b <1M vectors and @b <100 queries per call. + * + * Uses a 3-step procedure to minimize: + * - cache-misses on vector lookups, + * - multi-threaded contention on concurrent writes. + */ +class exact_search_t { + + inline static bool smaller_distance(exact_offset_and_distance_t a, exact_offset_and_distance_t b) noexcept { + return a.distance < b.distance; + } + + using keys_and_distances_t = buffer_gt; + keys_and_distances_t keys_and_distances; + + public: + template + exact_search_results_t operator()( // + vectors_view_gt dataset, vectors_view_gt queries, // + std::size_t wanted, metric_punned_t const& metric, // + executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + return operator()( // + metric, // + reinterpret_cast(dataset.data()), dataset.size(), dataset.stride(), // + reinterpret_cast(queries.data()), queries.size(), queries.stride(), // + wanted, executor, progress); + } + + template + exact_search_results_t operator()( // + byte_t const* dataset_data, std::size_t dataset_count, std::size_t dataset_stride, // + byte_t const* queries_data, std::size_t queries_count, std::size_t queries_stride, // + std::size_t wanted, metric_punned_t const& metric, executor_at&& executor = executor_at{}, + progress_at&& progress = progress_at{}) { + + // Allocate temporary memory to store the distance matrix + // Previous version didn't need temporary memory, but the performance was much lower. + // In the new design we keep two buffers - original and transposed, as in-place transpositions + // of non-rectangular matrixes is expensive. + std::size_t tasks_count = dataset_count * queries_count; + if (keys_and_distances.size() < tasks_count * 2) + keys_and_distances = keys_and_distances_t(tasks_count * 2); + if (keys_and_distances.size() < tasks_count * 2) + return {}; + + exact_offset_and_distance_t* keys_and_distances_per_dataset = keys_and_distances.data(); + exact_offset_and_distance_t* keys_and_distances_per_query = keys_and_distances_per_dataset + tasks_count; + + // §1. Compute distances in a data-parallel fashion + std::atomic processed{0}; + executor.dynamic(dataset_count, [&](std::size_t thread_idx, std::size_t dataset_idx) { + byte_t const* dataset = dataset_data + dataset_idx * dataset_stride; + for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx) { + byte_t const* query = queries_data + query_idx * queries_stride; + auto distance = metric(dataset, query); + std::size_t task_idx = queries_count * dataset_idx + query_idx; + keys_and_distances_per_dataset[task_idx].offset = static_cast(dataset_idx); + keys_and_distances_per_dataset[task_idx].distance = static_cast(distance); + } + + // It's more efficient in this case to report progress from a single thread + processed += queries_count; + if (thread_idx == 0) + if (!progress(processed.load(), tasks_count)) + return false; + return true; + }); + if (processed.load() != tasks_count) + return {}; + + // §2. Transpose in a single thread to avoid contention writing into the same memory buffers + for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx) { + for (std::size_t dataset_idx = 0; dataset_idx != dataset_count; ++dataset_idx) { + std::size_t from_idx = queries_count * dataset_idx + query_idx; + std::size_t to_idx = dataset_count * query_idx + dataset_idx; + keys_and_distances_per_query[to_idx] = keys_and_distances_per_dataset[from_idx]; + } + } + + // §3. Partial-sort every query result + executor.fixed(queries_count, [&](std::size_t, std::size_t query_idx) { + auto start = keys_and_distances_per_query + dataset_count * query_idx; + if (wanted > 1) { + // TODO: Consider alternative sorting approaches + // radix_sort(start, start + dataset_count, wanted); + // std::sort(start, start + dataset_count, &smaller_distance); + std::partial_sort(start, start + wanted, start + dataset_count, &smaller_distance); + } else { + auto min_it = std::min_element(start, start + dataset_count, &smaller_distance); + if (min_it != start) + std::swap(*min_it, *start); + } + }); + + // At the end report the latest numbers, because the reporter thread may be finished earlier + progress(tasks_count, tasks_count); + return {keys_and_distances_per_query, wanted, queries_count, + dataset_count * sizeof(exact_offset_and_distance_t)}; + } +}; + +/** + * @brief C++11 Multi-Hash-Set with Linear Probing. + * + * - Allows multiple equivalent values, + * - Supports transparent hashing and equality operator. + * - Doesn't throw exceptions, if forbidden. + * - Doesn't need reserving a value for deletions. + * + * @section Layout + * + * For every slot we store 2 extra bits for 3 possible states: empty, populated, or deleted. + * With linear probing the hashes at the end of the populated region will spill into its first half. + */ +template > +class flat_hash_multi_set_gt { + public: + using element_t = element_at; + using hash_t = hash_at; + using equals_t = equals_at; + using allocator_t = allocator_at; + + static constexpr std::size_t slots_per_bucket() { return 64; } + static constexpr std::size_t bytes_per_bucket() { + return slots_per_bucket() * sizeof(element_t) + sizeof(bucket_header_t); + } + + private: + struct bucket_header_t { + std::uint64_t populated{}; + std::uint64_t deleted{}; + }; + char* data_ = nullptr; + std::size_t buckets_ = 0; + std::size_t populated_slots_ = 0; + /// @brief Number of slots + std::size_t capacity_slots_ = 0; + + struct slot_ref_t { + bucket_header_t& header; + std::uint64_t mask; + element_t& element; + }; + + slot_ref_t slot_ref(char* data, std::size_t slot_index) const noexcept { + std::size_t bucket_index = slot_index / slots_per_bucket(); + std::size_t in_bucket_index = slot_index % slots_per_bucket(); + auto bucket_pointer = data + bytes_per_bucket() * bucket_index; + auto slot_pointer = bucket_pointer + sizeof(bucket_header_t) + sizeof(element_t) * in_bucket_index; + return { + *reinterpret_cast(bucket_pointer), + static_cast(1ull) << in_bucket_index, + *reinterpret_cast(slot_pointer), + }; + } + + slot_ref_t slot_ref(std::size_t slot_index) const noexcept { return slot_ref(data_, slot_index); } + + bool populate_slot(slot_ref_t slot, element_t const& new_element) { + if (slot.header.populated & slot.mask) { + slot.element = new_element; + slot.header.deleted &= ~slot.mask; + return false; + } else { + new (&slot.element) element_t(new_element); + slot.header.populated |= slot.mask; + return true; + } + } + + public: + std::size_t size() const noexcept { return populated_slots_; } + std::size_t capacity() const noexcept { return capacity_slots_; } + + flat_hash_multi_set_gt() noexcept {} + ~flat_hash_multi_set_gt() noexcept { reset(); } + + flat_hash_multi_set_gt(flat_hash_multi_set_gt const& other) { + + // On Windows allocating a zero-size array would fail + if (!other.buckets_) { + reset(); + return; + } + + // Allocate new memory + data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); + if (!data_) + throw std::bad_alloc(); + + // Copy metadata + buckets_ = other.buckets_; + populated_slots_ = other.populated_slots_; + capacity_slots_ = other.capacity_slots_; + + // Initialize new buckets to empty + std::memset(data_, 0, buckets_ * bytes_per_bucket()); + + // Copy elements and bucket headers + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t old_slot = other.slot_ref(i); + if ((old_slot.header.populated & old_slot.mask) && !(old_slot.header.deleted & old_slot.mask)) { + slot_ref_t new_slot = slot_ref(i); + populate_slot(new_slot, old_slot.element); + } + } + } + + flat_hash_multi_set_gt& operator=(flat_hash_multi_set_gt const& other) { + + // On Windows allocating a zero-size array would fail + if (!other.buckets_) { + reset(); + return *this; + } + + // Handle self-assignment + if (this == &other) + return *this; + + // Clear existing data + clear(); + if (data_) + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); + + // Allocate new memory + data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); + if (!data_) + throw std::bad_alloc(); + + // Copy metadata + buckets_ = other.buckets_; + populated_slots_ = other.populated_slots_; + capacity_slots_ = other.capacity_slots_; + + // Initialize new buckets to empty + std::memset(data_, 0, buckets_ * bytes_per_bucket()); + + // Copy elements and bucket headers + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t old_slot = other.slot_ref(i); + if ((old_slot.header.populated & old_slot.mask) && !(old_slot.header.deleted & old_slot.mask)) { + slot_ref_t new_slot = slot_ref(i); + populate_slot(new_slot, old_slot.element); + } + } + + return *this; + } + + void clear() noexcept { + // Call the destructors + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t slot = slot_ref(i); + if ((slot.header.populated & slot.mask) & (~slot.header.deleted & slot.mask)) + slot.element.~element_t(); + } + + // Reset populated slots count + if (data_) + std::memset(data_, 0, buckets_ * bytes_per_bucket()); + populated_slots_ = 0; + } + + void reset() noexcept { + clear(); // Clear all elements + if (data_) + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); + buckets_ = 0; + populated_slots_ = 0; + capacity_slots_ = 0; + } + + bool try_reserve(std::size_t capacity) noexcept { + if (capacity * 3u <= capacity_slots_ * 2u) + return true; + + // Calculate new sizes + std::size_t new_slots = ceil2((capacity * 3ul) / 2ul); + std::size_t new_buckets = divide_round_up(new_slots); + new_slots = new_buckets * slots_per_bucket(); // This must be a power of two! + std::size_t new_bytes = new_buckets * bytes_per_bucket(); + + // Allocate new memory + char* new_data = (char*)allocator_t{}.allocate(new_bytes); + if (!new_data) + return false; + + // Initialize new buckets to empty + std::memset(new_data, 0, new_bytes); + + // Rehash and copy existing elements to new_data + hash_t hasher; + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t old_slot = slot_ref(i); + if ((~old_slot.header.populated & old_slot.mask) | (old_slot.header.deleted & old_slot.mask)) + continue; + + // Rehash + std::size_t hash_value = hasher(old_slot.element); + std::size_t new_slot_index = hash_value & (new_slots - 1); + + // Linear probing to find an empty slot in new_data + while (true) { + slot_ref_t new_slot = slot_ref(new_data, new_slot_index); + if (!(new_slot.header.populated & new_slot.mask) || (new_slot.header.deleted & new_slot.mask)) { + populate_slot(new_slot, std::move(old_slot.element)); + new_slot.header.populated |= new_slot.mask; + break; + } + new_slot_index = (new_slot_index + 1) & (new_slots - 1); + } + } + + // Deallocate old data and update pointers and sizes + if (data_) + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); + data_ = new_data; + buckets_ = new_buckets; + capacity_slots_ = new_slots; + + return true; + } + + template class equal_iterator_gt { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = element_t*; + using reference = element_t&; + + equal_iterator_gt(std::size_t index, flat_hash_multi_set_gt* parent, query_at const& query, + equals_t const& equals) + : index_(index), parent_(parent), query_(query), equals_(equals) {} + + // Pre-increment + equal_iterator_gt& operator++() { + do { + index_ = (index_ + 1) & (parent_->capacity_slots_ - 1); + } while (!equals_(parent_->slot_ref(index_).element, query_) && + (parent_->slot_ref(index_).header.populated & parent_->slot_ref(index_).mask)); + return *this; + } + + equal_iterator_gt operator++(int) { + equal_iterator_gt temp = *this; + ++(*this); + return temp; + } + + reference operator*() { return parent_->slot_ref(index_).element; } + pointer operator->() { return &parent_->slot_ref(index_).element; } + bool operator!=(equal_iterator_gt const& other) const { return !(*this == other); } + bool operator==(equal_iterator_gt const& other) const { + return index_ == other.index_ && parent_ == other.parent_; + } + + private: + std::size_t index_; + flat_hash_multi_set_gt* parent_; + query_at query_; // Store the query object + equals_t equals_; // Store the equals functor + }; + + /** + * @brief Returns an iterator range of all elements matching the given query. + * + * Technically, the second iterator points to the first empty slot after a + * range of equal values and non-equal values with similar hashes. + */ + template + std::pair, equal_iterator_gt> + equal_range(query_at const& query) const noexcept { + + equals_t equals; + auto this_ptr = const_cast(this); + auto end = equal_iterator_gt(capacity_slots_, this_ptr, query, equals); + if (!capacity_slots_) + return {end, end}; + + hash_t hasher; + std::size_t hash_value = hasher(query); + std::size_t first_equal_index = hash_value & (capacity_slots_ - 1); + std::size_t const start_index = first_equal_index; + + // Linear probing to find the first equal element + do { + slot_ref_t slot = slot_ref(first_equal_index); + if (slot.header.populated & ~slot.header.deleted & slot.mask) { + if (equals(slot.element, query)) + break; + } + // Stop if we find an empty slot + else if (~slot.header.populated & slot.mask) + return {end, end}; + + // Move to the next slot + first_equal_index = (first_equal_index + 1) & (capacity_slots_ - 1); + } while (first_equal_index != start_index); + + // If no matching element was found, return end iterators + if (first_equal_index == capacity_slots_) + return {end, end}; + + // Start from the first matching element and find the end of the populated range + std::size_t first_empty_index = first_equal_index; + do { + first_empty_index = (first_empty_index + 1) & (capacity_slots_ - 1); + slot_ref_t slot = slot_ref(first_empty_index); + + // If we find an empty slot, this is our end + if (~slot.header.populated & slot.mask) + break; + } while (first_empty_index != start_index); + + return {equal_iterator_gt(first_equal_index, this_ptr, query, equals), + equal_iterator_gt(first_empty_index, this_ptr, query, equals)}; + } + + template bool pop_first(similar_at&& query, element_t& popped_value) noexcept { + + if (!capacity_slots_) + return false; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + std::size_t start_index = slot_index; // To detect loop in probing + + // Linear probing to find the first match + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) { + // Found a match, mark as deleted + slot.header.deleted |= slot.mask; + --populated_slots_; + popped_value = slot.element; + return true; // Successfully removed + } + } else { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + } while (slot_index != start_index); + + return false; // No match found + } + + template std::size_t erase(similar_at&& query) noexcept { + + if (!capacity_slots_) + return 0; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + std::size_t const start_index = slot_index; // To detect loop in probing + std::size_t count = 0; // Count of elements removed + + // Linear probing to find all matches + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) { + // Found a match, mark as deleted + slot.header.deleted |= slot.mask; + --populated_slots_; + ++count; // Increment count of elements removed + } + } else { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + } while (slot_index != start_index); + + return count; // Return the number of elements removed + } + + template element_t const* find(similar_at&& query) const noexcept { + + if (!capacity_slots_) + return nullptr; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + std::size_t start_index = slot_index; // To detect loop in probing + + // Linear probing to find the first match + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) + return &slot.element; // Found a match, return pointer to the element + } else { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + } while (slot_index != start_index); + + return nullptr; // No match found + } + + element_t const* end() const noexcept { return nullptr; } + + template void for_each(func_at&& func) const { + for (std::size_t bucket_index = 0; bucket_index < buckets_; ++bucket_index) { + auto bucket_pointer = data_ + bytes_per_bucket() * bucket_index; + bucket_header_t& header = *reinterpret_cast(bucket_pointer); + std::uint64_t populated = header.populated; + std::uint64_t deleted = header.deleted; + + // Iterate through slots in the bucket + for (std::size_t in_bucket_index = 0; in_bucket_index < slots_per_bucket(); ++in_bucket_index) { + std::uint64_t mask = std::uint64_t(1ull) << in_bucket_index; + + // Check if the slot is populated and not deleted + if ((populated & ~deleted) & mask) { + auto slot_pointer = bucket_pointer + sizeof(bucket_header_t) + sizeof(element_t) * in_bucket_index; + element_t const& element = *reinterpret_cast(slot_pointer); + func(element); + } + } + } + } + + template std::size_t count(similar_at&& query) const noexcept { + + if (!capacity_slots_) + return 0; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); + std::size_t start_index = slot_index; // To detect loop in probing + std::size_t count = 0; + + // Linear probing to find the range + do { + slot_ref_t slot = slot_ref(slot_index); + if ((slot.header.populated & slot.mask) && (~slot.header.deleted & slot.mask)) { + if (equals(slot.element, query)) + ++count; + } else if (~slot.header.populated & slot.mask) { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); + } while (slot_index != start_index); + + return count; + } + + template bool contains(similar_at&& query) const noexcept { + + if (!capacity_slots_) + return false; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); + std::size_t start_index = slot_index; // To detect loop in probing + + // Linear probing to find the first match + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) + return true; // Found a match, exit early + } else + // Stop if we find an empty slot + break; + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); + } while (slot_index != start_index); + + return false; // No match found + } + + void reserve(std::size_t capacity) { + if (!try_reserve(capacity)) + throw std::bad_alloc(); + } + + bool try_emplace(element_t const& element) noexcept { + // Check if we need to resize + if (populated_slots_ * 3u >= capacity_slots_ * 2u) + if (!try_reserve(populated_slots_ + 1)) + return false; + + hash_t hasher; + std::size_t hash_value = hasher(element); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); + + // Linear probing + while (true) { + slot_ref_t slot = slot_ref(slot_index); + if ((~slot.header.populated & slot.mask) | (slot.header.deleted & slot.mask)) { + // Found an empty or deleted slot + populate_slot(slot, element); + ++populated_slots_; + return true; + } + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); + } + } +}; + +} // namespace usearch +} // namespace unum diff --git a/src/yb/docdb/CMakeLists.txt b/src/yb/docdb/CMakeLists.txt index c051e9cdba6d..7c9b96f4f3e0 100644 --- a/src/yb/docdb/CMakeLists.txt +++ b/src/yb/docdb/CMakeLists.txt @@ -142,6 +142,7 @@ ADD_YB_TEST(scan_choices-test) ADD_YB_TEST(shared_lock_manager-test) ADD_YB_TEST(consensus_frontier-test) ADD_YB_TEST(compaction_file_filter-test) +ADD_YB_TEST(usearch_vector_index-test) if(YB_BUILD_FUZZ_TARGETS) # A library with common code shared between DocDB fuzz tests. diff --git a/src/yb/docdb/usearch_vector_index-test.cc b/src/yb/docdb/usearch_vector_index-test.cc new file mode 100644 index 000000000000..c48489c0763e --- /dev/null +++ b/src/yb/docdb/usearch_vector_index-test.cc @@ -0,0 +1,152 @@ +// Copyright (c) YugabyteDB, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations +// under the License. +// + +#pragma GCC diagnostic push + +#ifdef __clang__ +// For https://gist.githubusercontent.com/mbautin/87278fc41654c6c74cf7232960364c95/raw +#pragma GCC diagnostic ignored "-Wpass-failed" + +#if __clang_major__ == 14 +// For https://gist.githubusercontent.com/mbautin/7856257553a1d41734b1cec7c73a0fb4/raw +#pragma GCC diagnostic ignored "-Wambiguous-reversed-operator" +#endif +#endif // __clang__ + +#include "usearch/index.hpp" +#include "usearch/index_dense.hpp" + +#pragma GCC diagnostic pop + +#include +#include +#include +#include +#include +#include + +#include "yb/util/logging.h" +#include "yb/util/monotime.h" +#include "yb/util/random_util.h" +#include "yb/util/test_thread_holder.h" +#include "yb/util/test_util.h" +#include "yb/util/tsan_util.h" + +// Helper function to generate random vectors +template +std::vector GenerateRandomVector(size_t dimensions, Distribution& dis) { + std::vector vec(dimensions); + for (auto& v : vec) { + v = static_cast(dis(yb::ThreadLocalRandom())); + } + return vec; +} + +namespace yb { + +class UsearchVectorIndexTest : public YBTest { +}; + +void ReportPerf( + const char* verb_past, size_t num_items, const char* item_description_plural, + size_t dimensions, int64_t elapsed_usec, size_t num_threads) { + LOG(INFO) << verb_past << " " << num_items << " " << item_description_plural << " with " + << dimensions << " dimensions in " << (elapsed_usec / 1000.0) << " ms " + << "(" << (elapsed_usec * 1.0 / num_items) << " usec per vector, or " + << (elapsed_usec * 1.0 / num_items / dimensions) << " usec per coordinate), " + << "using " << num_threads << " threads"; +} + +TEST_F(UsearchVectorIndexTest, CreateAndQuery) { + using namespace unum::usearch; + + // Create a metric and index + const size_t kDimensions = 96; + metric_punned_t metric(kDimensions, metric_kind_t::l2sq_k, scalar_kind_t::f32_k); + + // Generate and add vectors to the index + const size_t kNumVectors = ReleaseVsDebugVsAsanVsTsan(100000, 20000, 15000, 10000); + const size_t kNumIndexingThreads = 4; + + std::uniform_real_distribution<> uniform_distrib(0, 1); + + std::string index_path; + { + TestThreadHolder indexing_thread_holder; + index_dense_t index = index_dense_t::make(metric); + index.reserve(kNumVectors); + auto load_start_time = MonoTime::Now(); + CountDownLatch latch(kNumIndexingThreads); + std::atomic num_vectors_inserted{0}; + for (size_t thread_index = 0; thread_index < kNumIndexingThreads; ++thread_index) { + indexing_thread_holder.AddThreadFunctor( + [&num_vectors_inserted, &index, &latch, &uniform_distrib]() { + std::random_device rd; + size_t vector_id; + while ((vector_id = num_vectors_inserted.fetch_add(1)) < kNumVectors) { + auto vec = GenerateRandomVector(kDimensions, uniform_distrib); + ASSERT_TRUE(index.add(vector_id, vec.data())); + } + latch.CountDown(); + }); + } + latch.Wait(); + auto load_elapsed_usec = (MonoTime::Now() - load_start_time).ToMicroseconds(); + ReportPerf("Indexed", kNumVectors, "vectors", kDimensions, load_elapsed_usec, + kNumIndexingThreads); + + // Save the index to a file + index_path = GetTestDataDirectory() + "/hnsw_index.usearch"; + ASSERT_TRUE(index.save(index_path.c_str())); + } + auto file_size = ASSERT_RESULT(Env::Default()->GetFileSize(index_path)); + LOG(INFO) << "Index saved to " << index_path; + LOG(INFO) << "Index file size: " << file_size << " bytes " + << "(" << (file_size / 1024.0 / 1024.0) << " MiB), " + << (file_size * 1.0 / kNumVectors ) << " average bytes per vector, " + << (file_size * 1.0 / (kNumVectors * kDimensions)) << " average bytes per coordinate"; + + // Load the index from the file + auto index_load_start_time = MonoTime::Now(); + index_dense_t loaded_index = index_dense_t::make(index_path.c_str(), /* load= */ true); + ASSERT_TRUE(loaded_index); // Ensure loading was successful + auto index_load_elapsed_usec = (MonoTime::Now() - index_load_start_time).ToMicroseconds(); + LOG(INFO) << "Index loaded from " << index_path << " in " << (index_load_elapsed_usec / 1000.0) + << " ms"; + + auto query_start_time = MonoTime::Now(); + const size_t kNumQueryThreads = 4; + const size_t kNumQueries = ReleaseVsDebugVsAsanVsTsan(100000, 30000, 40000, 10000); + const size_t kNumResultsPerQuery = 10; + CountDownLatch latch(kNumQueryThreads); + std::atomic num_vectors_queried; + TestThreadHolder query_thread_holder; + for (size_t thread_index = 0; thread_index < kNumQueryThreads; ++thread_index) { + query_thread_holder.AddThreadFunctor( + [&num_vectors_queried, &loaded_index, &latch, &uniform_distrib, kNumResultsPerQuery]() { + // Perform searches on the loaded index + while (num_vectors_queried.fetch_add(1) < kNumQueries) { + auto query_vec = GenerateRandomVector(kDimensions, uniform_distrib); + auto results = loaded_index.search(query_vec.data(), kNumResultsPerQuery); + ASSERT_LE(results.size(), kNumResultsPerQuery); + } + latch.CountDown(); + }); + } + latch.Wait(); + auto query_elapsed_usec = (MonoTime::Now() - query_start_time).ToMicroseconds(); + ReportPerf("Performed", kNumQueries, "queries", kDimensions, query_elapsed_usec, + kNumIndexingThreads); +}; + +} // namespace yb diff --git a/src/yb/util/tsan_util.h b/src/yb/util/tsan_util.h index 66a615e2d0b4..812b55258c81 100644 --- a/src/yb/util/tsan_util.h +++ b/src/yb/util/tsan_util.h @@ -49,6 +49,20 @@ constexpr T RegularBuildVsDebugVsSanitizers( #endif } +template +constexpr T ReleaseVsDebugVsAsanVsTsan( + T release_build_value, T debug_build_value, T asan_value, T tsan_value) { +#if defined(THREAD_SANITIZER) + return tsan_value; +#elif defined(ADDRESS_SANITIZER) + return asan_value; +#elif defined(NDEBUG) + return release_build_value; +#else + return debug_build_value; +#endif +} + constexpr bool IsSanitizer() { return RegularBuildVsSanitizers(false, true); }