Skip to content

Commit 9411e2f

Browse files
authored
Merge 530346d into a81cee8
2 parents a81cee8 + 530346d commit 9411e2f

File tree

18 files changed

+1005
-55
lines changed

18 files changed

+1005
-55
lines changed

packages.json

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[
2+
{
3+
"branch": "BitIndex",
4+
"package_name": "ydb",
5+
"package_path": "ydb/tools/ydbd_slice/image/pkg.json",
6+
"package_version": "azevaykin-latest",
7+
"revision": "6aecd9a9d6e9fab15f80e8340b1f0defc418750a",
8+
"sandbox_task_id": 0,
9+
"svn_revision": "-1",
10+
"package_full_name": "ydb.azevaykin-latest",
11+
"docker_image": "cr.yandex/crpbo4q9lbgkn85vr1rm/ydb:azevaykin-latest",
12+
"name": "ydb",
13+
"version": "azevaykin-latest",
14+
"path": "/home-big/azevaykin/github/ydb.azevaykin-latest.tar.gz",
15+
"debug_path": null
16+
}
17+
]

ydb.azevaykin-latest.tar.gz

125 Bytes
Binary file not shown.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include "util/system/types.h"
4+
5+
enum EFormat : ui8 {
6+
FloatVector = 1, // 4-byte per element
7+
ByteVector = 2, // 1-byte per element
8+
BitVector = 10 // 1-bit per element
9+
};
10+
11+
static constexpr size_t HeaderLen = sizeof(ui8);
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
#pragma once
2+
3+
#include "knn-defines.h"
4+
#include "knn-serializer.h"
5+
6+
#include <ydb/library/yql/public/udf/udf_helpers.h>
7+
8+
#include <library/cpp/dot_product/dot_product.h>
9+
#include <library/cpp/l1_distance/l1_distance.h>
10+
#include <library/cpp/l2_distance/l2_distance.h>
11+
#include <util/generic/array_ref.h>
12+
#include <util/generic/buffer.h>
13+
#include <util/stream/format.h>
14+
15+
using namespace NYql;
16+
using namespace NYql::NUdf;
17+
18+
static std::optional<float> KnnManhattanDistance(const TStringRef& str1, const TStringRef& str2) {
19+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
20+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
21+
22+
if (Y_UNLIKELY(format1 != format2))
23+
return {};
24+
25+
switch (format1) {
26+
case EFormat::FloatVector: {
27+
const TArrayRef<const float> vector1 = TKnnSerializerFacade::GetArray<float>(str1);
28+
const TArrayRef<const float> vector2 = TKnnSerializerFacade::GetArray<float>(str2);
29+
30+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
31+
return {};
32+
33+
return ::L1Distance(vector1.data(), vector2.data(), vector1.size());
34+
}
35+
case EFormat::ByteVector: {
36+
const TArrayRef<const ui8> vector1 = TKnnSerializerFacade::GetArray<ui8>(str1);
37+
const TArrayRef<const ui8> vector2 = TKnnSerializerFacade::GetArray<ui8>(str2);
38+
39+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
40+
return {};
41+
42+
return ::L1Distance(vector1.data(), vector2.data(), vector1.size());
43+
}
44+
case EFormat::BitVector: {
45+
const TArrayRef<const ui64> vector1 = TKnnBitVectorSerializer::GetArray64(str1);
46+
const TArrayRef<const ui64> vector2 = TKnnBitVectorSerializer::GetArray64(str2);
47+
48+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector1.size() > UINT16_MAX))
49+
return {};
50+
51+
ui64 ret = 0;
52+
for (size_t i = 0; i < vector1.size(); ++i)
53+
ret += __builtin_popcountll(vector1[i] ^ vector2[i]);
54+
return ret;
55+
}
56+
default:
57+
return {};
58+
}
59+
}
60+
61+
static std::optional<float> KnnEuclideanDistance(const TStringRef& str1, const TStringRef& str2) {
62+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
63+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
64+
65+
if (Y_UNLIKELY(format1 != format2))
66+
return {};
67+
68+
switch (format1) {
69+
case EFormat::FloatVector: {
70+
const TArrayRef<const float> vector1 = TKnnSerializerFacade::GetArray<float>(str1);
71+
const TArrayRef<const float> vector2 = TKnnSerializerFacade::GetArray<float>(str2);
72+
73+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
74+
return {};
75+
76+
return ::L2Distance(vector1.data(), vector2.data(), vector1.size());
77+
}
78+
case EFormat::ByteVector: {
79+
const TArrayRef<const ui8> vector1 = TKnnSerializerFacade::GetArray<ui8>(str1);
80+
const TArrayRef<const ui8> vector2 = TKnnSerializerFacade::GetArray<ui8>(str2);
81+
82+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
83+
return {};
84+
85+
return ::L2Distance(vector1.data(), vector2.data(), vector1.size());
86+
}
87+
case EFormat::BitVector: {
88+
const TArrayRef<const ui64> vector1 = TKnnBitVectorSerializer::GetArray64(str1);
89+
const TArrayRef<const ui64> vector2 = TKnnBitVectorSerializer::GetArray64(str2);
90+
91+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector1.size() > UINT16_MAX))
92+
return {};
93+
94+
ui64 ret = 0;
95+
for (size_t i = 0; i < vector1.size(); ++i)
96+
ret += __builtin_popcountll(vector1[i] ^ vector2[i]);
97+
return NPrivate::NL2Distance::L2DistanceSqrt(ret);
98+
}
99+
default:
100+
return {};
101+
}
102+
}
103+
104+
static std::optional<float> KnnDotProduct(const TStringRef& str1, const TStringRef& str2) {
105+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
106+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
107+
108+
if (Y_UNLIKELY(format1 != format2))
109+
return {};
110+
111+
switch (format1) {
112+
case EFormat::FloatVector: {
113+
const TArrayRef<const float> vector1 = TKnnSerializerFacade::GetArray<float>(str1);
114+
const TArrayRef<const float> vector2 = TKnnSerializerFacade::GetArray<float>(str2);
115+
116+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
117+
return {};
118+
119+
return ::DotProduct(vector1.data(), vector2.data(), vector1.size());
120+
}
121+
case EFormat::ByteVector: {
122+
const TArrayRef<const ui8> vector1 = TKnnSerializerFacade::GetArray<ui8>(str1);
123+
const TArrayRef<const ui8> vector2 = TKnnSerializerFacade::GetArray<ui8>(str2);
124+
125+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
126+
return {};
127+
128+
return ::DotProduct(vector1.data(), vector2.data(), vector1.size());
129+
}
130+
default:
131+
return {};
132+
}
133+
}
134+
135+
static std::optional<TTriWayDotProduct<float>> KnnTriWayDotProduct(const TStringRef& str1, const TStringRef& str2) {
136+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
137+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
138+
139+
if (Y_UNLIKELY(format1 != format2))
140+
return {};
141+
142+
switch (format1) {
143+
case EFormat::FloatVector: {
144+
const TArrayRef<const float> vector1 = TKnnSerializerFacade::GetArray<float>(str1);
145+
const TArrayRef<const float> vector2 = TKnnSerializerFacade::GetArray<float>(str2);
146+
147+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
148+
return {};
149+
150+
return ::TriWayDotProduct(vector1.data(), vector2.data(), vector1.size());
151+
}
152+
case EFormat::ByteVector: {
153+
const TArrayRef<const ui8> vector1 = TKnnSerializerFacade::GetArray<ui8>(str1);
154+
const TArrayRef<const ui8> vector2 = TKnnSerializerFacade::GetArray<ui8>(str2);
155+
156+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
157+
return {};
158+
159+
TTriWayDotProduct<float> result;
160+
result.LL = ::DotProduct(vector1.data(), vector1.data(), vector1.size());
161+
result.LR = ::DotProduct(vector1.data(), vector2.data(), vector1.size());
162+
result.RR = ::DotProduct(vector2.data(), vector2.data(), vector1.size());
163+
return result;
164+
}
165+
default:
166+
return {};
167+
}
168+
}

ydb/library/yql/udfs/common/knn/knn-serializer.h

Lines changed: 96 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include "knn-defines.h"
34
#include "knn-enumerator.h"
45

56
#include <ydb/library/yql/public/udf/udf_helpers.h>
@@ -11,23 +12,21 @@
1112
using namespace NYql;
1213
using namespace NYql::NUdf;
1314

14-
enum EFormat : ui8 {
15-
FloatVector = 1
16-
};
17-
18-
static constexpr size_t HeaderLen = sizeof(ui8);
19-
20-
class TFloatVectorSerializer {
15+
template<typename T, EFormat Format>
16+
class TKnnVectorSerializer {
2117
public:
2218
static TUnboxedValue Serialize(const IValueBuilder* valueBuilder, const TUnboxedValue x) {
2319
auto serialize = [&x] (IOutputStream& outStream) {
24-
EnumerateVector(x, [&outStream] (float element) { outStream.Write(&element, sizeof(float)); });
25-
const EFormat format = EFormat::FloatVector;
20+
EnumerateVector(x, [&outStream] (float floatElement) {
21+
T element = static_cast<T>(floatElement);
22+
outStream.Write(&element, sizeof(T));
23+
});
24+
const EFormat format = Format;
2625
outStream.Write(&format, HeaderLen);
2726
};
2827

2928
if (x.HasFastListLength()) {
30-
auto str = valueBuilder->NewStringNotFilled(HeaderLen + x.GetListLength() * sizeof(float));
29+
auto str = valueBuilder->NewStringNotFilled(HeaderLen + x.GetListLength() * sizeof(T));
3130
auto strRef = str.AsStringRef();
3231
TMemoryOutput memoryOutput(strRef.Data(), strRef.Size());
3332

@@ -46,45 +45,108 @@ class TFloatVectorSerializer {
4645
const char* buf = str.Data();
4746
const size_t len = str.Size() - HeaderLen;
4847

49-
if (len % sizeof(float) != 0)
48+
if (Y_UNLIKELY(len % sizeof(T) != 0))
5049
return {};
5150

52-
const ui32 count = len / sizeof(float);
51+
const ui32 count = len / sizeof(T);
5352

5453
TUnboxedValue* items = nullptr;
5554
auto res = valueBuilder->NewArray(count, items);
5655

5756
TMemoryInput inStr(buf, len);
5857
for (ui32 i = 0; i < count; ++i) {
59-
float element;
60-
if (inStr.Read(&element, sizeof(float)) != sizeof(float))
58+
T element;
59+
if (Y_UNLIKELY(inStr.Read(&element, sizeof(T)) != sizeof(T)))
6160
return {};
62-
*items++ = TUnboxedValuePod{element};
61+
*items++ = TUnboxedValuePod{static_cast<float>(element)};
6362
}
6463

6564
return res.Release();
6665
}
6766

68-
static const TArrayRef<const float> GetArray(const TStringRef& str) {
67+
static const TArrayRef<const T> GetArray(const TStringRef& str) {
6968
const char* buf = str.Data();
7069
const size_t len = str.Size() - HeaderLen;
7170

72-
if (len % sizeof(float) != 0)
71+
if (Y_UNLIKELY(len % sizeof(T) != 0))
7372
return {};
7473

75-
const ui32 count = len / sizeof(float);
74+
const ui32 count = len / sizeof(T);
7675

77-
return MakeArrayRef(reinterpret_cast<const float*>(buf), count);
76+
return MakeArrayRef(reinterpret_cast<const T*>(buf), count);
7877
}
7978
};
8079

80+
// Encode all positive floats as bit 1, negative floats as bit 0.
81+
// So 1024 float vector is serialized in 1024/8=128 bytes.
82+
// Place all bits in ui64. So, only vector sizes divisible by 64 are supported.
83+
class TKnnBitVectorSerializer {
84+
public:
85+
static TUnboxedValue Serialize(const IValueBuilder* valueBuilder, const TUnboxedValue x) {
86+
auto serialize = [&x] (IOutputStream& outStream) {
87+
ui64 accumulator = 0;
88+
ui8 filledBits = 0;
89+
90+
EnumerateVector(x, [&] (float element) {
91+
if (element > 0)
92+
accumulator |= 1ll << filledBits;
93+
94+
++filledBits;
95+
if (filledBits == 64) {
96+
outStream.Write(&accumulator, sizeof(ui64));
97+
accumulator = 0;
98+
filledBits = 0;
99+
}
100+
});
101+
102+
// only vector sizes divisible by 64 are supported
103+
if (Y_UNLIKELY(filledBits))
104+
return false;
105+
106+
const EFormat format = EFormat::BitVector;
107+
outStream.Write(&format, HeaderLen);
108+
109+
return true;
110+
};
111+
112+
if (x.HasFastListLength()) {
113+
auto str = valueBuilder->NewStringNotFilled(HeaderLen + x.GetListLength() / 8);
114+
auto strRef = str.AsStringRef();
115+
TMemoryOutput memoryOutput(strRef.Data(), strRef.Size());
81116

82-
class TSerializerFacade {
117+
if (Y_UNLIKELY(!serialize(memoryOutput)))
118+
return {};
119+
120+
return str;
121+
} else {
122+
TString str;
123+
TStringOutput stringOutput(str);
124+
125+
if (Y_UNLIKELY(!serialize(stringOutput)))
126+
return {};
127+
128+
return valueBuilder->NewString(str);
129+
}
130+
}
131+
132+
static const TArrayRef<const ui64> GetArray64(const TStringRef& str) {
133+
const char* buf = str.Data();
134+
const size_t len = (str.Size() - HeaderLen) / sizeof(ui64);
135+
136+
return MakeArrayRef(reinterpret_cast<const ui64*>(buf), len);
137+
}
138+
};
139+
140+
class TKnnSerializerFacade {
83141
public:
84142
static TUnboxedValue Serialize(EFormat format, const IValueBuilder* valueBuilder, const TUnboxedValue x) {
85143
switch (format) {
86144
case EFormat::FloatVector:
87-
return TFloatVectorSerializer::Serialize(valueBuilder, x);
145+
return TKnnVectorSerializer<float, EFormat::FloatVector>::Serialize(valueBuilder, x);
146+
case EFormat::ByteVector:
147+
return TKnnVectorSerializer<ui8, EFormat::ByteVector>::Serialize(valueBuilder, x);
148+
case EFormat::BitVector:
149+
return TKnnBitVectorSerializer::Serialize(valueBuilder, x);
88150
default:
89151
return {};
90152
}
@@ -97,20 +159,29 @@ class TSerializerFacade {
97159
const ui8 format = str.Data()[str.Size() - HeaderLen];
98160
switch (format) {
99161
case EFormat::FloatVector:
100-
return TFloatVectorSerializer::Deserialize(valueBuilder, str);
162+
return TKnnVectorSerializer<float, EFormat::FloatVector>::Deserialize(valueBuilder, str);
163+
case EFormat::ByteVector:
164+
return TKnnVectorSerializer<ui8, EFormat::ByteVector>::Deserialize(valueBuilder, str);
165+
case EFormat::BitVector:
166+
return {};
101167
default:
102168
return {};
103169
}
104170
}
105171

106-
static const TArrayRef<const float> GetArray(const TStringRef& str) {
107-
if (str.Size() == 0)
172+
template<typename T>
173+
static const TArrayRef<const T> GetArray(const TStringRef& str) {
174+
if (Y_UNLIKELY(str.Size() == 0))
108175
return {};
109176

110177
const ui8 format = str.Data()[str.Size() - HeaderLen];
111178
switch (format) {
112179
case EFormat::FloatVector:
113-
return TFloatVectorSerializer::GetArray(str);
180+
return TKnnVectorSerializer<T, EFormat::FloatVector>::GetArray(str);
181+
case EFormat::ByteVector:
182+
return TKnnVectorSerializer<T, EFormat::ByteVector>::GetArray(str);
183+
case EFormat::BitVector:
184+
return {};
114185
default:
115186
return {};
116187
}

0 commit comments

Comments
 (0)