Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
d51bd49
First version
azevaykin Apr 27, 2024
812b332
Alice test
azevaykin Apr 28, 2024
4b914c2
ToBitString
azevaykin Apr 29, 2024
d6a6520
topk & seed params
azevaykin Apr 30, 2024
973e2d2
cache friendly
azevaykin Apr 30, 2024
41ced45
distanceThreshold
azevaykin May 1, 2024
1167c45
ToBinaryString with format
azevaykin May 4, 2024
5924d25
ManhattenDistance
azevaykin May 4, 2024
c5ce63f
template TKnnVectorSerializer
azevaykin May 4, 2024
e2d2cc5
FloatByteVector
azevaykin May 4, 2024
36b345f
Do not store format for bitstring
azevaykin May 4, 2024
714d5b6
Manhattan
azevaykin May 4, 2024
93e3274
Revert "Do not store format for bitstring"
azevaykin May 5, 2024
9f47f08
Remove bit indexes
azevaykin May 5, 2024
9661dc0
floatbyte -> float
azevaykin May 12, 2024
67c6533
Wrong format fix
azevaykin May 12, 2024
3f2b174
ManhattanDistance for float&byte vectors
azevaykin May 14, 2024
c26cfea
Euclidean distance
azevaykin May 14, 2024
530346d
ManhattanDistance returns float, so bit vector is no longer limited b…
azevaykin May 14, 2024
7e41fc5
Remove docker files
azevaykin May 17, 2024
f447320
Apply suggestions from code review
azevaykin May 17, 2024
cd4a529
Uint8Vector
azevaykin May 17, 2024
c233555
remove vector2.empty()
azevaykin May 17, 2024
d0327be
style
azevaykin May 17, 2024
684c480
const auto
azevaykin May 17, 2024
960bc8a
First step
MBkkt May 20, 2024
ab10a4f
Second step
MBkkt May 20, 2024
6f9c3ab
Format
MBkkt May 20, 2024
14d154a
Fix
MBkkt May 20, 2024
78bf150
Add new tests
MBkkt May 20, 2024
bf41d9b
Apply review suggestions and add more tests
MBkkt May 21, 2024
0f05a2a
Apply review suggestions
MBkkt May 21, 2024
845796f
Apply review suggestions
MBkkt May 21, 2024
5dc24c5
Apply review suggestions
MBkkt May 21, 2024
fdd26c7
Apply review suggestions
MBkkt May 21, 2024
31a0ba2
Merge pull request #3 from MBkkt/BitIndex
azevaykin May 22, 2024
a8f385e
Add another tests
MBkkt May 22, 2024
7bc2057
Add another test
MBkkt May 22, 2024
3b95f33
BitSerialization tests
azevaykin May 23, 2024
ef87329
Change format arg to funcname
MBkkt May 23, 2024
f8e34ed
Some improvements
MBkkt May 23, 2024
d0affd3
WIP
MBkkt May 24, 2024
7a6a046
WIP
MBkkt May 24, 2024
e54af87
WIP
MBkkt May 27, 2024
48b03d2
Review suggestions
MBkkt May 27, 2024
797c698
Back to optional :(
MBkkt May 27, 2024
75e82f8
Merge pull request #5 from MBkkt/BitIndex
azevaykin May 27, 2024
456ce3e
Fix typo
MBkkt May 27, 2024
01c1992
Merge pull request #6 from MBkkt/BitIndex
azevaykin May 27, 2024
9ee300f
Fix null forwarding and optional unpacking for functions twhich use t…
MBkkt May 28, 2024
59805d3
Merge pull request #7 from MBkkt/BitIndex
azevaykin May 28, 2024
b412c9e
better
MBkkt May 31, 2024
c80646b
Merge pull request #9 from MBkkt/BitIndex
azevaykin May 31, 2024
301dace
temporary change
MBkkt May 31, 2024
df732fb
temporary change (#10)
MBkkt May 31, 2024
3872187
better
MBkkt May 31, 2024
59c407f
Merge branch 'BitIndex' into BitIndex
MBkkt May 31, 2024
704927e
Bit index (#11)
MBkkt May 31, 2024
f286bd7
add overloads for ToBinaryStringBit
MBkkt Jun 3, 2024
36f534f
improvements
MBkkt Jun 3, 2024
6b460b0
add overloads for ToBinaryStringBit
azevaykin Jun 3, 2024
a5c605c
Fix review comments
MBkkt Jun 3, 2024
6580ac5
English grammar
azevaykin Jun 3, 2024
8065fae
Canonization
azevaykin Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions ydb/library/yql/udfs/common/knn/knn-defines.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once

#include "util/system/types.h"

enum EFormat: ui8 {
FloatVector = 1, // 4-byte per element
Uint8Vector = 2, // 1-byte per element, better than Int8 for positive-only Float
Int8Vector = 3, // 1-byte per element
BitVector = 10, // 1-bit per element
};

template <typename T>
struct TTypeToFormat;

template <>
struct TTypeToFormat<float> {
static constexpr auto Format = EFormat::FloatVector;
};

template <>
struct TTypeToFormat<i8> {
static constexpr auto Format = EFormat::Int8Vector;
};

template <>
struct TTypeToFormat<ui8> {
static constexpr auto Format = EFormat::Uint8Vector;
};

template <>
struct TTypeToFormat<bool> {
static constexpr auto Format = EFormat::BitVector;
};

template <typename T>
inline constexpr auto Format = TTypeToFormat<T>::Format;
inline constexpr auto HeaderLen = sizeof(ui8);
234 changes: 234 additions & 0 deletions ydb/library/yql/udfs/common/knn/knn-distance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
#pragma once

#include "knn-defines.h"
#include "knn-serializer.h"

#include <ydb/library/yql/public/udf/udf_helpers.h>

#include <library/cpp/dot_product/dot_product.h>
#include <library/cpp/l1_distance/l1_distance.h>
#include <library/cpp/l2_distance/l2_distance.h>
#include <util/generic/array_ref.h>
#include <util/generic/buffer.h>
#include <util/stream/format.h>

#include <bit>

using namespace NYql;
using namespace NYql::NUdf;

inline void BitVectorHandleShort(ui64 byteLen, const ui64* v1, const ui64* v2, auto&& op) {
Y_ASSERT(0 < byteLen);
Y_ASSERT(byteLen < sizeof(ui64));
ui64 d1 = 0;
ui64 d2 = 0;
// TODO manual switch for [1..7]?
std::memcpy(&d1, v1, byteLen);
std::memcpy(&d2, v2, byteLen);
op(d1, d2);
}

inline void BitVectorHandleTail(ui64 byteLen, const ui64* v1, const ui64* v2, auto&& op) {
if (Y_LIKELY(byteLen == 0)) // fast-path for aligned case
return;
Y_ASSERT(byteLen < sizeof(ui64));
const auto unneededBytes = sizeof(ui64) - byteLen;
const auto* r1 = reinterpret_cast<const char*>(v1) - unneededBytes;
const auto* r2 = reinterpret_cast<const char*>(v2) - unneededBytes;
ui64 d1, d2; // unaligned loads
std::memcpy(&d1, r1, sizeof(ui64));
std::memcpy(&d2, r2, sizeof(ui64));
ui64 mask = 0;
// big endian: 0 1 2 3 4 5 6 7 | 0 1 2 3 | 0 1 | 0 | 0 => needs to zero high bits
// little endian: 7 6 5 4 3 2 1 0 | 3 2 1 0 | 1 0 | 0 | 0 => needs to zero low bits
if constexpr (std::endian::native == std::endian::big) {
mask = (ui64{1} << (byteLen * 8)) - 1;
} else {
mask = ~((ui64{1} << (unneededBytes * 8)) - 1);
}
op(d1 & mask, d2 & mask);
}

inline void BitVectorHandleOp(ui64 bitLen, const ui64* v1, const ui64* v2, auto&& op) {
if (Y_UNLIKELY(bitLen == 0))
return;
auto byteLen = (bitLen + 7) / 8;
const auto wordLen = byteLen / sizeof(ui64);
if (Y_LIKELY(wordLen == 0)) // fast-path for short case
return BitVectorHandleShort(byteLen, v1, v2, op);
byteLen %= sizeof(ui64);
for (const auto* end = v1 + wordLen; v1 != end; ++v1, ++v2) {
op(*v1, *v2);
}
BitVectorHandleTail(byteLen, v1, v2, op);
}

using TDistanceResult = std::optional<float>;

template <typename Func>
inline TDistanceResult VectorFuncImpl(const auto* v1, const auto* v2, auto len1, auto len2, Func&& func) {
if (Y_UNLIKELY(len1 != len2))
return {};
return {func(v1, v2, len1)};
}

template <typename T, typename Func>
inline auto VectorFunc(const TStringRef& str1, const TStringRef& str2, Func&& func) {
const TArrayRef<const T> v1 = TKnnVectorSerializer<T>::GetArray(str1);
const TArrayRef<const T> v2 = TKnnVectorSerializer<T>::GetArray(str2);
return VectorFuncImpl(v1.data(), v2.data(), v1.size(), v2.size(), std::forward<Func>(func));
}

template <typename Func>
inline auto BitVectorFunc(const TStringRef& str1, const TStringRef& str2, Func&& func) {
auto [v1, bitLen1] = TKnnSerializerFacade::GetBitArray(str1);
auto [v2, bitLen2] = TKnnSerializerFacade::GetBitArray(str2);
return VectorFuncImpl(v1, v2, bitLen1, bitLen2, std::forward<Func>(func));
}

inline TDistanceResult KnnManhattanDistance(const TStringRef& str1, const TStringRef& str2) {
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
if (Y_UNLIKELY(format1 != format2))
return {};

switch (format1) {
case EFormat::FloatVector:
return VectorFunc<float>(str1, str2, [](const float* v1, const float* v2, size_t len) {
return ::L1Distance(v1, v2, len);
});
case EFormat::Int8Vector:
return VectorFunc<i8>(str1, str2, [](const i8* v1, const i8* v2, size_t len) {
return ::L1Distance(v1, v2, len);
});
case EFormat::Uint8Vector:
return VectorFunc<ui8>(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) {
return ::L1Distance(v1, v2, len);
});
case EFormat::BitVector:
return BitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 bitLen) {
ui64 ret = 0;
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
ret += std::popcount(d1 ^ d2);
});
return ret;
});
default:
return {};
}
}

inline TDistanceResult KnnEuclideanDistance(const TStringRef& str1, const TStringRef& str2) {
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
if (Y_UNLIKELY(format1 != format2))
return {};

switch (format1) {
case EFormat::FloatVector:
return VectorFunc<float>(str1, str2, [](const float* v1, const float* v2, size_t len) {
return ::L2Distance(v1, v2, len);
});
case EFormat::Int8Vector:
return VectorFunc<i8>(str1, str2, [](const i8* v1, const i8* v2, size_t len) {
return ::L2Distance(v1, v2, len);
});
case EFormat::Uint8Vector:
return VectorFunc<ui8>(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) {
return ::L2Distance(v1, v2, len);
});
case EFormat::BitVector:
return BitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 bitLen) {
ui64 ret = 0;
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
ret += std::popcount(d1 ^ d2);
});
return NPrivate::NL2Distance::L2DistanceSqrt(ret);
});
default:
return {};
}
}

inline TDistanceResult KnnDotProduct(const TStringRef& str1, const TStringRef& str2) {
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
if (Y_UNLIKELY(format1 != format2))
return {};

switch (format1) {
case EFormat::FloatVector:
return VectorFunc<float>(str1, str2, [](const float* v1, const float* v2, size_t len) {
return ::DotProduct(v1, v2, len);
});
case EFormat::Int8Vector:
return VectorFunc<i8>(str1, str2, [](const i8* v1, const i8* v2, size_t len) {
return ::DotProduct(v1, v2, len);
});
case EFormat::Uint8Vector:
return VectorFunc<ui8>(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) {
return ::DotProduct(v1, v2, len);
});
case EFormat::BitVector:
return BitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 bitLen) {
ui64 ret = 0;
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
ret += std::popcount(d1 & d2);
});
return ret;
});
default:
return {};
}
}

inline TDistanceResult KnnCosineSimilarity(const TStringRef& str1, const TStringRef& str2) {
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
if (Y_UNLIKELY(format1 != format2))
return {};

auto compute = [](auto ll, float lr, auto rr) {
const float norm = std::sqrt(ll * rr);
const float cosine = norm != 0 ? lr / norm : 1;
return cosine;
};

switch (format1) {
case EFormat::FloatVector:
return VectorFunc<float>(str1, str2, [&](const float* v1, const float* v2, size_t len) {
const auto res = ::TriWayDotProduct(v1, v2, len);
return compute(res.LL, res.LR, res.RR);
});
case EFormat::Int8Vector:
return VectorFunc<i8>(str1, str2, [&](const i8* v1, const i8* v2, size_t len) {
// TODO We can optimize it if we will iterate over both vector at the same time, look to the float implementation
const i64 ll = ::DotProduct(v1, v1, len);
const i64 lr = ::DotProduct(v1, v2, len);
const i64 rr = ::DotProduct(v2, v2, len);
return compute(ll, lr, rr);
});
case EFormat::Uint8Vector:
return VectorFunc<ui8>(str1, str2, [&](const ui8* v1, const ui8* v2, size_t len) {
// TODO We can optimize it if we will iterate over both vector at the same time, look to the float implementation
const ui64 ll = ::DotProduct(v1, v1, len);
const ui64 lr = ::DotProduct(v1, v2, len);
const ui64 rr = ::DotProduct(v2, v2, len);
return compute(ll, lr, rr);
});
case EFormat::BitVector:
return BitVectorFunc(str1, str2, [&](const ui64* v1, const ui64* v2, ui64 bitLen) {
ui64 ll = 0;
ui64 rr = 0;
ui64 lr = 0;
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
ll += std::popcount(d1);
rr += std::popcount(d2);
lr += std::popcount(d1 & d2);
});
return compute(ll, lr, rr);
});
default:
return {};
}
}
84 changes: 6 additions & 78 deletions ydb/library/yql/udfs/common/knn/knn-enumerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,95 +2,23 @@

#include <ydb/library/yql/public/udf/udf_helpers.h>

#include <util/generic/buffer.h>
#include <util/stream/format.h>
#include <util/generic/array_ref.h>

using namespace NYql;
using namespace NYql::NUdf;

template <typename TCallback>
template <typename T, typename TCallback>
void EnumerateVector(const TUnboxedValuePod vector, TCallback&& callback) {
const auto elements = vector.GetElements();
const auto* elements = vector.GetElements();
if (elements) {
const auto size = vector.GetListLength();

for (ui32 i = 0; i < size; ++i) {
callback(elements[i].Get<float>());
for (auto& value : TArrayRef{elements, vector.GetListLength()}) {
callback(value.Get<T>());
}
} else {
TUnboxedValue value;
const auto it = vector.GetListIterator();
while (it.Next(value)) {
callback(value.Get<float>());
callback(value.Get<T>());
}
}
}

template <typename TCallback>
bool EnumerateVectors(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2, TCallback&& callback) {

auto enumerateBothSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2, const TUnboxedValue* elements2) {
const auto size1 = vector1.GetListLength();
const auto size2 = vector2.GetListLength();

// Length mismatch
if (size1 != size2)
return false;

for (ui32 i = 0; i < size1; ++i) {
callback(elements1[i].Get<float>(), elements2[i].Get<float>());
}

return true;
};

auto enumerateOneSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2) {
const auto size = vector1.GetListLength();
ui32 idx = 0;
TUnboxedValue value;
const auto it = vector2.GetListIterator();

while (it.Next(value)) {
callback(elements1[idx++].Get<float>(), value.Get<float>());
}

// Length mismatch
if (it.Next(value) || idx != size)
return false;

return true;
};

auto enumerateNoSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
TUnboxedValue value1, value2;
const auto it1 = vector1.GetListIterator();
const auto it2 = vector2.GetListIterator();
for (; it1.Next(value1) && it2.Next(value2);) {
callback(value1.Get<float>(), value2.Get<float>());
}

// Length mismatch
if (it1.Next(value1) || it2.Next(value2))
return false;

return true;
};

const auto elements1 = vector1.GetElements();
const auto elements2 = vector2.GetElements();
if (elements1 && elements2) {
if (!enumerateBothSized(vector1, elements1, vector2, elements2))
return false;
} else if (elements1) {
if (!enumerateOneSized(vector1, elements1, vector2))
return false;
} else if (elements2) {
if (!enumerateOneSized(vector2, elements2, vector1))
return false;
} else {
if (!enumerateNoSized(vector1, vector2))
return false;
}

return true;
}
Loading