Skip to content

Commit 03ecb11

Browse files
azevaykinMBkkt
andauthored
New serialization formats in Knn UDF (#4445)
Co-authored-by: Valerii Mironov <mbkkt@ydb.tech>
1 parent 0b60307 commit 03ecb11

File tree

48 files changed

+20870
-292
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+20870
-292
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#include "util/system/types.h"
4+
5+
enum EFormat: ui8 {
6+
FloatVector = 1, // 4-byte per element
7+
Uint8Vector = 2, // 1-byte per element, better than Int8 for positive-only Float
8+
Int8Vector = 3, // 1-byte per element
9+
BitVector = 10, // 1-bit per element
10+
};
11+
12+
template <typename T>
13+
struct TTypeToFormat;
14+
15+
template <>
16+
struct TTypeToFormat<float> {
17+
static constexpr auto Format = EFormat::FloatVector;
18+
};
19+
20+
template <>
21+
struct TTypeToFormat<i8> {
22+
static constexpr auto Format = EFormat::Int8Vector;
23+
};
24+
25+
template <>
26+
struct TTypeToFormat<ui8> {
27+
static constexpr auto Format = EFormat::Uint8Vector;
28+
};
29+
30+
template <>
31+
struct TTypeToFormat<bool> {
32+
static constexpr auto Format = EFormat::BitVector;
33+
};
34+
35+
template <typename T>
36+
inline constexpr auto Format = TTypeToFormat<T>::Format;
37+
inline constexpr auto HeaderLen = sizeof(ui8);
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
#pragma once
2+
3+
#include "knn-defines.h"
4+
#include "knn-serializer.h"
5+
6+
#include <ydb/library/yql/public/udf/udf_helpers.h>
7+
8+
#include <library/cpp/dot_product/dot_product.h>
9+
#include <library/cpp/l1_distance/l1_distance.h>
10+
#include <library/cpp/l2_distance/l2_distance.h>
11+
#include <util/generic/array_ref.h>
12+
#include <util/generic/buffer.h>
13+
#include <util/stream/format.h>
14+
15+
#include <bit>
16+
17+
using namespace NYql;
18+
using namespace NYql::NUdf;
19+
20+
inline void BitVectorHandleShort(ui64 byteLen, const ui64* v1, const ui64* v2, auto&& op) {
21+
Y_ASSERT(0 < byteLen);
22+
Y_ASSERT(byteLen < sizeof(ui64));
23+
ui64 d1 = 0;
24+
ui64 d2 = 0;
25+
// TODO manual switch for [1..7]?
26+
std::memcpy(&d1, v1, byteLen);
27+
std::memcpy(&d2, v2, byteLen);
28+
op(d1, d2);
29+
}
30+
31+
inline void BitVectorHandleTail(ui64 byteLen, const ui64* v1, const ui64* v2, auto&& op) {
32+
if (Y_LIKELY(byteLen == 0)) // fast-path for aligned case
33+
return;
34+
Y_ASSERT(byteLen < sizeof(ui64));
35+
const auto unneededBytes = sizeof(ui64) - byteLen;
36+
const auto* r1 = reinterpret_cast<const char*>(v1) - unneededBytes;
37+
const auto* r2 = reinterpret_cast<const char*>(v2) - unneededBytes;
38+
ui64 d1, d2; // unaligned loads
39+
std::memcpy(&d1, r1, sizeof(ui64));
40+
std::memcpy(&d2, r2, sizeof(ui64));
41+
ui64 mask = 0;
42+
// big endian: 0 1 2 3 4 5 6 7 | 0 1 2 3 | 0 1 | 0 | 0 => needs to zero high bits
43+
// little endian: 7 6 5 4 3 2 1 0 | 3 2 1 0 | 1 0 | 0 | 0 => needs to zero low bits
44+
if constexpr (std::endian::native == std::endian::big) {
45+
mask = (ui64{1} << (byteLen * 8)) - 1;
46+
} else {
47+
mask = ~((ui64{1} << (unneededBytes * 8)) - 1);
48+
}
49+
op(d1 & mask, d2 & mask);
50+
}
51+
52+
inline void BitVectorHandleOp(ui64 bitLen, const ui64* v1, const ui64* v2, auto&& op) {
53+
if (Y_UNLIKELY(bitLen == 0))
54+
return;
55+
auto byteLen = (bitLen + 7) / 8;
56+
const auto wordLen = byteLen / sizeof(ui64);
57+
if (Y_LIKELY(wordLen == 0)) // fast-path for short case
58+
return BitVectorHandleShort(byteLen, v1, v2, op);
59+
byteLen %= sizeof(ui64);
60+
for (const auto* end = v1 + wordLen; v1 != end; ++v1, ++v2) {
61+
op(*v1, *v2);
62+
}
63+
BitVectorHandleTail(byteLen, v1, v2, op);
64+
}
65+
66+
using TDistanceResult = std::optional<float>;
67+
68+
template <typename Func>
69+
inline TDistanceResult VectorFuncImpl(const auto* v1, const auto* v2, auto len1, auto len2, Func&& func) {
70+
if (Y_UNLIKELY(len1 != len2))
71+
return {};
72+
return {func(v1, v2, len1)};
73+
}
74+
75+
template <typename T, typename Func>
76+
inline auto VectorFunc(const TStringRef& str1, const TStringRef& str2, Func&& func) {
77+
const TArrayRef<const T> v1 = TKnnVectorSerializer<T>::GetArray(str1);
78+
const TArrayRef<const T> v2 = TKnnVectorSerializer<T>::GetArray(str2);
79+
return VectorFuncImpl(v1.data(), v2.data(), v1.size(), v2.size(), std::forward<Func>(func));
80+
}
81+
82+
template <typename Func>
83+
inline auto BitVectorFunc(const TStringRef& str1, const TStringRef& str2, Func&& func) {
84+
auto [v1, bitLen1] = TKnnSerializerFacade::GetBitArray(str1);
85+
auto [v2, bitLen2] = TKnnSerializerFacade::GetBitArray(str2);
86+
return VectorFuncImpl(v1, v2, bitLen1, bitLen2, std::forward<Func>(func));
87+
}
88+
89+
inline TDistanceResult KnnManhattanDistance(const TStringRef& str1, const TStringRef& str2) {
90+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
91+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
92+
if (Y_UNLIKELY(format1 != format2))
93+
return {};
94+
95+
switch (format1) {
96+
case EFormat::FloatVector:
97+
return VectorFunc<float>(str1, str2, [](const float* v1, const float* v2, size_t len) {
98+
return ::L1Distance(v1, v2, len);
99+
});
100+
case EFormat::Int8Vector:
101+
return VectorFunc<i8>(str1, str2, [](const i8* v1, const i8* v2, size_t len) {
102+
return ::L1Distance(v1, v2, len);
103+
});
104+
case EFormat::Uint8Vector:
105+
return VectorFunc<ui8>(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) {
106+
return ::L1Distance(v1, v2, len);
107+
});
108+
case EFormat::BitVector:
109+
return BitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 bitLen) {
110+
ui64 ret = 0;
111+
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
112+
ret += std::popcount(d1 ^ d2);
113+
});
114+
return ret;
115+
});
116+
default:
117+
return {};
118+
}
119+
}
120+
121+
inline TDistanceResult KnnEuclideanDistance(const TStringRef& str1, const TStringRef& str2) {
122+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
123+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
124+
if (Y_UNLIKELY(format1 != format2))
125+
return {};
126+
127+
switch (format1) {
128+
case EFormat::FloatVector:
129+
return VectorFunc<float>(str1, str2, [](const float* v1, const float* v2, size_t len) {
130+
return ::L2Distance(v1, v2, len);
131+
});
132+
case EFormat::Int8Vector:
133+
return VectorFunc<i8>(str1, str2, [](const i8* v1, const i8* v2, size_t len) {
134+
return ::L2Distance(v1, v2, len);
135+
});
136+
case EFormat::Uint8Vector:
137+
return VectorFunc<ui8>(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) {
138+
return ::L2Distance(v1, v2, len);
139+
});
140+
case EFormat::BitVector:
141+
return BitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 bitLen) {
142+
ui64 ret = 0;
143+
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
144+
ret += std::popcount(d1 ^ d2);
145+
});
146+
return NPrivate::NL2Distance::L2DistanceSqrt(ret);
147+
});
148+
default:
149+
return {};
150+
}
151+
}
152+
153+
inline TDistanceResult KnnDotProduct(const TStringRef& str1, const TStringRef& str2) {
154+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
155+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
156+
if (Y_UNLIKELY(format1 != format2))
157+
return {};
158+
159+
switch (format1) {
160+
case EFormat::FloatVector:
161+
return VectorFunc<float>(str1, str2, [](const float* v1, const float* v2, size_t len) {
162+
return ::DotProduct(v1, v2, len);
163+
});
164+
case EFormat::Int8Vector:
165+
return VectorFunc<i8>(str1, str2, [](const i8* v1, const i8* v2, size_t len) {
166+
return ::DotProduct(v1, v2, len);
167+
});
168+
case EFormat::Uint8Vector:
169+
return VectorFunc<ui8>(str1, str2, [](const ui8* v1, const ui8* v2, size_t len) {
170+
return ::DotProduct(v1, v2, len);
171+
});
172+
case EFormat::BitVector:
173+
return BitVectorFunc(str1, str2, [](const ui64* v1, const ui64* v2, ui64 bitLen) {
174+
ui64 ret = 0;
175+
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
176+
ret += std::popcount(d1 & d2);
177+
});
178+
return ret;
179+
});
180+
default:
181+
return {};
182+
}
183+
}
184+
185+
inline TDistanceResult KnnCosineSimilarity(const TStringRef& str1, const TStringRef& str2) {
186+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
187+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
188+
if (Y_UNLIKELY(format1 != format2))
189+
return {};
190+
191+
auto compute = [](auto ll, float lr, auto rr) {
192+
const float norm = std::sqrt(ll * rr);
193+
const float cosine = norm != 0 ? lr / norm : 1;
194+
return cosine;
195+
};
196+
197+
switch (format1) {
198+
case EFormat::FloatVector:
199+
return VectorFunc<float>(str1, str2, [&](const float* v1, const float* v2, size_t len) {
200+
const auto res = ::TriWayDotProduct(v1, v2, len);
201+
return compute(res.LL, res.LR, res.RR);
202+
});
203+
case EFormat::Int8Vector:
204+
return VectorFunc<i8>(str1, str2, [&](const i8* v1, const i8* v2, size_t len) {
205+
// TODO We can optimize it if we will iterate over both vector at the same time, look to the float implementation
206+
const i64 ll = ::DotProduct(v1, v1, len);
207+
const i64 lr = ::DotProduct(v1, v2, len);
208+
const i64 rr = ::DotProduct(v2, v2, len);
209+
return compute(ll, lr, rr);
210+
});
211+
case EFormat::Uint8Vector:
212+
return VectorFunc<ui8>(str1, str2, [&](const ui8* v1, const ui8* v2, size_t len) {
213+
// TODO We can optimize it if we will iterate over both vector at the same time, look to the float implementation
214+
const ui64 ll = ::DotProduct(v1, v1, len);
215+
const ui64 lr = ::DotProduct(v1, v2, len);
216+
const ui64 rr = ::DotProduct(v2, v2, len);
217+
return compute(ll, lr, rr);
218+
});
219+
case EFormat::BitVector:
220+
return BitVectorFunc(str1, str2, [&](const ui64* v1, const ui64* v2, ui64 bitLen) {
221+
ui64 ll = 0;
222+
ui64 rr = 0;
223+
ui64 lr = 0;
224+
BitVectorHandleOp(bitLen, v1, v2, [&](ui64 d1, ui64 d2) {
225+
ll += std::popcount(d1);
226+
rr += std::popcount(d2);
227+
lr += std::popcount(d1 & d2);
228+
});
229+
return compute(ll, lr, rr);
230+
});
231+
default:
232+
return {};
233+
}
234+
}

ydb/library/yql/udfs/common/knn/knn-enumerator.h

Lines changed: 6 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,95 +2,23 @@
22

33
#include <ydb/library/yql/public/udf/udf_helpers.h>
44

5-
#include <util/generic/buffer.h>
6-
#include <util/stream/format.h>
5+
#include <util/generic/array_ref.h>
76

87
using namespace NYql;
98
using namespace NYql::NUdf;
109

11-
template <typename TCallback>
10+
template <typename T, typename TCallback>
1211
void EnumerateVector(const TUnboxedValuePod vector, TCallback&& callback) {
13-
const auto elements = vector.GetElements();
12+
const auto* elements = vector.GetElements();
1413
if (elements) {
15-
const auto size = vector.GetListLength();
16-
17-
for (ui32 i = 0; i < size; ++i) {
18-
callback(elements[i].Get<float>());
14+
for (auto& value : TArrayRef{elements, vector.GetListLength()}) {
15+
callback(value.Get<T>());
1916
}
2017
} else {
2118
TUnboxedValue value;
2219
const auto it = vector.GetListIterator();
2320
while (it.Next(value)) {
24-
callback(value.Get<float>());
21+
callback(value.Get<T>());
2522
}
2623
}
2724
}
28-
29-
template <typename TCallback>
30-
bool EnumerateVectors(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2, TCallback&& callback) {
31-
32-
auto enumerateBothSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2, const TUnboxedValue* elements2) {
33-
const auto size1 = vector1.GetListLength();
34-
const auto size2 = vector2.GetListLength();
35-
36-
// Length mismatch
37-
if (size1 != size2)
38-
return false;
39-
40-
for (ui32 i = 0; i < size1; ++i) {
41-
callback(elements1[i].Get<float>(), elements2[i].Get<float>());
42-
}
43-
44-
return true;
45-
};
46-
47-
auto enumerateOneSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2) {
48-
const auto size = vector1.GetListLength();
49-
ui32 idx = 0;
50-
TUnboxedValue value;
51-
const auto it = vector2.GetListIterator();
52-
53-
while (it.Next(value)) {
54-
callback(elements1[idx++].Get<float>(), value.Get<float>());
55-
}
56-
57-
// Length mismatch
58-
if (it.Next(value) || idx != size)
59-
return false;
60-
61-
return true;
62-
};
63-
64-
auto enumerateNoSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
65-
TUnboxedValue value1, value2;
66-
const auto it1 = vector1.GetListIterator();
67-
const auto it2 = vector2.GetListIterator();
68-
for (; it1.Next(value1) && it2.Next(value2);) {
69-
callback(value1.Get<float>(), value2.Get<float>());
70-
}
71-
72-
// Length mismatch
73-
if (it1.Next(value1) || it2.Next(value2))
74-
return false;
75-
76-
return true;
77-
};
78-
79-
const auto elements1 = vector1.GetElements();
80-
const auto elements2 = vector2.GetElements();
81-
if (elements1 && elements2) {
82-
if (!enumerateBothSized(vector1, elements1, vector2, elements2))
83-
return false;
84-
} else if (elements1) {
85-
if (!enumerateOneSized(vector1, elements1, vector2))
86-
return false;
87-
} else if (elements2) {
88-
if (!enumerateOneSized(vector2, elements2, vector1))
89-
return false;
90-
} else {
91-
if (!enumerateNoSized(vector1, vector2))
92-
return false;
93-
}
94-
95-
return true;
96-
}

0 commit comments

Comments
 (0)