11#pragma once
22
3+ #include " knn-defines.h"
34#include " knn-enumerator.h"
45
56#include < ydb/library/yql/public/udf/udf_helpers.h>
1112using namespace NYql ;
1213using namespace NYql ::NUdf;
1314
14- enum EFormat : ui8 {
15- FloatVector = 1
16- };
17-
18- static constexpr size_t HeaderLen = sizeof (ui8);
19-
20- class TFloatVectorSerializer {
15+ template <typename T, EFormat Format>
16+ class TKnnVectorSerializer {
2117public:
2218 static TUnboxedValue Serialize (const IValueBuilder* valueBuilder, const TUnboxedValue x) {
2319 auto serialize = [&x] (IOutputStream& outStream) {
24- EnumerateVector (x, [&outStream] (float element) { outStream.Write (&element, sizeof (float )); });
25- const EFormat format = EFormat::FloatVector;
20+ EnumerateVector (x, [&outStream] (float floatElement) {
21+ T element = static_cast <T>(floatElement);
22+ outStream.Write (&element, sizeof (T));
23+ });
24+ const EFormat format = Format;
2625 outStream.Write (&format, HeaderLen);
2726 };
2827
2928 if (x.HasFastListLength ()) {
30- auto str = valueBuilder->NewStringNotFilled (HeaderLen + x.GetListLength () * sizeof (float ));
29+ auto str = valueBuilder->NewStringNotFilled (HeaderLen + x.GetListLength () * sizeof (T ));
3130 auto strRef = str.AsStringRef ();
3231 TMemoryOutput memoryOutput (strRef.Data (), strRef.Size ());
3332
@@ -46,45 +45,115 @@ class TFloatVectorSerializer {
4645 const char * buf = str.Data ();
4746 const size_t len = str.Size () - HeaderLen;
4847
49- if (len % sizeof (float ) != 0 )
48+ if (Y_UNLIKELY ( len % sizeof (T ) != 0 ))
5049 return {};
5150
52- const ui32 count = len / sizeof (float );
51+ const ui32 count = len / sizeof (T );
5352
5453 TUnboxedValue* items = nullptr ;
5554 auto res = valueBuilder->NewArray (count, items);
5655
5756 TMemoryInput inStr (buf, len);
5857 for (ui32 i = 0 ; i < count; ++i) {
59- float element;
60- if (inStr.Read (&element, sizeof (float )) != sizeof (float ))
58+ T element;
59+ if (Y_UNLIKELY ( inStr.Read (&element, sizeof (T )) != sizeof (T) ))
6160 return {};
62- *items++ = TUnboxedValuePod{element};
61+ *items++ = TUnboxedValuePod{static_cast < float >( element) };
6362 }
6463
6564 return res.Release ();
6665 }
6766
68- static const TArrayRef<const float > GetArray (const TStringRef& str) {
67+ static const TArrayRef<const T > GetArray (const TStringRef& str) {
6968 const char * buf = str.Data ();
7069 const size_t len = str.Size () - HeaderLen;
7170
72- if (len % sizeof (float ) != 0 )
71+ if (Y_UNLIKELY ( len % sizeof (T ) != 0 ))
7372 return {};
7473
75- const ui32 count = len / sizeof (float );
74+ const ui32 count = len / sizeof (T );
7675
77- return MakeArrayRef (reinterpret_cast <const float *>(buf), count);
76+ return MakeArrayRef (reinterpret_cast <const T *>(buf), count);
7877 }
7978};
8079
80+ // Encode all positive floats as bit 1, negative floats as bit 0.
81+ // So 1024 float vector is serialized in 1024/8=128 bytes.
82+ // Place all bits in ui64. So, only vector sizes divisible by 64 are supported.
83+ // Max vector lenght is 32767.
84+ class TKnnBitVectorSerializer {
85+ public:
86+ static TUnboxedValue Serialize (const IValueBuilder* valueBuilder, const TUnboxedValue x) {
87+ auto serialize = [&x] (IOutputStream& outStream) {
88+ ui64 accumulator = 0 ;
89+ ui8 filledBits = 0 ;
90+ ui64 lenght = 0 ;
91+
92+ EnumerateVector (x, [&] (float element) {
93+ if (element > 0 )
94+ accumulator |= 1ll << filledBits;
95+
96+ ++filledBits;
97+ if (filledBits == 64 ) {
98+ outStream.Write (&accumulator, sizeof (ui64));
99+ lenght++;
100+ accumulator = 0 ;
101+ filledBits = 0 ;
102+ }
103+ });
104+
105+ // only vector sizes divisible by 64 are supported
106+ if (Y_UNLIKELY (filledBits))
107+ return false ;
108+
109+ // max vector lenght is 32767
110+ if (Y_UNLIKELY (lenght > UINT16_MAX))
111+ return false ;
112+
113+ const EFormat format = EFormat::BitVector;
114+ outStream.Write (&format, HeaderLen);
115+
116+ return true ;
117+ };
118+
119+ if (x.HasFastListLength ()) {
120+ auto str = valueBuilder->NewStringNotFilled (HeaderLen + x.GetListLength () / 8 );
121+ auto strRef = str.AsStringRef ();
122+ TMemoryOutput memoryOutput (strRef.Data (), strRef.Size ());
81123
82- class TSerializerFacade {
124+ if (Y_UNLIKELY (!serialize (memoryOutput)))
125+ return {};
126+
127+ return str;
128+ } else {
129+ TString str;
130+ TStringOutput stringOutput (str);
131+
132+ if (Y_UNLIKELY (!serialize (stringOutput)))
133+ return {};
134+
135+ return valueBuilder->NewString (str);
136+ }
137+ }
138+
139+ static const TArrayRef<const ui64> GetArray64 (const TStringRef& str) {
140+ const char * buf = str.Data ();
141+ const size_t len = (str.Size () - HeaderLen) / sizeof (ui64);
142+
143+ return MakeArrayRef (reinterpret_cast <const ui64*>(buf), len);
144+ }
145+ };
146+
147+ class TKnnSerializerFacade {
83148public:
84149 static TUnboxedValue Serialize (EFormat format, const IValueBuilder* valueBuilder, const TUnboxedValue x) {
85150 switch (format) {
86151 case EFormat::FloatVector:
87- return TFloatVectorSerializer::Serialize (valueBuilder, x);
152+ return TKnnVectorSerializer<float , EFormat::FloatVector>::Serialize (valueBuilder, x);
153+ case EFormat::ByteVector:
154+ return TKnnVectorSerializer<ui8, EFormat::ByteVector>::Serialize (valueBuilder, x);
155+ case EFormat::BitVector:
156+ return TKnnBitVectorSerializer::Serialize (valueBuilder, x);
88157 default :
89158 return {};
90159 }
@@ -97,20 +166,29 @@ class TSerializerFacade {
97166 const ui8 format = str.Data ()[str.Size () - HeaderLen];
98167 switch (format) {
99168 case EFormat::FloatVector:
100- return TFloatVectorSerializer::Deserialize (valueBuilder, str);
169+ return TKnnVectorSerializer<float , EFormat::FloatVector>::Deserialize (valueBuilder, str);
170+ case EFormat::ByteVector:
171+ return TKnnVectorSerializer<ui8, EFormat::ByteVector>::Deserialize (valueBuilder, str);
172+ case EFormat::BitVector:
173+ return {};
101174 default :
102175 return {};
103176 }
104177 }
105178
106- static const TArrayRef<const float > GetArray (const TStringRef& str) {
107- if (str.Size () == 0 )
179+ template <typename T>
180+ static const TArrayRef<const T> GetArray (const TStringRef& str) {
181+ if (Y_UNLIKELY (str.Size () == 0 ))
108182 return {};
109183
110184 const ui8 format = str.Data ()[str.Size () - HeaderLen];
111185 switch (format) {
112186 case EFormat::FloatVector:
113- return TFloatVectorSerializer::GetArray (str);
187+ return TKnnVectorSerializer<T, EFormat::FloatVector>::GetArray (str);
188+ case EFormat::ByteVector:
189+ return TKnnVectorSerializer<T, EFormat::ByteVector>::GetArray (str);
190+ case EFormat::BitVector:
191+ return {};
114192 default :
115193 return {};
116194 }
0 commit comments