Skip to content

Commit c26cfea

Browse files
committed
Euclidean distance
1 parent 3f2b174 commit c26cfea

File tree

6 files changed

+269
-1
lines changed

6 files changed

+269
-1
lines changed

ydb/library/yql/udfs/common/knn/knn-distance.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <library/cpp/dot_product/dot_product.h>
99
#include <library/cpp/l1_distance/l1_distance.h>
10+
#include <library/cpp/l2_distance/l2_distance.h>
1011
#include <util/generic/array_ref.h>
1112
#include <util/generic/buffer.h>
1213
#include <util/stream/format.h>
@@ -57,6 +58,49 @@ static std::optional<float> KnnManhattanDistance(const TStringRef& str1, const T
5758
}
5859
}
5960

61+
static std::optional<float> KnnEuclideanDistance(const TStringRef& str1, const TStringRef& str2) {
62+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
63+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
64+
65+
if (Y_UNLIKELY(format1 != format2))
66+
return {};
67+
68+
switch (format1) {
69+
case EFormat::FloatVector: {
70+
const TArrayRef<const float> vector1 = TKnnSerializerFacade::GetArray<float>(str1);
71+
const TArrayRef<const float> vector2 = TKnnSerializerFacade::GetArray<float>(str2);
72+
73+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
74+
return {};
75+
76+
return ::L2Distance(vector1.data(), vector2.data(), vector1.size());
77+
}
78+
case EFormat::ByteVector: {
79+
const TArrayRef<const ui8> vector1 = TKnnSerializerFacade::GetArray<ui8>(str1);
80+
const TArrayRef<const ui8> vector2 = TKnnSerializerFacade::GetArray<ui8>(str2);
81+
82+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
83+
return {};
84+
85+
return ::L2Distance(vector1.data(), vector2.data(), vector1.size());
86+
}
87+
case EFormat::BitVector: {
88+
const TArrayRef<const ui64> vector1 = TKnnBitVectorSerializer::GetArray64(str1);
89+
const TArrayRef<const ui64> vector2 = TKnnBitVectorSerializer::GetArray64(str2);
90+
91+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector1.size() > UINT16_MAX))
92+
return {};
93+
94+
ui16 ret = 0;
95+
for (size_t i = 0; i < vector1.size(); ++i)
96+
ret += __builtin_popcountll(vector1[i] ^ vector2[i]);
97+
return NPrivate::NL2Distance::L2DistanceSqrt<ui64>(ret);
98+
}
99+
default:
100+
return {};
101+
}
102+
}
103+
60104
static std::optional<float> KnnDotProduct(const TStringRef& str1, const TStringRef& str2) {
61105
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
62106
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];

ydb/library/yql/udfs/common/knn/knn.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,24 @@ SIMPLE_STRICT_UDF(TManhattanDistance, TOptional<float>(TAutoMap<const char*>, TA
8383
return TUnboxedValuePod{ret.value()};
8484
}
8585

86+
SIMPLE_STRICT_UDF(TEuclideanDistance, TOptional<float>(TAutoMap<const char*>, TAutoMap<const char*>)) {
87+
Y_UNUSED(valueBuilder);
88+
89+
const auto ret = KnnEuclideanDistance(args[0].AsStringRef(), args[1].AsStringRef());
90+
if (Y_UNLIKELY(!ret))
91+
return {};
92+
93+
return TUnboxedValuePod{ret.value()};
94+
}
95+
8696
SIMPLE_MODULE(TKnnModule,
8797
TFromBinaryString,
8898
TToBinaryString,
8999
TInnerProductSimilarity,
90100
TCosineSimilarity,
91101
TCosineDistance,
92-
TManhattanDistance
102+
TManhattanDistance,
103+
TEuclideanDistance
93104
)
94105

95106
REGISTER_MODULES(TKnnModule)

ydb/library/yql/udfs/common/knn/test/canondata/result.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
"uri": "file://test.test_DeserializationError_/results.txt"
1515
}
1616
],
17+
"test.test[EuclideanDistance]": [
18+
{
19+
"uri": "file://test.test_EuclideanDistance_/results.txt"
20+
}
21+
],
1722
"test.test[FloatByteSerialization]": [
1823
{
1924
"uri": "file://test.test_FloatByteSerialization_/results.txt"
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
[
2+
{
3+
"Write" = [
4+
{
5+
"Type" = [
6+
"ListType";
7+
[
8+
"StructType";
9+
[
10+
[
11+
"column0";
12+
[
13+
"OptionalType";
14+
[
15+
"DataType";
16+
"Float"
17+
]
18+
]
19+
]
20+
]
21+
]
22+
];
23+
"Data" = [
24+
[
25+
[
26+
"5.196152"
27+
]
28+
]
29+
]
30+
}
31+
]
32+
};
33+
{
34+
"Write" = [
35+
{
36+
"Type" = [
37+
"ListType";
38+
[
39+
"StructType";
40+
[
41+
[
42+
"column0";
43+
[
44+
"OptionalType";
45+
[
46+
"DataType";
47+
"Float"
48+
]
49+
]
50+
]
51+
]
52+
]
53+
];
54+
"Data" = [
55+
[
56+
[
57+
"5"
58+
]
59+
]
60+
]
61+
}
62+
]
63+
};
64+
{
65+
"Write" = [
66+
{
67+
"Type" = [
68+
"ListType";
69+
[
70+
"StructType";
71+
[
72+
[
73+
"column0";
74+
[
75+
"OptionalType";
76+
[
77+
"DataType";
78+
"Float"
79+
]
80+
]
81+
]
82+
]
83+
]
84+
];
85+
"Data" = [
86+
[
87+
#
88+
]
89+
]
90+
}
91+
]
92+
};
93+
{
94+
"Write" = [
95+
{
96+
"Type" = [
97+
"ListType";
98+
[
99+
"StructType";
100+
[
101+
[
102+
"column0";
103+
[
104+
"OptionalType";
105+
[
106+
"DataType";
107+
"Float"
108+
]
109+
]
110+
]
111+
]
112+
]
113+
];
114+
"Data" = [
115+
[
116+
[
117+
"8"
118+
]
119+
]
120+
]
121+
}
122+
]
123+
};
124+
{
125+
"Write" = [
126+
{
127+
"Type" = [
128+
"ListType";
129+
[
130+
"StructType";
131+
[
132+
[
133+
"column0";
134+
[
135+
"OptionalType";
136+
[
137+
"DataType";
138+
"Float"
139+
]
140+
]
141+
]
142+
]
143+
]
144+
];
145+
"Data" = [
146+
[
147+
[
148+
"8"
149+
]
150+
]
151+
]
152+
}
153+
]
154+
};
155+
{
156+
"Write" = [
157+
{
158+
"Type" = [
159+
"ListType";
160+
[
161+
"StructType";
162+
[
163+
[
164+
"column0";
165+
[
166+
"OptionalType";
167+
[
168+
"DataType";
169+
"Float"
170+
]
171+
]
172+
]
173+
]
174+
]
175+
];
176+
"Data" = [
177+
[
178+
[
179+
"5"
180+
]
181+
]
182+
]
183+
}
184+
]
185+
}
186+
]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
--float vector
2+
$float_vector1 = Knn::ToBinaryString([1.0f, 2.0f, 3.0f]);
3+
$float_vector2 = Knn::ToBinaryString([4.0f, 5.0f, 6.0f]);
4+
select Knn::EuclideanDistance($float_vector1, $float_vector2);
5+
6+
--byte vector
7+
$byte_vector1 = Knn::ToBinaryString([1.0f, 2.0f, 3.0f], "byte");
8+
$byte_vector2 = Knn::ToBinaryString([4.0f, 5.0f, 6.0f], "byte");
9+
select Knn::EuclideanDistance($byte_vector1, $byte_vector2);
10+
11+
--bit vector
12+
$bitvector_positive = Knn::ToBinaryString(ListReplicate(1.0f, 64), "bit");
13+
$bitvector_positive_double_size = Knn::ToBinaryString(ListReplicate(1.0f, 128), "bit");
14+
$bitvector_negative = Knn::ToBinaryString(ListReplicate(-1.0f, 64), "bit");
15+
$bitvector_negative_and_positive = Knn::ToBinaryString(ListFromRange(-63.0f, 64.1f), "bit");
16+
$bitvector_negative_and_positive_striped = Knn::ToBinaryString(ListFlatten(ListReplicate([-1.0f, 1.0f], 32)), "bit");
17+
18+
select Knn::EuclideanDistance($bitvector_positive, $bitvector_positive_double_size);
19+
select Knn::EuclideanDistance($bitvector_positive, $bitvector_negative);
20+
select Knn::EuclideanDistance($bitvector_positive_double_size, $bitvector_negative_and_positive);
21+
select Knn::EuclideanDistance($bitvector_positive, $bitvector_negative_and_positive_striped);

ydb/library/yql/udfs/common/knn/ya.make

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ SRCS(
1313
PEERDIR(
1414
library/cpp/dot_product
1515
library/cpp/l1_distance
16+
library/cpp/l2_distance
1617
)
1718

1819

0 commit comments

Comments
 (0)